From 489df28b4b39fb6747aabb8d6278b2ca40d6c669 Mon Sep 17 00:00:00 2001 From: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 22:09:11 -0700 Subject: [PATCH] [https://nvbugs/5945047][fix] Fix Eagle3 one-model hang on SM120 via extend_ctx On SM120/121, XQA spec-dec cubins are not available for non-MLA models, causing Eagle3 one-model to hang. The previous fix (extend_ctx=True for one-model on SM120) correctly routes multi-token generation queries through FMHA instead of XQA, but broke the Eagle3 spec worker because extend_ctx inflates attn_metadata.num_contexts and num_ctx_tokens to include extended generation requests. Fix the Eagle3 one-model spec worker to use spec_metadata.num_generations (the real generation count from the scheduler) instead of deriving it from attn_metadata.num_contexts. This decouples the spec worker's ctx/gen split from the attention layer's view, which may differ under extend_ctx mode. Changes: - model_engine: Set spec_metadata.num_generations in the incremental update path (_apply_incremental_update_target) - eagle3: Use spec_metadata.num_generations for the real gen/ctx split in forward(), sample_and_accept_draft_tokens(), and prepare_1st_drafter_inputs() - eagle3: Compute real num_ctx_tokens from _seq_lens[:num_contexts] instead of using the inflated attn_metadata.num_ctx_tokens Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 1 + tensorrt_llm/_torch/speculative/eagle3.py | 21 ++++++++++++------- tensorrt_llm/_torch/speculative/interface.py | 8 ++++++- tests/integration/test_lists/waives.txt | 1 - 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 7c2ce12a7f4..ac12020f80b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2107,6 +2107,7 @@ def _apply_incremental_update_target( spec_metadata.gather_ids = self.gather_ids_cuda[:total_num_tokens] spec_metadata.num_accepted_draft_tokens = self.num_accepted_draft_tokens_cuda[: num_extend_requests] + spec_metadata.num_generations = num_extend_requests # Determine if we're using extend_ctx mode for linear tree decoding num_extend_ctx_requests = 0 diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 0211f3e7c98..30c3d7f92cb 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -446,8 +446,13 @@ def forward(self, draft_model) batch_size = attn_metadata.num_seqs - num_contexts = attn_metadata.num_contexts - num_gens = batch_size - num_contexts + # Use spec_metadata.num_generations for the real gen/ctx split. + # attn_metadata.num_contexts and num_ctx_tokens may be inflated by + # extend_ctx mode (which treats gen draft tokens as context for attn). + num_gens = spec_metadata.num_generations + num_contexts = batch_size - num_gens + num_ctx_tokens = attn_metadata._seq_lens[:num_contexts].sum().item( + ) if num_contexts > 0 else 0 raw_logits = logits @@ -499,7 +504,7 @@ def forward(self, (runtime_draft_len + 1)).long() gather_ids_gen = (start_ids_gen + num_accepted_tokens[num_contexts:] - 1 + - attn_metadata.num_ctx_tokens) + num_ctx_tokens) gather_ids = torch.concat([ spec_metadata.gather_ids[:num_contexts], gather_ids_gen ], @@ -609,8 +614,8 @@ def sample_and_accept_draft_tokens( spec_metadata: Eagle3OneModelSpecMetadata, ): batch_size = attn_metadata.num_seqs - num_contexts = attn_metadata.num_contexts - num_gens = batch_size - num_contexts + num_gens = spec_metadata.num_generations + num_contexts = batch_size - num_gens draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, spec_metadata.runtime_draft_len) @@ -653,7 +658,9 @@ def prepare_1st_drafter_inputs( spec_metadata: Eagle3OneModelSpecMetadata, draft_model: nn.Module, ): - num_contexts = attn_metadata.num_contexts + num_contexts = attn_metadata.num_seqs - spec_metadata.num_generations + num_ctx_tokens = attn_metadata._seq_lens[:num_contexts].sum().item( + ) if num_contexts > 0 else 0 num_tokens = input_ids.shape[0] # prepare hidden states @@ -665,7 +672,7 @@ def prepare_1st_drafter_inputs( # context input_ids_ctx = self._prepare_context_input_ids( - input_ids, attn_metadata.num_ctx_tokens, spec_metadata.gather_ids, + input_ids, num_ctx_tokens, spec_metadata.gather_ids, accepted_tokens, num_contexts) # generation diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 081030f482c..2c7c33bee81 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -262,7 +262,13 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]): """ if self.use_one_engine(): - # 1-model has separate logic for handling draft tokens + # SM120/121 lacks XQA spec-dec cubins for non-MLA models, so + # one-model must fall back to treating draft tokens as context. + # The C++ attention layer will still use FMHA fallback for the + # draft model's multi-token queries (spec-dec mode remains on). + if get_sm_version() in (120, 121) and issubclass( + attention_backend, TrtllmAttention): + return True return False xqa_supported = get_sm_version() < 120 diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 38b031e778e..2df0965d0cd 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -288,7 +288,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[True-True unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[True-False-False-2880-2880-128] SKIP (https://nvbugs/5996776) unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[False-True-False-2880-2880-128] SKIP (https://nvbugs/5996776) unittest/_torch/modules/test_fused_moe.py::test_fused_moe_triton_mxfp4[False-False-False-2880-2880-128] SKIP (https://nvbugs/5996776) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[v2_kv_cache-cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5945047) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-low_precision_combine=False-torch_compile=False] SKIP (https://nvbugs/5945081) full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-low_precision_combine=False-torch_compile=False] SKIP (https://nvbugs/5948435) unittest/_torch/modeling -k "modeling_siglip" SKIP (https://nvbugs/5941242)