diff --git a/tensorrt_llm/_torch/pyexecutor/encoder_executor.py b/tensorrt_llm/_torch/pyexecutor/encoder_executor.py new file mode 100644 index 00000000000..6a2c1df78a1 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/encoder_executor.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +import torch + +from tensorrt_llm.logger import logger + + +class EncoderExecutor: + """Executor for encoder-only models. + + Primary path: batch_forward(inputs) — synchronous batch execution. + Delegates to model_engine.encoder_forward() for all heavy lifting + (pre-allocated buffers, attention metadata, torch.compile). + + This executor has no background thread, no scheduler, no sampler, + and no request queue. It runs entirely on the calling thread. + """ + + def __init__(self, model_engine, dist): + self.model_engine = model_engine + self.dist = dist + + logger.info( + "encoder_only mode enabled: using EncoderExecutor. " + "Scheduler, sampler, KV cache, and generation-related parameters " + "(disable_overlap_scheduler, max_tokens, temperature, etc.) " + "are bypassed. Use llm.encode() for inference.") + + def batch_forward(self, + inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: + """Execute a pre-formed batch in one forward pass. + + Args: + inputs: Dict with 'input_ids' ([total_tokens]) and 'seq_lens' + ([batch_size]) required. Optional model-specific kwargs + (token_type_ids, inputs_embeds, etc.) are passed through. + + Returns: + Dict with 'logits' tensor and any other model outputs. + """ + return self.model_engine.encoder_forward(inputs) + + def shutdown(self): + """No background thread to stop — just release model engine resources.""" + del self.model_engine diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 7c2ce12a7f4..b2411aa88ec 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -3645,6 +3645,84 @@ def _prepare_inputs( num_accepted_tokens_device, req_id_to_old_request, resource_manager, maybe_graph) + def _prepare_encoder_inputs(self, + inputs: Dict[str, Any]) -> Dict[str, Any]: + """Prepare model-ready inputs dict for encoder-only models. + + Encoder equivalent of _prepare_tp_inputs + _preprocess_inputs. + Consumes raw inputs dict, copies to pre-allocated CUDA buffers, + sets up attention metadata, and returns model-ready dict. + + Args: + inputs: Dict with required keys 'input_ids' ([total_tokens]) and + 'seq_lens' ([batch_size]). Optional 'position_ids' + ([total_tokens]). Any additional keys (token_type_ids, + inputs_embeds, etc.) are passed through to the model's + forward() via **kwargs. + """ + token_ids = inputs['input_ids'] + seq_lens = inputs['seq_lens'] + position_ids = inputs.get('position_ids') + num_tokens = token_ids.shape[0] + batch_size = seq_lens.shape[0] + + assert num_tokens <= self.max_num_tokens, ( + f"num_tokens ({num_tokens}) exceeds max_num_tokens " + f"({self.max_num_tokens}). Reduce batch size or sequence lengths.") + + # 1. Copy to pre-allocated CUDA buffers + self.input_ids_cuda[:num_tokens].copy_(token_ids, non_blocking=True) + if position_ids is None: + # Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...] + position_ids = torch.cat( + [torch.arange(s, dtype=torch.int32) for s in seq_lens.tolist()]) + self.position_ids_cuda[:num_tokens].copy_(position_ids, + non_blocking=True) + + # 2. Set up attention metadata + attn_metadata = self._set_up_attn_metadata(kv_cache_manager=None) + attn_metadata.seq_lens = seq_lens + attn_metadata.num_contexts = batch_size + attn_metadata.max_seq_len = self.max_seq_len + attn_metadata.request_ids = list(range(batch_size)) + attn_metadata.prepare() + + # 3. Build model-ready dict. + # **inputs goes FIRST so that the explicit buffer keys override the + # raw tensors. Extra keys (seq_lens, token_type_ids, etc.) pass + # through to the model's **kwargs and are silently ignored if not + # in the model's forward() signature. + model_inputs = { + **inputs, + 'attn_metadata': attn_metadata, + 'input_ids': self.input_ids_cuda[:num_tokens], + 'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0), + 'inputs_embeds': None, + } + + return model_inputs + + @torch.inference_mode() + @with_model_extra_attrs(lambda self: self.model.extra_attrs) + @nvtx_range("encoder_forward") + def encoder_forward(self, + inputs: Dict[str, Any]) -> Dict[str, Any]: + """Direct tensor-level forward for encoder-only models. + + Bypasses ScheduledRequests/LlmRequest entirely. Takes a raw inputs + dict, prepares model-ready inputs via _prepare_encoder_inputs, and + calls _forward_step (which preserves torch.compile). + + Args: + inputs: Dict with 'input_ids' and 'seq_lens' (required), plus + any model-specific kwargs (token_type_ids, inputs_embeds, etc.). + + Returns: + Dict with 'logits' tensor and any other model outputs. + """ + model_inputs = self._prepare_encoder_inputs(inputs) + return self._forward_step(model_inputs, None, False) + @torch.inference_mode() @with_model_extra_attrs(lambda self: self.model.extra_attrs) def forward(self, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index e0aa739d869..a9d55f796fb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -225,6 +225,50 @@ def get_guided_decoding_config(guided_decoding_backend: str, return guided_decoding_config +def create_encoder_executor( + llm_args: TorchLlmArgs, + checkpoint_dir: Optional[str] = None, +): + """Create an EncoderExecutor for encoder-only models. + + Handles model loading and model_engine creation, then wraps in a + lightweight EncoderExecutor. Skips all decoder infrastructure + (KV cache, scheduler, sampler, drafter, speculative decoding). + + Args: + llm_args: Configuration arguments. + checkpoint_dir: Path to model checkpoint. + + Returns: + An EncoderExecutor instance ready for batch_forward() calls. + """ + from .encoder_executor import EncoderExecutor + + torch.cuda.set_per_process_memory_fraction(1.0) + checkpoint_loader = _construct_checkpoint_loader(llm_args.backend, + llm_args.checkpoint_loader, + llm_args.checkpoint_format) + llm_args = ModelLoader.load_config_and_apply_defaults( + checkpoint_dir, llm_args, checkpoint_loader) + + mapping = _get_mapping(llm_args.parallel_config.to_mapping()) + dist = Distributed.get(mapping) + + model_engine = PyTorchModelEngine( + model_path=checkpoint_dir, + llm_args=llm_args, + mapping=mapping, + dist=dist, + spec_config=None, + checkpoint_loader=checkpoint_loader, + ) + + return EncoderExecutor( + model_engine=model_engine, + dist=dist, + ) + + def create_py_executor( llm_args: TorchLlmArgs, checkpoint_dir: Optional[str] = None, diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 4f207cf4403..b832de6043b 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Any, List, Literal, Optional, Sequence, Tuple, Union, cast +import torch import transformers from tqdm import tqdm from transformers import PreTrainedTokenizerBase @@ -100,6 +101,23 @@ def _repr_fields(self): ] +@dataclass +class EncoderOutput: + """Output from an encoder-only model. + + Attributes: + logits: Model output tensor. Shape depends on model: + - Classification: [num_classes] + - Per-token scoring: [seq_len, num_labels] + - Embeddings: [hidden_size] + prompt_token_ids: The tokenized input IDs. + prompt: The original text prompt, if provided as string. + """ + logits: torch.Tensor + prompt_token_ids: List[int] + prompt: Optional[str] = None + + TRT_LLM_DOCSTRING = TRT_LLMARGS_EXPLICIT_DOCSTRING + """ Attributes: @@ -237,6 +255,8 @@ def __init__(self, # Due to the Executor can only accept a engine path, we need to save the engine to a directory self._engine_dir: Optional[Path] = None self._executor: Optional[GenerationExecutor] = None + self._encoder_only: bool = False + self._encoder_executor = None if self._on_trt_backend: self._workspace = tempfile.TemporaryDirectory( suffix="-llm-workspace", dir=self.args.workspace) @@ -424,6 +444,11 @@ def generate_async( tensorrt_llm.llmapi.RequestOutput: The output data of the completion request to the LLM. """ + if self._encoder_only: + raise RuntimeError( + "generate_async() is not available when encoder_only=True. " + "Use llm.encode() for encoder-only models.") + # Check if the worker is shutting down if self._executor is None or self._executor.is_shutdown(): raise RuntimeError("LLM is shutting down") @@ -695,6 +720,135 @@ def preprocess( multimodal_params=multimodal_params, ) + @set_api_status("prototype") + def encode( + self, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + add_special_tokens: bool = True, + **model_kwargs, + ) -> Union[EncoderOutput, List[EncoderOutput]]: + """Encode inputs using an encoder-only model (PyTorch backend only). + + Only available when encoder_only=True is set in the LLM constructor. + + Args: + inputs: Text string(s), token ID list(s), or TextPrompt/TokensPrompt dict(s). + add_special_tokens: Whether to add special tokens (e.g., [CLS]/[SEP]) + during tokenization. Defaults to True. + **model_kwargs: Model-specific inputs passed through to the model's + forward(). Examples: token_type_ids (BERT), inputs_embeds + (reward models). + + Returns: + EncoderOutput or List[EncoderOutput] with logits/embeddings. + + Raises: + RuntimeError: If encoder_only mode is not enabled. + """ + if not self._encoder_only: + raise RuntimeError( + "encode() requires encoder_only=True. " + "Set encoder_only=True in the LLM() constructor.") + + unbatched = not isinstance(inputs, list) + if not unbatched: + if isinstance(inputs[0], int): + unbatched = True + if unbatched: + inputs = [inputs] + + engine = self._encoder_executor.model_engine + max_seq_len = engine.max_seq_len + max_num_tokens = engine.max_num_tokens + max_batch_size = engine.batch_size + + if len(inputs) > max_batch_size: + raise ValueError( + f"Batch size ({len(inputs)}) exceeds max_batch_size " + f"({max_batch_size}). Split inputs into smaller batches.") + + # Tokenize each input (reuses existing input_processor) + token_ids_list = [] + prompts = [] + sampling_params = SamplingParams( + add_special_tokens=add_special_tokens) + + total_tokens = 0 + max_seq_len_batch = 0 + for inp in inputs: + inp = prompt_inputs(inp) + if "prompt_token_ids" in inp: + token_ids_list.append(inp["prompt_token_ids"]) + seq_len = len(inp["prompt_token_ids"]) + total_tokens += seq_len + max_seq_len_batch = max(max_seq_len_batch, seq_len) + prompts.append(None) + elif "prompt" in inp: + token_ids, _ = self.input_processor(inp, sampling_params) + token_ids_list.append(token_ids) + seq_len = len(token_ids) + total_tokens += seq_len + max_seq_len_batch = max(max_seq_len_batch, seq_len) + prompts.append(inp["prompt"]) + else: + raise TypeError(f"Unsupported input type: {type(inp)}") + + # Validate inputs against model capacity + if total_tokens > max_num_tokens: + raise ValueError( + f"Total tokens ({total_tokens}) across the batch exceeds " + f"max_num_tokens ({max_num_tokens}). Reduce batch size or " + f"sequence lengths.") + + if max_seq_len_batch > max_seq_len: + raise ValueError( + f"Max sequence length ({max_seq_len_batch}) exceeds " + f"max_seq_len ({max_seq_len}). Truncate the input or increase " + f"max_seq_len.") + + # Pack into flat tensors + seq_lens = torch.tensor([len(t) for t in token_ids_list], + dtype=torch.int32) + flat_token_ids = torch.tensor( + [tid for tids in token_ids_list for tid in tids], + dtype=torch.int32) + + # Build inputs dict — common + model-specific kwargs. + # Filter keys that are set internally by _prepare_encoder_inputs or + # _forward_step to avoid "multiple values for keyword argument" errors. + _RESERVED_KEYS = { + 'input_ids', 'position_ids', 'seq_lens', 'attn_metadata', + 'inputs_embeds', 'return_context_logits', + } + filtered_kwargs = { + k: v + for k, v in model_kwargs.items() if k not in _RESERVED_KEYS + } + forward_inputs = { + 'input_ids': flat_token_ids, + 'seq_lens': seq_lens, + **filtered_kwargs, + } + + # Single forward pass + outputs = self._encoder_executor.batch_forward(forward_inputs) + + # Package as EncoderOutput. + # NOTE: logits[i] assumes batch-indexed output (e.g., BERT classification + # returns [batch_size, num_classes]). Per-token models that return packed + # [total_tokens, hidden_size] would need cumulative-sum slicing instead. + logits = outputs['logits'] + results = [] + for i in range(len(token_ids_list)): + results.append( + EncoderOutput( + logits=logits[i] if logits.dim() > 1 else logits, + prompt_token_ids=token_ids_list[i], + prompt=prompts[i], + )) + + return results[0] if unbatched else results + @set_api_status("beta") def get_stats(self, timeout: Optional[float] = 2) -> List[dict]: '''Get iteration statistics from the runtime. @@ -707,6 +861,10 @@ def get_stats(self, timeout: Optional[float] = 2) -> List[dict]: List[dict]: A list of runtime stats as dicts. e.g., [{"cpuMemUsage": ..., "iter": 0, ...}, {"cpuMemUsage": ..., "iter": 1, ...}] ''' + if self._encoder_only: + raise RuntimeError( + "get_stats() is not available when encoder_only=True. " + "Use llm.encode() for encoder-only models.") return self._executor.get_stats(timeout=timeout) @set_api_status("beta") @@ -721,6 +879,10 @@ def get_stats_async(self, timeout: Optional[float] = 2) -> IterationResult: Returns: tensorrt_llm.executor.result.IterationResult: An async iterable object containing runtime stats. ''' + if self._encoder_only: + raise RuntimeError( + "get_stats_async() is not available when encoder_only=True. " + "Use llm.encode() for encoder-only models.") return self._executor.aget_stats(timeout=timeout) @set_api_status("beta") @@ -743,6 +905,10 @@ def get_kv_cache_events(self, timeout: Optional[float] = 2) -> List[dict]: Returns: List[dict]: A list of runtime events as dict. ''' + if self._encoder_only: + raise RuntimeError( + "get_kv_cache_events() is not available when " + "encoder_only=True.") return self._executor.get_kv_events(timeout=timeout) @set_api_status("beta") @@ -767,6 +933,10 @@ def get_kv_cache_events_async(self, Returns: tensorrt_llm.executor.result.IterationResult: An async iterable object containing runtime events. ''' + if self._encoder_only: + raise RuntimeError( + "get_kv_cache_events_async() is not available when " + "encoder_only=True.") return self._executor.aget_kv_events(timeout=timeout) def _process_env_overrides(self, @@ -963,6 +1133,10 @@ def shutdown(self) -> None: self._executor.shutdown() self._executor = None + if hasattr(self, "_encoder_executor") and self._encoder_executor is not None: + self._encoder_executor.shutdown() + self._encoder_executor = None + if hasattr(self, 'mpi_session') and self.mpi_session is not None: self.mpi_session.shutdown() self.mpi_session = None @@ -973,6 +1147,9 @@ def _check_health(self) -> bool: Returns: bool: True if the executor is running and not shutdown, False otherwise. """ + if self._encoder_only: + return (hasattr(self, "_encoder_executor") + and self._encoder_executor is not None) if hasattr(self, "_executor") and self._executor is not None: return not self._executor.is_shutdown() @@ -1245,6 +1422,9 @@ def _collective_rpc( Returns: list[Any]: A list of results from each worker. """ + if self._encoder_only: + raise RuntimeError( + "_collective_rpc() is not available when encoder_only=True.") if hasattr(self._executor, 'collective_rpc'): return self._executor.collective_rpc(method, args, kwargs, non_block, unique_reply_rank, @@ -1278,6 +1458,38 @@ def _build_model(self): **input_processor_kwargs) self._tokenizer = self.input_processor.tokenizer + # Resolve encoder_only mode (opt-in only) + self._encoder_only = (self.args.encoder_only is True) + + if self._encoder_only: + # Create ONLY the EncoderExecutor — skip decoder infrastructure. + from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ + create_encoder_executor + self._encoder_executor = create_encoder_executor( + llm_args=self.args, + checkpoint_dir=str(self._hf_model_dir) + if self._hf_model_dir else None, + ) + logger.info( + "encoder_only=True: using EncoderExecutor. Only llm.encode() " + "is available. generate()/generate_async() are not supported.") + return # Skip _executor creation + + # Hint: if this looks like an encoder model, suggest encode() + if self.args.encoder_only is None and not self.args.mm_encoder_only: + from tensorrt_llm._torch.model_config import ModelConfig + architectures = getattr( + self._hf_model_config, 'architectures', + None) if self._hf_model_config else None + if architectures and not ModelConfig.is_generation_model( + architectures): + logger.info( + "Detected encoder-only model architecture (%s). Consider " + "using LLM(model=..., encoder_only=True) with " + "llm.encode() for optimized batch-forward inference that " + "bypasses the decoder scheduler.", architectures[0]) + + # Create the standard executor for generate()/generate_async() # TODO: revisit gather_context_logits return_logits = self.args.gather_generation_logits self._executor = self._executor_cls.create( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 9f001b4e5ae..1aa68932153 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3550,6 +3550,16 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) + encoder_only: Optional[bool] = Field( + default=None, + description= + "Set to True for encoder-only models (BERT, RoBERTa, reward models, " + "etc.) to enable the optimized batch-forward encode() path that " + "bypasses the decoder scheduler and autoregressive loop. When None, " + "proceed with the old generate() path.", + status="prototype", + ) + ray_worker_extension_cls: Optional[str] = Field( default=None, description="The full worker extension class name including module path. " diff --git a/tests/unittest/llmapi/test_llm_encode.py b/tests/unittest/llmapi/test_llm_encode.py new file mode 100644 index 00000000000..4b13e3ee78b --- /dev/null +++ b/tests/unittest/llmapi/test_llm_encode.py @@ -0,0 +1,189 @@ +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from tensorrt_llm import LLM +from tensorrt_llm.llmapi.llm import EncoderOutput + +# isort: off +from .test_llm import get_model_path + +# isort: on + +BERT_MODEL_PATH = "bert/bert-base-uncased-yelp-polarity" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +@pytest.fixture(scope="module") +def bert_encode_llm(): + """Create an LLM with encoder_only=True for BERT, shared across tests.""" + model_dir = get_model_path(BERT_MODEL_PATH) + llm = LLM(model=model_dir, encoder_only=True) + yield llm + llm.shutdown() + + +# --------------------------------------------------------------------------- # +# Basic encode() functionality +# --------------------------------------------------------------------------- # + + +def test_encode_single_string(bert_encode_llm): + """encode() with a single string returns a single EncoderOutput.""" + result = bert_encode_llm.encode("Hello, my name is") + + assert isinstance(result, EncoderOutput) + assert isinstance(result.logits, torch.Tensor) + assert result.logits.dim() == 1 # [num_classes] for classification + assert result.logits.shape[0] == 2 # yelp-polarity has 2 classes + assert result.prompt == "Hello, my name is" + assert isinstance(result.prompt_token_ids, list) + assert len(result.prompt_token_ids) > 0 + + +def test_encode_batch(bert_encode_llm): + """encode() with a list of strings returns a list of EncoderOutput.""" + results = bert_encode_llm.encode(PROMPTS) + + assert isinstance(results, list) + assert len(results) == len(PROMPTS) + for i, result in enumerate(results): + assert isinstance(result, EncoderOutput) + assert result.logits.shape == (2,) # 2 classes + assert result.prompt == PROMPTS[i] + + +def test_encode_token_ids(bert_encode_llm): + """encode() accepts pre-tokenized token ID lists.""" + token_ids = [101, 7592, 1010, 2026, 2171, 2003, 102] # "[CLS] hello, my name is [SEP]" + result = bert_encode_llm.encode(token_ids) + + assert isinstance(result, EncoderOutput) + assert result.logits.shape == (2,) + assert result.prompt is None # no text prompt when passing token IDs + assert result.prompt_token_ids == token_ids + + +def test_encode_mixed_batch(bert_encode_llm): + """encode() handles mixed input types in a batch.""" + from tensorrt_llm.inputs import TextPrompt, TokensPrompt + + inputs = [ + "Hello world", + TextPrompt(prompt="Test sentence"), + TokensPrompt(prompt_token_ids=[101, 7592, 2088, 102]), + ] + results = bert_encode_llm.encode(inputs) + + assert len(results) == 3 + assert results[0].prompt == "Hello world" + assert results[1].prompt == "Test sentence" + assert results[2].prompt is None + + +# --------------------------------------------------------------------------- # +# Output correctness — compare with HuggingFace +# --------------------------------------------------------------------------- # + + +def test_encode_matches_huggingface(bert_encode_llm): + """encode() logits match HuggingFace BertForSequenceClassification.""" + from transformers import (AutoModelForSequenceClassification, + AutoTokenizer) + + model_dir = get_model_path(BERT_MODEL_PATH) + + # Get TRT-LLM results + results = bert_encode_llm.encode(PROMPTS) + tllm_logits = torch.stack([r.logits.cpu() for r in results]) + + # Get HuggingFace results + tokenizer = AutoTokenizer.from_pretrained(model_dir) + hf_model = AutoModelForSequenceClassification.from_pretrained(model_dir) + hf_model = hf_model.half().to(tllm_logits.device) + + with torch.inference_mode(): + inputs = tokenizer(PROMPTS, return_tensors="pt", + padding="longest").to(hf_model.device) + hf_outputs = hf_model(**inputs) + hf_logits = hf_outputs.logits.float() + + torch.testing.assert_close(tllm_logits, hf_logits, rtol=1.5e-2, + atol=1.5e-2) + + +# --------------------------------------------------------------------------- # +# Cross-API guards +# --------------------------------------------------------------------------- # + + +def test_generate_raises_on_encoder_only(bert_encode_llm): + """generate() raises RuntimeError when encoder_only=True.""" + with pytest.raises(RuntimeError, match="encoder_only=True"): + bert_encode_llm.generate(PROMPTS) + + +def test_generate_async_raises_on_encoder_only(bert_encode_llm): + """generate_async() raises RuntimeError when encoder_only=True.""" + with pytest.raises(RuntimeError, match="encoder_only=True"): + bert_encode_llm.generate_async("Hello") + + +def test_encode_raises_without_encoder_only(): + """encode() raises RuntimeError on a decoder model (encoder_only=False).""" + model_dir = get_model_path(BERT_MODEL_PATH) + with LLM(model=model_dir, encoder_only=False, + disable_overlap_scheduler=True) as llm: + with pytest.raises(RuntimeError, match="encoder_only=True"): + llm.encode("Hello") + + +def test_get_stats_raises_on_encoder_only(bert_encode_llm): + """get_stats() raises RuntimeError when encoder_only=True.""" + with pytest.raises(RuntimeError, match="encoder_only=True"): + bert_encode_llm.get_stats() + + +# --------------------------------------------------------------------------- # +# Input validation +# --------------------------------------------------------------------------- # + + +def test_encode_empty_string(bert_encode_llm): + """encode("") should either raise or produce a valid (empty-ish) result. + + Tokenizing "" with add_special_tokens=True produces [CLS][SEP] (2 tokens), + so this is actually a valid input for BERT. + """ + result = bert_encode_llm.encode("") + assert isinstance(result, EncoderOutput) + assert result.logits.shape == (2,) + + +# --------------------------------------------------------------------------- # +# Health check +# --------------------------------------------------------------------------- # + + +def test_check_health_encoder_only(bert_encode_llm): + """_check_health() returns True for a live encoder-only LLM.""" + assert bert_encode_llm._check_health() is True