Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,7 @@ def _apply_incremental_update_target(
spec_metadata.gather_ids = self.gather_ids_cuda[:total_num_tokens]
spec_metadata.num_accepted_draft_tokens = self.num_accepted_draft_tokens_cuda[:
num_extend_requests]
spec_metadata.num_generations = num_extend_requests

# Determine if we're using extend_ctx mode for linear tree decoding
num_extend_ctx_requests = 0
Expand Down
21 changes: 14 additions & 7 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,13 @@ def forward(self,
draft_model)

batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
# Use spec_metadata.num_generations for the real gen/ctx split.
# attn_metadata.num_contexts and num_ctx_tokens may be inflated by
# extend_ctx mode (which treats gen draft tokens as context for attn).
num_gens = spec_metadata.num_generations
num_contexts = batch_size - num_gens
num_ctx_tokens = attn_metadata._seq_lens[:num_contexts].sum().item(
) if num_contexts > 0 else 0

raw_logits = logits

Expand Down Expand Up @@ -499,7 +504,7 @@ def forward(self,
(runtime_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
num_ctx_tokens)
gather_ids = torch.concat([
spec_metadata.gather_ids[:num_contexts], gather_ids_gen
],
Expand Down Expand Up @@ -609,8 +614,8 @@ def sample_and_accept_draft_tokens(
spec_metadata: Eagle3OneModelSpecMetadata,
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
num_gens = spec_metadata.num_generations
num_contexts = batch_size - num_gens

draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, spec_metadata.runtime_draft_len)
Expand Down Expand Up @@ -653,7 +658,9 @@ def prepare_1st_drafter_inputs(
spec_metadata: Eagle3OneModelSpecMetadata,
draft_model: nn.Module,
):
num_contexts = attn_metadata.num_contexts
num_contexts = attn_metadata.num_seqs - spec_metadata.num_generations
num_ctx_tokens = attn_metadata._seq_lens[:num_contexts].sum().item(
) if num_contexts > 0 else 0
num_tokens = input_ids.shape[0]

# prepare hidden states
Expand All @@ -665,7 +672,7 @@ def prepare_1st_drafter_inputs(

# context
input_ids_ctx = self._prepare_context_input_ids(
input_ids, attn_metadata.num_ctx_tokens, spec_metadata.gather_ids,
input_ids, num_ctx_tokens, spec_metadata.gather_ids,
accepted_tokens, num_contexts)

# generation
Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,13 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""

if self.use_one_engine():
# 1-model has separate logic for handling draft tokens
# SM120/121 lacks XQA spec-dec cubins for non-MLA models, so
# one-model must fall back to treating draft tokens as context.
# The C++ attention layer will still use FMHA fallback for the
# draft model's multi-token queries (spec-dec mode remains on).
if get_sm_version() in (120, 121) and issubclass(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return true for non-MLA only to align with above comment?

attention_backend, TrtllmAttention):
return True
return False

xqa_supported = get_sm_version() < 120
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[True-True
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[True-False-False-2880-2880-128] SKIP (https://nvbugs/5996776)
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[False-True-False-2880-2880-128] SKIP (https://nvbugs/5996776)
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[False-False-False-2880-2880-128] SKIP (https://nvbugs/5996776)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[v2_kv_cache-cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5945047)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-low_precision_combine=False-torch_compile=False] SKIP (https://nvbugs/5945081)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-low_precision_combine=False-torch_compile=False] SKIP (https://nvbugs/5948435)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False-v2_kv_cache=False] SKIP (https://nvbugs/5955765)
Expand Down
Loading