Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,7 @@ def _apply_incremental_update_target(
spec_metadata.gather_ids = self.gather_ids_cuda[:total_num_tokens]
spec_metadata.num_accepted_draft_tokens = self.num_accepted_draft_tokens_cuda[:
num_extend_requests]
spec_metadata.num_generations = num_extend_requests

# Determine if we're using extend_ctx mode for linear tree decoding
num_extend_ctx_requests = 0
Expand Down
21 changes: 14 additions & 7 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,13 @@ def forward(self,
draft_model)

batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
# Use spec_metadata.num_generations for the real gen/ctx split.
# attn_metadata.num_contexts and num_ctx_tokens may be inflated by
# extend_ctx mode (which treats gen draft tokens as context for attn).
num_gens = spec_metadata.num_generations
num_contexts = batch_size - num_gens
num_ctx_tokens = attn_metadata._seq_lens[:num_contexts].sum().item(
) if num_contexts > 0 else 0

raw_logits = logits

Expand Down Expand Up @@ -499,7 +504,7 @@ def forward(self,
(runtime_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
num_ctx_tokens)
gather_ids = torch.concat([
spec_metadata.gather_ids[:num_contexts], gather_ids_gen
],
Expand Down Expand Up @@ -609,8 +614,8 @@ def sample_and_accept_draft_tokens(
spec_metadata: Eagle3OneModelSpecMetadata,
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
num_gens = spec_metadata.num_generations
num_contexts = batch_size - num_gens

draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, spec_metadata.runtime_draft_len)
Expand Down Expand Up @@ -653,7 +658,9 @@ def prepare_1st_drafter_inputs(
spec_metadata: Eagle3OneModelSpecMetadata,
draft_model: nn.Module,
):
num_contexts = attn_metadata.num_contexts
num_contexts = attn_metadata.num_seqs - spec_metadata.num_generations
num_ctx_tokens = attn_metadata._seq_lens[:num_contexts].sum().item(
) if num_contexts > 0 else 0
num_tokens = input_ids.shape[0]

# prepare hidden states
Expand All @@ -665,7 +672,7 @@ def prepare_1st_drafter_inputs(

# context
input_ids_ctx = self._prepare_context_input_ids(
input_ids, attn_metadata.num_ctx_tokens, spec_metadata.gather_ids,
input_ids, num_ctx_tokens, spec_metadata.gather_ids,
accepted_tokens, num_contexts)

# generation
Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,13 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""

if self.use_one_engine():
# 1-model has separate logic for handling draft tokens
# SM120/121 lacks XQA spec-dec cubins for non-MLA models, so
# one-model must fall back to treating draft tokens as context.
# The C++ attention layer will still use FMHA fallback for the
# draft model's multi-token queries (spec-dec mode remains on).
if get_sm_version() in (120, 121) and issubclass(
attention_backend, TrtllmAttention):
return True
Comment on lines 264 to +271
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gate this SM120/121 fallback on MLA capability.

The new override now forces extend_ctx() for every one-engine TrtllmAttention model on SM120/121, but the failure mode called out here is specifically non-MLA. As written, MLA models will also be pushed onto the extend_ctx/FMHA path, which broadens the workaround and can regress the fast path on Blackwell unnecessarily. Please pass model capability into this check instead of keying only on the backend class.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/interface.py` around lines 264 - 271, The
current SM120/121 fallback unconditionally forces extend_ctx() for any
one-engine TrtllmAttention; change the conditional in the method that contains
use_one_engine() so it also checks the model's MLA capability (e.g., require
that the model does NOT support MLA) before returning True. Concretely, update
the if that references get_sm_version() and issubclass(attention_backend,
TrtllmAttention) to additionally query the model capability (for example via a
model.supports_mla() or model_capabilities.mla flag passed into this scope) and
only trigger the fallback when MLA is unavailable; keep the rest of the logic
(SM version check and TrtllmAttention class guard) intact.

return False

xqa_supported = get_sm_version() < 120
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[True-True
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[True-False-False-2880-2880-128] SKIP (https://nvbugs/5996776)
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[False-True-False-2880-2880-128] SKIP (https://nvbugs/5996776)
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[False-False-False-2880-2880-128] SKIP (https://nvbugs/5996776)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[v2_kv_cache-cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5945047)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-low_precision_combine=False-torch_compile=False] SKIP (https://nvbugs/5945081)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-low_precision_combine=False-torch_compile=False] SKIP (https://nvbugs/5948435)
unittest/_torch/modeling -k "modeling_siglip" SKIP (https://nvbugs/5941242)
Expand Down
Loading