-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[None][feat] Add llm.encode() fast path for encoder-only models #12801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Any, Dict | ||
|
|
||
| import torch | ||
|
|
||
| from tensorrt_llm.logger import logger | ||
|
|
||
|
|
||
| class EncoderExecutor: | ||
| """Executor for encoder-only models. | ||
|
|
||
| Primary path: batch_forward(inputs) — synchronous batch execution. | ||
| Delegates to model_engine.encoder_forward() for all heavy lifting | ||
| (pre-allocated buffers, attention metadata, torch.compile). | ||
|
|
||
| This executor has no background thread, no scheduler, no sampler, | ||
| and no request queue. It runs entirely on the calling thread. | ||
| """ | ||
|
|
||
| def __init__(self, model_engine, dist): | ||
| self.model_engine = model_engine | ||
| self.dist = dist | ||
|
|
||
| logger.info( | ||
| "encoder_only mode enabled: using EncoderExecutor. " | ||
| "Scheduler, sampler, KV cache, and generation-related parameters " | ||
| "(disable_overlap_scheduler, max_tokens, temperature, etc.) " | ||
| "are bypassed. Use llm.encode() for inference.") | ||
|
|
||
| def batch_forward(self, | ||
| inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: | ||
| """Execute a pre-formed batch in one forward pass. | ||
|
|
||
| Args: | ||
| inputs: Dict with 'input_ids' ([total_tokens]) and 'seq_lens' | ||
| ([batch_size]) required. Optional model-specific kwargs | ||
| (token_type_ids, inputs_embeds, etc.) are passed through. | ||
|
|
||
| Returns: | ||
| Dict with 'logits' tensor and any other model outputs. | ||
| """ | ||
| return self.model_engine.encoder_forward(inputs) | ||
|
|
||
| def shutdown(self): | ||
| """No background thread to stop — just release model engine resources.""" | ||
| del self.model_engine |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -3645,6 +3645,84 @@ def _prepare_inputs( | |||||
| num_accepted_tokens_device, req_id_to_old_request, resource_manager, | ||||||
| maybe_graph) | ||||||
|
|
||||||
| def _prepare_encoder_inputs(self, | ||||||
| inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||||
| """Prepare model-ready inputs dict for encoder-only models. | ||||||
|
|
||||||
| Encoder equivalent of _prepare_tp_inputs + _preprocess_inputs. | ||||||
| Consumes raw inputs dict, copies to pre-allocated CUDA buffers, | ||||||
| sets up attention metadata, and returns model-ready dict. | ||||||
|
|
||||||
| Args: | ||||||
| inputs: Dict with required keys 'input_ids' ([total_tokens]) and | ||||||
| 'seq_lens' ([batch_size]). Optional 'position_ids' | ||||||
| ([total_tokens]). Any additional keys (token_type_ids, | ||||||
| inputs_embeds, etc.) are passed through to the model's | ||||||
| forward() via **kwargs. | ||||||
| """ | ||||||
| token_ids = inputs['input_ids'] | ||||||
| seq_lens = inputs['seq_lens'] | ||||||
| position_ids = inputs.get('position_ids') | ||||||
| num_tokens = token_ids.shape[0] | ||||||
| batch_size = seq_lens.shape[0] | ||||||
|
|
||||||
| assert num_tokens <= self.max_num_tokens, ( | ||||||
| f"num_tokens ({num_tokens}) exceeds max_num_tokens " | ||||||
| f"({self.max_num_tokens}). Reduce batch size or sequence lengths.") | ||||||
|
|
||||||
| # 1. Copy to pre-allocated CUDA buffers | ||||||
| self.input_ids_cuda[:num_tokens].copy_(token_ids, non_blocking=True) | ||||||
| if position_ids is None: | ||||||
| # Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...] | ||||||
| position_ids = torch.cat( | ||||||
| [torch.arange(s, dtype=torch.int32) for s in seq_lens.tolist()]) | ||||||
| self.position_ids_cuda[:num_tokens].copy_(position_ids, | ||||||
| non_blocking=True) | ||||||
|
|
||||||
| # 2. Set up attention metadata | ||||||
| attn_metadata = self._set_up_attn_metadata(kv_cache_manager=None) | ||||||
| attn_metadata.seq_lens = seq_lens | ||||||
| attn_metadata.num_contexts = batch_size | ||||||
| attn_metadata.max_seq_len = self.max_seq_len | ||||||
| attn_metadata.request_ids = list(range(batch_size)) | ||||||
| attn_metadata.prepare() | ||||||
|
|
||||||
| # 3. Build model-ready dict. | ||||||
| # **inputs goes FIRST so that the explicit buffer keys override the | ||||||
| # raw tensors. Extra keys (seq_lens, token_type_ids, etc.) pass | ||||||
| # through to the model's **kwargs and are silently ignored if not | ||||||
| # in the model's forward() signature. | ||||||
| model_inputs = { | ||||||
| **inputs, | ||||||
| 'attn_metadata': attn_metadata, | ||||||
| 'input_ids': self.input_ids_cuda[:num_tokens], | ||||||
| 'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0), | ||||||
| 'inputs_embeds': None, | ||||||
| } | ||||||
|
|
||||||
| return model_inputs | ||||||
|
|
||||||
| @torch.inference_mode() | ||||||
| @with_model_extra_attrs(lambda self: self.model.extra_attrs) | ||||||
| @nvtx_range("encoder_forward") | ||||||
| def encoder_forward(self, | ||||||
| inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||||
| """Direct tensor-level forward for encoder-only models. | ||||||
|
|
||||||
| Bypasses ScheduledRequests/LlmRequest entirely. Takes a raw inputs | ||||||
| dict, prepares model-ready inputs via _prepare_encoder_inputs, and | ||||||
| calls _forward_step (which preserves torch.compile). | ||||||
|
|
||||||
| Args: | ||||||
| inputs: Dict with 'input_ids' and 'seq_lens' (required), plus | ||||||
| any model-specific kwargs (token_type_ids, inputs_embeds, etc.). | ||||||
|
|
||||||
| Returns: | ||||||
| Dict with 'logits' tensor and any other model outputs. | ||||||
| """ | ||||||
| model_inputs = self._prepare_encoder_inputs(inputs) | ||||||
| return self._forward_step(model_inputs, None, False) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest to enforce kwargs by adding |
||||||
|
|
||||||
| @torch.inference_mode() | ||||||
| @with_model_extra_attrs(lambda self: self.model.extra_attrs) | ||||||
| def forward(self, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -225,6 +225,50 @@ def get_guided_decoding_config(guided_decoding_backend: str, | |
| return guided_decoding_config | ||
|
|
||
|
|
||
| def create_encoder_executor( | ||
| llm_args: TorchLlmArgs, | ||
| checkpoint_dir: Optional[str] = None, | ||
| ): | ||
| """Create an EncoderExecutor for encoder-only models. | ||
|
|
||
| Handles model loading and model_engine creation, then wraps in a | ||
| lightweight EncoderExecutor. Skips all decoder infrastructure | ||
| (KV cache, scheduler, sampler, drafter, speculative decoding). | ||
|
|
||
| Args: | ||
| llm_args: Configuration arguments. | ||
| checkpoint_dir: Path to model checkpoint. | ||
|
|
||
| Returns: | ||
| An EncoderExecutor instance ready for batch_forward() calls. | ||
| """ | ||
| from .encoder_executor import EncoderExecutor | ||
|
|
||
| torch.cuda.set_per_process_memory_fraction(1.0) | ||
| checkpoint_loader = _construct_checkpoint_loader(llm_args.backend, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. checkpoint loader logic is common across |
||
| llm_args.checkpoint_loader, | ||
| llm_args.checkpoint_format) | ||
| llm_args = ModelLoader.load_config_and_apply_defaults( | ||
| checkpoint_dir, llm_args, checkpoint_loader) | ||
|
|
||
| mapping = _get_mapping(llm_args.parallel_config.to_mapping()) | ||
| dist = Distributed.get(mapping) | ||
|
|
||
| model_engine = PyTorchModelEngine( | ||
| model_path=checkpoint_dir, | ||
| llm_args=llm_args, | ||
| mapping=mapping, | ||
| dist=dist, | ||
| spec_config=None, | ||
| checkpoint_loader=checkpoint_loader, | ||
| ) | ||
|
|
||
| return EncoderExecutor( | ||
| model_engine=model_engine, | ||
| dist=dist, | ||
| ) | ||
|
|
||
|
|
||
| def create_py_executor( | ||
| llm_args: TorchLlmArgs, | ||
| checkpoint_dir: Optional[str] = None, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this function relate to
_forward_step_mm_encoder_only? The purpose seems similar. Does it make sense to unify these?