diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 7c2ce12a7f4..ac12020f80b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2107,6 +2107,7 @@ def _apply_incremental_update_target( spec_metadata.gather_ids = self.gather_ids_cuda[:total_num_tokens] spec_metadata.num_accepted_draft_tokens = self.num_accepted_draft_tokens_cuda[: num_extend_requests] + spec_metadata.num_generations = num_extend_requests # Determine if we're using extend_ctx mode for linear tree decoding num_extend_ctx_requests = 0 diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 0211f3e7c98..30c3d7f92cb 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -446,8 +446,13 @@ def forward(self, draft_model) batch_size = attn_metadata.num_seqs - num_contexts = attn_metadata.num_contexts - num_gens = batch_size - num_contexts + # Use spec_metadata.num_generations for the real gen/ctx split. + # attn_metadata.num_contexts and num_ctx_tokens may be inflated by + # extend_ctx mode (which treats gen draft tokens as context for attn). + num_gens = spec_metadata.num_generations + num_contexts = batch_size - num_gens + num_ctx_tokens = attn_metadata._seq_lens[:num_contexts].sum().item( + ) if num_contexts > 0 else 0 raw_logits = logits @@ -499,7 +504,7 @@ def forward(self, (runtime_draft_len + 1)).long() gather_ids_gen = (start_ids_gen + num_accepted_tokens[num_contexts:] - 1 + - attn_metadata.num_ctx_tokens) + num_ctx_tokens) gather_ids = torch.concat([ spec_metadata.gather_ids[:num_contexts], gather_ids_gen ], @@ -609,8 +614,8 @@ def sample_and_accept_draft_tokens( spec_metadata: Eagle3OneModelSpecMetadata, ): batch_size = attn_metadata.num_seqs - num_contexts = attn_metadata.num_contexts - num_gens = batch_size - num_contexts + num_gens = spec_metadata.num_generations + num_contexts = batch_size - num_gens draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, spec_metadata.runtime_draft_len) @@ -653,7 +658,9 @@ def prepare_1st_drafter_inputs( spec_metadata: Eagle3OneModelSpecMetadata, draft_model: nn.Module, ): - num_contexts = attn_metadata.num_contexts + num_contexts = attn_metadata.num_seqs - spec_metadata.num_generations + num_ctx_tokens = attn_metadata._seq_lens[:num_contexts].sum().item( + ) if num_contexts > 0 else 0 num_tokens = input_ids.shape[0] # prepare hidden states @@ -665,7 +672,7 @@ def prepare_1st_drafter_inputs( # context input_ids_ctx = self._prepare_context_input_ids( - input_ids, attn_metadata.num_ctx_tokens, spec_metadata.gather_ids, + input_ids, num_ctx_tokens, spec_metadata.gather_ids, accepted_tokens, num_contexts) # generation diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 081030f482c..2c7c33bee81 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -262,7 +262,13 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]): """ if self.use_one_engine(): - # 1-model has separate logic for handling draft tokens + # SM120/121 lacks XQA spec-dec cubins for non-MLA models, so + # one-model must fall back to treating draft tokens as context. + # The C++ attention layer will still use FMHA fallback for the + # draft model's multi-token queries (spec-dec mode remains on). + if get_sm_version() in (120, 121) and issubclass( + attention_backend, TrtllmAttention): + return True return False xqa_supported = get_sm_version() < 120 diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 38b031e778e..2df0965d0cd 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -288,7 +288,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[True-True unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[True-False-False-2880-2880-128] SKIP (https://nvbugs/5996776) unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[False-True-False-2880-2880-128] SKIP (https://nvbugs/5996776) unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[False-False-False-2880-2880-128] SKIP (https://nvbugs/5996776) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[v2_kv_cache-cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5945047) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-low_precision_combine=False-torch_compile=False] SKIP (https://nvbugs/5945081) full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-low_precision_combine=False-torch_compile=False] SKIP (https://nvbugs/5948435) unittest/_torch/modeling -k "modeling_siglip" SKIP (https://nvbugs/5941242)