-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[None][test] add unit test and e2e test for gpt_oss_20b MHA kernel #12796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,12 +3,23 @@ | |||||||||||||||||||||||||||||||
| import shutil | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||
| from transformers import AutoTokenizer | ||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||
| from torch.profiler import ProfilerActivity | ||||||||||||||||||||||||||||||||
| from transformers import AutoTokenizer, GptOssConfig | ||||||||||||||||||||||||||||||||
| from utils.llm_data import llm_models_root | ||||||||||||||||||||||||||||||||
| from utils.util import skip_no_hopper | ||||||||||||||||||||||||||||||||
| from utils.util import skip_no_hopper, skip_pre_hopper | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import tensorrt_llm | ||||||||||||||||||||||||||||||||
| from tensorrt_llm import LLM, SamplingParams | ||||||||||||||||||||||||||||||||
| from tensorrt_llm._torch.attention_backend.utils import get_attention_backend | ||||||||||||||||||||||||||||||||
| from tensorrt_llm._torch.metadata import KVCacheParams | ||||||||||||||||||||||||||||||||
| from tensorrt_llm._torch.model_config import ModelConfig | ||||||||||||||||||||||||||||||||
| from tensorrt_llm._torch.models.modeling_gpt_oss import GptOssForCausalLM | ||||||||||||||||||||||||||||||||
| from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager | ||||||||||||||||||||||||||||||||
| from tensorrt_llm.bindings.executor import \ | ||||||||||||||||||||||||||||||||
| KvCacheConfig as BindingsKvCacheConfig | ||||||||||||||||||||||||||||||||
| from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MoeConfig | ||||||||||||||||||||||||||||||||
| from tensorrt_llm.mapping import Mapping | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| configs = """ | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
|
|
@@ -89,3 +100,137 @@ def test_gpt_oss_trtllmgen(moe_backend): | |||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| sampling_params = SamplingParams(max_tokens=20) | ||||||||||||||||||||||||||||||||
| llm.generate(prompts, sampling_params) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| @skip_pre_hopper | ||||||||||||||||||||||||||||||||
| def test_gpt_oss_xqa_kernel_selection(): | ||||||||||||||||||||||||||||||||
| """NVBug 5720470: GPT-OSS-20B must use XQA kernel (not MMHA) in decode. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| GPT-OSS-20B config: num_heads=64, num_kv_heads=8, head_dim=64, bfloat16. | ||||||||||||||||||||||||||||||||
| With batch_size=8 and sufficient decode history, the XQA heuristic | ||||||||||||||||||||||||||||||||
| (mayHavePerfGain) should select XQA over MMHA. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Verification: use torch.profiler to capture CUDA kernels and assert that | ||||||||||||||||||||||||||||||||
| XQA kernel (kernel_mha) is launched instead of MMHA | ||||||||||||||||||||||||||||||||
| (masked_multihead_attention_kernel). | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| config_dict = json.loads(configs) | ||||||||||||||||||||||||||||||||
| # Use fewer layers for faster test | ||||||||||||||||||||||||||||||||
| config_dict["num_hidden_layers"] = 1 | ||||||||||||||||||||||||||||||||
| gpt_oss_config = GptOssConfig.from_dict(config_dict) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| dtype = torch.bfloat16 | ||||||||||||||||||||||||||||||||
| device = torch.device("cuda") | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| model_config = ModelConfig(pretrained_config=gpt_oss_config, | ||||||||||||||||||||||||||||||||
| attn_backend="TRTLLM") | ||||||||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||||||||
| model = GptOssForCausalLM(model_config).cuda() | ||||||||||||||||||||||||||||||||
| # Cast model weights to bfloat16 but keep float32 params (e.g. sinks). | ||||||||||||||||||||||||||||||||
| for name, param in model.named_parameters(): | ||||||||||||||||||||||||||||||||
| if param.dtype == torch.float32 and 'sinks' not in name: | ||||||||||||||||||||||||||||||||
| param.data = param.data.to(dtype) | ||||||||||||||||||||||||||||||||
| elif param.dtype not in (torch.float32, dtype): | ||||||||||||||||||||||||||||||||
| param.data = param.data.to(dtype) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # All-decode batch: 8 sequences in generation phase with past history. | ||||||||||||||||||||||||||||||||
| # This triggers the generation-phase attention where XQA/MMHA dispatch | ||||||||||||||||||||||||||||||||
| # happens. batch_size=8, past_kv >= 128 ensures the occupancy heuristic | ||||||||||||||||||||||||||||||||
| # selects XQA: num_kv_heads(8) * batch(8) * multi_block(>=1) * 4.0 >= SM_count. | ||||||||||||||||||||||||||||||||
| batch_size = 8 | ||||||||||||||||||||||||||||||||
| context_sequence_length = [] # no context (prefill) sequences | ||||||||||||||||||||||||||||||||
| sequence_length = [1] * batch_size # all decode, 1 new token each | ||||||||||||||||||||||||||||||||
| past_seen_tokens = [256] * batch_size # enough history for multi-block | ||||||||||||||||||||||||||||||||
| request_ids = list(range(batch_size)) | ||||||||||||||||||||||||||||||||
| token_nums = [p + s for p, s in zip(past_seen_tokens, sequence_length)] | ||||||||||||||||||||||||||||||||
|
Comment on lines
+140
to
+145
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
set -euo pipefail
python - <<'PY'
import ast
from pathlib import Path
path = Path("tests/unittest/_torch/modeling/test_modeling_gpt_oss.py")
tree = ast.parse(path.read_text(encoding="utf-8"))
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "zip":
has_strict = any(keyword.arg == "strict" for keyword in node.keywords)
print(f"Line {node.lineno}: zip(...), strict={has_strict}")
PYRepository: NVIDIA/TensorRT-LLM Length of output: 95 Add The Suggested fix- token_nums = [p + s for p, s in zip(past_seen_tokens, sequence_length)]
+ token_nums = [
+ p + s
+ for p, s in zip(past_seen_tokens, sequence_length, strict=True)
+ ]📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.15.9)[warning] 145-145: Add explicit value for parameter (B905) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
| prompt_lens = past_seen_tokens | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| num_blocks = 200 | ||||||||||||||||||||||||||||||||
| tokens_per_block = 128 | ||||||||||||||||||||||||||||||||
| head_dim = gpt_oss_config.head_dim | ||||||||||||||||||||||||||||||||
| num_layers = gpt_oss_config.num_hidden_layers | ||||||||||||||||||||||||||||||||
| num_kv_heads = gpt_oss_config.num_key_value_heads | ||||||||||||||||||||||||||||||||
| max_seq_len = num_blocks * tokens_per_block | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 | ||||||||||||||||||||||||||||||||
| mapping = Mapping(world_size=1, tp_size=1, rank=0) | ||||||||||||||||||||||||||||||||
| kv_cache_config = BindingsKvCacheConfig(max_tokens=num_blocks * | ||||||||||||||||||||||||||||||||
| tokens_per_block) | ||||||||||||||||||||||||||||||||
| kv_cache_manager = KVCacheManager( | ||||||||||||||||||||||||||||||||
| kv_cache_config, | ||||||||||||||||||||||||||||||||
| tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, | ||||||||||||||||||||||||||||||||
| num_layers=num_layers, | ||||||||||||||||||||||||||||||||
| num_kv_heads=num_kv_heads, | ||||||||||||||||||||||||||||||||
| head_dim=head_dim, | ||||||||||||||||||||||||||||||||
| tokens_per_block=tokens_per_block, | ||||||||||||||||||||||||||||||||
| max_seq_len=max_seq_len, | ||||||||||||||||||||||||||||||||
| max_batch_size=batch_size, | ||||||||||||||||||||||||||||||||
| mapping=mapping, | ||||||||||||||||||||||||||||||||
| dtype=kv_cache_dtype, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| kv_cache_manager.add_dummy_requests(request_ids, token_nums) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| metadata_cls = get_attention_backend(model_config.attn_backend).Metadata | ||||||||||||||||||||||||||||||||
| attn_metadata = metadata_cls( | ||||||||||||||||||||||||||||||||
| seq_lens=torch.tensor(sequence_length, dtype=torch.int32), | ||||||||||||||||||||||||||||||||
| num_contexts=len(context_sequence_length), | ||||||||||||||||||||||||||||||||
| kv_cache_params=KVCacheParams( | ||||||||||||||||||||||||||||||||
| use_cache=True, | ||||||||||||||||||||||||||||||||
| num_cached_tokens_per_seq=past_seen_tokens, | ||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||
| kv_cache_manager=kv_cache_manager, | ||||||||||||||||||||||||||||||||
| request_ids=request_ids, | ||||||||||||||||||||||||||||||||
| prompt_lens=prompt_lens, | ||||||||||||||||||||||||||||||||
| max_num_requests=batch_size, | ||||||||||||||||||||||||||||||||
| max_num_tokens=8192, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # 1 token per decode sequence | ||||||||||||||||||||||||||||||||
| input_ids = torch.randint(0, | ||||||||||||||||||||||||||||||||
| gpt_oss_config.vocab_size, (batch_size, ), | ||||||||||||||||||||||||||||||||
| dtype=torch.int32, | ||||||||||||||||||||||||||||||||
| device=device) | ||||||||||||||||||||||||||||||||
| position_ids = torch.tensor(past_seen_tokens, | ||||||||||||||||||||||||||||||||
| dtype=torch.long, | ||||||||||||||||||||||||||||||||
| device=device).unsqueeze(0) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Warm-up run (JIT compile XQA kernels, allocate buffers). | ||||||||||||||||||||||||||||||||
| with torch.inference_mode(): | ||||||||||||||||||||||||||||||||
| attn_metadata.prepare() | ||||||||||||||||||||||||||||||||
| model.forward(input_ids=input_ids, | ||||||||||||||||||||||||||||||||
| position_ids=position_ids, | ||||||||||||||||||||||||||||||||
| attn_metadata=attn_metadata) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Profiled run: capture CUDA kernel names to verify XQA dispatch. | ||||||||||||||||||||||||||||||||
| kernel_names = [] | ||||||||||||||||||||||||||||||||
| with torch.inference_mode(), \ | ||||||||||||||||||||||||||||||||
| torch.profiler.profile(activities=[ProfilerActivity.CUDA]) as prof: | ||||||||||||||||||||||||||||||||
| attn_metadata.prepare() | ||||||||||||||||||||||||||||||||
| logits = model.forward(input_ids=input_ids, | ||||||||||||||||||||||||||||||||
| position_ids=position_ids, | ||||||||||||||||||||||||||||||||
| attn_metadata=attn_metadata) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| assert logits.shape[0] == batch_size, \ | ||||||||||||||||||||||||||||||||
| f"Expected {batch_size} logits, got {logits.shape[0]}" | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| kv_cache_manager.shutdown() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Collect CUDA kernel names from the profiler trace. | ||||||||||||||||||||||||||||||||
| kernel_names = [ | ||||||||||||||||||||||||||||||||
| evt.key for evt in prof.key_averages() | ||||||||||||||||||||||||||||||||
| if evt.device_type == torch.autograd.DeviceType.CUDA | ||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||
| all_kernels = " ".join(kernel_names) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # XQA kernel: "kernel_mha" (from decoderXQA JIT). | ||||||||||||||||||||||||||||||||
| # MMHA kernel: "masked_multihead_attention_kernel". | ||||||||||||||||||||||||||||||||
| has_xqa = any("kernel_mha" in k for k in kernel_names) | ||||||||||||||||||||||||||||||||
| has_mmha = any("masked_multihead_attention_kernel" in k | ||||||||||||||||||||||||||||||||
| for k in kernel_names) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| assert has_xqa, ( | ||||||||||||||||||||||||||||||||
| "GPT-OSS-20B decode did not launch XQA kernel (kernel_mha). " | ||||||||||||||||||||||||||||||||
| f"See NVBug 5720470. Captured CUDA kernels: {all_kernels}") | ||||||||||||||||||||||||||||||||
| assert not has_mmha, ( | ||||||||||||||||||||||||||||||||
| "GPT-OSS-20B decode launched MMHA kernel instead of XQA. " | ||||||||||||||||||||||||||||||||
| f"See NVBug 5720470. Captured CUDA kernels: {all_kernels}") | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the required NVIDIA SPDX header before the import block.
This file was meaningfully modified in 2026 but still starts directly with imports, so the repository-mandated copyright/license header is missing.
As per coding guidelines,
All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification.🤖 Prompt for AI Agents