diff --git a/ATTRIBUTIONS-Python.md b/ATTRIBUTIONS-Python.md index a5e5ab38773..e3218e845b1 100644 --- a/ATTRIBUTIONS-Python.md +++ b/ATTRIBUTIONS-Python.md @@ -5261,7 +5261,7 @@ For more information, please refer to - `Tracker`: https://github.com/tox-dev/py-filelock/issues -## flashinfer-python (0.6.6) +## flashinfer-python (0.6.7) ### Licenses License: `Apache-2.0` diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index ad162658899..e1fd9bc7f08 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -60,12 +60,30 @@ using tensorrt_llm::common::launchWithPdlWhenEnabled; __VA_ARGS__; \ break; \ } \ + case 18: \ + { \ + constexpr int TOP_K = 18; \ + __VA_ARGS__; \ + break; \ + } \ case 16: \ { \ constexpr int TOP_K = 16; \ __VA_ARGS__; \ break; \ } \ + case 14: \ + { \ + constexpr int TOP_K = 14; \ + __VA_ARGS__; \ + break; \ + } \ + case 12: \ + { \ + constexpr int TOP_K = 12; \ + __VA_ARGS__; \ + break; \ + } \ case 10: \ { \ constexpr int TOP_K = 10; \ diff --git a/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu b/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu index 21f68c71824..38a7a24b8e9 100644 --- a/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu +++ b/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu @@ -287,8 +287,9 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk int num_threads = NumDeepseekExperts; if (is_single_group) { - // Special case for Nemotron, which selects top 22 from 512 experts, and 1 group only. - if (num_experts == NumNemotronExperts && n_group == 1 && topk == MaxSupportedTopExperts) + // Nemotron models: 512 experts, 1 group, top_k up to 22. + // Variants use varying top_k (4..22) across layers. + if (num_experts == NumNemotronExperts && n_group == 1 && topk <= MaxSupportedTopExperts) { kernel_instance = &deepseek_v3_topk_kernel; diff --git a/requirements.txt b/requirements.txt index b76e28208bd..4ef84c4a512 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,7 +54,7 @@ ordered-set peft patchelf einops -flashinfer-python==0.6.6 +flashinfer-python @ https://github.com/flashinfer-ai/flashinfer/releases/download/nightly-v0.6.7-20260406/flashinfer_python-0.6.7.dev20260406-py3-none-any.whl opencv-python-headless xgrammar==0.1.32 llguidance==0.7.29 @@ -71,7 +71,7 @@ xdsl>=0.59.0 # Optional: required for MLIR-based elementwise fusion in AutoDeplo tiktoken blobfile openai-harmony==0.0.4 -nvidia-cutlass-dsl==4.3.4; python_version >= "3.10" +nvidia-cutlass-dsl>=4.4.2; python_version >= "3.10" plotly numexpr partial_json_parser diff --git a/security_scanning/pyproject.toml b/security_scanning/pyproject.toml index b0f9994e95d..ccecf4cddef 100644 --- a/security_scanning/pyproject.toml +++ b/security_scanning/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "peft (>=0.18.1,<0.19.0)", "patchelf (>=0.17.2.4,<0.18.0.0)", "einops (>=0.8.2,<0.9.0)", - "flashinfer-python (==0.6.6)", + "flashinfer-python (==0.6.7)", "opencv-python-headless (>=4.13.0.92,<5.0.0.0)", "xgrammar (==0.1.32)", "llguidance (==0.7.29)", diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py index 5be9a3d59e2..852d8cb4b5d 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py @@ -7,6 +7,7 @@ from tensorrt_llm._torch.utils import split +@register_mapper("HF", "NemotronHPuzzleForCausalLM") @register_mapper("HF", "NemotronHForCausalLM") class NemotronHHfWeightMapper(HfWeightMapper): diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index 623195da94a..7badef90714 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -55,6 +55,27 @@ class NemotronHConfig(PretrainedConfig): model_type = "nemotron_h" +class NemotronHPuzzleConfig(PretrainedConfig): + model_type = "nemotron_h_puzzle" + + +def _bc_getattr(bc, key, default=None): + """Get attribute from a block_config entry (dict or dataclass).""" + if isinstance(bc, dict): + return bc.get(key, default) + return getattr(bc, key, default) + + +def _get_layer_moe_param(config, layer_idx: int, param_name: str): + """Get per-layer MoE parameter, falling back to global config.""" + block_configs = getattr(config, 'block_configs', None) + if block_configs and layer_idx < len(block_configs): + val = _bc_getattr(block_configs[layer_idx], param_name) + if val is not None: + return val + return getattr(config, param_name, None) + + class MLPLayer(MLP): def __init__( @@ -152,22 +173,26 @@ def __init__( self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.layer_idx = layer_idx - self.moe_intermediate_size = (config.moe_intermediate_size[0] - if isinstance( - config.moe_intermediate_size, list) - else config.moe_intermediate_size) - self.use_latent_moe: bool = getattr(config, "moe_latent_size", - None) is not None - self.moe_hidden_size: int = (config.moe_latent_size - if self.use_latent_moe else + + # Per-layer MoE params (models with block_configs have varying params). + def _moe(name): + return _get_layer_moe_param(config, layer_idx, name) + + moe_intermediate = _moe('moe_intermediate_size') + self.moe_intermediate_size = (moe_intermediate[0] if isinstance( + moe_intermediate, list) else moe_intermediate) + + moe_latent = _moe('moe_latent_size') + self.use_latent_moe: bool = moe_latent is not None + self.moe_hidden_size: int = (moe_latent if self.use_latent_moe else config.hidden_size) self.mlp_bias = config.mlp_bias if hasattr(config, "mlp_bias") else False self.moe_n_group = config.n_group - self.num_experts = config.n_routed_experts + self.num_experts = _moe('n_routed_experts') self.hidden_size = config.hidden_size self.num_shared_experts = config.n_shared_experts - self.top_k = config.num_experts_per_tok + self.top_k = _moe('num_experts_per_tok') self.enable_attention_dp = model_config.mapping.enable_attention_dp self.routed_scaling_factor = config.routed_scaling_factor self.mapping = model_config.mapping @@ -177,7 +202,7 @@ def __init__( self.shared_experts = None else: shared_expert_intermediate_size = ( - config.moe_shared_expert_intermediate_size * + _moe('moe_shared_expert_intermediate_size') * config.n_shared_experts) self.shared_experts = MLP( @@ -703,6 +728,7 @@ def forward( return hidden_states +@register_auto_model("NemotronHPuzzleForCausalLM") @register_auto_model("NemotronHForCausalLM") class NemotronHForCausalLM(SpecDecOneEngineForCausalLM[NemotronHModel, NemotronHConfig]): @@ -720,6 +746,9 @@ def __init__( raise ValueError("layer_norm_epsilon or rms_norm_eps is not set") model_config.pretrained_config.rms_norm_eps = rms_epsilon + # Normalize per-layer block_configs into global config attributes. + self._normalize_puzzle_config(model_config.pretrained_config) + if (not model_config.mapping.enable_attention_dp and model_config.mapping.tp_size not in [1, 2, 4, 8]): raise ValueError("TP has to be either 1, 2, 4 or 8") @@ -776,6 +805,31 @@ def __init__( self.epilogue.extend(self.draft_model.mtp_layers) self.epilogue.append(self.spec_worker) + @staticmethod + def _normalize_puzzle_config(config): + """Set global MoE defaults from block_configs for models with per-layer MoE params.""" + block_configs = getattr(config, 'block_configs', None) + if not block_configs: + return + + def _is_moe(bc): + return _bc_getattr(bc, 'block_type') == 'moe' + + first_moe = next((bc for bc in block_configs if _is_moe(bc)), None) + if first_moe is None: + return + + # Prefer MTP MoE block as fallback (used for MTP layers beyond + # block_configs range), otherwise use first main-model MoE block. + mtp_configs = getattr(config, 'mtp_block_configs', None) or [] + fallback = next((bc for bc in mtp_configs if _is_moe(bc)), first_moe) + + for attr in ('n_routed_experts', 'moe_intermediate_size', + 'num_experts_per_tok', 'moe_latent_size', + 'moe_shared_expert_intermediate_size'): + if not hasattr(config, attr) or getattr(config, attr) is None: + setattr(config, attr, _bc_getattr(fallback, attr)) + def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): new_weights = weight_mapper.preprocess_weights(weights) super().load_weights(weights=new_weights, weight_mapper=weight_mapper) @@ -1074,3 +1128,4 @@ def forward( AutoConfig.register(NemotronHConfig.model_type, NemotronHConfig) +AutoConfig.register(NemotronHPuzzleConfig.model_type, NemotronHPuzzleConfig) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 7c103d9b25e..041fa02c0f3 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -804,7 +804,7 @@ def __init__( case "exaone_moe": from .modeling_exaone_moe import ExaoneMoeMTP mtp_layer = ExaoneMoeMTP - case "nemotron_h": + case "nemotron_h" | "nemotron_h_puzzle": from .modeling_nemotron_h import NemotronHMTP mtp_layer = NemotronHMTP case "qwen3_next": diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index e37d5db1081..53d18135c11 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -224,7 +224,32 @@ def __init__( # Initialize or reuse workspace MnnvlMemory.initialize() - if self._WORKSPACE is None: + need_alloc = self._WORKSPACE is None + if not need_alloc: + assert self._WORKSPACE["max_num_tokens_per_rank"] == self.max_num_tokens_per_rank, ( + "reuse workspace with different max_num_tokens_per_rank" + ) + assert self._WORKSPACE["ep_rank"] == self.ep_rank, ( + "reuse workspace with different ep_rank" + ) + assert self._WORKSPACE["ep_size"] == self.ep_size, ( + "reuse workspace with different ep_size" + ) + assert self._WORKSPACE["eplb_stats_num_experts"] == self.eplb_stats_num_experts, ( + "reuse workspace with different eplb_stats_num_experts" + ) + + # Models with per-layer MoE params may request different workspace sizes across layers. + # Reallocate when a larger workspace is needed; reuse otherwise. + if self._WORKSPACE["workspace_size_per_rank"] < self.workspace_size_per_rank: + tllm_logger.info( + f"NVLinkOneSided: Reallocating workspace " + f"{self._WORKSPACE['workspace_size_per_rank']} -> " + f"{self.workspace_size_per_rank} bytes." + ) + need_alloc = True + + if need_alloc: tllm_logger.info( f"NVLinkOneSided: Allocating workspace with size {self.workspace_size_per_rank} bytes." f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, top_k: {self.top_k}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}" @@ -248,26 +273,8 @@ def __init__( "workspace": workspace, "metainfo": metainfo, } - else: - assert self._WORKSPACE["workspace_size_per_rank"] == self.workspace_size_per_rank, ( - "reuse workspace with different workspace_size_per_rank" - ) - assert self._WORKSPACE["max_num_tokens_per_rank"] == self.max_num_tokens_per_rank, ( - "reuse workspace with different max_num_tokens_per_rank" - ) - assert self._WORKSPACE["ep_rank"] == self.ep_rank, ( - "reuse workspace with different ep_rank" - ) - assert self._WORKSPACE["ep_size"] == self.ep_size, ( - "reuse workspace with different ep_size" - ) - assert self._WORKSPACE["eplb_stats_num_experts"] == self.eplb_stats_num_experts, ( - "reuse workspace with different eplb_stats_num_experts" - ) - self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] - self.workspace = self._WORKSPACE["workspace"] - self.moe_a2a_metainfo = self._WORKSPACE["metainfo"] + # Read max_num_tokens_per_rank from the (possibly grown) workspace. self.max_num_tokens_per_rank = self._WORKSPACE["max_num_tokens_per_rank"] # Initialize dispatch state @@ -276,6 +283,21 @@ def __init__( # Invalid token expert ID (default to -1), the kernels in TRTLLM-gen is hard-code to support -1 only. self.invalid_token_expert_id: int = -1 + # Properties delegate to _WORKSPACE so all instances see the latest + # allocation (workspace may be reallocated when layers need more space). + + @property + def mnnvl_mem(self): + return self._WORKSPACE["mnnvl_mem"] + + @property + def workspace(self): + return self._WORKSPACE["workspace"] + + @property + def moe_a2a_metainfo(self): + return self._WORKSPACE["metainfo"] + @staticmethod def is_platform_supported() -> bool: """ diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 69498c96cfc..47ed3057c88 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -264,9 +264,8 @@ def noaux_tc(self, logits, e_score_correction_bias): "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation." ) self.is_fused = False - elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22) - or (self.topk_group == 1 and self.top_k != 22)): - # We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3. + elif num_experts > 512 or self.top_k > 22 or (self.top_k > 8 + and num_experts != 512): if self.is_fused: warnings.warn( "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation." diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 3cbd88f4337..f17c049c4ca 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -150,13 +150,10 @@ def __init__( # Choose between flashinfer and native implementation. (default to flashinfer) self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype - # TODO: Update head_dims and head_group_ratios once flashinfer is updated. + # TODO: Update head_dims once flashinfer is updated. + # Nemotron-v2-Nano (mamba_head_dim=80) is not supported by flashinfer yet. supported_head_dims = [64, 128] - supported_head_group_ratios = [1, 8, 16] - head_group_ratio = (self.tp_nheads // - self.tp_ngroups if self.tp_ngroups > 0 else 0) - self._use_flashinfer = (head_dim in supported_head_dims and - head_group_ratio in supported_head_group_ratios) + self._use_flashinfer = head_dim in supported_head_dims # Stochastic rounding requires FlashInfer and fp16 cache self._use_stochastic_rounding = ( config.quant_config.mamba_ssm_stochastic_rounding diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py index a5916657c31..63ea5b91ee2 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py @@ -17,14 +17,138 @@ # limitations under the License. import torch +import torch.nn.functional as F from einops import rearrange +from tensorrt_llm._utils import is_sm_100f +from tensorrt_llm.logger import logger +from tensorrt_llm.math_utils import pad_up + from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen) from .ssd_state_passing import _state_passing_fwd +# FlashInfer fused SSD kernel cache (Blackwell SM100+ only). +_flashinfer_ssd_cache: dict = {} + + +def _get_flashinfer_ssd(chunk_size, nheads, headdim, dstate, ngroups): + """Get or compile a cached FlashInfer SSDCombined kernel instance.""" + key = (chunk_size, nheads, headdim, dstate, ngroups) + if key not in _flashinfer_ssd_cache: + from flashinfer.mamba import SSDCombined + _flashinfer_ssd_cache[key] = SSDCombined( + chunk_size=chunk_size, + nheads=nheads, + headdim=headdim, + dstate=dstate, + ngroups=ngroups, + io_dtype=torch.bfloat16, + state_dtype=torch.bfloat16, + has_d=True, + d_has_hdim=False, + has_initial_states=True, + has_varlen=True, + has_z=False, + seq_idx_dtype=torch.int32, + ) + logger.info_once("Using FlashInfer fused SSD kernel for Mamba2 prefill", + key="flashinfer_ssd_prefill") + return _flashinfer_ssd_cache[key] + + +def _mamba_chunk_scan_flashinfer_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + out=None, + return_final_states=False, + state_dtype=None, +): + """FlashInfer fused SSD forward using a single CUTLASS persistent kernel.""" + _, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + io_dtype = torch.bfloat16 + + ssd = _get_flashinfer_ssd(chunk_size, nheads, headdim, dstate, ngroups) + num_seqs = cu_seqlens.shape[0] - 1 + + # Pad seqlen to chunk_size boundary — padded tokens use dt=-100 + # so softplus ≈ 0, contributing nothing to state or output. + pad_len = pad_up(seqlen, chunk_size) - seqlen + if pad_len > 0: + x = F.pad(x, (0, 0, 0, 0, 0, pad_len)) + B = F.pad(B, (0, 0, 0, 0, 0, pad_len)) + C = F.pad(C, (0, 0, 0, 0, 0, pad_len)) + dt = F.pad(dt, (0, 0, 0, pad_len), value=-100.0) + if seq_idx is not None: + seq_idx = F.pad(seq_idx, (0, pad_len), value=int(num_seqs - 1)) + + if x.dtype != io_dtype: + x = x.to(io_dtype) + B = B.to(io_dtype) + C = C.to(io_dtype) + dt = dt.to(io_dtype) + + D_bf16 = D.to(io_dtype) if D is not None and D.dtype != io_dtype else D + + if initial_states is not None: + fi_initial_states = (initial_states if initial_states.dtype == io_dtype + else initial_states.to(io_dtype)) + else: + fi_initial_states = x.new_zeros(num_seqs, + nheads, + headdim, + dstate, + dtype=io_dtype) + + if chunk_indices is None or chunk_offsets is None: + from .mamba2_metadata import cu_seqlens_to_chunk_indices_offsets_triton + chunk_indices, chunk_offsets = ( + cu_seqlens_to_chunk_indices_offsets_triton(cu_seqlens, + chunk_size, + total_seqlens=seqlen)) + + out_view, fstate = ssd.run( + x, + dt, + A, + B, + C, + D=D_bf16, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + initial_states=fi_initial_states, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_final_states=True, + ) + + if out is not None: + out.copy_(out_view[:, :seqlen]) + + if state_dtype is not None and fstate.dtype != state_dtype: + fstate = fstate.to(state_dtype) + + # Both final_states and varlen_states are per-sequence in FlashInfer. + return (fstate, fstate) if return_final_states else fstate + def is_int_pow_2(n): return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 @@ -235,6 +359,30 @@ def mamba_chunk_scan_combined( else: assert (cu_seqlens is not None ), "cu_seqlens must be provided if return_varlen_states is True" + + # Dispatch to FlashInfer fused CUTLASS kernel on Blackwell (SM100+). + if (return_varlen_states and z is None and is_sm_100f()): + return _mamba_chunk_scan_flashinfer_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + out=out, + return_final_states=return_final_states, + state_dtype=state_dtype, + ) + out_x, dt_out, dA_cumsum, states, final_states, *rest = ( _mamba_chunk_scan_combined_fwd( x, diff --git a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/cute_dsl_utils.py b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/cute_dsl_utils.py index 9d6ee345d00..edb8f692a03 100644 --- a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/cute_dsl_utils.py +++ b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/cute_dsl_utils.py @@ -108,20 +108,35 @@ def load_cubin_module_data_patched(cubin_data, filepath): return load_cubin_module_data_og(cubin_data) -def cute_compile_patched(*args, **kwargs): - """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" - cubin_path = os.getenv("CUTE_CUBIN_PATH", None) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( - load_cubin_module_data_patched, filepath=cubin_path - ) - output = cute_compile_og(*args, **kwargs) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og - if extract is not None: - sass = extract(cubin_path, None) - pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) - return output +class _CuteCompilePatched: + """Wrapper around cute.compile that optionally dumps SASS via CUTE_CUBIN_PATH. + + Preserves the CompileCallable subscript interface (cute.compile[opts](...)) + so that third-party CuTe DSL kernels (e.g. FlashInfer) keep working. + """ + + def __init__(self, original=None): + self._original = original or cute_compile_og + + def __getitem__(self, item): + return _CuteCompilePatched(self._original[item]) + + def __call__(self, *args, **kwargs): + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + output = self._original(*args, **kwargs) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output + + +cute_compile_patched = _CuteCompilePatched() def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): diff --git a/tensorrt_llm/llmapi/reasoning_parser.py b/tensorrt_llm/llmapi/reasoning_parser.py index ded97ba0099..5d3c134834b 100644 --- a/tensorrt_llm/llmapi/reasoning_parser.py +++ b/tensorrt_llm/llmapi/reasoning_parser.py @@ -182,6 +182,7 @@ def parse_delta(self, delta_text: str) -> ReasoningParserResult: "deepseek_v3": "deepseek-r1", "deepseek_v32": "deepseek-r1", "nemotron_h": "nano-v3", + "nemotron_h_puzzle": "nano-v3", } _QWEN3_MODEL_TYPES = frozenset({ diff --git a/tests/unittest/_torch/models/test_nemotron_h_puzzle.py b/tests/unittest/_torch/models/test_nemotron_h_puzzle.py new file mode 100644 index 00000000000..4ecbc6bdd9b --- /dev/null +++ b/tests/unittest/_torch/models/test_nemotron_h_puzzle.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for NemotronHPuzzle model support.""" + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from tensorrt_llm._torch.models.modeling_nemotron_h import ( + NemotronHForCausalLM, + _get_layer_moe_param, +) + + +@dataclass +class _MambaBlock: + block_type: str = "mamba" + + +@dataclass +class _MoeBlock: + block_type: str = "moe" + moe_intermediate_size: int = 1280 + n_routed_experts: int = 512 + num_experts_per_tok: int = 4 + moe_latent_size: int = 1024 + moe_shared_expert_intermediate_size: int = 5376 + + +def _make_puzzle_config(use_dataclass=False): + """Minimal config mimicking the real puzzle model.""" + if use_dataclass: + bcs = [ + _MambaBlock(), + _MoeBlock(num_experts_per_tok=4), + _MambaBlock(), + _MoeBlock(moe_intermediate_size=2048, num_experts_per_tok=12), + ] + else: + bcs = [ + {"block_type": "mamba"}, + { + "block_type": "moe", + "moe_intermediate_size": 1280, + "n_routed_experts": 512, + "num_experts_per_tok": 4, + "moe_latent_size": 1024, + "moe_shared_expert_intermediate_size": 5376, + }, + {"block_type": "mamba"}, + { + "block_type": "moe", + "moe_intermediate_size": 2048, + "n_routed_experts": 512, + "num_experts_per_tok": 12, + "moe_latent_size": 1024, + "moe_shared_expert_intermediate_size": 5376, + }, + ] + return SimpleNamespace( + block_configs=bcs, + mtp_block_configs=[ + {"block_type": "attention"}, + { + "block_type": "moe", + "moe_intermediate_size": 2688, + "n_routed_experts": 512, + "num_experts_per_tok": 22, + "moe_latent_size": 1024, + "moe_shared_expert_intermediate_size": 5376, + }, + ], + ) + + +class TestPerLayerMoeParams: + """The key change: block_configs can be dicts or HF dataclass objects, + and per-layer values must differ while MTP falls back to globals.""" + + @pytest.mark.parametrize("use_dc", [False, True], ids=["dict", "dataclass"]) + def test_varying_params_per_layer(self, use_dc): + config = _make_puzzle_config(use_dataclass=use_dc) + NemotronHForCausalLM._normalize_puzzle_config(config) + + # MoE layer 1: top_k=4, intermediate=1280 + assert _get_layer_moe_param(config, 1, "num_experts_per_tok") == 4 + assert _get_layer_moe_param(config, 1, "moe_intermediate_size") == 1280 + # MoE layer 3: top_k=12, intermediate=2048 + assert _get_layer_moe_param(config, 3, "num_experts_per_tok") == 12 + assert _get_layer_moe_param(config, 3, "moe_intermediate_size") == 2048 + + @pytest.mark.parametrize("use_dc", [False, True], ids=["dict", "dataclass"]) + def test_mtp_layer_gets_global_defaults(self, use_dc): + """MTP layer_idx beyond block_configs range uses globals from mtp_block_configs.""" + config = _make_puzzle_config(use_dataclass=use_dc) + NemotronHForCausalLM._normalize_puzzle_config(config) + + mtp_idx = len(config.block_configs) # beyond range + assert _get_layer_moe_param(config, mtp_idx, "num_experts_per_tok") == 22 + assert _get_layer_moe_param(config, mtp_idx, "moe_intermediate_size") == 2688 + + @pytest.mark.parametrize("use_dc", [False, True], ids=["dict", "dataclass"]) + def test_normalize_sets_all_global_attrs(self, use_dc): + config = _make_puzzle_config(use_dataclass=use_dc) + NemotronHForCausalLM._normalize_puzzle_config(config) + + for attr in ( + "n_routed_experts", + "moe_intermediate_size", + "num_experts_per_tok", + "moe_latent_size", + "moe_shared_expert_intermediate_size", + ): + assert getattr(config, attr) is not None, f"{attr} not set" + + def test_normalize_preserves_existing_attrs(self): + config = _make_puzzle_config() + config.n_routed_experts = 999 + NemotronHForCausalLM._normalize_puzzle_config(config) + assert config.n_routed_experts == 999 + + def test_normalize_noop_without_block_configs(self): + config = SimpleNamespace() + NemotronHForCausalLM._normalize_puzzle_config(config) + assert not hasattr(config, "n_routed_experts") + + def test_standard_config_passthrough(self): + """Non-puzzle model: no block_configs, returns global directly.""" + config = SimpleNamespace(n_routed_experts=512) + assert _get_layer_moe_param(config, 0, "n_routed_experts") == 512