Skip to content

[https://nvbugs/5945047][fix] Fix Eagle3 one-model hang on SM120 via extend_ctx#12795

Open
ziyixiong-nv wants to merge 1 commit intoNVIDIA:mainfrom
ziyixiong-nv:dev-fxiong-bug5945047
Open

[https://nvbugs/5945047][fix] Fix Eagle3 one-model hang on SM120 via extend_ctx#12795
ziyixiong-nv wants to merge 1 commit intoNVIDIA:mainfrom
ziyixiong-nv:dev-fxiong-bug5945047

Conversation

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator

@ziyixiong-nv ziyixiong-nv commented Apr 7, 2026

On SM120/121, XQA spec-dec cubins are not available for non-MLA models, causing Eagle3 one-model to hang. The previous fix (extend_ctx=True for one-model on SM120) correctly routes multi-token generation queries through FMHA instead of XQA, but broke the Eagle3 spec worker because extend_ctx inflates attn_metadata.num_contexts and num_ctx_tokens to include extended generation requests.

Fix the Eagle3 one-model spec worker to use spec_metadata.num_generations (the real generation count from the scheduler) instead of deriving it from attn_metadata.num_contexts. This decouples the spec worker's ctx/gen split from the attention layer's view, which may differ under extend_ctx mode.

Changes:

  • model_engine: Set spec_metadata.num_generations in the incremental update path (_apply_incremental_update_target)
  • eagle3: Use spec_metadata.num_generations for the real gen/ctx split in forward(), sample_and_accept_draft_tokens(), and prepare_1st_drafter_inputs()
  • eagle3: Compute real num_ctx_tokens from _seq_lens[:num_contexts] instead of using the inflated attn_metadata.num_ctx_tokens

Summary by CodeRabbit

  • Bug Fixes

    • Corrected generation count metadata tracking in speculative decoding operations for improved accuracy.
    • Enhanced speculative decoding batch partitioning logic to use authoritative metadata sources.
    • Added optimized support for NVIDIA SM120/121 GPUs in one-engine speculative decoding mode.
  • Tests

    • Restored test coverage by removing outdated skip directive.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42077 [ run ] triggered by Bot. Commit: f15aca6 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2026

📝 Walkthrough

Walkthrough

The changes improve speculative decoding by consolidating metadata tracking to use spec_metadata.num_generations as the authoritative context/generation split source, and add SM120/121 optimizations for one-engine speculative decoding with TrtllmAttention backend. A corresponding test skip is removed.

Changes

Cohort / File(s) Summary
Speculative Decoding Metadata
tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/speculative/eagle3.py
Updated incremental context updates to populate spec_metadata.num_generations from num_extend_requests; refactored EAGLE3 batch partitioning to derive context/generation split from spec_metadata.num_generations instead of attn_metadata.num_contexts across three functions (batch_partitioning, sample_and_accept_draft_tokens, prepare_1st_drafter_inputs).
SM120/121 Speculative Decoding Optimization
tensorrt_llm/_torch/speculative/interface.py
Added special-case early return in SpeculativeDecodingMode.extend_ctx() for one-engine speculative decoding on SM 120/121 with TrtllmAttention backend, treating draft tokens as chunked context requests at kernel level.
Test Skip Removal
tests/integration/test_lists/waives.txt
Removed skip directive for EAGLE3 4-GPU accuracy test on RTXPro6000D.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the bug ticket, change type, and primary objective: fixing Eagle3 one-model hang on SM120 via extend_ctx mode.
Description check ✅ Passed The PR description provides a clear problem statement, solution explanation, and specific code changes, but the formal template sections (Description and Test Coverage) remain largely empty.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/speculative/interface.py`:
- Around line 264-271: The current SM120/121 fallback unconditionally forces
extend_ctx() for any one-engine TrtllmAttention; change the conditional in the
method that contains use_one_engine() so it also checks the model's MLA
capability (e.g., require that the model does NOT support MLA) before returning
True. Concretely, update the if that references get_sm_version() and
issubclass(attention_backend, TrtllmAttention) to additionally query the model
capability (for example via a model.supports_mla() or model_capabilities.mla
flag passed into this scope) and only trigger the fallback when MLA is
unavailable; keep the rest of the logic (SM version check and TrtllmAttention
class guard) intact.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c977448d-fd88-4704-8816-5119a51e6fa6

📥 Commits

Reviewing files that changed from the base of the PR and between 576816d and f15aca6.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
  • tests/integration/test_lists/waives.txt

Comment on lines 264 to +271
if self.use_one_engine():
# 1-model has separate logic for handling draft tokens
# SM120/121 lacks XQA spec-dec cubins for non-MLA models, so
# one-model must fall back to treating draft tokens as context.
# The C++ attention layer will still use FMHA fallback for the
# draft model's multi-token queries (spec-dec mode remains on).
if get_sm_version() in (120, 121) and issubclass(
attention_backend, TrtllmAttention):
return True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gate this SM120/121 fallback on MLA capability.

The new override now forces extend_ctx() for every one-engine TrtllmAttention model on SM120/121, but the failure mode called out here is specifically non-MLA. As written, MLA models will also be pushed onto the extend_ctx/FMHA path, which broadens the workaround and can regress the fast path on Blackwell unnecessarily. Please pass model capability into this check instead of keying only on the backend class.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/interface.py` around lines 264 - 271, The
current SM120/121 fallback unconditionally forces extend_ctx() for any
one-engine TrtllmAttention; change the conditional in the method that contains
use_one_engine() so it also checks the model's MLA capability (e.g., require
that the model does NOT support MLA) before returning True. Concretely, update
the if that references get_sm_version() and issubclass(attention_backend,
TrtllmAttention) to additionally query the model capability (for example via a
model.supports_mla() or model_capabilities.mla flag passed into this scope) and
only trigger the fallback when MLA is unavailable; keep the rest of the logic
(SM version check and TrtllmAttention class guard) intact.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42077 [ run ] completed with state SUCCESS. Commit: f15aca6
/LLM/main/L0_MergeRequest_PR pipeline #32915 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…extend_ctx

On SM120/121, XQA spec-dec cubins are not available for non-MLA models,
causing Eagle3 one-model to hang. The previous fix (extend_ctx=True for
one-model on SM120) correctly routes multi-token generation queries
through FMHA instead of XQA, but broke the Eagle3 spec worker because
extend_ctx inflates attn_metadata.num_contexts and num_ctx_tokens to
include extended generation requests.

Fix the Eagle3 one-model spec worker to use spec_metadata.num_generations
(the real generation count from the scheduler) instead of deriving it from
attn_metadata.num_contexts. This decouples the spec worker's ctx/gen split
from the attention layer's view, which may differ under extend_ctx mode.

Changes:
- model_engine: Set spec_metadata.num_generations in the incremental
  update path (_apply_incremental_update_target)
- eagle3: Use spec_metadata.num_generations for the real gen/ctx split
  in forward(), sample_and_accept_draft_tokens(), and
  prepare_1st_drafter_inputs()
- eagle3: Compute real num_ctx_tokens from _seq_lens[:num_contexts]
  instead of using the inflated attn_metadata.num_ctx_tokens

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
@ziyixiong-nv ziyixiong-nv force-pushed the dev-fxiong-bug5945047 branch from f15aca6 to 489df28 Compare April 7, 2026 08:52
@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42117 [ run ] triggered by Bot. Commit: 489df28 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42117 [ run ] completed with state SUCCESS. Commit: 489df28
/LLM/main/L0_MergeRequest_PR pipeline #32955 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42303 [ run ] triggered by Bot. Commit: 489df28 Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants