Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/integration/defs/perf/pytorch_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,29 @@ def get_model_yaml_config(model_label: str,
}
}
},
# GPT-OSS 20B (NVBug 5720470: MMHA vs XQA kernel regression)
{
'patterns': [
'gpt_oss_20b_fp4-bench-pytorch-float4',
],
'config': {
'cuda_graph_config': {
'max_batch_size': 512,
'enable_padding': True,
},
'enable_chunked_prefill': False,
'enable_attention_dp': False,
'disable_overlap_scheduler': False,
'kv_cache_config': {
'enable_block_reuse': False,
'free_gpu_memory_fraction': 0.9,
},
'moe_config': {
'backend': 'TRITON'
},
'print_iter_log': True,
}
},
# GPT-OSS 120B max throughput test
{
'patterns': [
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_lists/qa/llm_perf_core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ llm_perf_core:
- perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-maxbs:512-maxnt:2048-kv_frac:0.85-input_output_len:3000,500-reqs:200]
- perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-streaming-float8-input_output_len:128,128]
- perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:4096-maxnt:20000-input_output_len:20000,2000-reqs:500-con:200] TIMEOUT(120)
# gpt_oss_20b_fp4 (NVBug 5720470: MMHA vs XQA kernel regression)
- perf/test_perf.py::test_perf[gpt_oss_20b_fp4-bench-pytorch-float4-maxbs:512-maxnt:8192-input_output_len:2000,200-con:64]
- perf/test_perf.py::test_perf[gpt_oss_20b_fp4-bench-pytorch-float4-maxbs:512-maxnt:8192-input_output_len:128,128]
- perf/test_perf.py::test_perf[gpt_oss_20b_fp4-bench-pytorch-float4-maxbs:512-maxnt:8192-input_output_len:2000,200-con:256]


# 5: H100, H20, H200, GB200, B200, B300, GB300, RTX6000-D, RTX6000-Server test cases
Expand Down
149 changes: 147 additions & 2 deletions tests/unittest/_torch/modeling/test_modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,23 @@
import shutil

import pytest
from transformers import AutoTokenizer
import torch
from torch.profiler import ProfilerActivity
from transformers import AutoTokenizer, GptOssConfig
from utils.llm_data import llm_models_root
from utils.util import skip_no_hopper
from utils.util import skip_no_hopper, skip_pre_hopper

import tensorrt_llm
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_gpt_oss import GptOssForCausalLM
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import \
KvCacheConfig as BindingsKvCacheConfig
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MoeConfig
from tensorrt_llm.mapping import Mapping

Comment on lines +6 to 23
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add the required NVIDIA SPDX header before the import block.

This file was meaningfully modified in 2026 but still starts directly with imports, so the repository-mandated copyright/license header is missing.

As per coding guidelines, All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modeling/test_modeling_gpt_oss.py` around lines 6 - 23,
Add the required NVIDIA SPDX copyright/header at the very top of the file
(before the first import, e.g., before the "import torch" line) using the
repository-mandated template and the latest meaningful modification year (2026);
ensure the header appears above the existing imports such as "import torch",
"from transformers import AutoTokenizer, GptOssConfig" and references to
"GptOssForCausalLM" so the file passes license checks.

configs = """
{
Expand Down Expand Up @@ -89,3 +100,137 @@ def test_gpt_oss_trtllmgen(moe_backend):

sampling_params = SamplingParams(max_tokens=20)
llm.generate(prompts, sampling_params)


@skip_pre_hopper
def test_gpt_oss_xqa_kernel_selection():
"""NVBug 5720470: GPT-OSS-20B must use XQA kernel (not MMHA) in decode.

GPT-OSS-20B config: num_heads=64, num_kv_heads=8, head_dim=64, bfloat16.
With batch_size=8 and sufficient decode history, the XQA heuristic
(mayHavePerfGain) should select XQA over MMHA.

Verification: use torch.profiler to capture CUDA kernels and assert that
XQA kernel (kernel_mha) is launched instead of MMHA
(masked_multihead_attention_kernel).
"""
config_dict = json.loads(configs)
# Use fewer layers for faster test
config_dict["num_hidden_layers"] = 1
gpt_oss_config = GptOssConfig.from_dict(config_dict)

dtype = torch.bfloat16
device = torch.device("cuda")

model_config = ModelConfig(pretrained_config=gpt_oss_config,
attn_backend="TRTLLM")
with torch.no_grad():
model = GptOssForCausalLM(model_config).cuda()
# Cast model weights to bfloat16 but keep float32 params (e.g. sinks).
for name, param in model.named_parameters():
if param.dtype == torch.float32 and 'sinks' not in name:
param.data = param.data.to(dtype)
elif param.dtype not in (torch.float32, dtype):
param.data = param.data.to(dtype)

# All-decode batch: 8 sequences in generation phase with past history.
# This triggers the generation-phase attention where XQA/MMHA dispatch
# happens. batch_size=8, past_kv >= 128 ensures the occupancy heuristic
# selects XQA: num_kv_heads(8) * batch(8) * multi_block(>=1) * 4.0 >= SM_count.
batch_size = 8
context_sequence_length = [] # no context (prefill) sequences
sequence_length = [1] * batch_size # all decode, 1 new token each
past_seen_tokens = [256] * batch_size # enough history for multi-block
request_ids = list(range(batch_size))
token_nums = [p + s for p, s in zip(past_seen_tokens, sequence_length)]
Comment on lines +140 to +145
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail
python - <<'PY'
import ast
from pathlib import Path

path = Path("tests/unittest/_torch/modeling/test_modeling_gpt_oss.py")
tree = ast.parse(path.read_text(encoding="utf-8"))

for node in ast.walk(tree):
    if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "zip":
        has_strict = any(keyword.arg == "strict" for keyword in node.keywords)
        print(f"Line {node.lineno}: zip(...), strict={has_strict}")
PY

Repository: NVIDIA/TensorRT-LLM

Length of output: 95


Add strict=True to the zip() call on Line 145.

The zip() function on line 145 currently lacks the strict parameter, which allows silent truncation if past_seen_tokens and sequence_length lengths diverge. Since TensorRT-LLM requires Python 3.10+, use strict=True to detect such mismatches early. Ruff already flags this as B905.

Suggested fix
-    token_nums = [p + s for p, s in zip(past_seen_tokens, sequence_length)]
+    token_nums = [
+        p + s
+        for p, s in zip(past_seen_tokens, sequence_length, strict=True)
+    ]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
batch_size = 8
context_sequence_length = [] # no context (prefill) sequences
sequence_length = [1] * batch_size # all decode, 1 new token each
past_seen_tokens = [256] * batch_size # enough history for multi-block
request_ids = list(range(batch_size))
token_nums = [p + s for p, s in zip(past_seen_tokens, sequence_length)]
batch_size = 8
context_sequence_length = [] # no context (prefill) sequences
sequence_length = [1] * batch_size # all decode, 1 new token each
past_seen_tokens = [256] * batch_size # enough history for multi-block
request_ids = list(range(batch_size))
token_nums = [
p + s
for p, s in zip(past_seen_tokens, sequence_length, strict=True)
]
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 145-145: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modeling/test_modeling_gpt_oss.py` around lines 140 -
145, The zip() used to compute token_nums (token_nums = [p + s for p, s in
zip(past_seen_tokens, sequence_length)]) should include strict=True to avoid
silent truncation when past_seen_tokens and sequence_length lengths differ;
update that list comprehension to call zip(past_seen_tokens, sequence_length,
strict=True) so mismatched lengths raise an error during tests (references:
variables past_seen_tokens, sequence_length, token_nums).

prompt_lens = past_seen_tokens

num_blocks = 200
tokens_per_block = 128
head_dim = gpt_oss_config.head_dim
num_layers = gpt_oss_config.num_hidden_layers
num_kv_heads = gpt_oss_config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block

kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = BindingsKvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)

metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(sequence_length, dtype=torch.int32),
num_contexts=len(context_sequence_length),
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=past_seen_tokens,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=batch_size,
max_num_tokens=8192,
)

# 1 token per decode sequence
input_ids = torch.randint(0,
gpt_oss_config.vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
position_ids = torch.tensor(past_seen_tokens,
dtype=torch.long,
device=device).unsqueeze(0)

# Warm-up run (JIT compile XQA kernels, allocate buffers).
with torch.inference_mode():
attn_metadata.prepare()
model.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)

# Profiled run: capture CUDA kernel names to verify XQA dispatch.
kernel_names = []
with torch.inference_mode(), \
torch.profiler.profile(activities=[ProfilerActivity.CUDA]) as prof:
attn_metadata.prepare()
logits = model.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)

assert logits.shape[0] == batch_size, \
f"Expected {batch_size} logits, got {logits.shape[0]}"

kv_cache_manager.shutdown()

# Collect CUDA kernel names from the profiler trace.
kernel_names = [
evt.key for evt in prof.key_averages()
if evt.device_type == torch.autograd.DeviceType.CUDA
]
all_kernels = " ".join(kernel_names)

# XQA kernel: "kernel_mha" (from decoderXQA JIT).
# MMHA kernel: "masked_multihead_attention_kernel".
has_xqa = any("kernel_mha" in k for k in kernel_names)
has_mmha = any("masked_multihead_attention_kernel" in k
for k in kernel_names)

assert has_xqa, (
"GPT-OSS-20B decode did not launch XQA kernel (kernel_mha). "
f"See NVBug 5720470. Captured CUDA kernels: {all_kernels}")
assert not has_mmha, (
"GPT-OSS-20B decode launched MMHA kernel instead of XQA. "
f"See NVBug 5720470. Captured CUDA kernels: {all_kernels}")
Loading