Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/encoder_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict

import torch

from tensorrt_llm.logger import logger


class EncoderExecutor:
"""Executor for encoder-only models.

Primary path: batch_forward(inputs) — synchronous batch execution.
Delegates to model_engine.encoder_forward() for all heavy lifting
(pre-allocated buffers, attention metadata, torch.compile).

This executor has no background thread, no scheduler, no sampler,
and no request queue. It runs entirely on the calling thread.
"""

def __init__(self, model_engine, dist):
self.model_engine = model_engine
self.dist = dist

logger.info(
"encoder_only mode enabled: using EncoderExecutor. "
"Scheduler, sampler, KV cache, and generation-related parameters "
"(disable_overlap_scheduler, max_tokens, temperature, etc.) "
"are bypassed. Use llm.encode() for inference.")

def batch_forward(self,
inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
"""Execute a pre-formed batch in one forward pass.

Args:
inputs: Dict with 'input_ids' ([total_tokens]) and 'seq_lens'
([batch_size]) required. Optional model-specific kwargs
(token_type_ids, inputs_embeds, etc.) are passed through.

Returns:
Dict with 'logits' tensor and any other model outputs.
"""
return self.model_engine.encoder_forward(inputs)

def shutdown(self):
"""No background thread to stop — just release model engine resources."""
del self.model_engine
78 changes: 78 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3645,6 +3645,84 @@ def _prepare_inputs(
num_accepted_tokens_device, req_id_to_old_request, resource_manager,
maybe_graph)

def _prepare_encoder_inputs(self,
inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare model-ready inputs dict for encoder-only models.

Encoder equivalent of _prepare_tp_inputs + _preprocess_inputs.
Consumes raw inputs dict, copies to pre-allocated CUDA buffers,
sets up attention metadata, and returns model-ready dict.

Args:
inputs: Dict with required keys 'input_ids' ([total_tokens]) and
'seq_lens' ([batch_size]). Optional 'position_ids'
([total_tokens]). Any additional keys (token_type_ids,
inputs_embeds, etc.) are passed through to the model's
forward() via **kwargs.
"""
token_ids = inputs['input_ids']
seq_lens = inputs['seq_lens']
position_ids = inputs.get('position_ids')
num_tokens = token_ids.shape[0]
batch_size = seq_lens.shape[0]

assert num_tokens <= self.max_num_tokens, (
f"num_tokens ({num_tokens}) exceeds max_num_tokens "
f"({self.max_num_tokens}). Reduce batch size or sequence lengths.")

# 1. Copy to pre-allocated CUDA buffers
self.input_ids_cuda[:num_tokens].copy_(token_ids, non_blocking=True)
if position_ids is None:
# Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...]
position_ids = torch.cat(
[torch.arange(s, dtype=torch.int32) for s in seq_lens.tolist()])
self.position_ids_cuda[:num_tokens].copy_(position_ids,
non_blocking=True)

# 2. Set up attention metadata
attn_metadata = self._set_up_attn_metadata(kv_cache_manager=None)
attn_metadata.seq_lens = seq_lens
attn_metadata.num_contexts = batch_size
attn_metadata.max_seq_len = self.max_seq_len
attn_metadata.request_ids = list(range(batch_size))
attn_metadata.prepare()

# 3. Build model-ready dict.
# **inputs goes FIRST so that the explicit buffer keys override the
# raw tensors. Extra keys (seq_lens, token_type_ids, etc.) pass
# through to the model's **kwargs and are silently ignored if not
# in the model's forward() signature.
model_inputs = {
**inputs,
'attn_metadata': attn_metadata,
'input_ids': self.input_ids_cuda[:num_tokens],
'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0),
'inputs_embeds': None,
}

return model_inputs

@torch.inference_mode()
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
@nvtx_range("encoder_forward")
def encoder_forward(self,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this function relate to _forward_step_mm_encoder_only? The purpose seems similar. Does it make sense to unify these?

inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Direct tensor-level forward for encoder-only models.

Bypasses ScheduledRequests/LlmRequest entirely. Takes a raw inputs
dict, prepares model-ready inputs via _prepare_encoder_inputs, and
calls _forward_step (which preserves torch.compile).

Args:
inputs: Dict with 'input_ids' and 'seq_lens' (required), plus
any model-specific kwargs (token_type_ids, inputs_embeds, etc.).

Returns:
Dict with 'logits' tensor and any other model outputs.
"""
model_inputs = self._prepare_encoder_inputs(inputs)
return self._forward_step(model_inputs, None, False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self._forward_step(model_inputs, None, False)
return self._forward_step(model_inputs, gather_ids=None, gather_context_logits=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to enforce kwargs by adding *, after inputs here.


@torch.inference_mode()
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
def forward(self,
Expand Down
44 changes: 44 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,50 @@ def get_guided_decoding_config(guided_decoding_backend: str,
return guided_decoding_config


def create_encoder_executor(
llm_args: TorchLlmArgs,
checkpoint_dir: Optional[str] = None,
):
"""Create an EncoderExecutor for encoder-only models.
Handles model loading and model_engine creation, then wraps in a
lightweight EncoderExecutor. Skips all decoder infrastructure
(KV cache, scheduler, sampler, drafter, speculative decoding).
Args:
llm_args: Configuration arguments.
checkpoint_dir: Path to model checkpoint.
Returns:
An EncoderExecutor instance ready for batch_forward() calls.
"""
from .encoder_executor import EncoderExecutor

torch.cuda.set_per_process_memory_fraction(1.0)
checkpoint_loader = _construct_checkpoint_loader(llm_args.backend,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint loader logic is common across create_py_executor, create_encoder_executor, can be moved to helper function.

llm_args.checkpoint_loader,
llm_args.checkpoint_format)
llm_args = ModelLoader.load_config_and_apply_defaults(
checkpoint_dir, llm_args, checkpoint_loader)

mapping = _get_mapping(llm_args.parallel_config.to_mapping())
dist = Distributed.get(mapping)

model_engine = PyTorchModelEngine(
model_path=checkpoint_dir,
llm_args=llm_args,
mapping=mapping,
dist=dist,
spec_config=None,
checkpoint_loader=checkpoint_loader,
)

return EncoderExecutor(
model_engine=model_engine,
dist=dist,
)


def create_py_executor(
llm_args: TorchLlmArgs,
checkpoint_dir: Optional[str] = None,
Expand Down
Loading