Skip to content

[None][feat] Add llm.encode() fast path for encoder-only models#12801

Draft
tingyangk wants to merge 1 commit intoNVIDIA:mainfrom
tingyangk:tingyangk/encoder-llmapi-optimize
Draft

[None][feat] Add llm.encode() fast path for encoder-only models#12801
tingyangk wants to merge 1 commit intoNVIDIA:mainfrom
tingyangk:tingyangk/encoder-llmapi-optimize

Conversation

@tingyangk
Copy link
Copy Markdown

@coderabbitai summary

Summary

Adds a dedicated llm.encode() API for encoder-only models (BERT, RoBERTa, reward models) that bypasses the decoder-oriented PyExecutor loop entirely.

Problem

The current LLM API routes encoder models through the same PyExecutor designed for autoregressive decoders, introducing siginificant CPU overhead per batch from scheduler, KV cache management, sampling, and request state machine — none of which apply to encoders. Encoder models need a simple, direct path to the model’s forward call with batch inference executed in a single pass.

Solution

A new execution path (encoder_only=True) that creates a lightweight EncoderExecutor instead of the full PyExecutor. The encode() method tokenizes, packs, and runs a single forward pass directly through ModelEngine.encoder_forward(), returning EncoderOutput with logits. This new API demonstrates a 3.92× speedup for the BERT 110M model (textattack/bert-base-uncased-yelp-polarity) in eager mode with batch size 10.

encode()   mean: 5.17ms  (p50=5.16ms)
generate() mean: 20.25ms (p50=20.02ms)
Speedup: 3.92x

Usage

# New dedicated path
llm = LLM(model="bert-base-uncased-yelp-polarity", encoder_only=True)
outputs = llm.encode(["Hello world", "Test sentence"])                                                                                                                 
print(outputs[0].logits)  # [num_classes] tensor
                                                                                                                                                                       
# Old path still works unchanged (no encoder_only flag)
llm = LLM(model="bert-base-uncased-yelp-polarity", disable_overlap_scheduler=True)
outputs = llm.generate(prompts, SamplingParams(return_context_logits=True))    
  • encoder_only=True must be explicitly set. Default (None) uses the old generate() path.
  • encoder_only=True creates only EncoderExecutor; False/None creates only PyExecutor. Mutually exclusive.
  • generate()/generate_async() raise RuntimeError when encoder_only=True. encode() is the only API.
  • Since llm.encode() reuses PyTorchModelEngine and its _forward_step() path, features like TorchCompileConfig are compatible.

Future Works

  • Encoder CUDA graph integration — capture the encoder model to one single CUDA graph
  • Triton backend update — add an encoder model example
  • Parallelism supports (e.g. TP > 1) — expand the EncoderExecutor
  • Other minor optimizations (e.g. batch tokenization, cache AttentionMetadata, etc)

Test Coverage

  • tests/unittest/llmapi/test_llm_encode.py — 11 new tests:
    • Basic: single string, batch, token IDs, mixed input types
    • Correctness: logits compared against HuggingFace BertForSequenceClassification
    • Health check: _check_health() returns True for encoder-only LLM
  • Existing tests unaffected (no encoder_only=True → old path):
    • tests/integration/defs/test_e2e.py::test_ptp_quickstart_bert
    • tests/unittest/llmapi/test_llm_pytorch.py::test_llm_reward_model

CC: @symphonylyh @amukkara @nvrohanv @schetlur-nv @juney-nvidia

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Apr 7, 2026
Signed-off-by: tingyangk <tingyangk@nvidia.com>
@tingyangk tingyangk force-pushed the tingyangk/encoder-llmapi-optimize branch from 04bcd64 to 83bc6b9 Compare April 7, 2026 08:45
Copy link
Copy Markdown
Collaborator

@nvrohanv nvrohanv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on tokenization piece and handling of empty batch but overall looks good!


unbatched = not isinstance(inputs, list)
if not unbatched:
if isinstance(inputs[0], int):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: How do we handle case where an empty batch is passed in, unless theres handling elsewhere I think this would cause Index Error. I'm guessing this logic exists elsewhere as well so this might be a general question of how we want to handle it.

max_seq_len_batch = max(max_seq_len_batch, seq_len)
prompts.append(None)
elif "prompt" in inp:
token_ids, _ = self.input_processor(inp, sampling_params)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it faster to do the tokenization in a batch - I see that later we do the processing to turn it into a flat "packed" tensor? Especially for cases with larger batch sizes I'm curious if we could cut down on tokenization overhead this way. Not sure about the structure of input_processor and if it can do this well.

encoder_only: Optional[bool] = Field(
default=None,
description=
"Set to True for encoder-only models (BERT, RoBERTa, reward models, "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a bit clearer that it would work for decoder only models being run in "encoder-only" style

status="prototype",
)

encoder_only: Optional[bool] = Field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to mm_encoder_only above, should this field be bool with default=False?

from .encoder_executor import EncoderExecutor

torch.cuda.set_per_process_memory_fraction(1.0)
checkpoint_loader = _construct_checkpoint_loader(llm_args.backend,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint loader logic is common across create_py_executor, create_encoder_executor, can be moved to helper function.

Dict with 'logits' tensor and any other model outputs.
"""
model_inputs = self._prepare_encoder_inputs(inputs)
return self._forward_step(model_inputs, None, False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self._forward_step(model_inputs, None, False)
return self._forward_step(model_inputs, gather_ids=None, gather_context_logits=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to enforce kwargs by adding *, after inputs here.

@schetlur-nv schetlur-nv requested a review from brb-nv April 7, 2026 22:12
@pcastonguay pcastonguay requested a review from Superjomn April 8, 2026 13:19
@pcastonguay
Copy link
Copy Markdown
Collaborator

@Superjomn could you review since it adds a new method to LLM API? Thx.

Dict with 'logits' tensor and any other model outputs.
"""
model_inputs = self._prepare_encoder_inputs(inputs)
return self._forward_step(model_inputs, None, False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to enforce kwargs by adding *, after inputs here.

@torch.inference_mode()
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
@nvtx_range("encoder_forward")
def encoder_forward(self,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this function relate to _forward_step_mm_encoder_only? The purpose seems similar. Does it make sense to unify these?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants