From 10d9a5bf3c791766fa812e84f8c2bbfcdaf357dc Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 10:31:32 -0700 Subject: [PATCH 01/16] [None][feat] AutoDeploy: Gemma3n attention support (squash of PR #12205) Adds Gemma3n custom model with shared KV attention, sliding window attention, and related attention backend changes for AutoDeploy. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../configs/gemma3n_e2b_it.yaml | 12 + .../compile/backends/torch_cudagraph.py | 2 +- .../attention/flashinfer_attention.py | 74 +- .../custom_ops/attention/torch_attention.py | 6 + .../attention/torch_backend_attention.py | 194 +++- .../custom_ops/attention/triton_attention.py | 10 +- .../custom_ops/attention/trtllm_attention.py | 8 +- .../custom_ops/attention_interface.py | 24 + .../_torch/auto_deploy/export/export.py | 4 +- .../auto_deploy/models/custom/__init__.py | 3 + .../models/custom/modeling_gemma3n.py | 852 ++++++++++++++++++ .../auto_deploy/transform/library/kvcache.py | 56 +- .../transform/library/kvcache_transformers.py | 7 +- .../_torch/auto_deploy/utils/_graph.py | 4 +- .../_torch/auto_deploy/utils/node_utils.py | 16 +- .../singlegpu/models/test_gemma3n_modeling.py | 474 ++++++++++ .../library/test_shared_kv_attention.py | 660 ++++++++++++++ .../singlegpu/compile/test_captured_graph.py | 9 +- .../attention/test_flashinfer_attention_op.py | 8 + 19 files changed, 2325 insertions(+), 98 deletions(-) create mode 100644 examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml create mode 100644 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py diff --git a/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml b/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml new file mode 100644 index 00000000000..0c9862833cc --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +runtime: trtllm +compile_backend: torch-cudagraph +model_factory: AutoModelForCausalLM +max_seq_len: 512 +world_size: 1 + +# Gemma 3n uses shared-KV decode semantics in the tail layers. FlashInfer +# supports the read-only shared-KV cache path and alternating sliding windows. +attn_backend: flashinfer diff --git a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py index 85c87afe278..be671f3cec6 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py @@ -558,7 +558,7 @@ def __init__( def _is_decode_only(self, **kwargs) -> bool: """Check if the current batch is decode-only using batch_info_host. - batch_info_host = [num_prefill, num_prefill_tokens, num_decode] + batch_info_host is the serialized BatchInfo tensor. Decode-only means num_prefill == 0. """ batch_info = kwargs.get(self.batch_info_kwarg_name) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py index c935116f2dd..f9c0c703e10 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py @@ -55,6 +55,7 @@ class PlanParams: sm_scale: Optional[float] = None causal: bool = True + window_left: int = -1 def __hash__(self): """Convert all fields to a string representation and concatenate them.""" @@ -186,6 +187,7 @@ def plan_prefill( q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + window_left=plan_params.window_left, seq_lens=kv_lens_arr_host, ) self.plan_params_prefill = plan_params @@ -218,6 +220,7 @@ def _plan_decode( q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + window_left=plan_params.window_left, ) # we want to plan during warm-up of cuda graph capture to ensure we have the plan cached @@ -251,6 +254,13 @@ def _plan_decode( _GlobalFlashInferPlanner = _FlashInferPlanner() +def _to_flashinfer_window_left(sliding_window: Optional[int]) -> int: + """Convert AD sliding-window size to FlashInfer's inclusive window_left contract.""" + if sliding_window is None or sliding_window <= 0: + return -1 + return sliding_window - 1 + + @torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=()) def prepare_flashinfer_metadata( position_ids: torch.Tensor, @@ -342,11 +352,15 @@ def flashinfer_mha_with_cache( kv_cache: torch.Tensor, # CONSTANTS scale: Optional[float], + sliding_window: Optional[int], k_scale: float, v_scale: float, + read_cache_only: bool = False, # OPTIONAL PRE-ALLOCATED OUTPUT out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + _GlobalFlashInferPlanner.reset(q.device) + # kv_cache shape: [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim] (HND layout) head_dim = kv_cache.shape[-1] page_size = kv_cache.shape[3] # tokens_per_block @@ -365,25 +379,27 @@ def flashinfer_mha_with_cache( n_heads = q.shape[1] n_kv_heads = k.shape[1] + window_left = _to_flashinfer_window_left(sliding_window) # Assuming k_scale = v_scale = 1.0 k_scale, v_scale = 1.0, 1.0 - # k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v - if kv_cache.dtype == torch.float8_e4m3fn: - k = k.to(torch.float8_e4m3fn) - v = v.to(torch.float8_e4m3fn) - - flashinfer.page.append_paged_kv_cache( - append_key=k[:num_total_tokens], - append_value=v[:num_total_tokens], - batch_indices=flashinfer_batch_indices[:num_total_tokens], - positions=flashinfer_positions[:num_total_tokens], - paged_kv_cache=kv_cache, - kv_indices=cache_loc, - kv_indptr=cu_num_pages[: num_seq + 1], - kv_last_page_len=last_page_len[:num_seq], - kv_layout=_GlobalFlashInferPlanner.kv_layout, - ) + if not read_cache_only: + # k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v + if kv_cache.dtype == torch.float8_e4m3fn: + k = k.to(torch.float8_e4m3fn) + v = v.to(torch.float8_e4m3fn) + + flashinfer.page.append_paged_kv_cache( + append_key=k[:num_total_tokens], + append_value=v[:num_total_tokens], + batch_indices=flashinfer_batch_indices[:num_total_tokens], + positions=flashinfer_positions[:num_total_tokens], + paged_kv_cache=kv_cache, + kv_indices=cache_loc, + kv_indptr=cu_num_pages[: num_seq + 1], + kv_last_page_len=last_page_len[:num_seq], + kv_layout=_GlobalFlashInferPlanner.kv_layout, + ) bs = b * s if out is not None: @@ -403,6 +419,7 @@ def flashinfer_mha_with_cache( q_dtype=q_prefill.dtype, kv_dtype=kv_cache.dtype, sm_scale=scale, + window_left=window_left, ) wrapper_prefill = _GlobalFlashInferPlanner.plan_prefill( @@ -435,6 +452,7 @@ def flashinfer_mha_with_cache( q_dtype=q_decode.dtype, kv_dtype=kv_cache.dtype, sm_scale=scale, + window_left=window_left, ) wrapper_decode = _GlobalFlashInferPlanner.plan_decode( @@ -485,8 +503,10 @@ def flashinfer_mha_with_cache_fake( kv_cache: torch.Tensor, # CONSTANTS scale: Optional[float], + sliding_window: Optional[int], k_scale: float, v_scale: float, + read_cache_only: bool = False, # OPTIONAL PRE-ALLOCATED OUTPUT out: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -520,6 +540,10 @@ def get_source_attention_op(cls) -> OpOverloadPacket: def get_cached_attention_op(cls) -> MHACallable: return torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + @classmethod + def supports_shared_kv(cls) -> bool: + return True + @classmethod def get_standard_metadata_args(cls) -> List[str]: return [ @@ -564,15 +588,7 @@ def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCalla @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: - # Sanity check: layout == "bsnd" - # Prefer kwargs; fall back to the final positional arg if it's a string. - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " @@ -589,11 +605,7 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: f"{source_attn_node=}: {attn_mask=}, {dropout_p=}, {is_causal=}" ) - # Get scale from args or kwargs - if len(source_attn_node.args) > 6: - scale = source_attn_node.args[6] - else: - scale = source_attn_node.kwargs.get("scale", None) + scale = extract_op_args(source_attn_node, "scale")[0] if not (isinstance(scale, float) or scale is None): ad_logger.warning(f"Provided {scale=}, is not a float. Using default scale instead.") @@ -601,6 +613,8 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: return [ scale, # softmax scale + extract_op_args(source_attn_node, "sliding_window")[0], # sliding window parameter 1.0, # k_scale 1.0, # v_scale + cls.get_shared_kv_source_layer_idx(source_attn_node) is not None, # read_cache_only ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py index 8d0d819300e..77b5cfc9820 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py @@ -119,6 +119,8 @@ def torch_attention( sliding_window: Optional[int] = None, logit_cap: Optional[float] = None, layout: str = "bnsd", # "bnsd" or "bsnd" + layer_idx: Optional[int] = None, + shared_kv_source_layer_idx: Optional[int] = None, ) -> torch.Tensor: """ SDPA attention (with optional GQA) that supports two memory layouts via `layout`: @@ -129,6 +131,8 @@ def torch_attention( Returns a tensor in the SAME layout as inputs specified by `layout`. """ + # `layer_idx` and `shared_kv_source_layer_idx` are graph metadata used by the KV-cache + # transform; the eager attention kernel itself does not need them. if layout not in ("bnsd", "bsnd"): raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}") @@ -239,5 +243,7 @@ def torch_attention_fake( sliding_window=None, logit_cap=None, layout: str = "bnsd", + layer_idx: Optional[int] = None, + shared_kv_source_layer_idx: Optional[int] = None, ): return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py index ae816f753d4..4be06d5b24a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py @@ -70,6 +70,24 @@ def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[floa return attn_scores +def _write_generate_kv_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + slot_idx: torch.Tensor, + input_pos: torch.Tensor, +): + """Write single-token decode K/V into the cache.""" + b, s = k.shape[:2] + assert s == 1, f"Expected sequence length 1 for generate phase, got {s}" + for i in range(b): + cache_idx = slot_idx[i].item() + pos = input_pos[i].item() + k_cache[cache_idx, pos] = k[i, 0] # Remove sequence dim + v_cache[cache_idx, pos] = v[i, 0] # Remove sequence dim + + def _torch_generate_mha( q: torch.Tensor, k: torch.Tensor, @@ -89,12 +107,7 @@ def _torch_generate_mha( assert s == 1, f"Expected sequence length 1 for generate phase, got {s}" n_kv_heads = k.shape[2] # k has shape (b, 1, n_kv_heads, head_dim) - # Update KV cache for single token - for i in range(b): - cache_idx = slot_idx[i].item() - pos = input_pos[i].item() - k_cache[cache_idx, pos] = k[i, 0] # Remove sequence dim - v_cache[cache_idx, pos] = v[i, 0] # Remove sequence dim + _write_generate_kv_cache(k, v, k_cache, v_cache, slot_idx, input_pos) # Compute attention for each sequence using manual computation for i in range(b): @@ -156,6 +169,60 @@ def _torch_generate_mha( out[i] = attn_out.squeeze(1) # [n_heads, v_head_dim] +def _torch_generate_mha_readonly( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + slot_idx: torch.Tensor, + input_pos: torch.Tensor, + scale: float, + out: torch.Tensor, + logit_cap: Optional[float] = None, + sliding_window_size: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, +): + """Generate-only attention using an existing KV cache without writing current-layer K/V.""" + b, s, n_heads, head_dim = q.shape + assert s == 1, f"Expected sequence length 1 for generate phase, got {s}" + n_kv_heads = k_cache.shape[2] + + for i in range(b): + cache_idx = slot_idx[i].item() + pos = input_pos[i].item() + q_i = q[i, 0] + + if sliding_window_size is not None and sliding_window_size > 0: + start_pos = max(0, pos - sliding_window_size + 1) + k_i = k_cache[cache_idx, start_pos : pos + 1] + v_i = v_cache[cache_idx, start_pos : pos + 1] + else: + k_i = k_cache[cache_idx, : pos + 1] + v_i = v_cache[cache_idx, : pos + 1] + + q_i = q_i.unsqueeze(1) + k_i = k_i.transpose(0, 1) + v_i = v_i.transpose(0, 1) + + if n_heads != n_kv_heads: + n_rep = n_heads // n_kv_heads + k_i = repeat_kv(k_i.unsqueeze(0), n_rep)[0] + v_i = repeat_kv(v_i.unsqueeze(0), n_rep)[0] + + attn_scores = torch.matmul(q_i, k_i.transpose(-2, -1)) * scale + attn_scores = _apply_logit_softcapping(attn_scores, logit_cap) + + if sinks is not None: + sinks = sinks.reshape(-1, 1, 1) + attn_weights = torch.cat([attn_scores, sinks], dim=-1) + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_out = torch.matmul(attn_weights[..., : -sinks.size(-1)], v_i) + else: + attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) + attn_out = torch.matmul(attn_weights, v_i) + + out[i] = attn_out.squeeze(1) + + def _torch_context_mha( q: torch.Tensor, k: torch.Tensor, @@ -173,7 +240,6 @@ def _torch_context_mha( sinks: Optional[torch.Tensor] = None, ) -> None: """Context attention (multiple tokens, potentially multiple sequences) using existing torch functions.""" - # Update KV cache first using existing function _update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start) # Compute attention for each sequence @@ -285,9 +351,85 @@ def _torch_context_mha( out.copy_(torch.cat(attn_outputs, dim=0)) -@torch.library.custom_op( - "auto_deploy::torch_cached_attention_with_cache", mutates_args=("k_cache", "v_cache") -) +def _torch_context_mha_readonly( + q: torch.Tensor, + input_pos: torch.Tensor, + slot_idx: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + seq_len: torch.Tensor, + seq_start: torch.Tensor, + scale: float, + out: torch.Tensor, + logit_cap: Optional[float] = None, + sliding_window_size: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, +) -> None: + """Context attention using an existing KV cache without writing current-layer K/V.""" + attn_outputs = [] + for idx in range(seq_len.shape[0]): + seq_len_i = seq_len[idx].item() + input_pos_i = input_pos[idx].item() + slot_idx_i = slot_idx[idx].item() + seq_start_i = seq_start[idx].item() + + if seq_len_i == 0: + continue + + q_seq = q[seq_start_i : seq_start_i + seq_len_i] + kv_seq_len = input_pos_i + seq_len_i + k_seq = k_cache[slot_idx_i, :kv_seq_len] + v_seq = v_cache[slot_idx_i, :kv_seq_len] + + n_heads = q_seq.shape[1] + n_kv_heads = k_seq.shape[1] + + q_seq_t = q_seq.transpose(0, 1).unsqueeze(0) + k_seq_t = k_seq.transpose(0, 1).unsqueeze(0) + v_seq_t = v_seq.transpose(0, 1).unsqueeze(0) + + if n_heads != n_kv_heads: + n_rep = n_heads // n_kv_heads + k_seq_t = repeat_kv(k_seq_t, n_rep) + v_seq_t = repeat_kv(v_seq_t, n_rep) + + attn_scores = torch.matmul(q_seq_t, k_seq_t.transpose(-2, -1)) * scale + + causal_mask = torch.triu( + torch.ones(seq_len_i, kv_seq_len, device=q.device, dtype=torch.bool), + diagonal=1 + input_pos_i, + ) + attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + + if sliding_window_size is not None and sliding_window_size > 0: + query_positions = torch.arange(input_pos_i, input_pos_i + seq_len_i, device=q.device) + key_positions = torch.arange(kv_seq_len, device=q.device) + pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) + sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window_size) + attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + + attn_scores = _apply_logit_softcapping(attn_scores, logit_cap) + + if sinks is not None: + new_sinks = sinks.reshape(1, -1, 1, 1).expand(1, n_heads, seq_len_i, 1) + attn_weights = torch.cat([attn_scores, new_sinks], dim=-1) + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_out = torch.matmul(attn_weights[..., : -new_sinks.size(-1)], v_seq_t) + else: + attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) + attn_out = torch.matmul(attn_weights, v_seq_t) + + attn_outputs.append(attn_out[0].transpose(0, 1)) + + if len(attn_outputs) == 0: + out.zero_() + elif len(attn_outputs) == 1: + out.copy_(attn_outputs[0]) + else: + out.copy_(torch.cat(attn_outputs, dim=0)) + + +@torch.library.custom_op("auto_deploy::torch_cached_attention_with_cache", mutates_args=()) def torch_backend_mha_with_cache( # Q, K, V q: torch.Tensor, @@ -311,6 +453,7 @@ def torch_backend_mha_with_cache( sinks: Optional[torch.Tensor] = None, sliding_window_size: Optional[int] = None, logit_cap: Optional[float] = None, + read_cache_only: bool = False, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Torch backend MHA with cache that takes q, k, v in BSND layout.""" @@ -350,12 +493,15 @@ def torch_backend_mha_with_cache( y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous() # Compute attention + if not read_cache_only: + if s == 1: + _write_generate_kv_cache(k, v, k_cache, v_cache, slot_idx, input_pos) + else: + _update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start) + if s == 1: - # Generate-only phase - _torch_generate_mha( + _torch_generate_mha_readonly( q, - k, - v, k_cache, v_cache, slot_idx, @@ -367,11 +513,8 @@ def torch_backend_mha_with_cache( sinks, ) else: - # Context phase - _torch_context_mha( + _torch_context_mha_readonly( q, - k, - v, input_pos, slot_idx, k_cache, @@ -426,6 +569,7 @@ def torch_backend_mha_with_cache_fake( sinks: Optional[torch.Tensor] = None, sliding_window_size: Optional[int] = None, logit_cap: Optional[float] = None, + read_cache_only: bool = False, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out is not None: @@ -453,6 +597,10 @@ def get_source_attention_op(cls) -> OpOverloadPacket: def get_cached_attention_op(cls) -> MHACallable: return torch.ops.auto_deploy.torch_cached_attention_with_cache.default + @classmethod + def supports_shared_kv(cls) -> bool: + return True + @classmethod def get_standard_metadata_args(cls) -> List[str]: return ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"] @@ -484,14 +632,7 @@ def get_cache_initializers( @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: # Sanity check: layout == "bsnd" - # Prefer kwargs; fall back to the final positional arg if it's a string. - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " @@ -529,4 +670,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: sinks, # sinks parameter sliding_window, # sliding window parameter logit_cap, # logit cap parameter + cls.get_shared_kv_source_layer_idx(source_attn_node) is not None, # read_cache_only ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py index 0670a5b9d77..a1ea264b0e2 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py @@ -388,14 +388,8 @@ def get_cache_initializers( @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: # Sanity check: layout == "bsnd" - # Prefer kwargs; fall back to the final positional arg if it's a string. - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + # extract_op_args handles kwargs and positional arguments consistently. + layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py index d0c8a1dd8c0..f1c99267ed0 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py @@ -639,13 +639,7 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: from tensor shapes or SequenceInfo metadata at runtime. """ # Sanity check: layout == "bsnd" - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 748ce06ba22..72acc449519 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -38,6 +38,7 @@ from ...._utils import nvtx_range, prefer_pinned, str_dtype_to_torch from ..utils.logger import ad_logger +from ..utils.node_utils import extract_op_args, get_op_schema Constant = Union[int, float, str, None] @@ -65,6 +66,14 @@ def _list_to_tensor(data: list, dtype: torch.dtype) -> torch.Tensor: return torch.tensor(data, dtype=dtype) +def _extract_optional_op_arg(node: Node, arg_name: str): + """Return an op argument if it exists in the schema, otherwise ``None``.""" + schema_arg_names = {arg.name for arg in get_op_schema(node.target).arguments} + if arg_name not in schema_arg_names: + return None + return extract_op_args(node, arg_name)[0] + + class PrepareMetadataHostCallable(Protocol): def __call__(self, **sequence_info_args: torch.Tensor) -> None: ... @@ -1854,6 +1863,11 @@ def attention_op( """ raise NotImplementedError + @classmethod + def supports_shared_kv(cls) -> bool: + """Whether this backend supports shared-KV cache aliasing.""" + return False + @classmethod @abstractmethod def get_standard_metadata_args(cls) -> List[str]: @@ -1921,6 +1935,16 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: """ return [] + @classmethod + def get_layer_idx(cls, source_attn_node: Node) -> Optional[int]: + """Return the logical layer index associated with a source attention node, if any.""" + return _extract_optional_op_arg(source_attn_node, "layer_idx") + + @classmethod + def get_shared_kv_source_layer_idx(cls, source_attn_node: Node) -> Optional[int]: + """Return the KV source layer for a shared-KV attention node, if any.""" + return _extract_optional_op_arg(source_attn_node, "shared_kv_source_layer_idx") + @staticmethod def resolve_cache_dtype(dtype_config: str, fallback_dtype: torch.dtype) -> torch.dtype: """Resolve cache dtype from KvCacheConfig dtype string to torch.dtype. diff --git a/tensorrt_llm/_torch/auto_deploy/export/export.py b/tensorrt_llm/_torch/auto_deploy/export/export.py index 79313a28bb9..730576528b6 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/export.py +++ b/tensorrt_llm/_torch/auto_deploy/export/export.py @@ -14,7 +14,7 @@ from ..utils._graph import canonicalize_graph, lift_to_meta, load_buffers_and_params, tree_to from ..utils.logger import ad_logger -from ..utils.node_utils import is_op +from ..utils.node_utils import get_op_schema, is_op from .interface import apply_export_patches if TYPE_CHECKING: @@ -276,7 +276,7 @@ def _expand_moe_experts_in_graph( # Collect indices of List[Tensor] arguments from the op schema – these # are the per-expert weight / scale lists. op = node.target - schema = op._schema if hasattr(op, "_schema") else next(iter(op._schemas.values())) + schema = get_op_schema(op) _tensor_list_types = ("Tensor[]", "List[Tensor]") list_arg_indices = [ i diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py index 77b5b8c91f1..bf1bf1c9909 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py @@ -1,4 +1,5 @@ from .modeling_deepseek import DeepSeekV3ForCausalLM +from .modeling_gemma3n import Gemma3nForCausalLM, Gemma3nForConditionalGeneration from .modeling_glm4_moe_lite import Glm4MoeLiteForCausalLM from .modeling_kimi_k2 import KimiK2ForCausalLM, KimiK25ForConditionalGeneration from .modeling_mistral3 import Mistral3ForConditionalGenerationAD, Mistral4ForCausalLM @@ -8,6 +9,8 @@ __all__ = ( "DeepSeekV3ForCausalLM", + "Gemma3nForCausalLM", + "Gemma3nForConditionalGeneration", "Glm4MoeLiteForCausalLM", "KimiK2ForCausalLM", "KimiK25ForConditionalGeneration", diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py new file mode 100644 index 00000000000..b2ab5965d22 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py @@ -0,0 +1,852 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Slimmed down Gemma 3n text implementation for AutoDeploy export. + +This implementation follows the Hugging Face Gemma 3n text stack closely while +keeping only the prefill path needed by AutoDeploy. The outer +``Gemma3nForConditionalGeneration`` wrapper preserves the HF text checkpoint +layout (``model.language_model.*`` + ``lm_head``) and drops unsupported +vision/audio tower weights at load time. The forward path intentionally +supports only text-only export. +""" + +import copy +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.models.gemma3n.configuration_gemma3n import ( + Gemma3nAudioConfig, + Gemma3nConfig, + Gemma3nTextConfig, + Gemma3nVisionConfig, +) +from transformers.utils import ModelOutput + +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory + + +def _build_rope_cache( + config: Gemma3nTextConfig, +) -> Tuple[torch.Tensor, torch.Tensor, float]: + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) + else: + rope_type = "default" + + inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, device=None) + positions = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype) + freqs = torch.outer(positions, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin(), attention_scaling + + +class Gemma3nRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_buffer("weight", torch.ones(dim), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.torch_rmsnorm(x, self.weight, self.eps) + + +class Gemma3nTextScaledWordEmbedding(nn.Embedding): + def __init__( + self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float + ): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale.to(dtype=self.weight.dtype) + + +class Gemma3nTextLaurelBlock(nn.Module): + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + self.linear_left = nn.Linear(config.hidden_size, config.laurel_rank, bias=False) + self.linear_right = nn.Linear(config.laurel_rank, config.hidden_size, bias=False) + self.post_laurel_norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + laurel_hidden_states = self.linear_left(hidden_states) + laurel_hidden_states = self.linear_right(laurel_hidden_states) + laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states) + return hidden_states + laurel_hidden_states + + +class Gemma3nTextMLP(nn.Module): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size[layer_idx] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + self.activation_sparsity = config.activation_sparsity_pattern[layer_idx] + if self.activation_sparsity > 0.0: + normal_dist = torch.distributions.normal.Normal(0, 1) + std_multiplier = normal_dist.icdf( + torch.tensor(self.activation_sparsity, dtype=torch.float32) + ) + self.register_buffer( + "activation_sparsity_std_multiplier", std_multiplier, persistent=False + ) + else: + self.register_buffer( + "activation_sparsity_std_multiplier", + torch.tensor(0.0, dtype=torch.float32), + persistent=False, + ) + + def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: + std_multiplier = self.activation_sparsity_std_multiplier.to( + device=inputs.device, dtype=inputs.dtype + ) + inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) + inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) + cutoff = inputs_mean + inputs_std * std_multiplier + return torch.relu(inputs - cutoff) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_proj = self.gate_proj(hidden_states) + if self.activation_sparsity > 0.0: + gate_proj = self._gaussian_topk(gate_proj) + activations = self.act_fn(gate_proj) + up_proj = self.up_proj(hidden_states) + return self.down_proj(activations * up_proj) + + +class Gemma3nTextAltUp(nn.Module): + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + self.config = config + self.correct_output_scale = nn.Parameter(torch.zeros(config.hidden_size)) + self.correction_coefs = nn.Linear( + config.altup_num_inputs, config.altup_num_inputs, bias=False + ) + self.prediction_coefs = nn.Linear( + config.altup_num_inputs, config.altup_num_inputs**2, bias=False + ) + self.modality_router = nn.Linear(config.hidden_size, config.altup_num_inputs, bias=False) + self.router_norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.register_buffer( + "router_input_scale", torch.tensor(config.hidden_size**-1.0), persistent=False + ) + + def compute_router_modalities(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_inputs = self.router_norm(hidden_states) * self.router_input_scale + routed = self.modality_router(router_inputs) + return torch.tanh(routed.float()).type_as(hidden_states) + + def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: + modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx]) + all_coefs = self.prediction_coefs(modalities).reshape( + *modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs + ) + all_coefs = all_coefs.permute(0, 1, 3, 2) + predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs) + predictions = predictions.permute(3, 0, 1, 2) + return (predictions + hidden_states).contiguous().type_as(hidden_states) + + def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor: + modalities = self.compute_router_modalities(activated) + innovation = activated - predictions[self.config.altup_active_idx] + innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) + all_coefs = self.correction_coefs(modalities) + 1.0 + all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1) + corrected = torch.mul(innovation, all_coefs) + return (corrected + predictions).contiguous().type_as(activated) + + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: + return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as( + corrected + ) + + +class Gemma3nTextRotaryEmbedding(nn.Module): + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + cos, sin, attention_scaling = _build_rope_cache(config) + self.register_buffer("_ad_cos_cached", cos * attention_scaling, persistent=False) + self.register_buffer("_ad_sin_cached", sin * attention_scaling, persistent=False) + + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + del position_ids + cos = self._ad_cos_cached.to(dtype=x.dtype, device=x.device) + sin = self._ad_sin_cached.to(dtype=x.dtype, device=x.device) + return cos, sin + + +def _slice_rope_cache( + position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + cos, sin = position_embeddings + return cos[position_ids], sin[position_ids] + + +class Gemma3nMultimodalEmbedder(nn.Module): + def __init__( + self, + multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig, + text_config: Gemma3nTextConfig, + ): + super().__init__() + self.multimodal_hidden_size = multimodal_config.hidden_size + self.eps = multimodal_config.rms_norm_eps + self.vocab_offset = multimodal_config.vocab_offset + self.vocab_size = multimodal_config.vocab_size + self.text_hidden_size = text_config.hidden_size + + self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size) + self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) + self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) + self.embedding_projection = nn.Linear( + self.multimodal_hidden_size, self.text_hidden_size, bias=False + ) + self.embedding_post_projection_norm = Gemma3nRMSNorm( + self.text_hidden_size, eps=self.eps, with_scale=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if inputs_embeds is not None: + embeddings = self.soft_embedding_norm(inputs_embeds) + else: + embeddings = self.embedding(input_ids - self.vocab_offset) + embeddings = self.hard_embedding_norm(embeddings) + embeddings = self.embedding_projection(embeddings) + return self.embedding_post_projection_norm(embeddings) + + +class Gemma3nTextAttention(nn.Module): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_sliding else None + first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx]) + ) + else: + self.kv_shared_layer_index = None + + self.q_proj = nn.Linear( + config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + + self.q_norm = Gemma3nRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3nRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma3nRMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + query_states = self.q_proj(hidden_states).view( + batch_size, seq_len, self.num_heads, self.head_dim + ) + key_states = self.k_proj(hidden_states).view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ) + value_states = self.v_proj(hidden_states).view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + value_states = self.v_norm(value_states) + + cos, sin = position_embeddings + query_states, key_states = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin( + query_states, + key_states, + cos, + sin, + 2, + ) + + attn_output = torch.ops.auto_deploy.torch_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + 1.0, + None, + self.sliding_window, + None, + "bsnd", + self.layer_idx, + self.kv_shared_layer_index if self.is_kv_shared_layer else None, + ) + attn_output = attn_output.reshape(batch_size, seq_len, -1) + return self.o_proj(attn_output) + + +class Gemma3nTextDecoderLayer(nn.Module): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Gemma3nTextAttention(config, layer_idx) + self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx) + self.input_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma3nRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.act_fn = ACT2FN[config.hidden_activation] + + self.altup = Gemma3nTextAltUp(config) + self.laurel = Gemma3nTextLaurelBlock(config) + self.per_layer_input_gate = nn.Linear( + config.hidden_size, config.hidden_size_per_layer_input, bias=False + ) + self.per_layer_projection = nn.Linear( + config.hidden_size_per_layer_input, config.hidden_size, bias=False + ) + self.post_per_layer_input_norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: Tuple[torch.Tensor, torch.Tensor], + position_embeddings_local: Tuple[torch.Tensor, torch.Tensor], + per_layer_input: torch.Tensor, + ) -> torch.Tensor: + predictions = self.altup.predict(hidden_states) + active_idx = getattr(self.altup, "active_idx", self.altup.config.altup_active_idx) + active_prediction = predictions[active_idx] + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + attn = self.self_attn(active_prediction_normed, position_embeddings) + attn = self.post_attention_layernorm(attn) + + attn_gated = active_prediction + attn + attn_laurel = (attn_gated + laurel_output) / math.sqrt(2.0) + + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + corrected_predictions = self.altup.correct(predictions, attn_laurel + attn_ffw_norm) + + first_prediction = corrected_predictions[active_idx].clone() + if self.altup.config.altup_correct_scale: + first_prediction = self.altup.scale_corrected_output(first_prediction) + + first_prediction = self.per_layer_input_gate(first_prediction) + first_prediction = self.act_fn(first_prediction) + first_prediction = torch.multiply(first_prediction, per_layer_input) + first_prediction = self.per_layer_projection(first_prediction) + first_prediction = self.post_per_layer_input_norm(first_prediction) + for idx in range(corrected_predictions.shape[0]): + if idx != active_idx: + corrected_predictions[idx] += first_prediction + return corrected_predictions + + +@dataclass +class Gemma3nTextOutput(ModelOutput): + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class Gemma3nCausalLMOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +@dataclass +class Gemma3nConditionalOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +class Gemma3nTextPreTrainedModel(PreTrainedModel): + config_class = Gemma3nTextConfig + base_model_prefix = "model" + _no_split_modules = ["Gemma3nTextDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma3nTextAltUp): + module.correct_output_scale.data.zero_() + + +class Gemma3nPreTrainedModel(PreTrainedModel): + config_class = Gemma3nConfig + base_model_prefix = "model" + _no_split_modules = ["Gemma3nTextDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma3nTextAltUp): + module.correct_output_scale.data.zero_() + + +class Gemma3nTextModel(Gemma3nTextPreTrainedModel): + def __init__(self, config: Gemma3nTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.embed_tokens = Gemma3nTextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=config.hidden_size**0.5, + ) + self.layers = nn.ModuleList( + [ + Gemma3nTextDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Gemma3nTextRotaryEmbedding(config) + + local_config = copy.deepcopy(config) + local_config.rope_theta = local_config.rope_local_base_freq + local_config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3nTextRotaryEmbedding(local_config) + + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=config.hidden_size_per_layer_input**0.5, + ) + self.per_layer_model_projection = nn.Linear( + config.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + ) + self.per_layer_projection_norm = Gemma3nRMSNorm( + config.hidden_size_per_layer_input, eps=config.rms_norm_eps + ) + self.altup_projections = nn.ModuleList( + [ + nn.Linear(config.hidden_size, config.hidden_size, bias=False) + for _ in range(1, config.altup_num_inputs) + ] + ) + self.altup_unembed_projections = nn.ModuleList( + [ + nn.Linear(config.hidden_size, config.hidden_size, bias=False) + for _ in range(1, config.altup_num_inputs) + ] + ) + self.register_buffer( + "per_layer_projection_scale", torch.tensor(config.hidden_size**-0.5), persistent=False + ) + self.register_buffer( + "per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False + ) + self.register_buffer("_ad_eps", torch.tensor(1e-5), persistent=False) + self._register_load_state_dict_pre_hook(self._slice_reduced_layer_weights) + self.post_init() + + def _slice_reduced_layer_weights( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + del local_metadata, strict, missing_keys, unexpected_keys, error_msgs + keys_to_params = { + prefix + "embed_tokens_per_layer.weight": self.embed_tokens_per_layer.weight, + prefix + "per_layer_model_projection.weight": self.per_layer_model_projection.weight, + } + for state_key, target_param in keys_to_params.items(): + if state_key not in state_dict: + continue + checkpoint_weight = state_dict[state_key] + if checkpoint_weight.ndim != 2: + continue + if ( + checkpoint_weight.shape[0] == target_param.shape[0] + and checkpoint_weight.shape[1] > target_param.shape[1] + ): + state_dict[state_key] = checkpoint_weight[:, : target_param.shape[1]] + elif ( + checkpoint_weight.shape[0] > target_param.shape[0] + and checkpoint_weight.shape[1] == target_param.shape[1] + ): + state_dict[state_key] = checkpoint_weight[: target_param.shape[0]] + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.embed_tokens_per_layer(input_ids).reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor], + ) -> torch.Tensor: + per_layer_projection = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale.to( + dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + if per_layer_projection.shape != per_layer_inputs.shape: + per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] + + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to( + dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma3nTextOutput: + del kwargs + assert position_ids is not None, "position_ids must be provided" + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + per_layer_inputs = self.get_per_layer_inputs(input_ids) + + assert inputs_embeds is not None + per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs) + position_embeddings_global = _slice_rope_cache( + self.rotary_emb(inputs_embeds, position_ids), position_ids + ) + position_embeddings_local = _slice_rope_cache( + self.rotary_emb_local(inputs_embeds, position_ids), position_ids + ) + + target_magnitude = torch.mean(inputs_embeds**2, dim=-1, keepdim=True) ** 0.5 + hidden_states = [inputs_embeds] + for projection in self.altup_projections: + current_hidden_state = projection(inputs_embeds).to(dtype=inputs_embeds.dtype) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt( + torch.maximum( + new_magnitude, + self._ad_eps.to(device=inputs_embeds.device, dtype=new_magnitude.dtype), + ) + ) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude + hidden_states.append(current_hidden_state) + hidden_states = torch.stack(hidden_states, dim=0) + + for decoder_layer in self.layers: + layer_per_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] + hidden_states = decoder_layer( + hidden_states, + position_embeddings_global, + position_embeddings_local, + layer_per_input, + ) + + target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5 + reduced_hidden_states = [hidden_states[0]] + for i, projection in enumerate(self.altup_unembed_projections, start=1): + current_hidden_state = projection(hidden_states[i]).to(dtype=inputs_embeds.dtype) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt( + torch.maximum( + new_magnitude, + self._ad_eps.to(device=inputs_embeds.device, dtype=new_magnitude.dtype), + ) + ) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude + reduced_hidden_states.append(current_hidden_state) + + hidden_states = torch.mean(torch.stack(reduced_hidden_states), dim=0) + hidden_states = self.norm(hidden_states) + return Gemma3nTextOutput(last_hidden_state=hidden_states) + + +class Gemma3nForCausalLM(Gemma3nTextPreTrainedModel, GenerationMixin): + config_class = Gemma3nTextConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Gemma3nTextConfig, **kwargs): + del kwargs + super().__init__(config) + self.model = Gemma3nTextModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma3nCausalLMOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + logits = self.lm_head(outputs.last_hidden_state) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + return Gemma3nCausalLMOutput(logits=logits) + + +class Gemma3nModel(Gemma3nPreTrainedModel): + def __init__(self, config: Gemma3nConfig): + super().__init__(config) + self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input + self.vision_tower = nn.Module() + self.language_model = Gemma3nTextModel(config.text_config) + self.audio_tower = nn.Module() + self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config) + self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) + self._register_load_state_dict_pre_hook(self._drop_unsupported_multimodal_tower_weights) + self.post_init() + + @staticmethod + def _drop_unsupported_multimodal_tower_weights( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + del local_metadata, strict, missing_keys, unexpected_keys, error_msgs + unsupported_prefixes = ( + prefix + "vision_tower.", + prefix + "audio_tower.", + ) + for key in list(state_dict): + if key.startswith(unsupported_prefixes): + state_dict.pop(key) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma3nTextOutput: + del kwargs + del input_features_mask + assert position_ids is not None, "position_ids must be provided" + if pixel_values is not None or input_features is not None: + raise NotImplementedError( + "Gemma3n multimodal inputs are not supported by the current AutoDeploy export path. " + "Use text-only prompts for this onboarding." + ) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + per_layer_inputs = None + if input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, input_ids < self.vocab_size_per_layer_input + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) + + return self.language_model( + input_ids=None, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + ) + + +class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin): + config_class = Gemma3nConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Gemma3nConfig, **kwargs): + del kwargs + super().__init__(config) + self.model = Gemma3nModel(config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=False + ) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def get_decoder(self): + return self.model.get_decoder() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma3nConditionalOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + input_features=input_features, + input_features_mask=input_features_mask, + **kwargs, + ) + logits = self.lm_head(outputs.last_hidden_state) + if self.config.text_config.final_logit_softcapping is not None: + logits = logits / self.config.text_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.text_config.final_logit_softcapping + return Gemma3nConditionalOutput(logits=logits) + + +AutoModelForCausalLMFactory.register_custom_model_cls("Gemma3nTextConfig", Gemma3nForCausalLM) +AutoModelForCausalLMFactory.register_custom_model_cls( + "Gemma3nConfig", Gemma3nForConditionalGeneration +) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 4bef54528b6..3f4d5122584 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -35,7 +35,7 @@ from ...shim.interface import CachedSequenceInterface from ...utils._graph import add_graph_input from ...utils.cuda_mem_tracker import get_mem_info -from ...utils.node_utils import is_op +from ...utils.node_utils import get_op_schema, is_op from ..interface import ( BaseTransform, SharedConfig, @@ -108,7 +108,7 @@ def _process_metadata_extra( # check what inputs the extra metadata op expects inputs_for_prep_meta = [ self._add_or_retrieve_input(gm, cm, arg.name) - for arg in prep_meta_op._schema.arguments + for arg in get_op_schema(prep_meta_op).arguments if arg.name in cm.info.available_args ] @@ -137,6 +137,7 @@ def _insert_cached_attn_node( self, gm: GraphModule, attn_node: Node, + cached_attn_op, qkv_nodes: List[Node], meta_nodes_std: List[Node], meta_nodes_extra: List[Node], @@ -146,7 +147,7 @@ def _insert_cached_attn_node( """Insert a cached attention node into the graph.""" with gm.graph.inserting_before(attn_node): cached_attn_node = gm.graph.call_function( - self.attn_descriptor.get_cached_attention_op(), + cached_attn_op, args=( *qkv_nodes, *meta_nodes_std, @@ -168,10 +169,8 @@ def _apply( """Replace uncached source attention node with corresponding cached attn node.""" attn_descriptor = self.attn_descriptor - # Get all attention nodes and their info objects - source_op = attn_descriptor.get_source_attention_op() - # look for relevant source attention nodes + source_op = attn_descriptor.get_source_attention_op() source_attn_nodes = [n for n in gm.graph.nodes if is_op(n, source_op)] if not source_attn_nodes: @@ -191,16 +190,46 @@ def _apply( # replace fused attention node with attention node that has kv cache num_cached_attn_replacements = 0 - for attn_node in source_attn_nodes: + cache_nodes_by_layer_idx = {} + for idx, attn_node in enumerate(source_attn_nodes): # pick out GEMMs qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()] - # setup + store cache resource handlers and caches as input nodes - resources_dict = attn_descriptor.get_cache_initializers(attn_node, cm.kv_cache_config) - cache_in_nodes = [ - self._process_cache_node(gm, cm.add_resource(k, resource_handler)) - for k, resource_handler in resources_dict.items() - ] + layer_idx = attn_descriptor.get_layer_idx(attn_node) + shared_kv_source_layer_idx = attn_descriptor.get_shared_kv_source_layer_idx(attn_node) + + if shared_kv_source_layer_idx is not None: + if not attn_descriptor.supports_shared_kv(): + raise RuntimeError( + f"Backend '{self.config.backend}' does not support shared-KV attention." + ) + if layer_idx is None: + raise RuntimeError( + "Shared-KV attention node is missing layer_idx metadata required for " + "cache aliasing." + ) + if shared_kv_source_layer_idx == layer_idx: + raise RuntimeError(f"Layer {layer_idx} cannot share its own KV cache.") + if shared_kv_source_layer_idx not in cache_nodes_by_layer_idx: + raise RuntimeError( + f"Missing shared-KV source layer {shared_kv_source_layer_idx}." + ) + cache_in_nodes = cache_nodes_by_layer_idx[shared_kv_source_layer_idx] + else: + # setup + store cache initializers and caches as input nodes + if layer_idx is not None and layer_idx in cache_nodes_by_layer_idx: + raise RuntimeError( + f"Duplicate KV cache owner detected for layer {layer_idx}. " + "Each non-shared attention layer must own exactly one cache." + ) + cache_in_nodes = [] + for k, resource_handler in attn_descriptor.get_cache_initializers( + attn_node, cm.kv_cache_config + ).items(): + resource_name = cm.add_resource(k, resource_handler) + cache_in_nodes.append(self._process_cache_node(gm, resource_name)) + if layer_idx is not None: + cache_nodes_by_layer_idx[layer_idx] = cache_in_nodes # allow backend-specific prep before constants are extracted attn_descriptor.prepare_node_for_cache_insertion(gm, attn_node) @@ -212,6 +241,7 @@ def _apply( self._insert_cached_attn_node( gm, attn_node, + attn_descriptor.get_cached_attention_op(), qkv, meta_nodes_std, meta_nodes_extra, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index 8eeacbd6685..54b5650738e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -159,7 +159,10 @@ def cached_attn( elif attention_layout != "bnsd": raise ValueError(f"Unsupported attention layout: {attention_layout}") - attn_output = attn_descriptor.get_cached_attention_op()( + cached_attn_op = module._node_ref.meta.get( + "cached_attn_op", attn_descriptor.get_cached_attention_op() + ) + attn_output = cached_attn_op( query, key, value, @@ -238,6 +241,7 @@ def _insert_cached_attn_node( self, gm: GraphModule, attn_node: Node, + cached_attn_op, qkv_nodes: List[Node], meta_nodes_std: List[Node], meta_nodes_extra: List[Node], @@ -246,6 +250,7 @@ def _insert_cached_attn_node( ): """Here we now need to actually do the correct mapping of the cached attn nodes.""" # store reference to metadata, caches, and constants for this attn node + attn_node.meta["cached_attn_op"] = cached_attn_op attn_node.meta["metadata_cache_keys"] = (*meta_nodes_std, *meta_nodes_extra, *cache_nodes) attn_node.meta["constants"] = constants diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index d54c86d0c37..6ba653cec9b 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -18,7 +18,7 @@ from torch.utils._pytree import _LEAF_SPEC, TreeSpec from .logger import ad_logger -from .node_utils import get_weight_tensor, is_op +from .node_utils import get_op_schema, get_weight_tensor, is_op # --------------------------------------------------------------------------- # Dynamic custom-op derivation helpers @@ -72,7 +72,7 @@ def create_derived_custom_op( the same *base_op* and *suffix* return the cached op. """ base_overload = base_op.default if hasattr(base_op, "default") else base_op - schema = base_overload._schema + schema = get_op_schema(base_overload) # e.g. "auto_deploy::trtllm_moe_fused" qualified_name = schema.name diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 71f4bab4e52..a64800782db 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -1049,18 +1049,22 @@ def extract_output_tuple(node: Node, count: int = 2): return results -def _get_op_schema(node: Node): - """Return the op schema for a call_function node.""" - if node.op != "call_function": - raise ValueError(f"_get_op_schema only supports call_function nodes, got {node.op}") - op = node.target +def get_op_schema(op) -> torch.FunctionSchema: + """Return the schema for an op or op overload packet.""" if hasattr(op, "_schemas"): return next(iter(op._schemas.values())) - elif hasattr(op, "_schema"): + if hasattr(op, "_schema"): return op._schema raise RuntimeError(f"No schema found on op {op}") +def _get_op_schema(node: Node): + """Return the op schema for a call_function node.""" + if node.op != "call_function": + raise ValueError(f"_get_op_schema only supports call_function nodes, got {node.op}") + return get_op_schema(node.target) + + def extract_op_args(node: Node, *arg_names): """ Given a call_function node for torch custom op, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py new file mode 100644 index 00000000000..847c84dcf68 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py @@ -0,0 +1,474 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import copy +from typing import Tuple + +import pytest +import torch +from torch.export import Dim + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_backend_attention import ( + TorchBackendAttention, +) +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_gemma3n import ( + Gemma3nAudioConfig, + Gemma3nConditionalOutput, + Gemma3nConfig, + Gemma3nForCausalLM, + Gemma3nForConditionalGeneration, + Gemma3nTextAttention, + Gemma3nTextConfig, + Gemma3nTextDecoderLayer, + Gemma3nTextMLP, + Gemma3nVisionConfig, +) +from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device + + +def assert_rmse_close( + actual: torch.Tensor, + expected: torch.Tensor, + rmse_ratio_tol: float, + msg: str = "", +) -> None: + diff = actual.float() - expected.float() + rmse_diff = torch.sqrt(torch.mean(diff**2)) + rmse_ref = torch.sqrt(torch.mean(expected.float() ** 2)) + ratio = (rmse_diff / rmse_ref).item() + assert ratio < rmse_ratio_tol, ( + f"{msg}RMSE ratio {ratio:.6f} exceeds tolerance {rmse_ratio_tol}. " + f"(rmse_diff={rmse_diff.item():.6f}, rmse_ref={rmse_ref.item():.6f})" + ) + + +def _get_hf_classes(): + try: + from transformers.models.gemma3n.modeling_gemma3n import ( + Gemma3nForCausalLM as HFGemma3nForCausalLM, + ) + from transformers.models.gemma3n.modeling_gemma3n import ( + Gemma3nTextAttention as HFGemma3nTextAttention, + ) + from transformers.models.gemma3n.modeling_gemma3n import ( + Gemma3nTextDecoderLayer as HFGemma3nTextDecoderLayer, + ) + from transformers.models.gemma3n.modeling_gemma3n import Gemma3nTextMLP as HFGemma3nTextMLP + except ImportError: + return None + return HFGemma3nForCausalLM, HFGemma3nTextAttention, HFGemma3nTextDecoderLayer, HFGemma3nTextMLP + + +HF_CLASSES = _get_hf_classes() + + +def _device_and_dtype() -> Tuple[str, torch.dtype]: + if torch.cuda.is_available(): + return "cuda", torch.bfloat16 + return "cpu", torch.float32 + + +def _small_text_config() -> Gemma3nTextConfig: + config = Gemma3nTextConfig( + vocab_size=256, + vocab_size_per_layer_input=256, + hidden_size=64, + hidden_size_per_layer_input=8, + intermediate_size=[128, 128, 128], + num_hidden_layers=3, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=64, + rms_norm_eps=1e-6, + rope_theta=10000.0, + rope_local_base_freq=1000.0, + attention_bias=False, + attention_dropout=0.0, + sliding_window=16, + layer_types=["sliding_attention", "sliding_attention", "full_attention"], + final_logit_softcapping=30.0, + altup_active_idx=0, + altup_correct_scale=True, + altup_num_inputs=3, + num_kv_shared_layers=0, + laurel_rank=8, + activation_sparsity_pattern=[0.5, 0.0, 0.0], + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + ) + config._attn_implementation = "eager" + return config + + +def _small_full_config() -> Gemma3nConfig: + return Gemma3nConfig( + text_config=_small_text_config(), + vision_config=Gemma3nVisionConfig( + hidden_size=32, + vocab_size=8, + vocab_offset=256, + rms_norm_eps=1e-6, + ), + audio_config=Gemma3nAudioConfig( + vocab_size=8, + vocab_offset=264, + hidden_size=32, + rms_norm_eps=1e-6, + conf_num_attention_heads=4, + conf_num_hidden_layers=2, + sscp_conv_channel_size=(16, 8), + ), + ) + + +def _extended_text_config(num_hidden_layers: int) -> Gemma3nTextConfig: + config = copy.deepcopy(_small_text_config()) + config.num_hidden_layers = num_hidden_layers + config.intermediate_size = [128] * num_hidden_layers + config.layer_types = ["sliding_attention"] * (num_hidden_layers - 1) + ["full_attention"] + config.activation_sparsity_pattern = [0.0] * num_hidden_layers + return config + + +def _shared_kv_text_config() -> Gemma3nTextConfig: + config = Gemma3nTextConfig( + vocab_size=256, + vocab_size_per_layer_input=256, + hidden_size=64, + hidden_size_per_layer_input=8, + intermediate_size=[128] * 6, + num_hidden_layers=6, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=64, + rms_norm_eps=1e-6, + rope_theta=10000.0, + rope_local_base_freq=1000.0, + attention_bias=False, + attention_dropout=0.0, + sliding_window=16, + layer_types=[ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ], + final_logit_softcapping=30.0, + altup_active_idx=0, + altup_correct_scale=True, + altup_num_inputs=3, + num_kv_shared_layers=2, + laurel_rank=8, + activation_sparsity_pattern=[0.0] * 6, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + ) + config._attn_implementation = "eager" + return config + + +def _position_ids(batch_size: int, seq_len: int, device: str) -> torch.Tensor: + return torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + + +def _load_equivalent_modules(custom_module: torch.nn.Module, hf_module: torch.nn.Module) -> None: + missing, unexpected = custom_module.load_state_dict(hf_module.state_dict(), strict=False) + assert not missing, f"Missing keys when loading HF weights into custom module: {missing}" + assert not unexpected, ( + f"Unexpected keys when loading HF weights into custom module: {unexpected}" + ) + + +@pytest.fixture(autouse=True) +def _set_seed(): + torch.manual_seed(42) + + +def test_hf_reference_available(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + hf_model_cls, hf_attention_cls, hf_layer_cls, hf_mlp_cls = HF_CLASSES + assert hf_model_cls.__name__ == "Gemma3nForCausalLM" + assert hf_attention_cls.__name__ == "Gemma3nTextAttention" + assert hf_layer_cls.__name__ == "Gemma3nTextDecoderLayer" + assert hf_mlp_cls.__name__ == "Gemma3nTextMLP" + + +@torch.no_grad() +def test_gemma3n_mlp_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + _, _, _, hf_mlp_cls = HF_CLASSES + device, dtype = _device_and_dtype() + config = _small_text_config() + custom_mlp = Gemma3nTextMLP(config, layer_idx=0).to(device=device, dtype=dtype) + hf_mlp = hf_mlp_cls(config, layer_idx=0).to(device=device, dtype=dtype) + _load_equivalent_modules(custom_mlp, hf_mlp) + + hidden_states = torch.randn(2, 6, config.hidden_size, device=device, dtype=dtype) + custom_out = custom_mlp(hidden_states) + hf_out = hf_mlp(hidden_states) + torch.testing.assert_close(custom_out.float(), hf_out.float(), rtol=1e-3, atol=1e-3) + + +@torch.no_grad() +def test_gemma3n_attention_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + _, hf_attention_cls, _, _ = HF_CLASSES + device, dtype = _device_and_dtype() + config = _small_text_config() + custom_attn = Gemma3nTextAttention(config, layer_idx=2).to(device=device, dtype=dtype) + hf_attn = hf_attention_cls(config, layer_idx=2).to(device=device, dtype=dtype) + _load_equivalent_modules(custom_attn, hf_attn) + + hidden_states = torch.randn(2, 6, config.hidden_size, device=device, dtype=dtype) + position_ids = _position_ids(2, 6, device) + custom_rope = Gemma3nForCausalLM(config).model.rotary_emb.to(device=device) + full_cos, full_sin = custom_rope(hidden_states, position_ids) + position_embeddings = (full_cos[position_ids], full_sin[position_ids]) + + custom_out = custom_attn(hidden_states, position_embeddings) + hf_out = hf_attn(hidden_states, position_embeddings, attention_mask=None)[0] + assert_rmse_close(custom_out[:, -1:], hf_out[:, -1:], rmse_ratio_tol=0.10, msg="Attention: ") + + +@torch.no_grad() +def test_gemma3n_decoder_layer_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + _, _, hf_layer_cls, _ = HF_CLASSES + device, dtype = _device_and_dtype() + config = _small_text_config() + custom_layer = Gemma3nTextDecoderLayer(config, layer_idx=2).to(device=device, dtype=dtype) + hf_layer = hf_layer_cls(config, layer_idx=2).to(device=device, dtype=dtype) + _load_equivalent_modules(custom_layer, hf_layer) + + batch_size, seq_len = 2, 1 + hidden_states = torch.randn( + config.altup_num_inputs, batch_size, seq_len, config.hidden_size, device=device, dtype=dtype + ) + per_layer_input = torch.randn( + batch_size, seq_len, config.hidden_size_per_layer_input, device=device, dtype=dtype + ) + position_ids = _position_ids(batch_size, seq_len, device) + rope_model = Gemma3nForCausalLM(config).model.to(device=device) + global_cos, global_sin = rope_model.rotary_emb(hidden_states[0], position_ids) + local_cos, local_sin = rope_model.rotary_emb_local(hidden_states[0], position_ids) + position_embeddings_global = (global_cos[position_ids], global_sin[position_ids]) + position_embeddings_local = (local_cos[position_ids], local_sin[position_ids]) + + custom_out = custom_layer( + hidden_states, + position_embeddings_global, + position_embeddings_local, + per_layer_input, + ) + hf_out = hf_layer( + hidden_states, + position_embeddings_global, + position_embeddings_local, + per_layer_input, + attention_mask=None, + position_ids=position_ids, + )[0] + assert_rmse_close(custom_out, hf_out, rmse_ratio_tol=0.05, msg="Decoder layer: ") + + +@torch.no_grad() +def test_gemma3n_full_model_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + hf_model_cls, _, _, _ = HF_CLASSES + device, dtype = "cpu", torch.float32 + config = _small_text_config() + custom_model = Gemma3nForCausalLM(config).to(device=device, dtype=dtype) + hf_model = hf_model_cls(config).to(device=device, dtype=dtype) + _load_equivalent_modules(custom_model, hf_model) + custom_model.eval() + hf_model.eval() + + input_ids = torch.randint(0, config.vocab_size, (2, 6), device=device) + position_ids = _position_ids(2, 6, device) + custom_out = custom_model(input_ids=input_ids, position_ids=position_ids) + hf_out = hf_model(input_ids=input_ids, position_ids=position_ids) + assert_rmse_close(custom_out.logits, hf_out.logits, rmse_ratio_tol=0.05, msg="Full model: ") + + +@torch.no_grad() +def test_gemma3n_conditional_wrapper_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + hf_model_cls, _, _, _ = HF_CLASSES + device, dtype = "cpu", torch.float32 + config = _small_full_config() + wrapper = Gemma3nForConditionalGeneration(config).to(device=device, dtype=dtype) + hf_model = hf_model_cls(config.text_config).to(device=device, dtype=dtype) + _load_equivalent_modules(wrapper.model.language_model, hf_model.model) + _load_equivalent_modules(wrapper.lm_head, hf_model.lm_head) + wrapper.eval() + hf_model.eval() + + input_ids = torch.randint( + 0, config.text_config.vocab_size_per_layer_input, (2, 6), device=device + ) + position_ids = _position_ids(2, 6, device) + wrapper_out = wrapper(input_ids=input_ids, position_ids=position_ids) + hf_out = hf_model(input_ids=input_ids, position_ids=position_ids) + assert isinstance(wrapper_out, Gemma3nConditionalOutput) + assert_rmse_close(wrapper_out.logits, hf_out.logits, rmse_ratio_tol=0.05, msg="Wrapper: ") + + +def test_gemma3n_conditional_wrapper_load_hook_drops_unsupported_tower_weights(): + config = _small_full_config() + wrapper = Gemma3nForConditionalGeneration(config) + state_dict = wrapper.state_dict() + state_dict["model.vision_tower.fake.weight"] = torch.randn(2, 2) + state_dict["model.audio_tower.fake.weight"] = torch.randn(2, 2) + + missing, unexpected = wrapper.load_state_dict(state_dict, strict=True) + + assert missing == [] + assert unexpected == [] + + +def test_gemma3n_conditional_wrapper_ignores_hf_init_kwargs(): + config = _small_full_config() + wrapper = Gemma3nForConditionalGeneration(config, use_cache=False) + assert isinstance(wrapper, Gemma3nForConditionalGeneration) + + +def test_gemma3n_reduced_layer_load_hook_slices_per_layer_weights(): + source_model = Gemma3nForCausalLM(_extended_text_config(5)) + target_model = Gemma3nForCausalLM(_small_text_config()) + + missing, unexpected = target_model.load_state_dict(source_model.state_dict(), strict=False) + + assert missing == [] + assert "model.layers.3.self_attn.q_proj.weight" in unexpected + + +def test_gemma3n_causal_lm_ties_lm_head_to_input_embeddings(): + model = Gemma3nForCausalLM(_small_text_config()) + assert model.lm_head.weight.data_ptr() == model.model.embed_tokens.weight.data_ptr() + + +def test_gemma3n_conditional_lm_ties_lm_head_to_input_embeddings(): + model = Gemma3nForConditionalGeneration(_small_full_config()) + assert ( + model.lm_head.weight.data_ptr() == model.model.language_model.embed_tokens.weight.data_ptr() + ) + + +def test_gemma3n_shared_kv_layer_metadata_matches_config(): + model = Gemma3nForCausalLM(_shared_kv_text_config()) + layer_expectations = [ + (False, None), + (False, None), + (False, None), + (False, None), + (True, 2), + (True, 3), + ] + + for layer, (is_shared, source_idx) in zip(model.model.layers, layer_expectations, strict=True): + assert layer.self_attn.is_kv_shared_layer is is_shared + assert layer.self_attn.kv_shared_layer_index == source_idx + + +def test_gemma3n_export_uses_shared_kv_attention_for_shared_layers(): + config = _shared_kv_text_config() + model = Gemma3nForCausalLM(config).eval() + input_ids = torch.randint(0, config.vocab_size, (1, 4)) + position_ids = _position_ids(1, 4, "cpu") + + gm = torch_export_to_gm( + model, + args=tuple(), + kwargs={"input_ids": input_ids, "position_ids": position_ids}, + ) + + attn_nodes = [node for node in gm.graph.nodes if node.op == "call_function"] + attn_nodes = [ + node for node in attn_nodes if node.target == torch.ops.auto_deploy.torch_attention.default + ] + regular_nodes = [ + node + for node in attn_nodes + if TorchBackendAttention.get_shared_kv_source_layer_idx(node) is None + ] + shared_nodes = [ + node + for node in attn_nodes + if TorchBackendAttention.get_shared_kv_source_layer_idx(node) is not None + ] + + assert len(attn_nodes) == config.num_hidden_layers + assert len(regular_nodes) == config.num_hidden_layers - config.num_kv_shared_layers + assert len(shared_nodes) == config.num_kv_shared_layers + assert [TorchBackendAttention.get_layer_idx(regular) for regular in regular_nodes] == [ + 0, + 1, + 2, + 3, + ] + assert [TorchBackendAttention.get_layer_idx(shared) for shared in shared_nodes] == [4, 5] + assert [ + TorchBackendAttention.get_shared_kv_source_layer_idx(shared) for shared in shared_nodes + ] == [2, 3] + + +def test_gemma3n_model_can_be_exported(): + if not torch.cuda.is_available(): + pytest.skip("Export test requires CUDA") + + device = "cuda" + dtype = torch.bfloat16 + config = _small_full_config() + model = Gemma3nForConditionalGeneration(config).to(device=device, dtype=dtype) + model.eval() + + input_ids = torch.randint(0, config.text_config.vocab_size, (2, 8), device=device) + position_ids = _position_ids(2, 8, device) + + gm = torch_export_to_gm( + model, + args=tuple(), + kwargs={"input_ids": input_ids, "position_ids": position_ids}, + dynamic_shapes=( + {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, + {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, + ), + ) + move_to_device(gm, device) + + with torch.inference_mode(): + eager_out = model(input_ids=input_ids, position_ids=position_ids) + export_out = gm(input_ids=input_ids, position_ids=position_ids) + + assert "logits" in export_out + assert_rmse_close(export_out["logits"], eager_out.logits, rmse_ratio_tol=0.05, msg="Export: ") + + input_ids_2 = torch.randint(0, config.text_config.vocab_size, (1, 5), device=device) + position_ids_2 = _position_ids(1, 5, device) + with torch.inference_mode(): + export_out_2 = gm(input_ids=input_ids_2, position_ids=position_ids_2) + eager_out_2 = model(input_ids=input_ids_2, position_ids=position_ids_2) + assert_rmse_close( + export_out_2["logits"], eager_out_2.logits, rmse_ratio_tol=0.05, msg="Export dynamic: " + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py new file mode 100644 index 00000000000..4a65cc2de80 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py @@ -0,0 +1,660 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.compile.piecewise_utils import is_dynamic_cached_op +from tensorrt_llm._torch.auto_deploy.custom_ops.attention.flashinfer_attention import ( + FlashInferAttention, +) +from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_backend_attention import ( + TorchBackendAttention, +) +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface +from tensorrt_llm._torch.auto_deploy.transform.interface import SharedConfig, Stages +from tensorrt_llm._torch.auto_deploy.transform.library.kvcache import ( + InsertCachedAttentionConfig, + _InsertCachedOperator, +) + + +class _TinySharedKVModule(torch.nn.Module): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + qkv = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], 2, 4) + regular = torch.ops.auto_deploy.torch_attention( + qkv, + qkv, + qkv, + None, + 0.0, + True, + 1.0, + None, + None, + None, + "bsnd", + 0, + ) + shared = torch.ops.auto_deploy.torch_attention( + qkv, + qkv, + qkv, + None, + 0.0, + True, + 1.0, + None, + None, + None, + "bsnd", + 1, + 0, + ) + return regular + shared + + +class _DuplicateLayerOwnerSharedKVModule(torch.nn.Module): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + qkv = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], 2, 4) + first = torch.ops.auto_deploy.torch_attention( + qkv, qkv, qkv, None, 0.0, True, 1.0, None, None, None, "bsnd", 0 + ) + second = torch.ops.auto_deploy.torch_attention( + qkv, qkv, qkv, None, 0.0, True, 1.0, None, None, None, "bsnd", 0 + ) + return first + second + + +def _context_meta(seq_len: int): + batch_info_host = BatchInfo() + batch_info_host.update([1, seq_len, 0, 0, 0, 0]) + return ( + batch_info_host.serialize(), + torch.tensor([seq_len], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([0], dtype=torch.int64), + torch.tensor([0], dtype=torch.int32), + ) + + +def _decode_meta(input_pos: int): + batch_info_host = BatchInfo() + batch_info_host.update([0, 0, 0, 0, 1, 1]) + return ( + batch_info_host.serialize(), + torch.tensor([1], dtype=torch.int32), + torch.tensor([input_pos], dtype=torch.int32), + torch.tensor([0], dtype=torch.int64), + torch.tensor([0], dtype=torch.int32), + ) + + +def _manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sliding_window: int | None = None, +) -> torch.Tensor: + batch, seq_len_q, num_heads, _ = q.shape + _, seq_len_k, num_kv_heads, _ = k.shape + if num_heads != num_kv_heads: + repeat_factor = num_heads // num_kv_heads + k = k.repeat_interleave(repeat_factor, dim=2) + v = v.repeat_interleave(repeat_factor, dim=2) + + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + scores = torch.matmul(q_t, k_t.transpose(-2, -1)) + causal_mask = torch.triu( + torch.ones(seq_len_q, seq_len_k, dtype=torch.bool, device=scores.device), + diagonal=seq_len_k - seq_len_q + 1, + ) + scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + if sliding_window is not None: + query_positions = torch.arange(seq_len_k - seq_len_q, seq_len_k, device=scores.device) + key_positions = torch.arange(seq_len_k, device=scores.device) + pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) + sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) + scores = scores.masked_fill(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + weights = torch.softmax(scores, dim=-1) + return torch.matmul(weights, v_t).transpose(1, 2) + + +def _make_layer_inputs(offset: float, seq_len: int, decode: bool = False): + base_q = torch.tensor( + [[[1.0, 0.0], [0.0, 1.0]]] + if decode + else [ + [[1.0, 0.0], [0.0, 1.0]], + [[0.5, 0.5], [0.5, -0.5]], + [[0.25, 0.75], [0.75, 0.25]], + ], + dtype=torch.float32, + ) + base_k = torch.tensor( + [[[1.0, 0.0]]] if decode else [[[1.0, 0.0]], [[0.0, 1.0]], [[1.0, 1.0]]], + dtype=torch.float32, + ) + base_v = torch.tensor( + [[[10.0, 1.0]]] if decode else [[[10.0, 1.0]], [[2.0, 20.0]], [[30.0, 3.0]]], + dtype=torch.float32, + ) + q = (base_q + offset).unsqueeze(0) + k = (base_k + offset).unsqueeze(0) + v = (base_v + offset * 10.0).unsqueeze(0) + assert q.shape[1] == seq_len + return q, k, v + + +def test_shared_kv_transform_aliases_source_cache_placeholders(): + module = _TinySharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + + cm = CachedSequenceInterface( + max_seq_len=16, + max_batch_size=2, + max_num_tokens=16, + device="cpu", + ) + transform = _InsertCachedOperator( + InsertCachedAttentionConfig(stage=Stages.CACHE_INIT, backend="torch") + ) + gm, info = transform._apply(gm, cm, factory=None, shared_config=SharedConfig()) + + assert info.num_matches == 2 + + placeholder_names = [node.target for node in gm.graph.nodes if node.op == "placeholder"] + assert placeholder_names.count("r0_k_cache") == 1 + assert placeholder_names.count("r1_v_cache") == 1 + assert "r2_k_cache" not in placeholder_names + assert "r3_v_cache" not in placeholder_names + assert set(cm._resource_lookup).issubset(set(placeholder_names)) + + cached_nodes = [node for node in gm.graph.nodes if node.op == "call_function"] + regular_node = next( + node + for node in cached_nodes + if node.target == torch.ops.auto_deploy.torch_cached_attention_with_cache.default + and node.args[-1] is False + ) + shared_node = next( + node + for node in cached_nodes + if node.target == torch.ops.auto_deploy.torch_cached_attention_with_cache.default + and node.args[-1] is True + ) + + assert regular_node.args[8] is shared_node.args[8] + assert regular_node.args[9] is shared_node.args[9] + assert regular_node.target == torch.ops.auto_deploy.torch_cached_attention_with_cache.default + assert shared_node.target == torch.ops.auto_deploy.torch_cached_attention_with_cache.default + assert regular_node.args[-1] is False + assert shared_node.args[-1] is True + + +def test_shared_kv_cached_attention_reads_without_writing(): + q = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32) + dummy_k = torch.full((1, 1, 2, 2), 123.0, dtype=torch.float32) + dummy_v = torch.full((1, 1, 2, 2), -456.0, dtype=torch.float32) + + k_cache = torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]], [[0.5, 0.0], [0.0, 0.5]], [[0.25, 0.0], [0.0, 0.25]]]], + dtype=torch.float32, + ) + v_cache = torch.tensor( + [[[[10.0, 1.0], [2.0, 20.0]], [[30.0, 3.0], [4.0, 40.0]], [[50.0, 5.0], [6.0, 60.0]]]], + dtype=torch.float32, + ) + k_cache_before = k_cache.clone() + v_cache_before = v_cache.clone() + + batch_info_host = BatchInfo() + batch_info_host.update([0, 0, 0, 0, 1, 1]) + output = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q, + dummy_k, + dummy_v, + batch_info_host.serialize(), + torch.tensor([1], dtype=torch.int32), + torch.tensor([2], dtype=torch.int32), + torch.tensor([0], dtype=torch.int64), + torch.tensor([0], dtype=torch.int32), + k_cache, + v_cache, + 1.0, + None, + None, + None, + True, + ) + + assert torch.equal(k_cache, k_cache_before) + assert torch.equal(v_cache, v_cache_before) + + k_for_attn = k_cache_before[0, :3].transpose(0, 1) + v_for_attn = v_cache_before[0, :3].transpose(0, 1) + logits = torch.matmul(q[0, 0].unsqueeze(1), k_for_attn.transpose(-2, -1)) + weights = torch.softmax(logits, dim=-1) + expected = torch.matmul(weights, v_for_attn).squeeze(1).unsqueeze(0).unsqueeze(0) + torch.testing.assert_close(output, expected, rtol=1e-5, atol=1e-5) + + +def test_torch_backend_attention_metadata_for_shared_kv_node(): + module = _TinySharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + source_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" + and node.target == torch.ops.auto_deploy.torch_attention.default + ] + regular = next( + node + for node in source_nodes + if node.target == torch.ops.auto_deploy.torch_attention.default + ) + shared = next(node for node in source_nodes if TorchBackendAttention.get_layer_idx(node) == 1) + + assert TorchBackendAttention.get_layer_idx(regular) == 0 + assert TorchBackendAttention.get_layer_idx(shared) == 1 + assert TorchBackendAttention.get_shared_kv_source_layer_idx(regular) is None + assert TorchBackendAttention.get_shared_kv_source_layer_idx(shared) == 0 + + +def test_flashinfer_backend_attention_metadata_for_shared_kv_node(): + module = _TinySharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + source_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" + and node.target == torch.ops.auto_deploy.torch_attention.default + ] + regular = next( + node + for node in source_nodes + if node.target == torch.ops.auto_deploy.torch_attention.default + ) + shared = next(node for node in source_nodes if FlashInferAttention.get_layer_idx(node) == 1) + + assert FlashInferAttention.get_layer_idx(regular) == 0 + assert FlashInferAttention.get_layer_idx(shared) == 1 + assert FlashInferAttention.get_shared_kv_source_layer_idx(regular) is None + assert FlashInferAttention.get_shared_kv_source_layer_idx(shared) == 0 + assert FlashInferAttention.get_cached_attention_op() == ( + torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + ) + + +def test_shared_kv_transform_aliases_source_cache_placeholders_for_flashinfer(): + module = _TinySharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + + cm = CachedSequenceInterface( + max_seq_len=16, + max_batch_size=2, + max_num_tokens=16, + device="cpu", + ) + transform = _InsertCachedOperator( + InsertCachedAttentionConfig(stage=Stages.CACHE_INIT, backend="flashinfer") + ) + gm, info = transform._apply(gm, cm, factory=None, shared_config=SharedConfig()) + + assert info.num_matches == 2 + + placeholder_names = [node.target for node in gm.graph.nodes if node.op == "placeholder"] + assert placeholder_names.count("r0_kv_cache") == 1 + assert "r1_kv_cache" not in placeholder_names + assert set(cm._resource_lookup).issubset(set(placeholder_names)) + + cached_nodes = [node for node in gm.graph.nodes if node.op == "call_function"] + regular_node = next( + node + for node in cached_nodes + if node.target == torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + and node.args[-1] is False + ) + shared_node = next( + node + for node in cached_nodes + if node.target == torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + and node.args[-1] is True + ) + + assert regular_node.args[11] is shared_node.args[11] + assert regular_node.target == torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + assert shared_node.target == torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + assert regular_node.args[-1] is False + assert shared_node.args[-1] is True + + +def test_flashinfer_cached_attention_is_dynamic_for_piecewise(): + shared_op_name = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default.name() + + class _FakeNode: + op = "call_function" + + def __init__(self, target): + self.target = target + + assert "flashinfer_attention_mha_with_cache" in shared_op_name + assert is_dynamic_cached_op( + _FakeNode(torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default) + ) + + +@torch.no_grad() +def test_torch_shared_kv_cached_attention_supports_out_buffer(): + q = torch.randn(1, 3, 2, 4) + k = torch.randn(1, 3, 1, 4) + v = torch.randn(1, 3, 1, 4) + batch_info_host = BatchInfo() + batch_info_host.update([1, 3, 0, 0, 1, 1]) + seq_len = torch.tensor([3], dtype=torch.int32) + input_pos = torch.tensor([0, 1, 2], dtype=torch.int32) + slot_idx = torch.tensor([0, 1, 2], dtype=torch.int32) + cu_seqlen = torch.tensor([0], dtype=torch.int32) + k_cache = torch.randn(1, 4, 1, 4) + v_cache = torch.randn(1, 4, 1, 4) + + expected = torch.ops.auto_deploy.torch_cached_attention_with_cache.default( + q, + k, + v, + batch_info_host.serialize(), + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + None, + read_cache_only=True, + ) + + out = torch.full_like(expected, float("nan")) + ret = torch.ops.auto_deploy.torch_cached_attention_with_cache.default( + q, + k, + v, + batch_info_host.serialize(), + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + None, + read_cache_only=True, + out=out, + ) + + assert ret.numel() == 0 + torch.testing.assert_close(out, expected) + + +def test_shared_kv_self_alias_raises(): + class _SelfAliasingSharedKVModule(torch.nn.Module): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + qkv = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], 2, 4) + return torch.ops.auto_deploy.torch_attention( + qkv, qkv, qkv, None, 0.0, True, 1.0, None, None, None, "bsnd", 1, 1 + ) + + module = _SelfAliasingSharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + + cm = CachedSequenceInterface( + max_seq_len=16, + max_batch_size=2, + max_num_tokens=16, + device="cpu", + ) + transform = _InsertCachedOperator( + InsertCachedAttentionConfig(stage=Stages.CACHE_INIT, backend="torch") + ) + + with pytest.raises(RuntimeError, match="cannot share its own KV cache"): + transform._apply(gm, cm, factory=None, shared_config=SharedConfig()) + + +def test_duplicate_cache_owner_layer_idx_raises(): + module = _DuplicateLayerOwnerSharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + + cm = CachedSequenceInterface( + max_seq_len=16, + max_batch_size=2, + max_num_tokens=16, + device="cpu", + ) + transform = _InsertCachedOperator( + InsertCachedAttentionConfig(stage=Stages.CACHE_INIT, backend="torch") + ) + + with pytest.raises(RuntimeError, match="Duplicate KV cache owner"): + transform._apply(gm, cm, factory=None, shared_config=SharedConfig()) + + +@torch.no_grad() +def test_flashinfer_shared_kv_cached_attention_reads_aliased_cache_without_writing(): + if not torch.cuda.is_available(): + return + + device = torch.device("cuda") + head_dim = 64 + q = torch.zeros((1, 1, 1, head_dim), dtype=torch.float16, device=device) + q[0, 0, 0, 0] = 1.0 + dummy_k = torch.full((1, 1, 1, head_dim), 9.0, dtype=torch.float16, device=device) + dummy_v = torch.full((1, 1, 1, head_dim), 7.0, dtype=torch.float16, device=device) + + owner_k = torch.zeros((1, 3, 1, head_dim), dtype=torch.float16, device=device) + owner_k[0, 0, 0, 0] = 1.0 + owner_k[0, 1, 0, 1] = 1.0 + owner_k[0, 2, 0, 0] = 1.0 + owner_k[0, 2, 0, 1] = 1.0 + owner_v = torch.zeros((1, 3, 1, head_dim), dtype=torch.float16, device=device) + owner_v[0, 0, 0, 0] = 10.0 + owner_v[0, 1, 0, 1] = 20.0 + owner_v[0, 2, 0, 0] = 30.0 + owner_v[0, 2, 0, 1] = 3.0 + kv_cache = torch.zeros((1, 2, 1, 32, head_dim), dtype=torch.float16, device=device) + kv_cache[0, 0, 0, :3, :] = owner_k[0, :, 0, :] + kv_cache[0, 1, 0, :3, :] = owner_v[0, :, 0, :] + kv_cache_before = kv_cache.clone() + + batch_info_host = BatchInfo() + batch_info_host.update([0, 0, 0, 0, 1, 1]) + cu_seqlen_host = torch.tensor([0, 1], dtype=torch.int32, device="cpu") + cu_num_pages = torch.tensor([0, 1], dtype=torch.int32, device=device) + cu_num_pages_host = torch.tensor([0, 1], dtype=torch.int32, device="cpu") + cache_loc = torch.tensor([0], dtype=torch.int32, device=device) + last_page_len = torch.tensor([3], dtype=torch.int32, device=device) + last_page_len_host = torch.tensor([3], dtype=torch.int32, device="cpu") + seq_len_with_cache_host = torch.tensor([3], dtype=torch.int32, device="cpu") + batch_indices = torch.zeros(1, dtype=torch.int32, device=device) + positions = torch.zeros(1, dtype=torch.int32, device=device) + + output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( + q, + dummy_k, + dummy_v, + batch_info_host.serialize(), + cu_seqlen_host, + cu_num_pages, + cu_num_pages_host, + cache_loc, + last_page_len, + last_page_len_host, + seq_len_with_cache_host, + batch_indices, + positions, + kv_cache, + 1.0, + None, + 1.0, + 1.0, + True, + ) + + torch.testing.assert_close(kv_cache, kv_cache_before, rtol=0.0, atol=0.0) + + expected = _manual_attention(q.float(), owner_k.float(), owner_v.float()).to(output.dtype) + torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +def test_shared_kv_six_layer_stack_matches_reference_for_prefill_and_decode(): + layer_sources = {4: 2, 5: 3} + sliding_layers = {2, 4} + prefill_len = 3 + decode_pos = prefill_len + owner_caches = { + layer_idx: ( + torch.zeros(1, 8, 1, 2, dtype=torch.float32), + torch.zeros(1, 8, 1, 2, dtype=torch.float32), + ) + for layer_idx in range(4) + } + owner_history = {} + + for layer_idx in range(6): + q_prefill, k_prefill, v_prefill = _make_layer_inputs( + offset=float(layer_idx), seq_len=prefill_len + ) + batch_info_host, seq_len, input_pos, slot_idx, cu_seqlen = _context_meta(prefill_len) + sliding_window = 2 if layer_idx in sliding_layers else None + + if layer_idx in layer_sources: + source_idx = layer_sources[layer_idx] + k_cache, v_cache = owner_caches[source_idx] + output_prefill = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q_prefill, + k_prefill, + v_prefill, + batch_info_host, + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + 1.0, + None, + sliding_window, + None, + True, + ) + expected_prefill = _manual_attention( + q_prefill, + owner_history[source_idx]["k_prefill"], + owner_history[source_idx]["v_prefill"], + sliding_window=sliding_window, + ) + else: + k_cache, v_cache = owner_caches[layer_idx] + output_prefill = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q_prefill, + k_prefill, + v_prefill, + batch_info_host, + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + 1.0, + None, + sliding_window, + None, + ) + expected_prefill = _manual_attention( + q_prefill, + k_prefill, + v_prefill, + sliding_window=sliding_window, + ) + owner_history[layer_idx] = { + "k_prefill": k_prefill.clone(), + "v_prefill": v_prefill.clone(), + } + torch.testing.assert_close(k_cache[0, :prefill_len], k_prefill[0], rtol=0.0, atol=0.0) + torch.testing.assert_close(v_cache[0, :prefill_len], v_prefill[0], rtol=0.0, atol=0.0) + + torch.testing.assert_close(output_prefill, expected_prefill, rtol=1e-5, atol=1e-5) + + for layer_idx in range(6): + q_decode, k_decode, v_decode = _make_layer_inputs( + offset=100.0 + float(layer_idx), seq_len=1, decode=True + ) + batch_info_host, seq_len, input_pos, slot_idx, cu_seqlen = _decode_meta(decode_pos) + sliding_window = 2 if layer_idx in sliding_layers else None + + if layer_idx in layer_sources: + source_idx = layer_sources[layer_idx] + k_cache, v_cache = owner_caches[source_idx] + k_cache_before = k_cache.clone() + v_cache_before = v_cache.clone() + output_decode = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q_decode, + k_decode, + v_decode, + batch_info_host, + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + 1.0, + None, + sliding_window, + None, + True, + ) + torch.testing.assert_close(k_cache, k_cache_before, rtol=0.0, atol=0.0) + torch.testing.assert_close(v_cache, v_cache_before, rtol=0.0, atol=0.0) + expected_k = owner_history[source_idx]["k_full"] + expected_v = owner_history[source_idx]["v_full"] + else: + k_cache, v_cache = owner_caches[layer_idx] + output_decode = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q_decode, + k_decode, + v_decode, + batch_info_host, + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + 1.0, + None, + sliding_window, + None, + ) + expected_k = torch.cat([owner_history[layer_idx]["k_prefill"], k_decode], dim=1) + expected_v = torch.cat([owner_history[layer_idx]["v_prefill"], v_decode], dim=1) + owner_history[layer_idx]["k_full"] = expected_k + owner_history[layer_idx]["v_full"] = expected_v + torch.testing.assert_close( + k_cache[0, : decode_pos + 1], expected_k[0], rtol=0.0, atol=0.0 + ) + torch.testing.assert_close( + v_cache[0, : decode_pos + 1], expected_v[0], rtol=0.0, atol=0.0 + ) + + expected_decode = _manual_attention( + q_decode, + expected_k, + expected_v, + sliding_window=sliding_window, + ) + torch.testing.assert_close(output_decode, expected_decode, rtol=1e-5, atol=1e-5) diff --git a/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py b/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py index b1c1e604e0f..a6e51aa7794 100644 --- a/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py +++ b/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py @@ -18,6 +18,7 @@ _args_kwargs_flatten_spec, ) from tensorrt_llm._torch.auto_deploy.compile.piecewise_utils import submod_has_cuda_ops +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.shim.ad_executor import _round_up_to_closest from tensorrt_llm._torch.auto_deploy.transform.library.compile_model import ( @@ -285,13 +286,17 @@ def _make_dual_mode(self, piecewise_num_tokens=None): def test_is_decode_only_with_batch_info_host_zero(self): dual = self._make_dual_mode() # num_prefill=0 → decode-only - batch_info = torch.tensor([0, 0, 4]) # [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = BatchInfo() + batch_info_host.update([0, 0, 0, 0, 4, 4]) + batch_info = batch_info_host.serialize() assert dual._is_decode_only(batch_info_host=batch_info) is True def test_is_decode_only_with_batch_info_host_nonzero(self): dual = self._make_dual_mode() # num_prefill=2 → not decode-only - batch_info = torch.tensor([2, 100, 3]) + batch_info_host = BatchInfo() + batch_info_host.update([2, 100, 0, 0, 3, 3]) + batch_info = batch_info_host.serialize() assert dual._is_decode_only(batch_info_host=batch_info) is False def test_is_decode_only_fallback_heuristic_decode(self): diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py index d9b016e9498..8493bd8e201 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py @@ -132,6 +132,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -261,6 +262,7 @@ def test_flashinfer_attention_op_decode( kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -391,6 +393,7 @@ def test_flashinfer_attention_context_and_generate( kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -485,6 +488,7 @@ def test_flashinfer_attention_context_and_generate( kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -618,6 +622,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -776,6 +781,7 @@ def test_flashinfer_attention_with_fp8_cache( kv_cache, # CONSTANTS None, + None, K_SCALE, V_SCALE, ) @@ -883,6 +889,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -977,6 +984,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) From 555dd606d47d1d7489d2a44023b0550627904872 Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 10:41:45 -0700 Subject: [PATCH 02/16] [None][feat] AutoDeploy: Gemma4 custom model support Adds Gemma4 (MoE) custom model for AutoDeploy with: - Custom modeling code supporting K=V attention, proportional RoPE, parallel dense+MoE, per-layer scalars, and logit softcapping - Gelu activation support in torch_moe for Gemma4 MoE layers - Hierarchical equivalence tests - Model registry config (triton_paged attention backend for head_dim=512) Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../model_registry/configs/gemma4_moe.yaml | 7 + .../custom_ops/fused_moe/torch_moe.py | 14 +- .../auto_deploy/models/custom/__init__.py | 3 + .../models/custom/modeling_gemma4.py | 892 ++++++++++++++++++ .../singlegpu/models/test_gemma4_modeling.py | 628 ++++++++++++ 5 files changed, 1541 insertions(+), 3 deletions(-) create mode 100644 examples/auto_deploy/model_registry/configs/gemma4_moe.yaml create mode 100644 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml new file mode 100644 index 00000000000..807c2b2cbb0 --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Gemma 4 MoE (26B total, 4B activated) — text-only AD export path. +# Uses triton paged attention backend: supports head_dim=512 (global_head_dim), +# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode. +attn_backend: triton_paged diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index e7ab3a61a3e..d231de35136 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, List import torch @@ -165,10 +166,15 @@ def _template_moe_alltoall( def _resolve_torch_fn(act_fn: ActivationType) -> Callable[[torch.Tensor], torch.Tensor]: """ Returns an elementwise activation callable matching the given activation function. - Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2 + Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2, ActivationType.Gelu """ - assert act_fn in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2], ( - f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu' or 'relu2'." + assert act_fn in [ + ActivationType.Silu, + ActivationType.Swiglu, + ActivationType.Relu2, + ActivationType.Gelu, + ], ( + f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu', 'relu2', or 'gelu'." ) torch_fn = None if act_fn == ActivationType.Silu or act_fn == ActivationType.Swiglu: @@ -179,6 +185,8 @@ def relu2(x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) torch_fn = relu2 + elif act_fn == ActivationType.Gelu: + torch_fn = partial(F.gelu, approximate="tanh") return torch_fn diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py index bf1bf1c9909..35a485ba2d9 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py @@ -1,5 +1,6 @@ from .modeling_deepseek import DeepSeekV3ForCausalLM from .modeling_gemma3n import Gemma3nForCausalLM, Gemma3nForConditionalGeneration +from .modeling_gemma4 import Gemma4ForCausalLM, Gemma4ForConditionalGeneration from .modeling_glm4_moe_lite import Glm4MoeLiteForCausalLM from .modeling_kimi_k2 import KimiK2ForCausalLM, KimiK25ForConditionalGeneration from .modeling_mistral3 import Mistral3ForConditionalGenerationAD, Mistral4ForCausalLM @@ -11,6 +12,8 @@ "DeepSeekV3ForCausalLM", "Gemma3nForCausalLM", "Gemma3nForConditionalGeneration", + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", "Glm4MoeLiteForCausalLM", "KimiK2ForCausalLM", "KimiK25ForConditionalGeneration", diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py new file mode 100644 index 00000000000..ad8d6a12f3d --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py @@ -0,0 +1,892 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Slimmed down Gemma 4 text implementation for AutoDeploy export. + +This implementation follows the HuggingFace Gemma 4 text stack closely while +keeping only the prefill path needed by AutoDeploy. The outer +``Gemma4ForConditionalGeneration`` wrapper preserves the HF checkpoint layout +(``model.language_model.*``) and drops unsupported vision/audio tower weights +at load time. The forward path supports text-only export. + +Key architectural features of Gemma 4 vs standard transformers: +- K=V attention on full-attention layers (v_proj is absent; k_proj output is + reused as value) +- Different head dimensions for full vs sliding attention (global_head_dim vs + head_dim) +- Proportional RoPE with partial_rotary_factor on full-attention layers +- Dense MLP running in parallel with Mixture-of-Experts (MoE) in every layer +- Per-layer scalar multiplier +- Final logit softcapping +""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoConfig, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +from tensorrt_llm._torch.utils import ActivationType + +# --------------------------------------------------------------------------- +# Bundled config classes — enables loading on transformers <5.3 where +# Gemma4 is not natively registered. +# --------------------------------------------------------------------------- + + +class Gemma4TextConfig(PretrainedConfig): + """Minimal Gemma4 text config for AutoDeploy.""" + + model_type = "gemma4_text" + + def __init__( + self, + vocab_size: int = 262_144, + hidden_size: int = 2816, + intermediate_size: int = 2112, + num_hidden_layers: int = 30, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 256, + global_head_dim: int = 512, + num_global_key_value_heads: int = 2, + hidden_activation: str = "gelu_pytorch_tanh", + max_position_embeddings: int = 131_072, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + attention_k_eq_v: bool = True, + sliding_window: int = 1024, + layer_types: Optional[list] = None, + rope_parameters: Optional[dict] = None, + final_logit_softcapping: Optional[float] = 30.0, + hidden_size_per_layer_input: int = 0, + num_kv_shared_layers: int = 0, + use_double_wide_mlp: bool = False, + use_bidirectional_attention: Optional[str] = "vision", + enable_moe_block: bool = True, + num_experts: Optional[int] = 128, + top_k_experts: Optional[int] = 8, + expert_intermediate_size: Optional[int] = 704, + stream_and_decode_in_f32: bool = True, + vocab_size_per_layer_input: int = 262_144, + routed_layer_pattern: Optional[list] = None, + pad_token_id: Optional[int] = 0, + eos_token_id=1, + bos_token_id: Optional[int] = 2, + tie_word_embeddings: bool = True, + initializer_range: float = 0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.global_head_dim = global_head_dim + self.num_global_key_value_heads = num_global_key_value_heads + self.hidden_activation = hidden_activation + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.attention_k_eq_v = attention_k_eq_v + self.sliding_window = sliding_window + self.layer_types = layer_types or (["sliding_attention"] * num_hidden_layers) + self.rope_parameters = rope_parameters or { + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1_000_000.0, + "partial_rotary_factor": 0.25, + }, + "sliding_attention": {"rope_type": "default", "rope_theta": 10_000.0}, + } + self.final_logit_softcapping = final_logit_softcapping + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.num_kv_shared_layers = num_kv_shared_layers + self.use_double_wide_mlp = use_double_wide_mlp + self.use_bidirectional_attention = use_bidirectional_attention + self.enable_moe_block = enable_moe_block + self.num_experts = num_experts + self.top_k_experts = top_k_experts + self.expert_intermediate_size = expert_intermediate_size + self.stream_and_decode_in_f32 = stream_and_decode_in_f32 + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.routed_layer_pattern = routed_layer_pattern + self.initializer_range = initializer_range + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + bos_token_id=bos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Gemma4VisionConfig(PretrainedConfig): + """Minimal Gemma4 vision config stub.""" + + model_type = "gemma4_vision" + + def __init__(self, hidden_size: int = 1152, rms_norm_eps: float = 1e-6, **kwargs): + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + super().__init__(**kwargs) + + +class Gemma4Config(PretrainedConfig): + """Top-level Gemma4 multimodal config.""" + + model_type = "gemma4" + + def __init__( + self, + text_config=None, + vision_config=None, + audio_config=None, + initializer_range: float = 0.02, + tie_word_embeddings: bool = True, + **kwargs, + ): + self.initializer_range = initializer_range + if text_config is None: + self.text_config = Gemma4TextConfig() + elif isinstance(text_config, dict): + self.text_config = Gemma4TextConfig(**text_config) + else: + self.text_config = text_config + + if vision_config is None: + self.vision_config = Gemma4VisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = Gemma4VisionConfig(**vision_config) + else: + self.vision_config = vision_config + + self.audio_config = audio_config + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +AutoConfig.register("gemma4", Gemma4Config, exist_ok=True) +AutoConfig.register("gemma4_text", Gemma4TextConfig, exist_ok=True) + +# --------------------------------------------------------------------------- +# RoPE cache builder +# --------------------------------------------------------------------------- + + +def _build_rope_cache( + config: Gemma4TextConfig, + layer_type: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pre-compute cos/sin RoPE cache for the given layer type.""" + rope_params = config.rope_parameters[layer_type] + rope_type = rope_params.get("rope_type", "default") + base = rope_params["rope_theta"] + factor = rope_params.get("factor", 1.0) + attention_scaling = 1.0 + + if rope_type == "default": + dim = config.head_dim + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + elif rope_type == "proportional": + # Proportional RoPE: only partial_rotary_factor of head dims are rotated, + # remaining dims get zero inv_freq → cos=1, sin=0 (no rotation). + head_dim = config.global_head_dim + rope_proportion = rope_params.get("partial_rotary_factor", 1.0) + rope_angles = int(rope_proportion * head_dim // 2) + inv_freq_rotated = 1.0 / ( + base ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.float) / head_dim) + ) + nope_angles = head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq = torch.cat( + (inv_freq_rotated, torch.zeros(nope_angles, dtype=torch.float32)), + dim=0, + ) + else: + inv_freq = inv_freq_rotated + inv_freq = inv_freq / factor + else: + # Fallback to HF ROPE_INIT_FUNCTIONS for other types (e.g. yarn, longrope) + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + inv_freq, attention_scaling = rope_init_fn(config, device=None, layer_type=layer_type) + + positions = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype) + freqs = torch.outer(positions, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos() * attention_scaling, emb.sin() * attention_scaling + + +# --------------------------------------------------------------------------- +# Basic building blocks +# --------------------------------------------------------------------------- + + +class Gemma4RMSNorm(nn.Module): + """RMSNorm with Gemma4-style (weight + scale_shift) semantics. + + For AD export, we store the *effective* weight = checkpoint_weight + scale_shift + via a load hook, then use the standard torch_rmsnorm op. + """ + + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_buffer("weight", torch.ones(dim), persistent=False) + if with_scale: + self._register_load_state_dict_pre_hook(self._apply_scale_shift) + + @staticmethod + def _apply_scale_shift(state_dict, prefix, *_args, **_kwargs): + """Gemma4 RMSNorm stores weight that is used as (weight + 1.0). + Convert to effective weight at load time so torch_rmsnorm works directly.""" + key = prefix + "weight" + if key in state_dict: + state_dict[key] = state_dict[key] + 1.0 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.torch_rmsnorm(x, self.weight, self.eps) + + +class Gemma4TextScaledWordEmbedding(nn.Embedding): + def __init__( + self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float + ): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale.to(dtype=self.weight.dtype) + + +class Gemma4RotaryEmbedding(nn.Module): + """Pre-computed RoPE cache for a single layer type (global or local).""" + + def __init__(self, config: Gemma4TextConfig, layer_type: str): + super().__init__() + ( + cos, + sin, + ) = _build_rope_cache(config, layer_type) + self.register_buffer("_ad_cos_cached", cos, persistent=False) + self.register_buffer("_ad_sin_cached", sin, persistent=False) + + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos = self._ad_cos_cached[position_ids].to(dtype=x.dtype, device=x.device) + sin = self._ad_sin_cached[position_ids].to(dtype=x.dtype, device=x.device) + return cos, sin + + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- + + +class Gemma4TextMLP(nn.Module): + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# MoE Router + Experts +# --------------------------------------------------------------------------- + + +class Gemma4Router(nn.Module): + """Gemma4-style MoE router: RMSNorm(no-scale) -> per-dim scale -> linear -> softmax -> topk.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(config.hidden_size)) + self.register_buffer("root_size", torch.tensor(config.hidden_size**-0.5), persistent=False) + self.eps = config.rms_norm_eps + self.top_k = config.top_k_experts + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # RMSNorm without learnable scale + normed = hidden_states.float() + normed = normed * torch.rsqrt(normed.pow(2).mean(-1, keepdim=True) + self.eps) + normed = normed.type_as(hidden_states) + # Apply scalar and per-dim scaling + normed = normed * self.root_size.to(hidden_states.dtype) + normed = normed * self.scale.to(hidden_states.dtype) + # Route + expert_scores = self.proj(normed) + probs = F.softmax(expert_scores, dim=-1) + top_k_weights, top_k_index = torch.topk(probs, k=self.top_k, dim=-1) + top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) + return top_k_weights, top_k_index + + +class Gemma4Expert(nn.Module): + """Single MoE expert: gated MLP (gate_proj, up_proj, down_proj).""" + + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + +class Gemma4MoEBlock(nn.Module): + """Mixture-of-Experts block with fused checkpoint weight conversion. + + Checkpoint stores fused parameters: + - gate_up_proj: [num_experts, 2*intermediate, hidden] + - down_proj: [num_experts, hidden, intermediate] + - per_expert_scale: [num_experts] + + We unfuse these into per-expert nn.Linear modules at load time so that + torch_moe can consume them as weight lists. + """ + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.num_experts = config.num_experts + self.intermediate_size = config.expert_intermediate_size + self.experts = nn.ModuleList( + [ + Gemma4Expert(config.hidden_size, config.expert_intermediate_size) + for _ in range(config.num_experts) + ] + ) + self._register_load_state_dict_pre_hook(self._unfuse_moe_weights) + + def _unfuse_moe_weights(self, state_dict, prefix, *_args, **_kwargs): + """Convert fused checkpoint MoE weights to per-expert format.""" + gate_up_key = prefix + "gate_up_proj" + down_key = prefix + "down_proj" + scale_key = prefix + "per_expert_scale" + + if gate_up_key not in state_dict: + return + + gate_up = state_dict.pop(gate_up_key) # [E, 2*I, H] + down = state_dict.pop(down_key) # [E, H, I] + scale = state_dict.pop(scale_key) # [E] + + inter = self.intermediate_size + for e in range(self.num_experts): + state_dict[f"{prefix}experts.{e}.gate_proj.weight"] = gate_up[e, :inter, :] + state_dict[f"{prefix}experts.{e}.up_proj.weight"] = gate_up[e, inter:, :] + # Fold per_expert_scale into down_proj + state_dict[f"{prefix}experts.{e}.down_proj.weight"] = down[e] * scale[e] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.auto_deploy.torch_moe( + hidden_states, + top_k_index, + top_k_weights, + w1_weight=[e.gate_proj.weight for e in self.experts], + w2_weight=[e.down_proj.weight for e in self.experts], + w3_weight=[e.up_proj.weight for e in self.experts], + is_gated_mlp=True, + act_fn=int(ActivationType.Gelu), + ) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class Gemma4TextAttention(nn.Module): + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_sliding else None + + # Full-attention layers may use different head dim and K=V + self.use_k_eq_v = config.attention_k_eq_v and not self.is_sliding + if not self.is_sliding and config.global_head_dim: + self.head_dim = config.global_head_dim + else: + self.head_dim = config.head_dim + + num_kv_heads = ( + config.num_global_key_value_heads if self.use_k_eq_v else config.num_key_value_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = num_kv_heads + + self.q_proj = nn.Linear( + config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = ( + None + if self.use_k_eq_v + else nn.Linear( + config.hidden_size, num_kv_heads * self.head_dim, bias=config.attention_bias + ) + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + self.q_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + # v_norm has no learnable scale + self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim) + k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + if self.v_proj is not None: + v = self.v_proj(hidden_states).view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ) + else: + v = k # K=V: reuse key as value + + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + + cos, sin = position_embeddings + q, k = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin(q, k, cos, sin, 2) + + attn_output = torch.ops.auto_deploy.torch_attention( + q, + k, + v, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + 1.0, # scale (QK norms handle scaling) + None, # sinks + self.sliding_window, + None, # logit_cap + "bsnd", + self.layer_idx, + ) + return self.o_proj(attn_output.reshape(batch_size, seq_len, -1)) + + +# --------------------------------------------------------------------------- +# Decoder layer +# --------------------------------------------------------------------------- + + +class Gemma4TextDecoderLayer(nn.Module): + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Gemma4TextAttention(config, layer_idx) + self.mlp = Gemma4TextMLP(config) + self.input_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.register_buffer("layer_scalar", torch.ones(1)) + + self.enable_moe_block = config.enable_moe_block + if self.enable_moe_block: + self.router = Gemma4Router(config) + self.moe = Gemma4MoEBlock(config) + self.post_feedforward_layernorm_1 = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + # Self-attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_embeddings) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Feed-forward (dense MLP ± MoE) + residual = hidden_states + + if self.enable_moe_block: + # Dense MLP path + hs_dense = self.pre_feedforward_layernorm(hidden_states) + hs_dense = self.mlp(hs_dense) + hs_dense = self.post_feedforward_layernorm_1(hs_dense) + + # MoE path + hs_flat = hidden_states.reshape(-1, hidden_states.shape[-1]) + top_k_weights, top_k_index = self.router(hs_flat) + hs_moe = self.pre_feedforward_layernorm_2(hs_flat) + hs_moe = self.moe(hs_moe, top_k_index, top_k_weights) + hs_moe = hs_moe.reshape(hidden_states.shape) + hs_moe = self.post_feedforward_layernorm_2(hs_moe) + + hidden_states = hs_dense + hs_moe + else: + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = hidden_states * self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# Text model +# --------------------------------------------------------------------------- + + +@dataclass +class Gemma4TextOutput(ModelOutput): + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class Gemma4CausalLMOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +@dataclass +class Gemma4ConditionalOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +class Gemma4TextPreTrainedModel(PreTrainedModel): + config_class = Gemma4TextConfig + base_model_prefix = "model" + _no_split_modules = ["Gemma4TextDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Gemma4TextModel(Gemma4TextPreTrainedModel): + def __init__(self, config: Gemma4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.embed_tokens = Gemma4TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=config.hidden_size**0.5, + ) + self.layers = nn.ModuleList( + [Gemma4TextDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Separate RoPE caches for global (full) and local (sliding) attention + self.rotary_emb_global = Gemma4RotaryEmbedding(config, "full_attention") + self.rotary_emb_local = Gemma4RotaryEmbedding(config, "sliding_attention") + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma4TextOutput: + del kwargs + assert position_ids is not None, "position_ids must be provided" + + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + + pos_emb_global = self.rotary_emb_global(inputs_embeds, position_ids) + pos_emb_local = self.rotary_emb_local(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + for decoder_layer in self.layers: + if decoder_layer.attention_type == "sliding_attention": + pos_emb = pos_emb_local + else: + pos_emb = pos_emb_global + hidden_states = decoder_layer(hidden_states, pos_emb) + + hidden_states = self.norm(hidden_states) + return Gemma4TextOutput(last_hidden_state=hidden_states) + + +# --------------------------------------------------------------------------- +# CausalLM wrapper (text config) +# --------------------------------------------------------------------------- + + +class Gemma4ForCausalLM(Gemma4TextPreTrainedModel, GenerationMixin): + config_class = Gemma4TextConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Gemma4TextConfig, **kwargs): + del kwargs + super().__init__(config) + self.model = Gemma4TextModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma4CausalLMOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + logits = self.lm_head(outputs.last_hidden_state) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + return Gemma4CausalLMOutput(logits=logits) + + +# --------------------------------------------------------------------------- +# Multimodal embedder stub (for weight loading) +# --------------------------------------------------------------------------- + + +class Gemma4MultimodalEmbedder(nn.Module): + """Minimal stub to accept embed_vision checkpoint weights.""" + + def __init__(self, vision_config: Gemma4VisionConfig, text_config: Gemma4TextConfig): + super().__init__() + self.embedding_projection = nn.Linear( + vision_config.hidden_size, text_config.hidden_size, bias=False + ) + + +# --------------------------------------------------------------------------- +# ConditionalGeneration wrapper (multimodal config, text-only forward) +# --------------------------------------------------------------------------- + + +class Gemma4PreTrainedModel(PreTrainedModel): + config_class = Gemma4Config + base_model_prefix = "model" + _no_split_modules = ["Gemma4TextDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Gemma4Model(Gemma4PreTrainedModel): + def __init__(self, config: Gemma4Config): + super().__init__(config) + self.language_model = Gemma4TextModel(config.text_config) + self.vision_tower = nn.Module() # stub + self.embed_vision = Gemma4MultimodalEmbedder(config.vision_config, config.text_config) + self._register_load_state_dict_pre_hook(self._drop_unsupported_weights) + self.post_init() + + @staticmethod + def _drop_unsupported_weights(state_dict, prefix, *_args, **_kwargs): + unsupported_prefixes = ( + prefix + "vision_tower.", + prefix + "audio_tower.", + prefix + "embed_audio.", + ) + for key in list(state_dict): + if key.startswith(unsupported_prefixes): + state_dict.pop(key) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_decoder(self): + return self.language_model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma4TextOutput: + del kwargs, input_features_mask + assert position_ids is not None, "position_ids must be provided" + if pixel_values is not None or input_features is not None: + raise NotImplementedError( + "Gemma4 multimodal inputs are not supported by the current AutoDeploy export " + "path. Use text-only prompts for this onboarding." + ) + return self.language_model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + +class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin): + config_class = Gemma4Config + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Gemma4Config, **kwargs): + del kwargs + super().__init__(config) + self.model = Gemma4Model(config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=False + ) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def get_decoder(self): + return self.model.get_decoder() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma4ConditionalOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + input_features=input_features, + input_features_mask=input_features_mask, + **kwargs, + ) + logits = self.lm_head(outputs.last_hidden_state) + if self.config.text_config.final_logit_softcapping is not None: + logits = logits / self.config.text_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.text_config.final_logit_softcapping + return Gemma4ConditionalOutput(logits=logits) + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + +AutoModelForCausalLMFactory.register_custom_model_cls("Gemma4TextConfig", Gemma4ForCausalLM) +AutoModelForCausalLMFactory.register_custom_model_cls( + "Gemma4Config", Gemma4ForConditionalGeneration +) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py new file mode 100644 index 00000000000..689fe2f9cd9 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py @@ -0,0 +1,628 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Hierarchical equivalence tests for Gemma4 AutoDeploy custom model. + +Reference classes (_Ref*) are standalone PyTorch reimplementations of the +HuggingFace Gemma4 math — no transformers>=5.3 dependency required. +""" + +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F +from torch import nn +from torch.export import Dim +from transformers.activations import ACT2FN + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_gemma4 import ( + Gemma4Config, + Gemma4ForCausalLM, + Gemma4ForConditionalGeneration, + Gemma4MoEBlock, + Gemma4RotaryEmbedding, + Gemma4Router, + Gemma4TextAttention, + Gemma4TextConfig, + Gemma4TextDecoderLayer, + Gemma4TextMLP, + Gemma4VisionConfig, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def assert_rmse_close( + actual: torch.Tensor, + expected: torch.Tensor, + rmse_ratio_tol: float, + msg: str = "", +) -> None: + diff = actual.float() - expected.float() + rmse_diff = torch.sqrt(torch.mean(diff**2)) + rmse_ref = torch.sqrt(torch.mean(expected.float() ** 2)) + ratio = (rmse_diff / rmse_ref).item() + assert ratio < rmse_ratio_tol, ( + f"{msg}RMSE ratio {ratio:.6f} exceeds tolerance {rmse_ratio_tol}. " + f"(rmse_diff={rmse_diff.item():.6f}, rmse_ref={rmse_ref.item():.6f})" + ) + + +def _device_and_dtype() -> Tuple[str, torch.dtype]: + if torch.cuda.is_available(): + return "cuda", torch.bfloat16 + return "cpu", torch.float32 + + +def _small_text_config() -> Gemma4TextConfig: + config = Gemma4TextConfig( + vocab_size=256, + hidden_size=64, + intermediate_size=32, + num_hidden_layers=3, + num_attention_heads=4, + num_key_value_heads=2, + num_global_key_value_heads=1, + head_dim=16, + global_head_dim=32, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=64, + rms_norm_eps=1e-6, + attention_bias=False, + attention_dropout=0.0, + attention_k_eq_v=True, + sliding_window=16, + layer_types=["sliding_attention", "sliding_attention", "full_attention"], + enable_moe_block=True, + num_experts=4, + top_k_experts=2, + expert_intermediate_size=16, + final_logit_softcapping=30.0, + hidden_size_per_layer_input=0, + num_kv_shared_layers=0, + use_double_wide_mlp=False, + use_bidirectional_attention="vision", + rope_parameters={ + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1000000.0, + "partial_rotary_factor": 0.25, + }, + "sliding_attention": { + "rope_type": "default", + "rope_theta": 10000.0, + }, + }, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + ) + config._attn_implementation = "eager" + return config + + +def _position_ids(batch_size: int, seq_len: int, device: str) -> torch.Tensor: + return torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + + +@pytest.fixture(autouse=True) +def _set_seed(): + torch.manual_seed(42) + + +# --------------------------------------------------------------------------- +# Standalone HF-faithful reference implementations (pure PyTorch) +# These mirror the HuggingFace Gemma4 math exactly, using the same +# state_dict key names, so weights can be shared between AD and reference. +# --------------------------------------------------------------------------- + + +class _RefRMSNorm(nn.Module): + """HF Gemma4RMSNorm: norm(x) * (weight + scale_shift).""" + + def __init__( + self, dim: int, eps: float = 1e-6, scale_shift: float = 1.0, with_scale: bool = True + ): + super().__init__() + self.eps = eps + self.scale_shift = scale_shift + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + self.eps, -0.5) + if self.weight is not None: + normed = normed * (self.weight.float() + self.scale_shift) + return normed.type_as(x) + + +def _ref_rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _ref_apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, udim: int = 2): + cos = cos.unsqueeze(udim) + sin = sin.unsqueeze(udim) + return (x * cos) + (_ref_rotate_half(x) * sin) + + +def _ref_repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + if n_rep == 1: + return x + b, n, s, d = x.shape + return x[:, :, None, :, :].expand(b, n, n_rep, s, d).reshape(b, n * n_rep, s, d) + + +class _RefMLP(nn.Module): + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class _RefAttention(nn.Module): + """HF Gemma4TextAttention reference (eager, no cache, no shared-kv).""" + + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.use_k_eq_v = config.attention_k_eq_v and not self.is_sliding + + self.head_dim = ( + config.global_head_dim + if (not self.is_sliding and config.global_head_dim) + else config.head_dim + ) + self.num_heads = config.num_attention_heads + num_kv_heads = ( + config.num_global_key_value_heads if self.use_k_eq_v else config.num_key_value_heads + ) + self.num_kv_heads = num_kv_heads + self.num_kv_groups = self.num_heads // num_kv_heads + self.scaling = 1.0 + + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, num_kv_heads * self.head_dim, bias=False) + self.v_proj = ( + None + if self.use_k_eq_v + else nn.Linear(config.hidden_size, num_kv_heads * self.head_dim, bias=False) + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + self.q_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, S, _ = hidden_states.shape + shape = (B, S, -1, self.head_dim) + cos, sin = position_embeddings + + q = self.q_proj(hidden_states).view(shape) + q = self.q_norm(q) + q = _ref_apply_rotary(q, cos, sin, udim=2) + q = q.transpose(1, 2) # -> [B, num_heads, S, head_dim] + + k = self.k_proj(hidden_states).view(shape) + v = self.v_proj(hidden_states).view(shape) if self.v_proj is not None else k + k = self.k_norm(k) + k = _ref_apply_rotary(k, cos, sin, udim=2) + k = k.transpose(1, 2) + v = self.v_norm(v) + v = v.transpose(1, 2) + + # Eager attention with GQA repeat + k = _ref_repeat_kv(k, self.num_kv_groups) + v = _ref_repeat_kv(v, self.num_kv_groups) + attn_w = torch.matmul(q, k.transpose(2, 3)) * self.scaling + if attention_mask is not None: + attn_w = attn_w + attention_mask + attn_w = F.softmax(attn_w, dim=-1, dtype=torch.float32).to(q.dtype) + out = torch.matmul(attn_w, v) + out = out.transpose(1, 2).contiguous().reshape(B, S, -1) + return self.o_proj(out) + + +class _RefRouter(nn.Module): + """HF Gemma4Router reference.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(config.hidden_size)) + self.register_buffer("root_size", torch.tensor(config.hidden_size**-0.5), persistent=False) + self.eps = config.rms_norm_eps + self.top_k = config.top_k_experts + + def forward(self, hidden_states: torch.Tensor): + normed = hidden_states.float() + normed = normed * torch.rsqrt(normed.pow(2).mean(-1, keepdim=True) + self.eps) + normed = normed.type_as(hidden_states) + normed = ( + normed * self.root_size.to(hidden_states.dtype) * self.scale.to(hidden_states.dtype) + ) + probs = F.softmax(self.proj(normed), dim=-1) + topk_w, topk_i = torch.topk(probs, k=self.top_k, dim=-1) + topk_w = topk_w / topk_w.sum(dim=-1, keepdim=True) + return topk_w, topk_i + + +class _RefMoEBlock(nn.Module): + """HF Gemma4MoEBlock reference with fused parameter layout.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.num_experts = config.num_experts + inter = config.expert_intermediate_size + hidden = config.hidden_size + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * inter, hidden)) + self.down_proj = nn.Parameter(torch.zeros(self.num_experts, hidden, inter)) + self.per_expert_scale = nn.Parameter(torch.ones(self.num_experts)) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, hidden_states, top_k_index, top_k_weights): + final = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = expert_mask.sum(dim=(-1, -2)).nonzero() + for eidx in expert_hit: + eidx = eidx[0] + top_k_pos, token_idx = torch.where(expert_mask[eidx]) + cur = hidden_states[token_idx] + gate, up = F.linear(cur, self.gate_up_proj[eidx]).chunk(2, dim=-1) + cur = self.act_fn(gate) * up + cur = F.linear(cur, self.down_proj[eidx]) + cur = cur * self.per_expert_scale[eidx] + cur = cur * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, cur.to(final.dtype)) + return final + + +class _RefDecoderLayer(nn.Module): + """HF Gemma4TextDecoderLayer reference (no cache/grad-ckpt).""" + + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.self_attn = _RefAttention(config, layer_idx) + self.mlp = _RefMLP(config) + self.input_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.register_buffer("layer_scalar", torch.ones(1)) + self.enable_moe_block = config.enable_moe_block + if self.enable_moe_block: + self.router = _RefRouter(config) + self.moe = _RefMoEBlock(config) + self.post_feedforward_layernorm_1 = _RefRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = _RefRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = _RefRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward(self, hidden_states, position_embeddings, attention_mask=None): + residual = hidden_states + h = self.input_layernorm(hidden_states) + h = self.self_attn(h, position_embeddings, attention_mask=attention_mask) + h = self.post_attention_layernorm(h) + hidden_states = residual + h + + residual = hidden_states + if self.enable_moe_block: + h1 = self.pre_feedforward_layernorm(hidden_states) + h1 = self.mlp(h1) + h1 = self.post_feedforward_layernorm_1(h1) + h_flat = hidden_states.reshape(-1, hidden_states.shape[-1]) + topk_w, topk_i = self.router(h_flat) + h2 = self.pre_feedforward_layernorm_2(h_flat) + h2 = self.moe(h2, topk_i, topk_w) + h2 = h2.reshape(hidden_states.shape) + h2 = self.post_feedforward_layernorm_2(h2) + hidden_states = h1 + h2 + else: + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + hidden_states = hidden_states * self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# Weight-transfer helpers +# --------------------------------------------------------------------------- + + +def _build_ref_rope(config: Gemma4TextConfig, layer_type: str, device, dtype): + """Build reference cos/sin matching AD's Gemma4RotaryEmbedding.""" + rope = Gemma4RotaryEmbedding(config, layer_type).to(device) + return rope + + +def _load_ref_into_ad(ad_module: nn.Module, ref_module: nn.Module): + """Load reference state_dict into AD module (hooks handle weight conversion).""" + missing, unexpected = ad_module.load_state_dict(ref_module.state_dict(), strict=False) + # v_norm buffer (non-persistent) won't be in state_dict, that's expected + allowed_missing = {"v_norm.weight"} + real_missing = {k for k in missing if not any(k.endswith(s) for s in allowed_missing)} + assert not real_missing, f"Unexpected missing keys: {real_missing}" + assert not unexpected, f"Unexpected keys: {unexpected}" + + +# --------------------------------------------------------------------------- +# Tests — Block equivalence +# --------------------------------------------------------------------------- + + +def test_mlp_equivalence(): + """MLP block: identical math, should match exactly.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + + ref = _RefMLP(config).to(device=device, dtype=dtype).eval() + ad = Gemma4TextMLP(config).to(device=device, dtype=dtype).eval() + ad.load_state_dict(ref.state_dict()) + + x = torch.randn(2, 8, config.hidden_size, device=device, dtype=dtype) + with torch.no_grad(): + torch.testing.assert_close(ad(x), ref(x), rtol=1e-3, atol=1e-3) + + +def test_attention_sliding_equivalence(): + """Sliding attention (standard GQA) matches reference.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + layer_idx = 0 # sliding + + ref = _RefAttention(config, layer_idx).to(device=device, dtype=dtype).eval() + ad = Gemma4TextAttention(config, layer_idx).to(device=device, dtype=dtype).eval() + _load_ref_into_ad(ad, ref) + + B, S = 2, 8 + x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + pos_ids = _position_ids(B, S, device) + rope = _build_ref_rope(config, "sliding_attention", device, dtype) + cos, sin = rope(x, pos_ids) + + # Build causal mask for reference (AD uses is_causal=True internally) + causal_mask = torch.triu( + torch.full((S, S), float("-inf"), device=device, dtype=dtype), diagonal=1 + ) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + with torch.no_grad(): + ad_out = ad(x, (cos, sin)) + ref_out = ref(x, (cos, sin), attention_mask=causal_mask) + assert_rmse_close(ad_out, ref_out, rmse_ratio_tol=0.10, msg="Sliding attention: ") + + +def test_attention_full_k_eq_v_equivalence(): + """Full attention with K=V and different head_dim matches reference.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + layer_idx = 2 # full_attention + + ref = _RefAttention(config, layer_idx).to(device=device, dtype=dtype).eval() + ad = Gemma4TextAttention(config, layer_idx).to(device=device, dtype=dtype).eval() + _load_ref_into_ad(ad, ref) + + B, S = 2, 8 + x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + pos_ids = _position_ids(B, S, device) + rope = _build_ref_rope(config, "full_attention", device, dtype) + cos, sin = rope(x, pos_ids) + + causal_mask = torch.triu( + torch.full((S, S), float("-inf"), device=device, dtype=dtype), diagonal=1 + ) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + with torch.no_grad(): + ad_out = ad(x, (cos, sin)) + ref_out = ref(x, (cos, sin), attention_mask=causal_mask) + assert_rmse_close(ad_out, ref_out, rmse_ratio_tol=0.10, msg="Full K=V attention: ") + + +def test_moe_block_equivalence(): + """MoE block (router + experts) matches reference with fused weight conversion.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + + ref_router = _RefRouter(config).to(device=device, dtype=dtype).eval() + ref_moe = _RefMoEBlock(config).to(device=device, dtype=dtype).eval() + # Initialize MoE fused params with random values (default is zeros → all-zero output) + nn.init.normal_(ref_moe.gate_up_proj, std=0.02) + nn.init.normal_(ref_moe.down_proj, std=0.02) + nn.init.uniform_(ref_moe.per_expert_scale, 0.5, 1.5) + + ad_router = Gemma4Router(config).to(device=device, dtype=dtype).eval() + ad_moe = Gemma4MoEBlock(config).to(device=device, dtype=dtype).eval() + + # Load router weights (same structure) + ad_router.load_state_dict(ref_router.state_dict()) + # Load MoE weights (hook unfuses gate_up_proj + folds per_expert_scale) + ad_moe.load_state_dict(ref_moe.state_dict(), strict=False) + + T = 16 # num tokens (flattened B*S) + x = torch.randn(T, config.hidden_size, device=device, dtype=dtype) + + with torch.no_grad(): + ref_w, ref_i = ref_router(x) + ad_w, ad_i = ad_router(x) + # Router outputs should match exactly (same math, no custom ops) + torch.testing.assert_close(ad_w, ref_w, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ad_i, ref_i) + + ref_out = ref_moe(x, ref_i, ref_w) + ad_out = ad_moe(x, ad_i, ad_w) + + assert_rmse_close(ad_out, ref_out, rmse_ratio_tol=0.02, msg="MoE block: ") + + +# --------------------------------------------------------------------------- +# Tests — Layer equivalence +# --------------------------------------------------------------------------- + + +def test_decoder_layer_equivalence(): + """Decoder layer (sliding + full) matches reference.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + + for layer_idx in [0, 2]: + layer_type = config.layer_types[layer_idx] + ref = _RefDecoderLayer(config, layer_idx).to(device=device, dtype=dtype).eval() + ad = Gemma4TextDecoderLayer(config, layer_idx).to(device=device, dtype=dtype).eval() + _load_ref_into_ad(ad, ref) + + B, S = 2, 8 + x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + pos_ids = _position_ids(B, S, device) + rope = _build_ref_rope(config, layer_type, device, dtype) + cos, sin = rope(x, pos_ids) + + causal_mask = ( + torch.triu(torch.full((S, S), float("-inf"), device=device, dtype=dtype), diagonal=1) + .unsqueeze(0) + .unsqueeze(0) + ) + + with torch.no_grad(): + ad_out = ad(x, (cos, sin)) + ref_out = ref(x, (cos, sin), attention_mask=causal_mask) + assert_rmse_close( + ad_out, ref_out, rmse_ratio_tol=0.05, msg=f"Layer {layer_idx} ({layer_type}): " + ) + + +# --------------------------------------------------------------------------- +# Tests — Full model equivalence +# --------------------------------------------------------------------------- + + +def test_full_model_equivalence(): + """Full CausalLM logits match layer-by-layer reference with shared weights. + + We verify this by comparing two AD ForCausalLM models with identical weights. + One is run normally; the other's output is verified through layer-by-layer + reference comparison (already tested above). Here we confirm that the + end-to-end model produces finite, deterministic logits with correct shape, + and that two forward passes with the same input produce identical output. + """ + device, dtype = _device_and_dtype() + config = _small_text_config() + + ad = Gemma4ForCausalLM(config).to(device=device, dtype=dtype).eval() + + B, S = 2, 8 + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + pos_ids = _position_ids(B, S, device) + + with torch.no_grad(): + out1 = ad(input_ids=input_ids, position_ids=pos_ids) + out2 = ad(input_ids=input_ids, position_ids=pos_ids) + + assert out1.logits.shape == (B, S, config.vocab_size) + assert torch.isfinite(out1.logits).all() + # Determinism: two identical passes must produce identical logits + torch.testing.assert_close(out1.logits, out2.logits) + + +def test_conditional_generation_wrapper(): + """ConditionalGeneration wrapper loads and forwards correctly.""" + device, dtype = _device_and_dtype() + config = Gemma4Config( + text_config=_small_text_config(), + vision_config=Gemma4VisionConfig(hidden_size=32), + ) + model = Gemma4ForConditionalGeneration(config).to(device=device, dtype=dtype).eval() + + B, S = 2, 8 + input_ids = torch.randint(0, config.text_config.vocab_size, (B, S), device=device) + pos_ids = _position_ids(B, S, device) + + with torch.no_grad(): + out = model(input_ids=input_ids, position_ids=pos_ids) + assert out.logits is not None + assert out.logits.shape == (B, S, config.text_config.vocab_size) + assert torch.isfinite(out.logits).all() + + +# --------------------------------------------------------------------------- +# Tests — Export +# --------------------------------------------------------------------------- + + +def test_export(): + """Model can be exported with torch.export and produces correct output.""" + device = "cpu" + dtype = torch.float32 + config = _small_text_config() + config.enable_moe_block = False # MoE expert dispatch uses data-dependent ops + + model = Gemma4ForCausalLM(config).to(device=device, dtype=dtype).eval() + + B, S = 2, 8 + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + pos_ids = _position_ids(B, S, device) + + batch_dim = Dim("batch", min=1, max=4) + seq_dim = Dim("seq", min=1, max=64) + dynamic_shapes = { + "input_ids": {0: batch_dim, 1: seq_dim}, + "position_ids": {0: batch_dim, 1: seq_dim}, + } + + gm = torch_export_to_gm( + model, + args=(input_ids,), + kwargs={"position_ids": pos_ids}, + dynamic_shapes=dynamic_shapes, + ) + + with torch.no_grad(): + exported_out = gm(input_ids, position_ids=pos_ids) + + logits = ( + exported_out[0] + if isinstance(exported_out, tuple) + else getattr(exported_out, "logits", exported_out) + ) + assert torch.isfinite(logits).all(), "Export produced non-finite values" + + # Test different shape + B2, S2 = 1, 4 + ids2 = torch.randint(0, config.vocab_size, (B2, S2), device=device) + pos2 = _position_ids(B2, S2, device) + with torch.no_grad(): + out2 = gm(ids2, position_ids=pos2) + logits2 = out2[0] if isinstance(out2, tuple) else getattr(out2, "logits", out2) + assert logits2.shape == (B2, S2, config.vocab_size) + assert torch.isfinite(logits2).all() From e7c06ff6ce216ed21ad5aedfef2d24594e4ca346 Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:26:50 -0700 Subject: [PATCH 03/16] [None][fix] wire gemma4 custom tokenizer and factory Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../model_registry/configs/gemma4_moe.yaml | 2 + .../models/custom/modeling_gemma4.py | 95 ++++++++++++++++++- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml index 807c2b2cbb0..2f05ffe0c31 100644 --- a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml @@ -4,4 +4,6 @@ # Gemma 4 MoE (26B total, 4B activated) — text-only AD export path. # Uses triton paged attention backend: supports head_dim=512 (global_head_dim), # paged KV cache, CUDA-graph-compatible, FlashDecoding for decode. +model_factory: Gemma4ForConditionalGeneration +tokenizer: google/gemma-4-31B-it attn_backend: triton_paged diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py index ad8d6a12f3d..ec847eda3f0 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py @@ -32,19 +32,23 @@ - Final logit softcapping """ +import json from dataclasses import dataclass -from typing import Optional, Tuple +from pathlib import Path +from typing import Any, Optional, Tuple import torch import torch.nn.functional as F +from tokenizers import Tokenizer from torch import nn -from transformers import AutoConfig, PretrainedConfig +from transformers import AutoConfig, PretrainedConfig, PreTrainedTokenizerFast from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ModelOutput +from transformers.utils import ModelOutput, cached_file +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory from tensorrt_llm._torch.utils import ActivationType @@ -882,11 +886,94 @@ def forward( return Gemma4ConditionalOutput(logits=logits) +# --------------------------------------------------------------------------- +# Wrapper tokenizer for Gemma 4 +# +# The upstream HF checkpoint ships ``extra_special_tokens`` as a *list* in +# tokenizer_config.json, which is incompatible with transformers <5.3. +# This thin wrapper loads the tokenizer assets directly, bypassing the +# problematic codepath. +# --------------------------------------------------------------------------- + +_TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +_CHAT_TEMPLATE_FILE = "chat_template.jinja" +_TOKENIZER_FILE = "tokenizer.json" + + +class ADGemma4Tokenizer(PreTrainedTokenizerFast): + """Wrapper that loads the upstream Gemma 4 tokenizer on current transformers.""" + + vocab_files_names = {"tokenizer_file": _TOKENIZER_FILE} + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = None + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + *inputs, + **kwargs, + ) -> "ADGemma4Tokenizer": + del inputs + for k in ("_from_auto", "_commit_hash", "trust_remote_code"): + kwargs.pop(k, None) + + config_path = cached_file(pretrained_model_name_or_path, _TOKENIZER_CONFIG_FILE, **kwargs) + assert config_path is not None + config = json.loads(Path(config_path).read_text()) + + tokenizer_file = cached_file(pretrained_model_name_or_path, _TOKENIZER_FILE, **kwargs) + assert tokenizer_file is not None + + # ``extra_special_tokens`` is a list in the upstream config; map it to + # the standard ``additional_special_tokens`` field. + extra = config.get("extra_special_tokens", []) + if isinstance(extra, list): + additional = extra + else: + additional = list(extra.keys()) if isinstance(extra, dict) else [] + + tokenizer = cls( + tokenizer_object=Tokenizer.from_file(tokenizer_file), + name_or_path=str(pretrained_model_name_or_path), + bos_token=config.get("bos_token"), + eos_token=config.get("eos_token"), + unk_token=config.get("unk_token"), + pad_token=config.get("pad_token"), + additional_special_tokens=additional, + clean_up_tokenization_spaces=config.get("clean_up_tokenization_spaces", False), + model_max_length=config.get("model_max_length"), + padding_side=config.get("padding_side", "left"), + truncation_side=config.get("truncation_side", "left"), + ) + + template_path = cached_file( + pretrained_model_name_or_path, + _CHAT_TEMPLATE_FILE, + _raise_exceptions_for_missing_entries=False, + **kwargs, + ) + if template_path is not None: + tokenizer.chat_template = Path(template_path).read_text() + + return tokenizer + + +@ModelFactoryRegistry.register("Gemma4ForConditionalGeneration") +class Gemma4ForConditionalGenerationFactory(AutoModelForCausalLMFactory): + """Factory that wires the wrapper tokenizer for Gemma 4.""" + + def init_tokenizer(self) -> Optional[Any]: + if self.tokenizer is None: + return None + return ADGemma4Tokenizer.from_pretrained(self.tokenizer) + + # --------------------------------------------------------------------------- # Registration # --------------------------------------------------------------------------- AutoModelForCausalLMFactory.register_custom_model_cls("Gemma4TextConfig", Gemma4ForCausalLM) -AutoModelForCausalLMFactory.register_custom_model_cls( +Gemma4ForConditionalGenerationFactory.register_custom_model_cls( "Gemma4Config", Gemma4ForConditionalGeneration ) From aecd7c6e986cfd173ebbe3a1c923814189b1e000 Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:32:26 -0700 Subject: [PATCH 04/16] [None][feat] add gemma4 registry entry and triton layout fix Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- examples/auto_deploy/model_registry/models.yaml | 3 +++ .../custom_ops/attention/triton_paged_attention.py | 8 +------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index 6e65ebee22d..f860f754b05 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -308,6 +308,9 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'multimodal.yaml'] - name: google/gemma-3n-E4B-it yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] +# --- Gemma 4 (2026) - MoE with K=V attention --- +- name: google/gemma-4-31B-it + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'gemma4_moe.yaml'] # --- JetBrains Mellum (Apr 2025) - code specialist --- - name: JetBrains/Mellum-4b-sft-all yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py index b33e70405c2..208ca3d8afb 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py @@ -1151,13 +1151,7 @@ def get_cache_initializers( @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + (layout,) = extract_op_args(source_attn_node, "layout") if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " From b11f2d451be28704db673c5f0ea903870c22b917 Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:57:09 -0700 Subject: [PATCH 05/16] [None][fix] support gemma4 a4b e2e config and moe gelu Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../model_registry/configs/gemma4_moe.yaml | 2 +- .../auto_deploy/model_registry/models.yaml | 2 +- .../custom_ops/fused_moe/trtllm_moe.py | 31 +++++++++++++------ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml index 2f05ffe0c31..4ef2f113efe 100644 --- a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml @@ -5,5 +5,5 @@ # Uses triton paged attention backend: supports head_dim=512 (global_head_dim), # paged KV cache, CUDA-graph-compatible, FlashDecoding for decode. model_factory: Gemma4ForConditionalGeneration -tokenizer: google/gemma-4-31B-it +tokenizer: google/gemma-4-26B-A4B-it attn_backend: triton_paged diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index f860f754b05..7a55536f6d6 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -309,7 +309,7 @@ models: - name: google/gemma-3n-E4B-it yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] # --- Gemma 4 (2026) - MoE with K=V attention --- -- name: google/gemma-4-31B-it +- name: google/gemma-4-26B-A4B-it yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'gemma4_moe.yaml'] # --- JetBrains Mellum (Apr 2025) - code specialist --- - name: JetBrains/Mellum-4b-sft-all diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 39d1359dece..d108eb03abf 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -401,12 +401,15 @@ def trtllm_moe_fused( activation_type = ActivationType.Swiglu if is_gated_mlp: - # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) + # Gated MLP accepts either SiLU/SwiGLU or GELU/GEGLU style gating. if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: activation_type = ActivationType.Swiglu + elif act_fn in [ActivationType.Gelu, ActivationType.Geglu]: + activation_type = ActivationType.Geglu else: raise ValueError( - f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." + f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. " + "Use 'silu' or 'gelu'." ) else: # For non-gated MLP with ReLU^2 @@ -466,14 +469,24 @@ def trtllm_moe_fused_fake( def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None: - assert (is_gated_mlp and act_fn in [ActivationType.Silu, ActivationType.Swiglu]) or ( - not is_gated_mlp and act_fn in [ActivationType.Relu2, ActivationType.Silu] - ), ( + assert ( + is_gated_mlp + and act_fn + in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Gelu, ActivationType.Geglu] + ) or (not is_gated_mlp and act_fn in [ActivationType.Relu2, ActivationType.Silu]), ( f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'. " - f"Supported combinations: gated mlp with silu or mlp with relu2 or silu." + f"Supported combinations: gated mlp with silu or gelu, or mlp with relu2 or silu." ) +def _normalize_trtllm_act_fn(act_fn: int) -> int: + if act_fn == ActivationType.Silu: + return ActivationType.Swiglu + if act_fn == ActivationType.Gelu: + return ActivationType.Geglu + return act_fn + + @torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused", mutates_args=()) def trtllm_quant_fp8_moe_fused( x: torch.Tensor, @@ -521,7 +534,7 @@ def trtllm_quant_fp8_moe_fused( """ _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) - act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn + act_fn = _normalize_trtllm_act_fn(act_fn) # Store original shape and flatten to 2D x_shape = x.shape @@ -663,7 +676,7 @@ def trtllm_quant_nvfp4_moe_fused( assert fc2_weight_blockscale_fp8.ndim == 3, "fc2_weight_blockscale_fp8 must be 3D" _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) - act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn + act_fn = _normalize_trtllm_act_fn(act_fn) # quant_scales is described by this code: # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 @@ -798,7 +811,7 @@ def trtllm_quant_finegrained_fp8_moe_fused( Output tensor of shape (B, H) or (B, S, H) """ _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) - act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn + act_fn = _normalize_trtllm_act_fn(act_fn) x_shape = x.shape x2d = x.view(-1, x_shape[-1]) From 7690df4763955ec81bc224879225660b84eee18b Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:31:17 -0700 Subject: [PATCH 06/16] fix weight loading Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../models/custom/modeling_gemma4.py | 64 +++++++++++++------ 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py index ec847eda3f0..f4b71ccce9a 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py @@ -391,27 +391,6 @@ def __init__(self, config: Gemma4TextConfig): for _ in range(config.num_experts) ] ) - self._register_load_state_dict_pre_hook(self._unfuse_moe_weights) - - def _unfuse_moe_weights(self, state_dict, prefix, *_args, **_kwargs): - """Convert fused checkpoint MoE weights to per-expert format.""" - gate_up_key = prefix + "gate_up_proj" - down_key = prefix + "down_proj" - scale_key = prefix + "per_expert_scale" - - if gate_up_key not in state_dict: - return - - gate_up = state_dict.pop(gate_up_key) # [E, 2*I, H] - down = state_dict.pop(down_key) # [E, H, I] - scale = state_dict.pop(scale_key) # [E] - - inter = self.intermediate_size - for e in range(self.num_experts): - state_dict[f"{prefix}experts.{e}.gate_proj.weight"] = gate_up[e, :inter, :] - state_dict[f"{prefix}experts.{e}.up_proj.weight"] = gate_up[e, inter:, :] - # Fold per_expert_scale into down_proj - state_dict[f"{prefix}experts.{e}.down_proj.weight"] = down[e] * scale[e] def forward( self, @@ -528,6 +507,8 @@ class Gemma4TextDecoderLayer(nn.Module): def __init__(self, config: Gemma4TextConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx + self.num_experts = config.num_experts + self.expert_intermediate_size = config.expert_intermediate_size self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma4TextAttention(config, layer_idx) self.mlp = Gemma4TextMLP(config) @@ -541,6 +522,7 @@ def __init__(self, config: Gemma4TextConfig, layer_idx: int): if self.enable_moe_block: self.router = Gemma4Router(config) self.moe = Gemma4MoEBlock(config) + self._register_load_state_dict_pre_hook(self._unfuse_moe_weights) self.post_feedforward_layernorm_1 = Gemma4RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) @@ -551,6 +533,46 @@ def __init__(self, config: Gemma4TextConfig, layer_idx: int): config.hidden_size, eps=config.rms_norm_eps ) + def _unfuse_moe_weights(self, state_dict, prefix, *_args, **_kwargs): + """Convert layer-level fused Gemma4 MoE checkpoint weights to per-expert weights.""" + candidates = [ + ( + prefix + "experts.gate_up_proj", + prefix + "experts.down_proj", + prefix + "router.per_expert_scale", + ), + ( + prefix + "moe.gate_up_proj", + prefix + "moe.down_proj", + prefix + "moe.per_expert_scale", + ), + ] + + gate_up_key = down_key = scale_key = None + for gate_up_candidate, down_candidate, scale_candidate in candidates: + if ( + gate_up_candidate in state_dict + and down_candidate in state_dict + and scale_candidate in state_dict + ): + gate_up_key = gate_up_candidate + down_key = down_candidate + scale_key = scale_candidate + break + + if gate_up_key is None or down_key is None or scale_key is None: + return + + gate_up = state_dict.pop(gate_up_key) # [E, 2*I, H] + down = state_dict.pop(down_key) # [E, H, I] + scale = state_dict.pop(scale_key) # [E] + + inter = self.expert_intermediate_size + for e in range(self.num_experts): + state_dict[f"{prefix}moe.experts.{e}.gate_proj.weight"] = gate_up[e, :inter, :] + state_dict[f"{prefix}moe.experts.{e}.up_proj.weight"] = gate_up[e, inter:, :] + state_dict[f"{prefix}moe.experts.{e}.down_proj.weight"] = down[e] * scale[e] + def forward( self, hidden_states: torch.Tensor, From 4abc677cb2f460d4ed9b6642dfe7c2c7e63d72d1 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:48:24 -0700 Subject: [PATCH 07/16] [None][fix] fix gemma4 RMSNorm scale_shift, add base model registry and tests - Remove incorrect +1.0 scale_shift from Gemma4RMSNorm. HF transformers 5.5.0 stores effective norm weights directly in the checkpoint; the previous implementation incorrectly added 1.0 at load time, causing compounding numerical drift across layers and garbled generation. - Add google/gemma-4-26B-A4B base model registry entry with gemma4_moe_base.yaml config. - Strengthen test_full_model_equivalence with end-to-end logits comparison against standalone reference model. - Add export functional equivalence assertion (pre-export vs post-export). - Update reference _RefRMSNorm to match corrected norm semantics. - Update MoE block test to manually unfuse weights (hook now on decoder layer, not MoE block). Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../configs/gemma4_moe_base.yaml | 9 ++ .../auto_deploy/model_registry/models.yaml | 2 + .../models/custom/modeling_gemma4.py | 16 +-- .../singlegpu/models/test_gemma4_modeling.py | 113 ++++++++++++++---- 4 files changed, 105 insertions(+), 35 deletions(-) create mode 100644 examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml new file mode 100644 index 00000000000..dd4a5f87fcb --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Gemma 4 MoE base (26B total, 4B activated) — text-only AD export path. +# Uses triton paged attention backend: supports head_dim=512 (global_head_dim), +# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode. +model_factory: Gemma4ForConditionalGeneration +tokenizer: google/gemma-4-26B-A4B +attn_backend: triton_paged diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index 7a55536f6d6..921e1223651 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -309,6 +309,8 @@ models: - name: google/gemma-3n-E4B-it yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] # --- Gemma 4 (2026) - MoE with K=V attention --- +- name: google/gemma-4-26B-A4B + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'gemma4_moe_base.yaml'] - name: google/gemma-4-26B-A4B-it yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'gemma4_moe.yaml'] # --- JetBrains Mellum (Apr 2025) - code specialist --- diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py index f4b71ccce9a..dad6947a4d8 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py @@ -250,10 +250,10 @@ def _build_rope_cache( class Gemma4RMSNorm(nn.Module): - """RMSNorm with Gemma4-style (weight + scale_shift) semantics. + """RMSNorm matching HF Gemma4 (transformers >= 5.5). - For AD export, we store the *effective* weight = checkpoint_weight + scale_shift - via a load hook, then use the standard torch_rmsnorm op. + The checkpoint stores effective weights directly — no ``+1.0`` offset. + Uses the ``torch_rmsnorm`` canonical op for AD transform compatibility. """ def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): @@ -264,16 +264,6 @@ def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): self.weight = nn.Parameter(torch.ones(dim)) else: self.register_buffer("weight", torch.ones(dim), persistent=False) - if with_scale: - self._register_load_state_dict_pre_hook(self._apply_scale_shift) - - @staticmethod - def _apply_scale_shift(state_dict, prefix, *_args, **_kwargs): - """Gemma4 RMSNorm stores weight that is used as (weight + 1.0). - Convert to effective weight at load time so torch_rmsnorm works directly.""" - key = prefix + "weight" - if key in state_dict: - state_dict[key] = state_dict[key] + 1.0 def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.auto_deploy.torch_rmsnorm(x, self.weight, self.eps) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py index 689fe2f9cd9..bb0b5566b2c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py @@ -124,14 +124,11 @@ def _set_seed(): class _RefRMSNorm(nn.Module): - """HF Gemma4RMSNorm: norm(x) * (weight + scale_shift).""" + """HF Gemma4RMSNorm (transformers>=5.5): norm(x) * weight.""" - def __init__( - self, dim: int, eps: float = 1e-6, scale_shift: float = 1.0, with_scale: bool = True - ): + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): super().__init__() self.eps = eps - self.scale_shift = scale_shift self.with_scale = with_scale if with_scale: self.weight = nn.Parameter(torch.ones(dim)) @@ -141,7 +138,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + self.eps, -0.5) if self.weight is not None: - normed = normed * (self.weight.float() + self.scale_shift) + normed = normed * self.weight.float() return normed.type_as(x) @@ -467,8 +464,19 @@ def test_moe_block_equivalence(): # Load router weights (same structure) ad_router.load_state_dict(ref_router.state_dict()) - # Load MoE weights (hook unfuses gate_up_proj + folds per_expert_scale) - ad_moe.load_state_dict(ref_moe.state_dict(), strict=False) + # Manually unfuse ref MoE fused weights into per-expert format + # (The unfusing hook is on the decoder layer, not the MoE block) + ref_sd = ref_moe.state_dict() + gate_up = ref_sd["gate_up_proj"] # [E, 2*I, H] + down = ref_sd["down_proj"] # [E, H, I] + scale = ref_sd["per_expert_scale"] # [E] + inter = config.expert_intermediate_size + ad_sd = {} + for e in range(config.num_experts): + ad_sd[f"experts.{e}.gate_proj.weight"] = gate_up[e, :inter, :] + ad_sd[f"experts.{e}.up_proj.weight"] = gate_up[e, inter:, :] + ad_sd[f"experts.{e}.down_proj.weight"] = down[e] * scale[e] + ad_moe.load_state_dict(ad_sd) T = 16 # num tokens (flattened B*S) x = torch.randn(T, config.hidden_size, device=device, dtype=dtype) @@ -527,32 +535,90 @@ def test_decoder_layer_equivalence(): # --------------------------------------------------------------------------- +class _RefForCausalLM(nn.Module): + """Standalone reference CausalLM for full-model equivalence testing.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.embed_scale = config.hidden_size**0.5 + self.layers = nn.ModuleList( + [_RefDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Tie weights like AD model (tie_word_embeddings=True) + if config.tie_word_embeddings: + self.lm_head.weight = self.embed_tokens.weight + + def forward(self, input_ids, position_ids): + hidden_states = self.embed_tokens(input_ids) * self.embed_scale + for i, layer in enumerate(self.layers): + layer_type = self.config.layer_types[i] + rope = _build_ref_rope( + self.config, layer_type, hidden_states.device, hidden_states.dtype + ) + cos, sin = rope(hidden_states, position_ids) + causal_mask = ( + torch.triu( + torch.full( + (hidden_states.shape[1], hidden_states.shape[1]), + float("-inf"), + device=hidden_states.device, + dtype=hidden_states.dtype, + ), + diagonal=1, + ) + .unsqueeze(0) + .unsqueeze(0) + ) + hidden_states = layer(hidden_states, (cos, sin), attention_mask=causal_mask) + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + return logits + + +def _transfer_ref_to_ad_full_model(ad_model: Gemma4ForCausalLM, ref_model: _RefForCausalLM) -> None: + """Transfer weights from reference full model into AD ForCausalLM.""" + ref_sd = ref_model.state_dict() + ad_sd = {} + for k, v in ref_sd.items(): + if k.startswith("lm_head."): + ad_sd[k] = v + else: + ad_sd[f"model.{k}"] = v + missing, unexpected = ad_model.load_state_dict(ad_sd, strict=False) + # v_norm buffers are non-persistent, expected missing + real_missing = {m for m in missing if "v_norm" not in m} + assert not real_missing, f"Missing keys: {real_missing}" + assert not unexpected, f"Unexpected keys: {unexpected}" + + def test_full_model_equivalence(): - """Full CausalLM logits match layer-by-layer reference with shared weights. - - We verify this by comparing two AD ForCausalLM models with identical weights. - One is run normally; the other's output is verified through layer-by-layer - reference comparison (already tested above). Here we confirm that the - end-to-end model produces finite, deterministic logits with correct shape, - and that two forward passes with the same input produce identical output. - """ + """Full CausalLM logits match standalone reference with shared weights.""" device, dtype = _device_and_dtype() config = _small_text_config() + ref = _RefForCausalLM(config).to(device=device, dtype=dtype).eval() ad = Gemma4ForCausalLM(config).to(device=device, dtype=dtype).eval() + _transfer_ref_to_ad_full_model(ad, ref) B, S = 2, 8 input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) pos_ids = _position_ids(B, S, device) with torch.no_grad(): - out1 = ad(input_ids=input_ids, position_ids=pos_ids) - out2 = ad(input_ids=input_ids, position_ids=pos_ids) + ref_logits = ref(input_ids, pos_ids) + ad_out = ad(input_ids=input_ids, position_ids=pos_ids) - assert out1.logits.shape == (B, S, config.vocab_size) - assert torch.isfinite(out1.logits).all() - # Determinism: two identical passes must produce identical logits - torch.testing.assert_close(out1.logits, out2.logits) + assert ad_out.logits.shape == (B, S, config.vocab_size) + assert torch.isfinite(ad_out.logits).all() + assert_rmse_close(ad_out.logits, ref_logits, rmse_ratio_tol=0.05, msg="Full model: ") def test_conditional_generation_wrapper(): @@ -608,6 +674,7 @@ def test_export(): ) with torch.no_grad(): + pre_export_out = model(input_ids=input_ids, position_ids=pos_ids) exported_out = gm(input_ids, position_ids=pos_ids) logits = ( @@ -616,6 +683,8 @@ def test_export(): else getattr(exported_out, "logits", exported_out) ) assert torch.isfinite(logits).all(), "Export produced non-finite values" + # Exported graph should produce identical output to the original model + torch.testing.assert_close(logits, pre_export_out.logits, rtol=1e-3, atol=1e-3) # Test different shape B2, S2 = 1, 4 From 077adf13974d6ab53cbe62cb3e15056b9e0a4f26 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Thu, 2 Apr 2026 21:39:26 -0700 Subject: [PATCH 08/16] [None][feat] update gemma4 configs with piecewise cudagraphs and chunked prefill Add piecewise CUDA graph compilation, expanded batch sizes, chunked prefill, and KV cache config to both gemma4_moe.yaml and gemma4_moe_base.yaml. Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../model_registry/configs/gemma4_moe.yaml | 12 ++++++++++++ .../model_registry/configs/gemma4_moe_base.yaml | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml index 4ef2f113efe..6a6b5967c3e 100644 --- a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml @@ -7,3 +7,15 @@ model_factory: Gemma4ForConditionalGeneration tokenizer: google/gemma-4-26B-A4B-it attn_backend: triton_paged +compile_backend: torch-cudagraph +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +max_num_tokens: 8192 +max_batch_size: 512 +max_seq_len: 8192 +enable_chunked_prefill: true +kv_cache_config: + enable_block_reuse: false + free_gpu_memory_fraction: 0.8 +transforms: + compile_model: + piecewise_enabled: true diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml index dd4a5f87fcb..abfc0a02078 100644 --- a/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml @@ -7,3 +7,15 @@ model_factory: Gemma4ForConditionalGeneration tokenizer: google/gemma-4-26B-A4B attn_backend: triton_paged +compile_backend: torch-cudagraph +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +max_num_tokens: 8192 +max_batch_size: 512 +max_seq_len: 8192 +enable_chunked_prefill: true +kv_cache_config: + enable_block_reuse: false + free_gpu_memory_fraction: 0.8 +transforms: + compile_model: + piecewise_enabled: true From b73c8ebdb139a749375de47492f46134507be46d Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 22:33:00 -0700 Subject: [PATCH 09/16] add cookbook and update support matrix Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- docs/source/models/supported-models.md | 2 + .../cookbooks/gemma_4_trtllm_cookbook.ipynb | 299 ++++++++++++++++++ 2 files changed, 301 insertions(+) create mode 100644 examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb diff --git a/docs/source/models/supported-models.md b/docs/source/models/supported-models.md index 0492533a798..283eca924eb 100644 --- a/docs/source/models/supported-models.md +++ b/docs/source/models/supported-models.md @@ -13,6 +13,7 @@ The following is a table of supported models for the PyTorch backend: | `Exaone4ForCausalLM` | EXAONE 4.0 | `LGAI-EXAONE/EXAONE-4.0-32B` | | `ExaoneMoEForCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B` | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` | +| `Gemma4ForConditionalGeneration` [^7]| Gemma 4 | `google/gemma-4-26B-A4B-it` | | `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6, GLM-4.7 | `THUDM/GLM-4-100B-A10B` | | `Glm4MoeLiteForCausalLM` [^6] | GLM-4.7-Flash | `zai-org/GLM-4.7-Flash` | | `GlmMoeDsaForCausalLM` | GLM-5 | `zai-org/GLM-5` | @@ -60,6 +61,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl [^4]: Overlap scheduler isn't supported when using EAGLE-3(Two Model Engine) for GPT-OSS. [^5]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml). [^6]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml). +[^7]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma4_moe.yaml). # Multimodal Feature Support Matrix (PyTorch Backend) diff --git a/examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb b/examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb new file mode 100644 index 00000000000..4286cf65a1b --- /dev/null +++ b/examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Deploying Gemma 4 MoE with TensorRT-LLM (AutoDeploy)\n", + "\n", + "This notebook walks you through serving **Gemma 4** (26B total, 4B activated MoE) with TensorRT-LLM using the **AutoDeploy** backend—same pattern as the Mistral Small 4 and GLM-4.7-Flash cookbooks in this folder.\n", + "\n", + "[TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/) is NVIDIA's open-source library for accelerating LLM inference on NVIDIA GPUs. AutoDeploy uses Hugging Face `transformers` modeling code and TensorRT-LLM graph transforms. See the [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).\n", + "\n", + "**Model resources:**\n", + "- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n", + "- Instruction-tuned MoE: [`google/gemma-4-26B-A4B-it`](https://huggingface.co/google/gemma-4-26B-A4B-it)\n", + "- Base MoE (no chat template): [`google/gemma-4-26B-A4B`](https://huggingface.co/google/gemma-4-26B-A4B)\n", + "\n", + "**Bundled AutoDeploy YAML (this branch):**\n", + "- **Instruction:** `examples/auto_deploy/model_registry/configs/gemma4_moe.yaml` — text-only export path; `attn_backend: triton_paged` (head_dim 512 / paged KV, CUDA-graph friendly).\n", + "- **Base:** `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` — same stack for the base checkpoint.\n", + "\n", + "`trtllm-serve` takes **one** YAML path via `--extra_llm_api_options` (or `--config`). The bundled MoE YAMLs omit `world_size`; add it (or copy the YAML and edit) so it matches your GPU count when you use tensor parallel or multi-GPU loading.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites and environment\n", + "\n", + "Run TensorRT-LLM in a GPU container, for example:\n", + "\n", + "```shell\n", + "docker run --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all -p 8000:8000 nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc1\n", + "```\n", + "\n", + "Use a TensorRT-LLM checkout that includes Gemma 4 AutoDeploy support (model card, tokenizer, and any required bridges should match your branch).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If pip is not available\n", + "!python -m ensurepip --default-pip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torch openai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Verify GPU\n", + "\n", + "Confirm CUDA and visible devices before starting the server.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python: 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0]\n", + "CUDA available: True\n", + "Num GPUs: 8\n", + "GPU[0]: NVIDIA H100 80GB HBM3\n", + "GPU[1]: NVIDIA H100 80GB HBM3\n", + "GPU[2]: NVIDIA H100 80GB HBM3\n", + "GPU[3]: NVIDIA H100 80GB HBM3\n", + "GPU[4]: NVIDIA H100 80GB HBM3\n", + "GPU[5]: NVIDIA H100 80GB HBM3\n", + "GPU[6]: NVIDIA H100 80GB HBM3\n", + "GPU[7]: NVIDIA H100 80GB HBM3\n" + ] + } + ], + "source": [ + "import sys\n", + "\n", + "import torch\n", + "\n", + "print(f\"Python: {sys.version}\")\n", + "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", + "print(f\"Num GPUs: {torch.cuda.device_count()}\")\n", + "\n", + "if torch.cuda.is_available():\n", + " for i in range(torch.cuda.device_count()):\n", + " print(f\"GPU[{i}]: {torch.cuda.get_device_name(i)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## OpenAI-compatible server\n", + "\n", + "From a shell **inside** the container, at the TensorRT-LLM repo root, start `trtllm-serve` with AutoDeploy.\n", + "\n", + "Use the Gemma 4 MoE YAML under `examples/auto_deploy/model_registry/configs/` (see the introduction). Add `world_size` to that YAML if your serve command needs an explicit tensor-parallel device count.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the model\n", + "\n", + "**Instruction-tuned (`gemma4_moe.yaml`):**\n", + "\n", + "```shell\n", + "trtllm-serve \"google/gemma-4-26B-A4B-it\" \\\n", + " --host 0.0.0.0 \\\n", + " --port 8000 \\\n", + " --backend _autodeploy \\\n", + " --trust_remote_code \\\n", + " --extra_llm_api_options examples/auto_deploy/model_registry/configs/gemma4_moe.yaml\n", + "```\n", + "\n", + "**Base checkpoint:** use model id `google/gemma-4-26B-A4B` and `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` instead.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When the server finishes loading weights, it is ready for requests.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the API\n", + "\n", + "Send chat completions with the OpenAI Python client pointed at the local server.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "BASE_URL = \"http://0.0.0.0:8000/v1\"\n", + "API_KEY = \"null\"\n", + "MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n", + "\n", + "client = OpenAI(base_url=BASE_URL, api_key=API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Chat completion example\n", + "==================================================\n", + "Response:\n", + "To find 15% of 85, you can use a few different methods. Here are two easy ways to think about it:\n", + "\n", + "### Method 1: The Breakdown Method (Mental Math)\n", + "This is often the easiest way to calculate percentages in your head by breaking the percentage into manageable parts (10% and 5%).\n", + "\n", + "1. **Find 10% of 85:**\n", + " To find 10%, simply move the decimal point one place to the left.\n", + " $85 \\div 10 = 8.5$\n", + "2. **Find 5% of 85:**\n", + " Since 5% is half of 10%, just divide your previous answer by 2.\n", + " $8.5 \\div 2 = 4.25$\n", + "3. **Add them together:**\n", + " $10\\% + 5\\% = 15\\%$\n", + " $8.5 + 4.25 = 12.75$\n", + "\n", + "***\n", + "\n", + "### Method 2: The Multiplication Method (Calculator/Paper)\n", + "To find a percentage, you can convert the percentage into a decimal and multiply it by the total number.\n", + "\n", + "1. **Convert 15% to a decimal:**\n", + " $15\\% = \\frac{15}{100} = 0.15$\n", + "2. **Multiply by 85:**\n", + " $85 \\times 0.15 = 12.75$\n", + "\n", + "**Final Answer:**\n", + "15% of 85 is **12.75**.\n" + ] + } + ], + "source": [ + "# Basic chat completion\n", + "print(\"Chat completion example\")\n", + "print(\"=\" * 50)\n", + "\n", + "response = client.chat.completions.create(\n", + " model=MODEL_ID,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"What is 15% of 85? Show your reasoning.\"},\n", + " ],\n", + " temperature=1.0,\n", + " top_p=0.95,\n", + " max_tokens=512,\n", + ")\n", + "\n", + "print(\"Response:\")\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Streaming response:\n", + "==================================================\n", + "The first 5 prime numbers are **2, 3, 5, 7, and 11**." + ] + } + ], + "source": [ + "# Streaming chat completion\n", + "print(\"Streaming response:\")\n", + "print(\"=\" * 50)\n", + "\n", + "stream = client.chat.completions.create(\n", + " model=MODEL_ID,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"What are the first 5 prime numbers?\"},\n", + " ],\n", + " temperature=0.7,\n", + " max_tokens=1024,\n", + " stream=True,\n", + ")\n", + "\n", + "for chunk in stream:\n", + " if chunk.choices[0].delta.content:\n", + " print(chunk.choices[0].delta.content, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Additional resources\n", + "\n", + "- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n", + "- [TensorRT-LLM documentation](https://nvidia.github.io/TensorRT-LLM/)\n", + "- [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html)\n", + "- [`gemma4_moe.yaml`](../model_registry/configs/gemma4_moe.yaml), [`gemma4_moe_base.yaml`](../model_registry/configs/gemma4_moe_base.yaml), [`models.yaml`](../model_registry/models.yaml)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 86fa54a358d3c68d74919c75d08d93d764d3e046 Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:01:17 -0700 Subject: [PATCH 10/16] added MMLU and GSM8k tests Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../defs/accuracy/test_llm_api_autodeploy.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 222f7052ed6..474a5575cf4 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -982,6 +982,41 @@ def test_nvfp4(self, ep_size, attention_dp): task.evaluate(llm) +class TestGemma4MoE(LlmapiAccuracyTestHarness): + """Bench-run coverage for Gemma4 MoE via AutoDeploy.""" + + MODEL_NAME = "google/gemma-4-26B-A4B-it" + EXTRA_EVALUATOR_KWARGS = { + "apply_chat_template": True, + } + + def get_default_sampling_params(self): + return SamplingParams(end_id=None, + pad_id=None, + n=1, + use_beam_search=False) + + @pytest.mark.skip_less_device_memory(80000) + def test_bf16(self): + yaml_paths, registry_world_size = _get_registry_yaml_extra( + self.MODEL_NAME) + if get_device_count() < registry_world_size: + pytest.skip("Not enough devices for world size, skipping test") + + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_NAME, + tokenizer=self.MODEL_NAME, + world_size=registry_world_size, + yaml_extra=yaml_paths) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, + sampling_params=sampling_params, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + + class TestModelRegistryAccuracy(LlmapiAccuracyTestHarness): """Accuracy tests for models from the AutoDeploy model registry. From 3ecceb537f3cfdd83dc696f3a2d8be38ebd71c69 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Sun, 5 Apr 2026 02:47:41 -0700 Subject: [PATCH 11/16] [None][feat] add sliding window support to triton paged attention Add sliding window attention to both decode (FlashDecoding) and context/prefill kernels. When sliding_window is set, queries only attend to the most recent W KV tokens, enabling efficient long-context inference for models with sliding window attention (e.g. Mistral). Key changes: - Decode kernel: restrict page splits to window range, apply per-token window mask, use effective sequence length for split-K heuristic - Context kernel: skip pages before window in Phase 1, add per-query sliding window mask in both Phase 1 (full pages) and Phase 2 (partial/causal pages), guard against NaN from -inf exponents - triton_paged_mha_with_cache: thread sliding_window through to both kernels, add optional pre-allocated output buffer support - Disable SDPA fast-path when sliding window is active - Extract sliding_window constant from source attention node MMLU: 75, GSM8K: 90 Signed-off-by: Suyog Gupta Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../attention/triton_paged_attention.py | 177 ++++++++++++++---- 1 file changed, 142 insertions(+), 35 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py index 208ca3d8afb..2ab575af2a6 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py @@ -209,7 +209,7 @@ def _get_num_splits(max_seq_len: int, batch_size: int, n_kv_heads: int, page_siz triton.Config({}, num_warps=8, num_stages=2), triton.Config({}, num_warps=8, num_stages=3), ], - key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED"], + key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED", "SLIDING_WINDOW"], ) @triton.jit def _flash_decode_stage1_kernel( @@ -249,6 +249,7 @@ def _flash_decode_stage1_kernel( HEAD_RATIO: tl.constexpr, HEAD_RATIO_PADDED: tl.constexpr, NUM_SPLITS: tl.constexpr, + SLIDING_WINDOW: tl.constexpr = 0, ): """ Key optimizations: @@ -266,9 +267,20 @@ def _flash_decode_stage1_kernel( num_pages = kv_page_end - kv_page_start last_page_len = tl.load(kv_last_page_len_ptr + batch_id) - # Compute this split's page range (page-aligned splits) - pages_per_split = (num_pages + NUM_SPLITS - 1) // NUM_SPLITS - page_split_start = split_id * pages_per_split + # Sliding window: restrict attention to pages within the window. + # Compute the total sequence length and the first valid KV position. + seq_len = (num_pages - 1) * PAGE_SIZE + last_page_len + if SLIDING_WINDOW > 0: + first_valid_pos = tl.maximum(0, seq_len - SLIDING_WINDOW) + first_window_page = first_valid_pos // PAGE_SIZE + else: + first_valid_pos = 0 + first_window_page = 0 + + # Only split over pages within the window + window_pages = num_pages - first_window_page + pages_per_split = (window_pages + NUM_SPLITS - 1) // NUM_SPLITS + page_split_start = first_window_page + split_id * pages_per_split page_split_end = tl.minimum(page_split_start + pages_per_split, num_pages) dhead_offsets = tl.arange(0, HEAD_DIM) @@ -346,7 +358,14 @@ def _flash_decode_stage1_kernel( # [HEAD_RATIO_PADDED, HEAD_DIM] @ [HEAD_DIM, PAGE_SIZE] -> [HEAD_RATIO_PADDED, PAGE_SIZE] attn = tl.dot(q_all, tl.trans(k)) * SM_SCALE - attn = tl.where(page_mask[None, :], attn, float("-inf")) + + # Combine validity mask with sliding window mask + if SLIDING_WINDOW > 0: + global_pos = page_idx * PAGE_SIZE + page_offsets + window_mask = global_pos >= first_valid_pos + attn = tl.where(page_mask[None, :] & window_mask[None, :], attn, float("-inf")) + else: + attn = tl.where(page_mask[None, :], attn, float("-inf")) # Online softmax update (vectorized over HEAD_RATIO_PADDED) m_ij = tl.max(attn, axis=1) # [HEAD_RATIO_PADDED] @@ -454,6 +473,7 @@ def triton_paged_decode( kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor, sm_scale: float, + sliding_window: Optional[int] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Optimized paged decode with GQA batching + FlashDecoding + page-aligned iteration. @@ -465,6 +485,7 @@ def triton_paged_decode( kv_indptr: Cumulative page counts [batch_size + 1] kv_last_page_len: Valid tokens in last page [batch_size] sm_scale: Softmax scale factor + sliding_window: If set, only attend to the last sliding_window tokens out: Optional output tensor [batch_size, n_heads, head_dim] Returns: @@ -477,13 +498,17 @@ def triton_paged_decode( max_pages = kv_indices.shape[0] max_seq_len = max_pages * page_size + # Normalize sliding_window: None/non-positive → 0 (full attention) + sw = sliding_window if isinstance(sliding_window, int) and sliding_window > 0 else 0 output = out if out is not None else torch.empty_like(q) if batch_size == 0: return output - num_splits = _get_num_splits(max_seq_len, batch_size, n_kv_heads, page_size) + # Use effective sequence length (capped by sliding window) for split-K heuristic + effective_seq_len = min(max_seq_len, sw) if sw > 0 else max_seq_len + num_splits = _get_num_splits(effective_seq_len, batch_size, n_kv_heads, page_size) # Allocate intermediate buffers for split-K partial_o = torch.empty( @@ -536,6 +561,7 @@ def triton_paged_decode( HEAD_RATIO=head_ratio, HEAD_RATIO_PADDED=head_ratio_padded, NUM_SPLITS=num_splits, + SLIDING_WINDOW=sw, ) # Stage 2: Combine partial results @@ -606,6 +632,7 @@ def _paged_context_kernel( N_KV_HEADS: tl.constexpr, HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr, + SLIDING_WINDOW: tl.constexpr = 0, ): """Context/prefill attention with paged KV cache, causal skip, and page-aligned iteration. @@ -669,6 +696,16 @@ def _paged_context_kernel( # Number of full pages (all tokens in these pages are attended by all Q tokens) num_full_pages = first_q_kv_pos // PAGE_SIZE + # Sliding window: compute the first page within the window for Phase 1 pruning. + # Each query at position q_pos attends to KV in [q_pos - W + 1, q_pos]. + # The most restrictive query is the first one (q_block_start), so: + # first_valid_pos = max(0, first_q_kv_pos - SLIDING_WINDOW + 1) + if SLIDING_WINDOW > 0: + first_valid_pos = tl.maximum(0, first_q_kv_pos - SLIDING_WINDOW + 1) + first_window_page = first_valid_pos // PAGE_SIZE + else: + first_window_page = 0 + # Check if this is a full Q block (no q_mask needed) is_full_q_block = (q_block_start + Q_BLOCK) <= q_len @@ -677,39 +714,71 @@ def _paged_context_kernel( kv_head_offset = kv_head_id * cache_stride_head local_kv = page_offsets[:, None] * cache_stride_token + dhead_offsets[None, :] - for page_idx in range(num_full_pages): + for page_idx in range(first_window_page, num_full_pages): physical_page = tl.load(kv_indices_ptr + kv_page_start + page_idx) # Use int64 to avoid overflow when physical_page * stride > 2^31 page_base = physical_page.to(tl.int64) * cache_stride_block + kv_head_offset - k_block_ptr = tl.make_block_ptr( - base=kv_cache_ptr + page_base, - shape=(PAGE_SIZE, HEAD_DIM), - strides=(cache_stride_token, 1), - offsets=(0, 0), - block_shape=(PAGE_SIZE, HEAD_DIM), - order=(1, 0), - ) - v_block_ptr = tl.make_block_ptr( - base=kv_cache_ptr + page_base + cache_stride_kv, - shape=(PAGE_SIZE, HEAD_DIM), - strides=(cache_stride_token, 1), - offsets=(0, 0), - block_shape=(PAGE_SIZE, HEAD_DIM), - order=(1, 0), - ) - k = tl.load(k_block_ptr) - v = tl.load(v_block_ptr) - qk = tl.dot(q, tl.trans(k)) * SM_SCALE + # When sliding window is active, the first window page may partially + # overlap the window boundary, requiring per-token masking. + # Use masked loads (like Phase 2) instead of block_ptr loads. + if SLIDING_WINDOW > 0: + k = tl.load( + kv_cache_ptr + page_base + local_kv, + mask=tl.full([PAGE_SIZE, HEAD_DIM], 1, tl.int1), + other=0.0, + ) + v = tl.load( + kv_cache_ptr + page_base + local_kv + cache_stride_kv, + mask=tl.full([PAGE_SIZE, HEAD_DIM], 1, tl.int1), + other=0.0, + ) + + qk = tl.dot(q, tl.trans(k)) * SM_SCALE + + # Per-query sliding window mask: each query position q_pos + # can attend to KV in [q_pos - W + 1, q_pos]. + kv_positions = page_idx * PAGE_SIZE + page_offsets[None, :] + q_kv_pos = q_offsets[:, None] + cache_len + sw_mask = (q_kv_pos - kv_positions) < SLIDING_WINDOW + full_mask_p1 = q_mask[:, None] & sw_mask + qk = tl.where(full_mask_p1, qk, float("-inf")) + else: + k_block_ptr = tl.make_block_ptr( + base=kv_cache_ptr + page_base, + shape=(PAGE_SIZE, HEAD_DIM), + strides=(cache_stride_token, 1), + offsets=(0, 0), + block_shape=(PAGE_SIZE, HEAD_DIM), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr( + base=kv_cache_ptr + page_base + cache_stride_kv, + shape=(PAGE_SIZE, HEAD_DIM), + strides=(cache_stride_token, 1), + offsets=(0, 0), + block_shape=(PAGE_SIZE, HEAD_DIM), + order=(1, 0), + ) + k = tl.load(k_block_ptr) + v = tl.load(v_block_ptr) + + qk = tl.dot(q, tl.trans(k)) * SM_SCALE - if not is_full_q_block: - qk = tl.where(q_mask[:, None], qk, float("-inf")) + if not is_full_q_block: + qk = tl.where(q_mask[:, None], qk, float("-inf")) m_ij = tl.max(qk, axis=1) m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - p = tl.exp(qk - m_i_new[:, None]) + if SLIDING_WINDOW > 0: + # Guard against NaN when m_i == m_i_new == -inf (no valid tokens seen + # yet for a query whose window doesn't overlap this page at all). + alpha = tl.where(m_i > float("-inf"), tl.exp(m_i - m_i_new), 0.0) + p = tl.where(m_i_new[:, None] > float("-inf"), tl.exp(qk - m_i_new[:, None]), 0.0) + else: + alpha = tl.exp(m_i - m_i_new) + p = tl.exp(qk - m_i_new[:, None]) acc = tl.dot(p.to(v.dtype), v, acc=acc * alpha[:, None]) l_i = l_i * alpha + tl.sum(p, axis=1) m_i = m_i_new @@ -740,13 +809,21 @@ def _paged_context_kernel( qk = tl.dot(q, tl.trans(k)) * SM_SCALE kv_positions = kv_base_pos + page_offsets[None, :] causal_mask = q_positions_2d >= kv_positions - full_mask = q_mask[:, None] & causal_mask & page_mask[None, :] + if SLIDING_WINDOW > 0: + sliding_mask = (q_positions_2d - kv_positions) < SLIDING_WINDOW + full_mask = q_mask[:, None] & causal_mask & sliding_mask & page_mask[None, :] + else: + full_mask = q_mask[:, None] & causal_mask & page_mask[None, :] qk = tl.where(full_mask, qk, float("-inf")) m_ij = tl.max(qk, axis=1) m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - p = tl.exp(qk - m_i_new[:, None]) + if SLIDING_WINDOW > 0: + alpha = tl.where(m_i > float("-inf"), tl.exp(m_i - m_i_new), 0.0) + p = tl.where(m_i_new[:, None] > float("-inf"), tl.exp(qk - m_i_new[:, None]), 0.0) + else: + alpha = tl.exp(m_i - m_i_new) + p = tl.exp(qk - m_i_new[:, None]) acc = tl.dot(p.to(v.dtype), v, acc=acc * alpha[:, None]) l_i = l_i * alpha + tl.sum(p, axis=1) m_i = m_i_new @@ -829,6 +906,7 @@ def triton_paged_context( kv_last_page_len: torch.Tensor, seq_len_with_cache: torch.Tensor, sm_scale: float, + sliding_window: Optional[int] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Context/prefill attention with paged KV cache.""" @@ -855,6 +933,9 @@ def triton_paged_context( # paged Triton kernel for shorter sequences where gather overhead dominates. # Compute max_pages from max_q_len without GPU sync # (assumes pure prefill where q_len == kv_len for each seq) + # Normalize sliding_window for kernel constexpr: None/non-positive → 0 + sw = sliding_window if isinstance(sliding_window, int) and sliding_window > 0 else 0 + max_pages = (max_q_len + page_size - 1) // page_size total_expected_pages = num_seq * max_pages use_sdpa = ( @@ -862,6 +943,7 @@ def triton_paged_context( and num_seq <= 64 and max_pages > 0 and kv_indices.shape[0] == total_expected_pages # all seqs same page count + and sw == 0 # SDPA doesn't support sliding window natively ) if use_sdpa: @@ -936,6 +1018,7 @@ def grid_paged(meta): N_KV_HEADS=n_kv_heads, HEAD_DIM=head_dim, PAGE_SIZE=page_size, + SLIDING_WINDOW=sw, ) return output @@ -1001,6 +1084,9 @@ def triton_paged_mha_with_cache( kv_cache: torch.Tensor, # CONSTANTS scale: Optional[float], + sliding_window: Optional[int] = None, + # OPTIONAL PRE-ALLOCATED OUTPUT + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Triton paged attention with mixed batch support.""" head_dim = kv_cache.shape[-1] @@ -1031,7 +1117,10 @@ def triton_paged_mha_with_cache( cu_num_pages[: num_seq + 1], ) - y = torch.empty_like(q) + if out is not None: + y = out.view(-1, q.shape[1], head_dim) + else: + y = torch.empty_like(q) # Process prefill tokens if any if num_prefill > 0: @@ -1046,6 +1135,7 @@ def triton_paged_mha_with_cache( last_page_len[:num_prefill], seq_len_with_cache, sm_scale, + sliding_window=sliding_window, out=y[:num_prefill_tokens], ) @@ -1058,9 +1148,20 @@ def triton_paged_mha_with_cache( cu_num_pages[num_prefill : num_seq + 1], last_page_len[num_prefill:num_seq], sm_scale, + sliding_window=sliding_window, out=y[num_prefill_tokens:num_total_tokens], ) + if out is not None: + # Zero stale data in padding region for CUDA graph replay stability + bs = b * s + if num_total_tokens < bs: + y[num_total_tokens:].zero_() + # Return a 0-element dummy to satisfy PyTorch's no-alias constraint. + # The caller (DynamicOpWrapper._coalesce_output) picks ``out`` over + # this dummy, so the pre-allocated buffer is used downstream. + return out.new_empty(0) + return y.view(q_shape_og) @@ -1081,7 +1182,11 @@ def triton_paged_mha_with_cache_fake( triton_positions: torch.Tensor, kv_cache: torch.Tensor, scale: Optional[float], + sliding_window: Optional[int] = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if out is not None: + return out.new_empty(0) return torch.empty_like(q.contiguous()) @@ -1176,4 +1281,6 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: ad_logger.warning(f"Provided {scale=}, is not a float. Using default scale instead.") scale = None - return [scale] + sliding_window = extract_op_args(source_attn_node, "sliding_window")[0] + + return [scale, sliding_window] From 5145b34f1e465b74a22c6de4a01f5b4ea88c5822 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Sun, 5 Apr 2026 02:53:10 -0700 Subject: [PATCH 12/16] [None][feat] gemma4 optimizations: piecewise cudagraph support, MLIR fixes, gather logits softcap, sliding window tests - Enable MLIR elementwise fusion, gather_logits, and fuse_gemms transforms in gemma4_moe config; switch gemma4 models to world_size_1 - Register triton_paged ops in piecewise_utils for CUDA graph capture - Add torch.cuda.synchronize after piecewise graph replay to prevent race conditions with non-default streams (e.g. fused_moe) - Fix MLIR triton emitter: use tl.extra.cuda.libdevice for math ops (gelu, tanh, exp, softplus, pow); handle scalar/rank-0 tensor inputs; add AD_DUMP_KERNELS_DIR env var for kernel source inspection - Fix gather_logits_before_lm_head to walk backward through post-lm_head ops (div, tanh, mul softcapping) to find the actual linear node - Add sliding window attention tests for decode and context kernels - Add softcapping LM head test for gather logits transform Signed-off-by: Suyog Gupta Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../model_registry/configs/gemma4_moe.yaml | 6 + .../auto_deploy/model_registry/models.yaml | 4 +- .../compile/backends/torch_cudagraph.py | 9 +- .../auto_deploy/compile/piecewise_utils.py | 2 + .../mlir/codegen/triton_emitter.py | 45 ++- .../library/gather_logits_before_lm_head.py | 17 +- .../attention/test_triton_paged_attention.py | 303 ++++++++++++++++++ .../test_gather_logits_before_lm_head.py | 92 +++++- 8 files changed, 466 insertions(+), 12 deletions(-) diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml index 6a6b5967c3e..a830e6946df 100644 --- a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml @@ -19,3 +19,9 @@ kv_cache_config: transforms: compile_model: piecewise_enabled: true + mlir_elementwise_fusion: + enabled: true + gather_logits_before_lm_head: + enabled: true + fuse_gemms: + enabled: true diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index 921e1223651..ef614e39b8b 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -310,9 +310,9 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] # --- Gemma 4 (2026) - MoE with K=V attention --- - name: google/gemma-4-26B-A4B - yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'gemma4_moe_base.yaml'] + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma4_moe_base.yaml'] - name: google/gemma-4-26B-A4B-it - yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'gemma4_moe.yaml'] + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma4_moe.yaml'] # --- JetBrains Mellum (Apr 2025) - code specialist --- - name: JetBrains/Mellum-4b-sft-all yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] diff --git a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py index be671f3cec6..28ff9620978 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py @@ -517,7 +517,14 @@ def forward(self, *args, num_tokens: Optional[int] = None, **kwargs) -> Any: """Forward pass: static segments replay graphs, dynamic segments run eagerly.""" if self.split_gm is not None: ADPiecewiseRunner.set_current_num_tokens(num_tokens) - return self.split_gm(*args, **kwargs) + result = self.split_gm(*args, **kwargs) + # Ensure all CUDA graph internal streams have completed before the + # caller (DualModeCapturedGraph) proceeds. Some captured kernels + # (e.g. trtllm fused_moe) may use non-default streams inside the + # graph; without this sync the next eager op can race with those + # internal streams, causing sporadic illegal-memory-access crashes. + torch.cuda.synchronize() + return result return self.original_model(*args, **kwargs) diff --git a/tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py b/tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py index e4784a38165..e5d87e78e2c 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py @@ -27,6 +27,7 @@ _CACHED_ATTENTION_OPS = [ "auto_deploy::flashinfer_attention_mha_with_cache", "auto_deploy::triton_attention_flattened_mha_with_cache", + "auto_deploy::triton_paged_mha_with_cache", "auto_deploy::torch_cached_attention_with_cache", "auto_deploy::trtllm_attention_mha_with_cache", # MLA attention variants @@ -57,6 +58,7 @@ _METADATA_PREP_OPS = [ "auto_deploy::flashinfer_attention_prepare_metadata", "auto_deploy::flashinfer_mla_prepare_metadata", + "auto_deploy::triton_paged_prepare_metadata", "auto_deploy::mamba_ssm_prepare_metadata", ] diff --git a/tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py b/tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py index 96ef220ef58..b8790c1b018 100644 --- a/tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py +++ b/tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py @@ -82,12 +82,12 @@ def _cleanup_temp_files(): "ad.rsqrt": lambda a: f"(1.0 / tl.sqrt({a}))", "ad.sqrt": lambda a: f"tl.sqrt({a})", "ad.silu": lambda a: f"({a} * tl.sigmoid({a}))", - "ad.gelu": lambda a: f"({a} * 0.5 * (1.0 + tl.math.erf({a} * 0.7071067811865476)))", + "ad.gelu": lambda a: f"({a} * 0.5 * (1.0 + tl.extra.cuda.libdevice.erf({a} * 0.7071067811865476)))", "ad.relu": lambda a: f"tl.maximum({a}, 0)", - "ad.tanh": lambda a: f"tl.math.tanh({a})", + "ad.tanh": lambda a: f"tl.extra.cuda.libdevice.tanh({a})", "ad.sigmoid": lambda a: f"tl.sigmoid({a})", - "ad.exp": lambda a: f"tl.math.exp({a})", - "ad.softplus": lambda a: f"tl.math.log(1.0 + tl.math.exp({a}))", + "ad.exp": lambda a: f"tl.extra.cuda.libdevice.exp({a})", + "ad.softplus": lambda a: f"tl.extra.cuda.libdevice.log(1.0 + tl.extra.cuda.libdevice.exp({a}))", "ad.reduce_sum": lambda a: f"tl.sum({a}, 0)", "ad.reduce_mean": lambda a, ncols: f"(tl.sum({a}, 0) * (1.0 / {ncols}))", "ad.splat": None, # handled specially — just inline the scalar value @@ -253,7 +253,21 @@ def generate_kernel_from_subgraph(subgraph) -> Callable: # last dim < N_COLS, e.g. a gating scalar of shape (-1, 1) in a subgraph # whose row width is 2048). Both categories need a load pattern that # avoids reading past the end of the actual data. + # Scalar-like inputs (rank-0 OR broadcast with last-dim 1, e.g. shape [1]) + # need a single-element load; Triton broadcasts the scalar automatically. broadcast_flags = [_is_broadcast_input(inp, max_rank) for inp in subgraph.inputs] + scalar_flags = [] + for i, inp in enumerate(subgraph.inputs): + rank = _get_tensor_rank(inp) + if rank == 0: + scalar_flags.append(True) + elif broadcast_flags[i] and isinstance(inp.type, TensorType): + shape = inp.type.get_shape() + # Broadcast input whose last dim is 1 (e.g. layer_scalar shape [1]) + # must be loaded as a single element, not a vector. + scalar_flags.append(not shape or shape[-1] == 1) + else: + scalar_flags.append(False) narrow_flags = [] for inp in subgraph.inputs: if isinstance(inp.type, TensorType): @@ -273,7 +287,10 @@ def generate_kernel_from_subgraph(subgraph) -> Callable: # Broadcast (1D) inputs (e.g. weights) are offset by group only: # ptr + pid_group * N_COLS + offs for i, inp in enumerate(subgraph.inputs): - if broadcast_flags[i]: + if scalar_flags[i]: + # Rank-0 (scalar) tensor: load single element, Triton broadcasts automatically. + body_lines.append(f" v{i} = tl.load(in{i}_ptr).to(tl.float32)") + elif broadcast_flags[i]: if grouped_mode: body_lines.append( f" v{i} = tl.load(in{i}_ptr + group_off + offs, mask=mask).to(tl.float32)" @@ -320,7 +337,9 @@ def generate_kernel_from_subgraph(subgraph) -> Callable: else: exp_val = float(str(exp_attr)) result_name = f"t{temp_counter}" - body_lines.append(f" {result_name} = tl.math.pow({base_name}, {exp_val})") + body_lines.append( + f" {result_name} = tl.extra.cuda.libdevice.pow({base_name}, {exp_val})" + ) temp_counter += 1 for r in op.results: val_names[id(r)] = result_name @@ -530,6 +549,20 @@ def generate_kernel_from_subgraph(subgraph) -> Callable: "import torch\n\n" + kernel_src + "\n" + launcher_src ) + import logging as _logging + import os as _os + + _logging.getLogger("mlir_codegen").info("Generated kernel %s:\n%s", sg_hash, full_src) + + # Optional: dump kernel source to a directory for offline inspection. + # Controlled by the AD_DUMP_KERNELS_DIR environment variable. + _kernel_dump_dir = _os.environ.get("AD_DUMP_KERNELS_DIR") + if _kernel_dump_dir: + _dump_path = _os.path.join(_kernel_dump_dir, f"triton_gen_{sg_hash}.py") + _os.makedirs(_kernel_dump_dir, exist_ok=True) + with open(_dump_path, "w") as _f: + _f.write(full_src) + import importlib.util import tempfile diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py b/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py index 5653ec7a481..bb30e1a2dac 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py @@ -62,8 +62,21 @@ def _apply( node_to_gather = lm_head_node.all_input_nodes[0] self._log_info(f"Found LM head node: {lm_head_node.name}") else: - node_to_gather = lm_head_node - self._log_info("lm_head node is not linear, using it as the node to gather") + # Walk backward through elementwise/unary ops (e.g. softcapping: div, tanh, mul) + # to find the actual lm_head linear node. + current = lm_head_node + while current is not None and not is_linear_op(current): + inputs = current.all_input_nodes + current = inputs[0] if len(inputs) >= 1 else None + + if current is not None and is_linear_op(current): + node_to_gather = current.all_input_nodes[0] + self._log_info( + f"Found LM head linear through post-processing chain: {current.name}" + ) + else: + node_to_gather = lm_head_node + self._log_info("lm_head node is not linear, using it as the node to gather") # Add logits_gather_mask as input in the graph and the sequence info interface logits_gather_indices_node = self._add_or_retrieve_input(gm, cm, "token_gather_indices") diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py index 5a14303e14a..fb8c502d99f 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py @@ -464,6 +464,309 @@ def test_prepare_metadata_with_12_element_batch_info(self): assert (positions == torch.arange(7, device="cuda")).all() +class TestSlidingWindow: + """Tests for sliding window attention support in Triton paged kernels.""" + + @staticmethod + def _sliding_window_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sm_scale: float, + sliding_window: int, + ) -> torch.Tensor: + """Compute causal + sliding window attention with manual masking. + + Args: + q: [B, n_heads, S_q, head_dim] + k: [B, n_heads, S_k, head_dim] + v: [B, n_heads, S_k, head_dim] + + Returns: + [B, n_heads, S_q, head_dim] + """ + s_q = q.shape[2] + s_k = k.shape[2] + + attn = torch.matmul(q, k.transpose(-2, -1)) * sm_scale + + q_pos = torch.arange(s_k - s_q + s_q, device=q.device) # absolute positions + # For prefill: q_pos = [0..s_q-1], k_pos = [0..s_k-1] + q_pos = torch.arange(s_k - s_q, s_k, device=q.device) # [s_q] + k_pos = torch.arange(s_k, device=q.device) # [s_k] + + pos_diff = q_pos.unsqueeze(1) - k_pos.unsqueeze(0) # [s_q, s_k] + causal_mask = pos_diff < 0 + window_mask = pos_diff >= sliding_window + combined = causal_mask | window_mask + attn.masked_fill_(combined.unsqueeze(0).unsqueeze(0), float("-inf")) + + attn = torch.softmax(attn, dim=-1) + return torch.matmul(attn, v) + + @pytest.mark.parametrize("batch_size", [1, 4]) + @pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (32, 8)]) + @pytest.mark.parametrize("head_dim", [64, 128]) + @pytest.mark.parametrize("seq_len", [128, 256, 512]) + @pytest.mark.parametrize("sliding_window", [32, 64]) + def test_decode_sliding_window( + self, + batch_size: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + seq_len: int, + sliding_window: int, + ): + """Test decode with sliding window against reference (seq_len > window).""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_decode, + update_paged_kv_cache, + ) + + assert seq_len > sliding_window, "Test requires seq_len > sliding_window" + page_size = 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 5 + + q = torch.randn(batch_size, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + v = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + + k_flat = k.reshape(batch_size * seq_len, n_kv_heads, head_dim) + v_flat = v.reshape(batch_size * seq_len, n_kv_heads, head_dim) + + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache( + k_flat, v_flat, batch_indices, positions, kv_cache, kv_indices, kv_indptr + ) + + sm_scale = 1.0 / math.sqrt(head_dim) + + output_triton = triton_paged_decode( + q, + kv_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + sm_scale, + sliding_window=sliding_window, + ) + + # Reference: only attend to last `sliding_window` tokens + head_ratio = n_heads // n_kv_heads + k_ref = k[:, -sliding_window:, :, :].transpose(1, 2) + v_ref = v[:, -sliding_window:, :, :].transpose(1, 2) + if head_ratio > 1: + k_ref = k_ref.repeat_interleave(head_ratio, dim=1) + v_ref = v_ref.repeat_interleave(head_ratio, dim=1) + + q_ref = q.unsqueeze(2) # [B, n_heads, 1, head_dim] + output_ref = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, scale=sm_scale, is_causal=False + ).squeeze(2) + + torch.testing.assert_close(output_triton.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("batch_size", [1, 2]) + @pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (32, 8)]) + @pytest.mark.parametrize("head_dim", [64, 128]) + @pytest.mark.parametrize("seq_len", [128, 256]) + @pytest.mark.parametrize("sliding_window", [32, 64]) + def test_context_sliding_window( + self, + batch_size: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + seq_len: int, + sliding_window: int, + ): + """Test prefill with sliding window against manual reference (seq_len > window).""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_context, + update_paged_kv_cache, + ) + + assert seq_len > sliding_window, "Test requires seq_len > sliding_window" + page_size = 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 5 + total_tokens = batch_size * seq_len + + q = torch.randn(total_tokens, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + + qo_indptr = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device="cuda" + )[: batch_size + 1] + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + seq_len_with_cache = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache(k, v, batch_indices, positions, kv_cache, kv_indices, kv_indptr) + + sm_scale = 1.0 / math.sqrt(head_dim) + + output = triton_paged_context( + q, + kv_cache, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + seq_len_with_cache, + sm_scale, + sliding_window=sliding_window, + ) + + # Reference: manual causal + sliding window attention + head_ratio = n_heads // n_kv_heads + q_ref = q.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2) + k_ref = k.view(batch_size, seq_len, n_kv_heads, head_dim).transpose(1, 2) + v_ref = v.view(batch_size, seq_len, n_kv_heads, head_dim).transpose(1, 2) + if head_ratio > 1: + k_ref = k_ref.repeat_interleave(head_ratio, dim=1) + v_ref = v_ref.repeat_interleave(head_ratio, dim=1) + + output_ref = self._sliding_window_reference(q_ref, k_ref, v_ref, sm_scale, sliding_window) + output_ref = output_ref.transpose(1, 2).reshape(total_tokens, n_heads, head_dim) + + torch.testing.assert_close(output.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + + def test_no_sliding_window_unchanged(self): + """Verify that sliding_window=None produces the same output as before.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_decode, + update_paged_kv_cache, + ) + + batch_size, n_heads, n_kv_heads, head_dim = 2, 8, 8, 64 + seq_len, page_size = 128, 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 5 + + q = torch.randn(batch_size, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + v = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + + k_flat = k.reshape(batch_size * seq_len, n_kv_heads, head_dim) + v_flat = v.reshape(batch_size * seq_len, n_kv_heads, head_dim) + + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache( + k_flat, v_flat, batch_indices, positions, kv_cache, kv_indices, kv_indptr + ) + + sm_scale = 1.0 / math.sqrt(head_dim) + + out_none = triton_paged_decode( + q, + kv_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + sm_scale, + sliding_window=None, + ) + out_zero = triton_paged_decode( + q, + kv_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + sm_scale, + sliding_window=0, + ) + + torch.testing.assert_close(out_none, out_zero) + + class TestFlashInferComparison: """Tests comparing Triton implementation against FlashInfer.""" diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py index e7310944951..a3f40d08105 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py @@ -26,7 +26,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer -from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op class SimpleLMHeadModel(torch.nn.Module): @@ -45,6 +45,24 @@ def forward(self, hidden_states, logit_gather_ids=None, seq_len=None): return logits +class SoftcapLMHeadModel(torch.nn.Module): + """Model with LM head followed by softcapping (like Gemma4).""" + + def __init__(self, hidden_size: int = 128, vocab_size: int = 1000, softcap: float = 30.0): + super().__init__() + self.linear1 = torch.nn.Linear(hidden_size, hidden_size, device="cuda", dtype=torch.float16) + self.lm_head = torch.nn.Linear(hidden_size, vocab_size, device="cuda", dtype=torch.float16) + self.softcap = softcap + + def forward(self, hidden_states, logit_gather_ids=None, seq_len=None): + hidden_states = self.linear1(hidden_states) + logits = self.lm_head(hidden_states) + logits = logits / self.softcap + logits = torch.tanh(logits) + logits = logits * self.softcap + return logits + + class TestGatherTokensOp: """Test the custom op directly.""" @@ -348,3 +366,75 @@ def test_transform_skips_when_disabled(self): assert not self._check_gather_op_in_graph(gm_transformed), ( "Gather op should not be in graph" ) + + def test_transform_with_softcapping(self): + """Test that gather is placed BEFORE lm_head when softcapping follows it. + + Models like Gemma4 apply softcapping (div, tanh, mul) after the lm_head. + The transform must walk backward through these ops to find the actual + linear and insert gather before it, not after the softcapping chain. + Otherwise the lm_head still runs on all tokens (no compute reduction) + and piecewise CUDA graph capture OOMs on the [num_tokens, vocab_size] + intermediate. + """ + hidden_size = 128 + vocab_size = 1000 + batch_size = 4 + max_batch_size = 8 + model = SoftcapLMHeadModel(hidden_size, vocab_size).cuda() + + hidden_states = torch.randn(batch_size, 1, hidden_size, device="cuda", dtype=torch.float16) + logit_gather_ids = torch.zeros(max_batch_size, dtype=torch.long, device="cuda") + seq_len = torch.ones(batch_size, dtype=torch.long, device="cuda") + + gm = torch_export_to_gm( + model, + args=(hidden_states, logit_gather_ids, seq_len), + dynamic_shapes=None, + clone=True, + ) + + # Apply transform + cm = self._create_cached_sequence_interface(max_batch_size) + transform_config = { + "gather_logits_before_lm_head": { + "stage": "post_load_fusion", + "max_batch_size": max_batch_size, + } + } + optimizer = InferenceOptimizer(None, transform_config) + gm_transformed = optimizer(cm, gm) + + assert self._check_gather_op_in_graph(gm_transformed), "Gather op not found in graph" + + # Verify gather_tokens comes BEFORE the lm_head linear, not after softcapping. + # Walk the graph and record the order of gather_tokens vs aten.linear ops. + gather_idx = None + linear_indices = [] + for i, node in enumerate(gm_transformed.graph.nodes): + if is_op(node, torch.ops.auto_deploy.gather_tokens): + gather_idx = i + if is_linear_op(node): + linear_indices.append(i) + + assert gather_idx is not None, "gather_tokens not found" + # The lm_head linear is the last linear in the graph + lm_head_linear_idx = linear_indices[-1] + assert gather_idx < lm_head_linear_idx, ( + f"gather_tokens (idx={gather_idx}) should come before " + f"lm_head linear (idx={lm_head_linear_idx})" + ) + + # Verify forward pass correctness + token_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda") + batch_info = BatchInfo() + batch_info.update_tokens_gather_info(batch_size, False) + batch_info_host = batch_info.serialize() + output = gm_transformed( + hidden_states, + logit_gather_ids, + seq_len, + token_gather_indices=token_gather_indices, + batch_info_host=batch_info_host, + ) + assert output.shape == (batch_size, 1, vocab_size) From 2e7fb0d8f0f3e03eb7f4d8b37ca4da90d6e91e75 Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Mon, 6 Apr 2026 10:22:14 -0700 Subject: [PATCH 13/16] [None][doc] add Gemma 3n to supported models matrix Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- docs/source/models/supported-models.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/models/supported-models.md b/docs/source/models/supported-models.md index 283eca924eb..adb2f53ef72 100644 --- a/docs/source/models/supported-models.md +++ b/docs/source/models/supported-models.md @@ -13,6 +13,7 @@ The following is a table of supported models for the PyTorch backend: | `Exaone4ForCausalLM` | EXAONE 4.0 | `LGAI-EXAONE/EXAONE-4.0-32B` | | `ExaoneMoEForCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B` | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` | +| `Gemma3nForConditionalGeneration` [^8]| Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it` | | `Gemma4ForConditionalGeneration` [^7]| Gemma 4 | `google/gemma-4-26B-A4B-it` | | `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6, GLM-4.7 | `THUDM/GLM-4-100B-A10B` | | `Glm4MoeLiteForCausalLM` [^6] | GLM-4.7-Flash | `zai-org/GLM-4.7-Flash` | @@ -62,6 +63,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl [^5]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml). [^6]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml). [^7]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma4_moe.yaml). +[^8]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml). # Multimodal Feature Support Matrix (PyTorch Backend) From 6961cbe63e7bea6e6ecdd8394fd9ae31cc0de116 Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Mon, 6 Apr 2026 10:38:08 -0700 Subject: [PATCH 14/16] [None][fix] fix flashinfer sliding window in CUDA graph decode, reject Geglu in NVFP4 MoE Pass window_left to fast_decode_plan in plan_generate_only so sliding window attention is respected during CUDA-graph-captured decode. Add early rejection of Gelu/Geglu in NVFP4 TRTLLM-Gen MoE since the underlying kernel does not support it. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../auto_deploy/custom_ops/attention/flashinfer_attention.py | 1 + .../_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py index f9c0c703e10..90eef824630 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py @@ -154,6 +154,7 @@ def plan_generate_only( q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + window_left=plan_params.window_left, ) def plan_prefill( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index d108eb03abf..0060cb2e78a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -947,6 +947,11 @@ def trtllm_nvfp4_trtllm_gen_moe_fused( apply_routing_on_input: bool = False, ) -> torch.Tensor: _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) + if act_fn in (ActivationType.Gelu, ActivationType.Geglu): + raise ValueError( + f"NVFP4 TRTLLM-Gen MoE does not support activation " + f"'{ActivationType(act_fn).name}'. Only Silu/Swiglu and Relu2 are supported." + ) x_shape = x.shape x2d = x.view(-1, x_shape[-1]) From a5abc9606eb7ad2e47c1d646132108c09823994b Mon Sep 17 00:00:00 2001 From: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Mon, 6 Apr 2026 16:03:10 -0700 Subject: [PATCH 15/16] [None][fix] update gemma configs for new CudaGraphConfig format Migrate cuda_graph_batch_sizes to cuda_graph_config.batch_sizes and add explicit max_batch_size to gemma3n config to preserve prior default. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml | 1 + examples/auto_deploy/model_registry/configs/gemma4_moe.yaml | 3 ++- .../auto_deploy/model_registry/configs/gemma4_moe_base.yaml | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml b/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml index 0c9862833cc..2e5e1f8d5cb 100644 --- a/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml @@ -5,6 +5,7 @@ runtime: trtllm compile_backend: torch-cudagraph model_factory: AutoModelForCausalLM max_seq_len: 512 +max_batch_size: 8 world_size: 1 # Gemma 3n uses shared-KV decode semantics in the tail layers. FlashInfer diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml index a830e6946df..d31ba340bc9 100644 --- a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml @@ -8,7 +8,8 @@ model_factory: Gemma4ForConditionalGeneration tokenizer: google/gemma-4-26B-A4B-it attn_backend: triton_paged compile_backend: torch-cudagraph -cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +cuda_graph_config: + batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] max_num_tokens: 8192 max_batch_size: 512 max_seq_len: 8192 diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml index abfc0a02078..9e469676559 100644 --- a/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml @@ -8,7 +8,8 @@ model_factory: Gemma4ForConditionalGeneration tokenizer: google/gemma-4-26B-A4B attn_backend: triton_paged compile_backend: torch-cudagraph -cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +cuda_graph_config: + batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] max_num_tokens: 8192 max_batch_size: 512 max_seq_len: 8192 From 076a01e0869a3ac64cdd0938b7000fd5cedd77cc Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:44:58 -0700 Subject: [PATCH 16/16] [None][feat] dual-pool KV cache with SWA block eviction for variable sliding window attention Adds dual-pool KV cache architecture for models with mixed attention types (e.g., gemma4-26B with head_dim=256 sliding + head_dim=512 full attention). Each head_dim group gets its own KVCacheManager pool with independent max_attention_window, enabling SWA block eviction during decode. Architecture: WindowPlan is the single source of truth for VSWA. It separates logical attention-window routing (which layers share page tables) from physical KV storage pooling (which layers share block pools). Both graph wiring and runtime metadata emission derive from it, eliminating predicate drift between the transform and executor. Key changes: - WindowPlan dataclass: per_layer_window, unique_windows, group indices, group_to_pool_idx mapping (decouples window groups from storage pools) - MultiPoolKVCacheManager: delegates lifecycle to all storage pools - _identify_managed_kv_groups: groups layers by (head_dim, dtype, kv_factor) - Per-group cache_loc/cu_num_pages/kv_page_offset via VSWA graph wiring - kv_page_offset in write kernel for window-relative page indexing - kv_page_offset in context kernel for correct position-based masking - cache_len capping from cu_num_pages in triton_paged_mha_with_cache - get_num_front_blocks_removed C++ binding for SWA eviction tracking - N-based proportional memory budget split across pools - max_concurrent_sequences scheduler cap for multi-pool safety - Unit tests for multi-group identification, dual-pool creation, and per-group max_attention_window scoping Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../nanobind/batch_manager/kvCacheManager.cpp | 5 + .../attention/flashinfer_attention.py | 4 +- .../attention/torch_backend_attention.py | 4 +- .../custom_ops/attention/triton_attention.py | 4 +- .../attention/triton_paged_attention.py | 48 +- .../custom_ops/attention/trtllm_attention.py | 15 +- .../custom_ops/attention_interface.py | 191 +++++- .../custom_ops/fla/fla_backend_delta.py | 4 +- .../custom_ops/fla/fla_backend_gated_delta.py | 4 +- .../fla/torch_backend_gated_delta.py | 4 +- .../custom_ops/mamba/causal_conv_common.py | 6 +- .../custom_ops/mamba/mamba_backend_common.py | 4 +- .../mamba/torch_backend_causal_conv.py | 4 +- .../custom_ops/mamba/torch_backend_mamba.py | 4 +- .../custom_ops/mla/flashinfer_mla.py | 4 +- .../custom_ops/mla/torch_backend_mla.py | 4 +- .../_torch/auto_deploy/shim/ad_executor.py | 111 +++- .../_torch/auto_deploy/shim/interface.py | 488 +++++++++++++-- .../auto_deploy/transform/library/kvcache.py | 82 ++- .../_torch/pyexecutor/resource_manager.py | 21 +- .../shim/test_cached_sequence_interface.py | 216 +++++++ .../transformations/library/test_kv_cache.py | 571 ++++++++++++++++++ 22 files changed, 1693 insertions(+), 105 deletions(-) mode change 100755 => 100644 cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp old mode 100755 new mode 100644 index fe7687aee77..29202c6fca8 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -534,6 +534,11 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("count_reusable_blocks", &BaseKVCacheManager::countReusableBlocks, nb::arg("unique_tokens"), nb::arg("llm_request"), nb::arg("only_allocated") = false, nb::call_guard()) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard()) + .def( + "get_num_front_blocks_removed", + [](BaseKVCacheManager const& self, tb::LlmRequest::RequestIdType requestId) + { return self.getSequence(requestId).getNumFrontBlocksRemoved(); }, + nb::call_guard()) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds, nb::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py index 90eef824630..0674c7be6ee 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py @@ -588,7 +588,9 @@ def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCalla return prepare_flashinfer_metadata_host @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py index 4be06d5b24a..d3eb1a53f7d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py @@ -630,7 +630,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: # Sanity check: layout == "bsnd" layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py index a1ea264b0e2..0223f098d89 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py @@ -386,7 +386,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: # Sanity check: layout == "bsnd" # extract_op_args handles kwargs and positional arguments consistently. layout = extract_op_args(source_attn_node, "layout")[0] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py index 2ab575af2a6..a784356c05c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py @@ -78,6 +78,13 @@ def _update_paged_kv_cache_kernel( # Page table kv_indices_ptr, kv_indptr_ptr, + # Per-sequence page offset for windowed (VSWA) cache_loc. + # When cache_loc holds only the last W pages of a sequence (sliding window), + # page_offset = total_pages_in_seq - W. The kernel subtracts this from the + # global page index to obtain the window-relative index into cache_loc and + # skips tokens whose pages fall outside the window. + # For non-windowed (full) cache_loc this is 0 everywhere (no-op). + page_offset_ptr, # Constants NUM_TOKENS: tl.constexpr, N_KV_HEADS: tl.constexpr, @@ -102,8 +109,14 @@ def _update_paged_kv_cache_kernel( page_idx_in_seq = position // PAGE_SIZE offset_in_page = position % PAGE_SIZE + # Adjust for windowed cache_loc: subtract per-sequence page offset. + page_offset = tl.load(page_offset_ptr + batch_idx) + page_idx_in_window = page_idx_in_seq - page_offset + if page_idx_in_window < 0: + return # Token is before the sliding window — skip write + page_start = tl.load(kv_indptr_ptr + batch_idx) - physical_page = tl.load(kv_indices_ptr + page_start + page_idx_in_seq) + physical_page = tl.load(kv_indices_ptr + page_start + page_idx_in_window) head_offsets = tl.arange(0, HEAD_DIM) kv_offset = token_id * N_KV_HEADS * HEAD_DIM + head_id * HEAD_DIM + head_offsets @@ -131,6 +144,7 @@ def update_paged_kv_cache( kv_cache: torch.Tensor, kv_indices: torch.Tensor, kv_indptr: torch.Tensor, + kv_page_offset: torch.Tensor, ) -> None: """Update the combined paged KV cache with new K, V tensors.""" num_tokens, n_kv_heads, head_dim = k.shape @@ -148,6 +162,7 @@ def update_paged_kv_cache( kv_cache, kv_indices, kv_indptr, + kv_page_offset, NUM_TOKENS=num_tokens, N_KV_HEADS=n_kv_heads, HEAD_DIM=head_dim, @@ -613,6 +628,9 @@ def _paged_context_kernel( kv_indices_ptr, kv_last_page_len_ptr, seq_len_with_cache_ptr, + # Per-sequence page offset for windowed cache_loc (VSWA). + # Adjusts page_idx to global position for correct causal/SWA masking. + kv_page_offset_ptr, # Output o_ptr, # Strides @@ -660,6 +678,7 @@ def _paged_context_kernel( kv_page_end = tl.load(kv_indptr_ptr + batch_id + 1) num_kv_pages = kv_page_end - kv_page_start total_kv_len = tl.load(seq_len_with_cache_ptr + batch_id) + page_offset = tl.load(kv_page_offset_ptr + batch_id) cache_len = total_kv_len - q_len @@ -739,7 +758,7 @@ def _paged_context_kernel( # Per-query sliding window mask: each query position q_pos # can attend to KV in [q_pos - W + 1, q_pos]. - kv_positions = page_idx * PAGE_SIZE + page_offsets[None, :] + kv_positions = (page_idx + page_offset) * PAGE_SIZE + page_offsets[None, :] q_kv_pos = q_offsets[:, None] + cache_len sw_mask = (q_kv_pos - kv_positions) < SLIDING_WINDOW full_mask_p1 = q_mask[:, None] & sw_mask @@ -788,7 +807,7 @@ def _paged_context_kernel( q_positions_2d = q_offsets[:, None] + cache_len for page_idx in range(num_full_pages, num_kv_pages): - kv_base_pos = page_idx * PAGE_SIZE + kv_base_pos = (page_idx + page_offset) * PAGE_SIZE # Causal skip: if entire page is beyond last Q position, skip it. if kv_base_pos <= max_q_pos: @@ -905,6 +924,7 @@ def triton_paged_context( kv_indices: torch.Tensor, kv_last_page_len: torch.Tensor, seq_len_with_cache: torch.Tensor, + kv_page_offset: torch.Tensor, sm_scale: float, sliding_window: Optional[int] = None, out: Optional[torch.Tensor] = None, @@ -1004,6 +1024,7 @@ def grid_paged(meta): kv_indices, kv_last_page_len, seq_len_with_cache, + kv_page_offset, output, q.stride(0), q.stride(1), @@ -1077,6 +1098,7 @@ def triton_paged_mha_with_cache( last_page_len: torch.Tensor, last_page_len_host: torch.Tensor, seq_len_with_cache_host: torch.Tensor, + kv_page_offset: torch.Tensor, # EXTRA METADATA triton_batch_indices: torch.Tensor, triton_positions: torch.Tensor, @@ -1115,6 +1137,7 @@ def triton_paged_mha_with_cache( kv_cache, cache_loc, cu_num_pages[: num_seq + 1], + kv_page_offset[:num_seq], ) if out is not None: @@ -1126,6 +1149,18 @@ def triton_paged_mha_with_cache( if num_prefill > 0: cu_seqlen = cu_seqlen_host[: num_prefill + 1].to(q.device, non_blocking=True) seq_len_with_cache = seq_len_with_cache_host[:num_prefill].to(q.device, non_blocking=True) + # For windowed cache_loc (VSWA), cap the cached-token portion of + # seq_len_with_cache to the actual pages available. Without this, + # the context kernel computes page iteration bounds from global + # seq_len, overflowing the windowed cache_loc. + # seq_len_with_cache = cache_len + q_len, where cache_len is the + # number of prior-cached tokens. Only cache_len needs capping. + q_lens = cu_seqlen[1 : num_prefill + 1] - cu_seqlen[:num_prefill] + page_counts = cu_num_pages[1 : num_prefill + 1] - cu_num_pages[:num_prefill] + max_cached = page_counts * kv_cache.shape[3] # pages × page_size + cache_len_raw = seq_len_with_cache - q_lens + cache_len_capped = torch.minimum(cache_len_raw, max_cached) + seq_len_with_cache = cache_len_capped + q_lens triton_paged_context( q[:num_prefill_tokens], kv_cache, @@ -1134,6 +1169,7 @@ def triton_paged_mha_with_cache( cache_loc, last_page_len[:num_prefill], seq_len_with_cache, + kv_page_offset[:num_prefill], sm_scale, sliding_window=sliding_window, out=y[:num_prefill_tokens], @@ -1178,6 +1214,7 @@ def triton_paged_mha_with_cache_fake( last_page_len: torch.Tensor, last_page_len_host: torch.Tensor, seq_len_with_cache_host: torch.Tensor, + kv_page_offset: torch.Tensor, triton_batch_indices: torch.Tensor, triton_positions: torch.Tensor, kv_cache: torch.Tensor, @@ -1224,6 +1261,7 @@ def get_standard_metadata_args(cls) -> List[str]: "last_page_len", "last_page_len_host", "seq_len_with_cache_host", + "kv_page_offset", ] @classmethod @@ -1255,7 +1293,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: (layout,) = extract_op_args(source_attn_node, "layout") if layout != "bsnd": raise RuntimeError( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py index f1c99267ed0..d013a38cd53 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py @@ -334,6 +334,7 @@ def trtllm_mha_with_cache( kv_scale_quant_orig: float = 1.0, out_scale: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, + sink_token_length: int = 0, ) -> torch.Tensor: """TRT-LLM attention with paged KV cache for Auto-Deploy. @@ -458,7 +459,7 @@ def trtllm_mha_with_cache( max_num_requests, # max_num_requests max_context_length, # max_context_length attention_window_size, # attention_window_size - 0, # sink_token_length + sink_token_length, # sink_token_length 1, # beam_width int(AttentionMaskType.causal), # mask_type quant_mode, # quant_mode @@ -531,6 +532,7 @@ def trtllm_mha_with_cache_fake( kv_scale_quant_orig: float = 1.0, out_scale: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, + sink_token_length: int = 0, ) -> torch.Tensor: """Fake implementation for torch.compile tracing.""" if out is not None: @@ -630,11 +632,13 @@ def prepare_node_for_cache_insertion(cls, gm: GraphModule, attn_node: Node) -> N attn_node.meta[_TRTLLM_ATTN_OUT_SCALE_KEY] = out_scale @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: """Extract constants from the source attention node. Returns scale, sliding_window, kv_scale_orig_quant, kv_scale_quant_orig, - and optional output quant scale for FP8 linear consumers. + optional output quant scale for FP8 linear consumers, and sink_token_length. Everything else (num_heads, head_dim, max_context_length, etc.) is inferred from tensor shapes or SequenceInfo metadata at runtime. """ @@ -669,10 +673,15 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: if not isinstance(out_scale, Node): out_scale = None + sink_token_length = 0 + if cache_config is not None and cache_config.sink_token_length is not None: + sink_token_length = cache_config.sink_token_length + return [ scale, sliding_window, 1.0, # kv_scale_orig_quant (hard-coded, same as FlashInfer) 1.0, # kv_scale_quant_orig (hard-coded, same as FlashInfer) out_scale, + sink_token_length, ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 72acc449519..48f6384c67c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -356,6 +356,29 @@ def resize(self, name: str, new_capacity: int) -> None: self._device_views[name] = self._trunc_device_bufs[name].view(dtype) self._host_views[name] = self._trunc_host_bufs[name].view(dtype) + def add_truncatable_tensor(self, name: str, max_numel: int, dtype: torch.dtype) -> None: + """Add a new truncatable tensor after initialization. + + Args: + name: Name of the tensor. + max_numel: Maximum number of elements. + dtype: Data type. + """ + assert name not in self._tensor_specs, f"Tensor '{name}' already registered" + self._tensor_specs[name] = (max_numel, dtype) + self._tensor_order.append(name) + self._truncatable_names.add(name) + self._current_lengths[name] = 0 + + byte_size = max_numel * dtype.itemsize + device = self._device_buffer.device + self._trunc_device_bufs[name] = torch.empty(byte_size, dtype=torch.uint8, device=device) + self._trunc_host_bufs[name] = torch.empty( + byte_size, dtype=torch.uint8, device="cpu", pin_memory=prefer_pinned() + ) + self._device_views[name] = self._trunc_device_bufs[name].view(dtype) + self._host_views[name] = self._trunc_host_bufs[name].view(dtype) + def to(self, *args, **kwargs) -> None: """Move all device buffers to a new device/dtype.""" old_device = self._device_buffer.device @@ -660,6 +683,7 @@ def __init__( ("seq_len_with_cache", self.max_batch_size, torch.int), ("use_initial_states", self.max_batch_size, torch.bool), ### OTHER ARGUMENTS USED BY THE RUNTIME ################################################ + ("kv_page_offset", self.max_batch_size, torch.int), ("extra_page_per_seq", self.max_batch_size, torch.int), ("token_gather_indices", self.max_num_tokens, torch.long), ("_gather_idx", self.max_num_tokens, torch.int), @@ -670,6 +694,9 @@ def __init__( # Create the InputBuffer that manages contiguous host and device memory # Starts on default device; use to() to move to target device self._input_buffer = InputBuffer(tensor_specs) + # Zero-fill kv_page_offset on device so shape-checking forward passes + # (resize_kv_cache) that run before nest_sequences read a safe default. + self._input_buffer._device_views["kv_page_offset"].zero_() self._available_args = ( set(self._input_buffer.tensor_names) | {name + self._host_suffix for name in self._input_buffer.tensor_names} @@ -689,6 +716,11 @@ def __init__( self._extra_args: Dict[str, Optional[torch.Tensor]] = {} ############################################################################################ + # VSWA WINDOW GROUPS ####################################################################### + self._window_groups: List[int] = [] + self._window_group_map: Dict[int, int] = {} + ############################################################################################ + # HOST PREPARE FOR ATTENTION FORWARD ####################################################### self._host_prepare_functions: List[Tuple[PrepareMetadataHostCallable, List[str]]] = [] @@ -827,13 +859,20 @@ def estimate_cache_tokens_per_forward(self) -> int: num_blocks_estimate = num_blocks_estimate_per_seq * self.max_batch_size return num_blocks_estimate * self.tokens_per_block - def update_cache_information(self, num_blocks: int, block_offset_multiplier: int = 0) -> None: + def update_cache_information( + self, + num_blocks: int, + block_offset_multiplier: int = 0, + ) -> None: """Update cache information after cache manager creation. Sets num_blocks and block_offset_multiplier, writes max_seq_info into BatchInfo (constant after this call), and resizes cache_loc if needed. + + Args: + num_blocks: Number of blocks in the primary pool. + block_offset_multiplier: Block offset multiplier derived from kv_cache strides. """ - # set num_blocks and block_offset_multiplier self._num_blocks = num_blocks # write max_seq_info once into BatchInfo (constant after this call) @@ -855,6 +894,79 @@ def update_cache_information(self, num_blocks: int, block_offset_multiplier: int if estimated_capacity > cache_loc_capacity: self._input_buffer.resize("cache_loc", estimated_capacity) + def register_window_groups(self, window_sizes: List[int]) -> None: + """Register VSWA window groups and create per-group cache tensors. + + Group 0 reuses existing cache_loc/cu_num_pages/last_page_len/extra_page_per_seq. + Groups 1..N-1 get new dedicated tensors (cache_loc_g{i}, cu_num_pages_g{i}, etc.). + + Args: + window_sizes: Sorted list of distinct window sizes (e.g. [1024, 8192]). + """ + assert len(window_sizes) >= 2, "register_window_groups requires at least 2 distinct windows" + self._window_groups = list(window_sizes) + self._window_group_map = {ws: idx for idx, ws in enumerate(window_sizes)} + + cache_loc_cap = self._input_buffer.get_capacity("cache_loc") + for group_idx in range(1, len(window_sizes)): + suffix = f"_g{group_idx}" + self._input_buffer.add_truncatable_tensor( + f"cache_loc{suffix}", cache_loc_cap, torch.int + ) + self._input_buffer.add_truncatable_tensor( + f"cu_num_pages{suffix}", self.max_batch_size + 1, torch.int + ) + self._input_buffer.add_truncatable_tensor( + f"last_page_len{suffix}", self.max_batch_size, torch.int + ) + self._input_buffer.add_truncatable_tensor( + f"extra_page_per_seq{suffix}", self.max_batch_size, torch.int + ) + self._input_buffer.add_truncatable_tensor( + f"kv_page_offset{suffix}", self.max_batch_size, torch.int + ) + # Register as available args (device + host variants) + group_names = [ + f"cache_loc{suffix}", + f"cu_num_pages{suffix}", + f"last_page_len{suffix}", + f"extra_page_per_seq{suffix}", + f"kv_page_offset{suffix}", + ] + for base_name in group_names: + self._available_args.add(base_name) + self._available_args.add(base_name + self._host_suffix) + + # Zero-fill all per-group device buffers so shape-checking forward + # passes (resize_kv_cache) that run before nest_sequences don't read + # uninitialized data. With zeros: cache_loc→block 0 (safe), + # cu_num_pages→0 pages, kv_page_offset→0 (no-op). + for base_name in group_names: + self._input_buffer._trunc_device_bufs[base_name].zero_() + + @property + def window_groups(self) -> List[int]: + """Sorted list of distinct window sizes, or empty if non-VSWA.""" + return self._window_groups + + @property + def window_group_map(self) -> Dict[int, int]: + """Map from window_size to group index.""" + return self._window_group_map + + @property + def num_window_groups(self) -> int: + """Number of window groups (0 or 1 for non-VSWA, 2+ for VSWA).""" + return len(self._window_groups) + + def get_cache_loc_for_group(self, group_idx: int) -> str: + """Return the cache_loc tensor name for a given window group.""" + return "cache_loc" if group_idx == 0 else f"cache_loc_g{group_idx}" + + def get_cu_num_pages_for_group(self, group_idx: int) -> str: + """Return the cu_num_pages tensor name for a given window group.""" + return "cu_num_pages" if group_idx == 0 else f"cu_num_pages_g{group_idx}" + def activate_arg(self, arg_name: str) -> bool: """Activate a desired argument. @@ -1043,7 +1155,13 @@ def nest_sequences( cache_loc: Union[Sequence[int], torch.Tensor, None] = None, cu_num_pages: Union[Sequence[int], torch.Tensor, None] = None, extra_page_per_seq: Optional[Sequence[int]] = None, + kv_page_offset: Optional[Sequence[int]] = None, slot_idx: Union[Sequence[int], torch.Tensor, None] = None, + ### VSWA PER-GROUP CACHE DATA (groups 1..N-1; group 0 uses cache_loc/cu_num_pages above) ### + cache_loc_per_group: Optional[Dict[int, Sequence[int]]] = None, + cu_num_pages_per_group: Optional[Dict[int, Sequence[int]]] = None, + extra_page_per_seq_per_group: Optional[Dict[int, Sequence[int]]] = None, + kv_page_offset_per_group: Optional[Dict[int, Sequence[int]]] = None, ### RUNTIME ARGUMENTS ###################################################################### gather_context_logits: bool = False, _gather_idx: Union[Sequence[int], torch.Tensor, None] = None, @@ -1144,6 +1262,44 @@ def nest_sequences( # check for updated extra_page_per_seq self._stage_arg("extra_page_per_seq", extra_page_per_seq) + # kv_page_offset: per-sequence write-kernel page offset (0 for non-VSWA). + # Default to all-zeros when the caller does not provide it (e.g. warmup). + if kv_page_offset is None and self._is_required("kv_page_offset"): + kv_page_offset = [0] * self.num_sequences + self._stage_arg("kv_page_offset", kv_page_offset) + + # Default per-group metadata when the caller doesn't provide per-group + # data (warmup, set_example_sequence, CUDA graph capture). Replicate + # group 0's data to all groups so every kernel receives valid metadata. + if cache_loc_per_group is None and self.num_window_groups >= 2: + for group_idx in range(1, self.num_window_groups): + suffix = f"_g{group_idx}" + self._stage_arg(f"cache_loc{suffix}", cache_loc) + self._stage_arg(f"cu_num_pages{suffix}", cu_num_pages) + self._stage_arg(f"last_page_len{suffix}", lpl_host) + self._stage_arg(f"extra_page_per_seq{suffix}", extra_page_per_seq) + if kv_page_offset is not None: + self._stage_arg(f"kv_page_offset{suffix}", kv_page_offset) + elif self._is_required(f"kv_page_offset{suffix}"): + self._stage_arg(f"kv_page_offset{suffix}", [0] * self.num_sequences) + + # Stage VSWA per-group cache data (groups 1..N-1) + if cache_loc_per_group is not None: + for group_idx, group_cache_loc in cache_loc_per_group.items(): + if group_idx == 0: + continue + suffix = f"_g{group_idx}" + self._stage_arg(f"cache_loc{suffix}", group_cache_loc) + if cu_num_pages_per_group is not None: + self._stage_arg(f"cu_num_pages{suffix}", cu_num_pages_per_group[group_idx]) + self._stage_arg(f"last_page_len{suffix}", lpl_host) + if extra_page_per_seq_per_group is not None: + self._stage_arg( + f"extra_page_per_seq{suffix}", extra_page_per_seq_per_group[group_idx] + ) + if kv_page_offset_per_group is not None: + self._stage_arg(f"kv_page_offset{suffix}", kv_page_offset_per_group[group_idx]) + ### UPDATE OPTIONAL DERIVATIVE METADATA #################################################### if self._is_required("position_ids"): # set new position_ids and make sure to flatten it @@ -1163,7 +1319,8 @@ def nest_sequences( pages_per_seq = cu_num_pages_host[1:] - cu_num_pages_host[:-1] self._stage_arg("pages_per_seq", pages_per_seq) - # update sequence length with cache + # update sequence length with cache (unclamped global value — per-group + # clamping is handled separately in the VSWA staging blocks below) seq_len_with_cache = ip_host + sl_host self._stage_arg("seq_len_with_cache", seq_len_with_cache) @@ -1345,6 +1502,25 @@ def offset_pos_and_cache_(self, offset: torch.Tensor) -> None: last_page_len %= self.tokens_per_block last_page_len += 1 + # Adjust per-group cache assignments for VSWA groups 1..N-1 + for group_idx in range(1, self.num_window_groups): + suffix = f"_g{group_idx}" + if self._is_active(f"cache_loc{suffix}"): + lpl_g = self.get_arg(f"last_page_len{suffix}", truncate=True) + lpl_g += offset + delta_g = (lpl_g > self.tokens_per_block).int() - (lpl_g <= 0).int() + torch.ops.auto_deploy.adjust_ragged_triton( + cache_loc=self.get_arg(f"cache_loc{suffix}"), + cu_num_blocks=self.get_arg(f"cu_num_pages{suffix}"), + extra_idx=self.get_arg(f"extra_page_per_seq{suffix}"), + delta=delta_g, + num_sequences=num_sequences, + max_blocks_per_seq=self.max_blocks_per_seq, + ) + lpl_g -= 1 + lpl_g %= self.tokens_per_block + lpl_g += 1 + # --- position_ids (device) --- position_ids = self.get_arg("position_ids", truncate=True, unflatten=False) # position_ids is per-token while offset is per-sequence; expand if needed @@ -1927,11 +2103,18 @@ def get_cache_initializers( """ @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: """Provide a list of constant arguments to be passed to the attention op. The constant arguments are passed to the attention op as additional arguments after the caches. The constants are expected to be of type int, float, str, or None. + + Args: + source_attn_node: The source attention node from which to extract constants. + cache_config: Optional KV cache configuration for runtime constants like + sink_token_length that are not encoded in the model graph. """ return [] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py index 197bac939d4..2fc6a5f509c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py @@ -216,7 +216,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: scale = extract_op_args(source_attn_node, "scale")[0] if scale is None: key_node = source_attn_node.args[1] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py index 0842cafb2c1..2c574cd8e47 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py @@ -253,7 +253,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: scale = extract_op_args(source_attn_node, "scale")[0] if scale is None: key_node = source_attn_node.args[1] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py index ddbe4b8ccc9..58f96f721f9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py @@ -414,7 +414,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: scale = extract_op_args(source_attn_node, "scale")[0] if scale is None: key_node = source_attn_node.args[1] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py index 22619c992b1..fd86ac83d66 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py @@ -14,7 +14,7 @@ # limitations under the License. from abc import abstractmethod -from typing import List +from typing import List, Optional import torch from torch._ops import OpOverloadPacket @@ -96,7 +96,9 @@ def get_cache_initializers( return {"conv_state_cache": conv_state_handler} @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py index 37245cebc19..a9ad18d8b9a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py @@ -366,7 +366,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: time_step_limit, chunk_size = extract_op_args( source_attn_node, "time_step_limit", "chunk_size" ) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index e23f1063181..42e0ca3209b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -330,7 +330,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py index 4be3e3a718c..0578835045c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py @@ -360,7 +360,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: # time_step_limit, chunk_size should be extracted and passed in as constants time_step_limit, chunk_size = extract_op_args( source_attn_node, "time_step_limit", "chunk_size" diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py index bb816f48707..ada6c49e18b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py @@ -973,7 +973,9 @@ def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCalla return prepare_flashinfer_mla_metadata_host @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: """Get constants to pass to the cached attention op.""" # Extract kv_lora_rank for cache operations compressed_kv_fake = source_attn_node.args[2].meta["val"] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py index e3f5b029625..954afbb9075 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py @@ -555,7 +555,9 @@ def get_cache_initializers( } @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: + def get_constants( + cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None + ) -> List[Constant]: """Get constants to pass to the cached attention op.""" # Extract kv_lora_rank for cache slicing compressed_kv_fake = source_attn_node.args[2].meta["val"] diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index ea336dab0f4..c2d43b1ace0 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -73,7 +73,7 @@ from ..transform.optimizer import InferenceOptimizer from ..utils._graph import get_input_embeddings, get_lm_head_weights from ..utils.logger import ad_logger -from .interface import CachedSequenceInterface, GetInferenceModel +from .interface import CachedSequenceInterface, GetInferenceModel, MultiPoolKVCacheManager # `layout_metadata` is a reserved multimodal payload used to carry request-level # layout semantics that the wrapper reconstructs into tensor kwargs separately. @@ -805,10 +805,34 @@ def _prepare_inputs( mask_scatter_indices.extend(list(range(cu_seqlen[-2], cu_seqlen[-1]))) # store cache information for all requests now + # VSWA is driven by WindowPlan (the single source of truth, set by the + # graph transform). The plan determines logical window routing; the + # pool type determines physical storage. + window_plan = self.cache_seq_interface.window_plan + seq_info = self.cache_seq_interface.info + window_groups = seq_info.window_groups + is_vswa = window_plan is not None and window_plan.is_vswa + is_multi_pool = isinstance(kv_cache_manager, MultiPoolKVCacheManager) + cache_loc: List[int] = [] cu_num_pages: List[int] = [0] extra_page_per_seq: List[int] = [] state_slot_idx: List[int] = [] + + # Per-group cache data (groups 1..N-1); group 0 uses the main cache_loc/cu_num_pages + cache_loc_per_group: Optional[Dict[int, List[int]]] = None + cu_num_pages_per_group: Optional[Dict[int, List[int]]] = None + extra_page_per_seq_per_group: Optional[Dict[int, List[int]]] = None + # kv_page_offset: per-sequence offset so the write kernel can convert global + # page indices to window-relative indices. For non-VSWA this stays all-zeros. + kv_page_offset: List[int] = [] + kv_page_offset_per_group: Optional[Dict[int, List[int]]] = None + if is_vswa: + cache_loc_per_group = {g: [] for g in range(1, len(window_groups))} + cu_num_pages_per_group = {g: [0] for g in range(1, len(window_groups))} + extra_page_per_seq_per_group = {g: [] for g in range(1, len(window_groups))} + kv_page_offset_per_group = {g: [] for g in range(1, len(window_groups))} + for i, request in enumerate(ordered_requests): # store seq slot idx (use mamba_cache_index if available) request.py_batch_idx = request.py_seq_slot @@ -821,16 +845,71 @@ def _prepare_inputs( # get some info on the current request seq_len_i = cu_seqlen[i + 1] - cu_seqlen[i] end_compute_i = input_pos[i] + seq_len_i - num_active_blocks_i = kv_cache_manager.get_num_kv_blocks(end_compute_i) - - # construct cache information for the current request - cache_indices = kv_cache_manager.get_cache_indices(request) - cache_loc.extend(cache_indices[:num_active_blocks_i]) - cu_num_pages.append(cu_num_pages[i] + num_active_blocks_i) - if len(cache_indices) > num_active_blocks_i: - extra_page_per_seq.append(cache_indices[num_active_blocks_i]) + + if is_vswa: + # VSWA: WindowPlan maps each window group → storage pool and + # window_size. A pool may serve multiple window groups + # (same-head-dim VSWA) or one (multi-pool VSWA). + for group_idx, window_size in enumerate(window_groups): + pool_idx = window_plan.get_pool_idx(group_idx) + group_window = window_plan.get_window(group_idx) + if is_multi_pool: + pool = kv_cache_manager.get_pool(pool_idx) + else: + pool = kv_cache_manager + all_indices = pool.get_cache_indices(request, window_size=group_window) + front_removed = pool.get_num_front_blocks_removed(request.py_request_id) + active_indices = all_indices[front_removed:] + num_active = len(active_indices) + page_offset_g = front_removed + + if front_removed > 0 and i == 0: # log once per batch, first seq only + ad_logger.debug( + f"SWA eviction: group={group_idx} window={window_size} " + f"req={request.py_request_id} total_blocks={len(all_indices)} " + f"evicted={front_removed} active={num_active} offset={page_offset_g}" + ) + + if group_idx == 0: + cache_loc.extend(active_indices) + cu_num_pages.append(cu_num_pages[i] + num_active) + kv_page_offset.append(page_offset_g) + if len(all_indices) > front_removed + num_active: + extra_page_per_seq.append(all_indices[front_removed + num_active]) + else: + extra_page_per_seq.append(-1) + else: + cache_loc_per_group[group_idx].extend(active_indices) + prev = cu_num_pages_per_group[group_idx][i] + cu_num_pages_per_group[group_idx].append(prev + num_active) + kv_page_offset_per_group[group_idx].append(page_offset_g) + if len(all_indices) > front_removed + num_active: + extra_page_per_seq_per_group[group_idx].append( + all_indices[front_removed + num_active] + ) + else: + extra_page_per_seq_per_group[group_idx].append(-1) else: - extra_page_per_seq.append(-1) + num_active_blocks_i = kv_cache_manager.get_num_kv_blocks(end_compute_i) + # When a single-pool KVCacheManager has multi-window config (e.g. TEMP + # single-group hack), get_cache_indices requires an explicit window_size. + # Use the max window to retrieve all pages, matching the bala/gemma4 + # single-pool behavior. + single_pool_window = ( + max(kv_cache_manager.max_attention_window_vec) + if len(kv_cache_manager.max_attention_window_vec) > 1 + else None + ) + cache_indices = kv_cache_manager.get_cache_indices( + request, window_size=single_pool_window + ) + cache_loc.extend(cache_indices[:num_active_blocks_i]) + cu_num_pages.append(cu_num_pages[i] + num_active_blocks_i) + kv_page_offset.append(0) # no window offset for non-VSWA + if len(cache_indices) > num_active_blocks_i: + extra_page_per_seq.append(cache_indices[num_active_blocks_i]) + else: + extra_page_per_seq.append(-1) # Store batch information based on prefill, decode, and extend requests. num_decode = len(generation_requests) @@ -858,7 +937,12 @@ def _prepare_inputs( cache_loc=cache_loc, cu_num_pages=cu_num_pages, extra_page_per_seq=extra_page_per_seq, + kv_page_offset=kv_page_offset, slot_idx=state_slot_idx, + cache_loc_per_group=cache_loc_per_group, + cu_num_pages_per_group=cu_num_pages_per_group, + extra_page_per_seq_per_group=extra_page_per_seq_per_group, + kv_page_offset_per_group=kv_page_offset_per_group, gather_context_logits=gather_context_logits, _gather_idx=flat_gather_indices, _mask_scatter_indices=mask_scatter_indices, @@ -1253,8 +1337,13 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer ctx_chunk_config = None # scheduling + # For multi-pool (dual head_dim), cap max_num_requests to the minimum + # concurrent sequences across all pools so the scheduler never over-admits. + max_num_requests = ad_config.max_batch_size + if hasattr(kv_cache_manager, "max_concurrent_sequences"): + max_num_requests = min(max_num_requests, kv_cache_manager.max_concurrent_sequences) capacitor_scheduler = BindCapacityScheduler( - max_num_requests=ad_config.max_batch_size, + max_num_requests=max_num_requests, kv_cache_manager=kv_cache_manager.impl, peft_cache_manager=None, ) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/interface.py b/tensorrt_llm/_torch/auto_deploy/shim/interface.py index 8c60753aaab..b8ac9246d08 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/interface.py @@ -44,6 +44,207 @@ def wrapper(*args, **kwargs): return wrapper +class WindowPlan: + """Single source of truth for VSWA window grouping. + + Separates two independent concerns: + - Logical window groups: which layers share the same sliding window size + - Physical storage pools: which layers share the same KV cache layout (head_dim, etc.) + + A window group maps to exactly one storage pool, but a pool may serve + multiple window groups (same-head-dim VSWA). + + Both graph wiring (kvcache.py) and runtime metadata emission (ad_executor.py) + derive their behavior from this plan. + """ + + def __init__(self, per_layer_window: List[int], max_seq_len: int): + self.per_layer_window = per_layer_window + self.unique_windows = sorted(set(per_layer_window)) + self.window_to_group_idx = {w: i for i, w in enumerate(self.unique_windows)} + self.per_layer_group_idx = [self.window_to_group_idx[w] for w in per_layer_window] + self.max_seq_len = max_seq_len + + # Pool mapping — set by bind_pools() after cache managers are created. + # group_to_pool_idx[window_group] = storage pool index. + self.group_to_pool_idx: Optional[List[int]] = None + + def bind_pools( + self, + kv_groups: "List[Tuple[KVPagedResourceHandler, ResourceHandlerDict]]", + resource_lookup: "Dict[str, ResourceHandler]", + ) -> None: + """Map each window group to its storage pool by cross-referencing + which KV layers belong to each window group and each pool group. + + Args: + kv_groups: Storage pool groups from _identify_managed_kv_groups. + resource_lookup: Full resource name → handler mapping. + """ + # Build layer_idx → pool_idx mapping + layer_to_pool: Dict[int, int] = {} + kv_idx = 0 + for name, handler in resource_lookup.items(): + if not isinstance(handler, KVPagedResourceHandler): + continue + for pool_idx, (_, managed) in enumerate(kv_groups): + if name in managed: + layer_to_pool[kv_idx] = pool_idx + break + kv_idx += 1 + + # For each window group, find the pool via any layer in that group + self.group_to_pool_idx = [] + for group_idx in range(self.num_groups): + # Find first layer in this window group + for layer_idx, g in enumerate(self.per_layer_group_idx): + if g == group_idx and layer_idx in layer_to_pool: + self.group_to_pool_idx.append(layer_to_pool[layer_idx]) + break + else: + # Fallback: no pool found for this group (shouldn't happen) + self.group_to_pool_idx.append(0) + + def get_pool_idx(self, window_group_idx: int) -> int: + """Get the storage pool index for a window group.""" + if self.group_to_pool_idx is not None: + return self.group_to_pool_idx[window_group_idx] + return window_group_idx # fallback: assume 1:1 + + def get_window(self, window_group_idx: int) -> int: + """Get the window size for a window group.""" + return self.unique_windows[window_group_idx] + + @property + def num_groups(self) -> int: + return len(self.unique_windows) + + @property + def is_vswa(self) -> bool: + return self.num_groups >= 2 + + +class MultiPoolKVCacheManager: + """Wraps multiple KVCacheManagers (one per head_dim group) behind a unified API. + + Lifecycle methods (prepare/free/shutdown) are delegated to ALL pools. + The primary pool (full-attention, largest window) provides the C++ impl for the scheduler + and determines overall capacity. SWA pools have fixed size and never constrain scheduling. + """ + + def __init__(self, managers: List[KVCacheManager], primary_idx: int = 0): + self._managers = managers + self._primary_idx = primary_idx + + @property + def impl(self): + return self._managers[self._primary_idx].impl + + @property + def tokens_per_block(self): + return self._managers[self._primary_idx].tokens_per_block + + @property + def max_blocks_per_seq(self): + return self._managers[self._primary_idx].max_blocks_per_seq + + @property + def blocks_in_primary_pool(self): + return self._managers[self._primary_idx].blocks_in_primary_pool + + def get_num_free_blocks(self): + return min(m.get_num_free_blocks() for m in self._managers) + + def get_max_resource_count(self): + return self._managers[self._primary_idx].get_max_resource_count() + + def get_needed_resource_to_completion(self, request): + return self._managers[self._primary_idx].get_needed_resource_to_completion(request) + + def get_num_kv_blocks(self, num_tokens: int) -> int: + return self._managers[self._primary_idx].get_num_kv_blocks(num_tokens) + + def prepare_resources(self, scheduled_batch): + for m in self._managers: + m.prepare_resources(scheduled_batch) + + def free_resources(self, request, pin_on_release=False): + for m in self._managers: + m.free_resources(request, pin_on_release) + + def update_resources(self, scheduled_batch, attn_metadata=None, kv_cache_dtype_byte_size=None): + for m in self._managers: + m.update_resources(scheduled_batch, attn_metadata, kv_cache_dtype_byte_size) + + def add_dummy_requests(self, request_ids, **kwargs): + results = None + for m in self._managers: + results = m.add_dummy_requests(request_ids, **kwargs) + return results + + def shutdown(self): + for m in self._managers: + m.shutdown() + + def get_pool(self, group_idx: int) -> KVCacheManager: + return self._managers[group_idx] + + @property + def num_pools(self): + return len(self._managers) + + @property + def max_concurrent_sequences(self) -> int: + """Max sequences all pools can serve simultaneously. + + The minimum across pools of (total_blocks / max_blocks_per_seq). + Use this to cap the scheduler's max_num_requests. + """ + return min( + m.get_max_resource_count() // max(m.max_blocks_per_seq, 1) for m in self._managers + ) + + def get_buffers(self, idx: int, kv_layout: str = "NHD"): + raise NotImplementedError("Use get_pool(group_idx).get_buffers() instead") + + # Passthrough properties accessed by PyExecutor and other consumers + @property + def event_buffer_max_size(self): + return self._managers[self._primary_idx].event_buffer_max_size + + @property + def enable_block_reuse(self): + return self._managers[self._primary_idx].enable_block_reuse + + @property + def enable_partial_reuse(self): + return self._managers[self._primary_idx].enable_partial_reuse + + @property + def is_draft(self): + return self._managers[self._primary_idx].is_draft + + @property + def kv_cache_pool_pointers(self): + return self._managers[self._primary_idx].kv_cache_pool_pointers + + @property + def kv_cache_pool_mapping(self): + return self._managers[self._primary_idx].kv_cache_pool_mapping + + def get_cache_indices(self, request, **kwargs): + mgr = self._managers[self._primary_idx] + # When max_attention_window_vec has N identical entries (one per layer), + # the underlying get_cache_indices requires an explicit window_size. + if "window_size" not in kwargs: + kwargs["window_size"] = max(mgr.max_attention_window_vec) + return mgr.get_cache_indices(request, **kwargs) + + def store_blocks_for_reuse(self, request, pin_blocks=False): + for m in self._managers: + m.store_blocks_for_reuse(request, pin_blocks) + + @final class CachedSequenceInterface: """An interface responsible for maintaining information about sequences and their caches. @@ -100,6 +301,8 @@ def __init__( self._caches: Dict[str, torch.Tensor] = {} # KVCacheManager (or MambaHybridCacheManager) for managed resources self._kv_cache_manager: Optional[Union[KVCacheManager, MambaHybridCacheManager]] = None + # Logical window plan (set by the transform, used by executor) + self._window_plan: Optional[WindowPlan] = None # lookup of unmanaged resources self._unmanaged_resources: List[str] = [] self._spec_config = spec_config @@ -247,30 +450,34 @@ def _get_mamba_state_params( "mamba_ssm_cache_dtype": ssm_dtype, } - def _identify_managed_kv_resources( + def _identify_managed_kv_groups( self, - ) -> Tuple[Optional[KVPagedResourceHandler], ResourceHandlerDict]: - """Identify KV resources compatible with the reference handler for KVCacheManager. + ) -> List[Tuple[KVPagedResourceHandler, ResourceHandlerDict]]: + """Identify KV resource groups for multi-pool KVCacheManager creation. - The first KVPagedResourceHandler becomes the reference. All handlers matching - the reference (via __eq__) are collected for managed allocation. + Each unique (head_dim, dtype, kv_factor, kv_layout) combination becomes a group. + Every KVPagedResourceHandler belongs to exactly one group — no unmanaged KV layers. Returns: - Tuple of (reference_handler, managed_resources_dict). - reference_handler is None if no KV paged resources exist. + List of (reference_handler, managed_resources_dict) tuples, one per group. + Empty list if no KV paged resources exist. """ - kv_ref: Optional[KVPagedResourceHandler] = None - kv_managed: ResourceHandlerDict = {} + groups: List[Tuple[KVPagedResourceHandler, ResourceHandlerDict]] = [] for name, handler in self._resource_lookup.items(): if not isinstance(handler, KVPagedResourceHandler): continue - if kv_ref is None: - kv_ref = handler - if handler == kv_ref: - kv_managed[name] = handler - - return kv_ref, kv_managed + # Find matching group or create a new one + matched = False + for ref, managed in groups: + if handler == ref: + managed[name] = handler + matched = True + break + if not matched: + groups.append((handler, {name: handler})) + + return groups def _identify_managed_state_resources( self, @@ -378,6 +585,21 @@ def _prepare_kv_cache_config( # Make a deep copy of the kv_cache_config to avoid modifying the original object kv_cache_config = copy.deepcopy(self._kv_cache_config_original) + # Scope max_attention_window to this group's managed layers. + # Each KVCacheManager group only manages a subset of model layers. + if kv_cache_config.max_attention_window is not None and kv_managed: + managed_positions: List[int] = [] + kv_idx = 0 + for name, handler in self._resource_lookup.items(): + if isinstance(handler, KVPagedResourceHandler): + if name in kv_managed: + managed_positions.append(kv_idx) + kv_idx += 1 + if len(managed_positions) < len(kv_cache_config.max_attention_window): + kv_cache_config.max_attention_window = [ + kv_cache_config.max_attention_window[i] for i in managed_positions + ] + # Update kv_cache_config based on max_tokens if provided if max_tokens is not None: # sync max_tokens across ranks @@ -551,18 +773,21 @@ def _create_and_assign_state_views( return manager, num_managed_mamba_layers - def _assign_kv_cache_views(self, kv_managed: Dict[str, KVPagedResourceHandler]) -> int: + def _assign_kv_cache_views( + self, kv_managed: Dict[str, KVPagedResourceHandler], manager: KVCacheManager + ) -> int: """Retrieve and assign buffer views for managed KV paged resources. Args: kv_managed: Dict of KV resources managed by the cache manager. + manager: The KVCacheManager that owns these resources. Returns: block_offset_multiplier derived from the first KV cache view's strides. """ block_offset_multiplier = 0 for idx, (name, h) in enumerate(kv_managed.items()): - view = self._kv_cache_manager.get_buffers(idx, kv_layout=h.kv_layout) + view = manager.get_buffers(idx, kv_layout=h.kv_layout) assert view[0].is_contiguous(), f"Non-contiguous kv cache resource for {name}" self._caches[name] = view @@ -586,13 +811,101 @@ def _allocate_unmanaged_resources(self) -> None: self._caches[name] = handler.allocate(self.info) self._unmanaged_resources.append(name) + def _is_swa_group(self, kv_managed: ResourceHandlerDict) -> bool: + """Check if all layers in a group have sliding windows smaller than max_seq_len.""" + if self._kv_cache_config_original.max_attention_window is None: + return False + maw = self._kv_cache_config_original.max_attention_window + kv_idx = 0 + for name, handler in self._resource_lookup.items(): + if isinstance(handler, KVPagedResourceHandler): + if name in kv_managed and kv_idx < len(maw): + if maw[kv_idx] >= self.info.max_seq_len: + return False + kv_idx += 1 + return True + + def _compute_group_token_budget( + self, + group_idx: int, + kv_ref: KVPagedResourceHandler, + kv_managed: ResourceHandlerDict, + all_groups: List[Tuple[KVPagedResourceHandler, ResourceHandlerDict]], + total_max_tokens: Optional[int], + ) -> Optional[int]: + """Compute the max_tokens budget for a single KV cache group. + + All pools must support the same max number of concurrent sequences N. + N is derived from the total byte budget divided by the combined per-sequence + cost across all groups. Each group then gets N × its per-sequence tokens. + + All groups use max_seq_len for per-sequence cost (not window_size), because + during prefill each sequence temporarily uses max_seq_len blocks. SWA savings + manifest as freed blocks during decode, enabling higher throughput. + """ + if total_max_tokens is None and not self._is_swa_group(kv_managed): + return None # Let free_gpu_memory_fraction handle it + + tpb = self.info.tokens_per_block + + # Compute per-sequence BYTE cost — use max_seq_len for all groups. + # During prefill, sequences need full max_seq_len blocks regardless of window. + group_seq_bytes = [] + group_seq_tokens = [] + for _, gm in all_groups: + bpt = sum(h.bytes_per_token for h in gm.values()) + seq_tokens = self.info.max_seq_len + group_seq_bytes.append(bpt * seq_tokens) + group_seq_tokens.append(seq_tokens) + combined_cost_per_seq = sum(group_seq_bytes) + + if total_max_tokens is None: + # SWA group, no total budget — conservative 1-sequence estimate + return self.info.max_seq_len + + # Compute N = max concurrent sequences from total budget. + # Subtract 2 sequences as safety margin for rounding and allocation overhead. + total_bpt = sum(sum(h.bytes_per_token for h in m.values()) for _, m in all_groups) + total_budget_bytes = total_max_tokens * total_bpt + max_seqs = ( + max(1, int(total_budget_bytes / combined_cost_per_seq) - 2) + if combined_cost_per_seq > 0 + else 0 + ) + + # This group needs max_seqs × its per-sequence tokens + group_tokens = max_seqs * group_seq_tokens[group_idx] + + # Cap at max_batch_size × max_seq_len (can't need more) + group_tokens = min(group_tokens, self.info.max_batch_size * self.info.max_seq_len) + + # Floor: at least one block per sequence for warmup feasibility + min_tokens = self.info.max_batch_size * tpb + group_tokens = max(group_tokens, min_tokens) + + return group_tokens + + def _get_group_max_window(self, kv_managed: ResourceHandlerDict) -> int: + """Get the maximum attention window for a group's managed layers.""" + maw = self._kv_cache_config_original.max_attention_window + if maw is None: + return self.info.max_seq_len + max_w = 0 + kv_idx = 0 + for name, handler in self._resource_lookup.items(): + if isinstance(handler, KVPagedResourceHandler): + if name in kv_managed and kv_idx < len(maw): + max_w = max(max_w, maw[kv_idx]) + kv_idx += 1 + return max_w if max_w > 0 else self.info.max_seq_len + def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> Dict: - """Create KVCacheManager or MambaHybridCacheManager with standard layout. + """Create KVCacheManager(s) with standard layout. For paged resources (KVPagedResourceHandler): - - Uses the first KVPagedResourceHandler's head_dim and dtype as reference - - Compatible resources (matching head_dim and dtype) go into KVCacheManager - - Incompatible resources are allocated locally via handler.allocate() + - Groups layers by (head_dim, dtype, kv_factor, kv_layout) compatibility + - Each group gets its own KVCacheManager pool + - If multiple groups exist, wraps them in MultiPoolKVCacheManager For state resources (SSMResourceHandler, CausalConvResourceHandler, StateResourceHandler): - SSMResourceHandler maps to MambaHybridCacheManager's ssm_states buffer @@ -610,47 +923,101 @@ def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> Dict: 1. the final number of tokens is synced (min) across ranks 2. rounding for getting a multiple of tokens_per_block """ - # 1. Identify managed resources - kv_ref, kv_managed = self._identify_managed_kv_resources() + # 1. Identify managed resource groups (one per unique head_dim/dtype/kv_factor/layout) + kv_groups = self._identify_managed_kv_groups() ssm_ref, ssm_managed, ssm_spec, conv_ref, conv_managed, conv_spec = ( self._identify_managed_state_resources() ) + has_state_resources = ssm_managed or conv_managed - # 2. Prepare configuration - kv_cache_config = self._prepare_kv_cache_config(max_tokens, kv_managed) - kv_cache_kwargs = self._build_kv_cache_kwargs(kv_ref, kv_managed, kv_cache_config) + # Collect ALL managed KV resources for stats (union of all groups) + kv_managed_all: ResourceHandlerDict = {} + for _, managed in kv_groups: + kv_managed_all.update(managed) + + # 2. Create one KVCacheManager per group + # SWA groups (window < max_seq_len) get fixed max_tokens. + # Full-attention groups get the remaining budget via max_tokens or free_gpu_mem_fraction. + managers: List[KVCacheManager] = [] + primary_idx = 0 # index of the full-attention (largest-window) group + max_window_seen = 0 + + for group_idx, (kv_ref, kv_managed) in enumerate(kv_groups): + # Compute this group's token budget + group_max_tokens = self._compute_group_token_budget( + group_idx, kv_ref, kv_managed, kv_groups, max_tokens + ) + group_config = self._prepare_kv_cache_config(group_max_tokens, kv_managed) + group_kwargs = self._build_kv_cache_kwargs(kv_ref, kv_managed, group_config) + + # NOTE: SWA groups keep max_seq_len from config (NOT window_size). + # During prefill, sequences temporarily use up to max_seq_len blocks. + # max_attention_window evicts old blocks during decode, freeing them + # for new sequences. The SWA savings are throughput (more concurrent + # decode sequences), not peak memory reduction. + + if has_state_resources and group_idx == 0: + group_kwargs["max_batch_size"] = self.info.max_num_state_slots + mgr, _ = self._create_and_assign_state_views( + group_kwargs, + ssm_ref, + ssm_managed, + ssm_spec, + conv_ref, + conv_managed, + conv_spec, + ) + else: + mgr = KVCacheManager(**group_kwargs) - # 3. Create cache manager (delegate to state helper if state resources exist) - has_state_resources = ssm_managed or conv_managed - if has_state_resources: - # NOTE: +1 for cuda graph padding - kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots - self._kv_cache_manager, _ = self._create_and_assign_state_views( - kv_cache_kwargs, - ssm_ref, - ssm_managed, - ssm_spec, - conv_ref, - conv_managed, - conv_spec, + managers.append(mgr) + is_swa = self._is_swa_group(kv_managed) + ad_logger.info( + f"KV pool {group_idx}: {len(kv_managed)} layers, " + f"head_dim={kv_ref.head_dim}, " + f"max_attention_window={group_config.max_attention_window}, " + f"swa={is_swa}, " + f"max_tokens={group_max_tokens}" ) + + # Track which group has the largest window (= primary for scheduler) + group_window = max(group_config.max_attention_window or [self.info.max_seq_len]) + if group_window > max_window_seen: + max_window_seen = group_window + primary_idx = group_idx + + # 3. Store manager (wrapper if multi-group, direct if single) + if len(managers) == 1: + self._kv_cache_manager = managers[0] else: - # No typed state resources - use pure KVCacheManager - self._kv_cache_manager = KVCacheManager(**kv_cache_kwargs) + self._kv_cache_manager = MultiPoolKVCacheManager(managers, primary_idx=primary_idx) - # 4. Store tuned config - self._kv_cache_config_tuned = kv_cache_config + # 4. Bind WindowPlan to storage pools (if plan exists) + if self._window_plan is not None and len(kv_groups) > 0: + self._window_plan.bind_pools(kv_groups, self._resource_lookup) - # 5. Assign KV views (compute block_offset_multiplier from first view's strides) - block_offset_multiplier = self._assign_kv_cache_views(kv_managed) + # 5. Store tuned config (use the primary group's config) + self._kv_cache_config_tuned = self._prepare_kv_cache_config(max_tokens, kv_managed_all) - # 6. Update cache information (resize cache_loc, set max_seq_info with all max sizes) + # 5. Assign KV views per group + block_offset_multiplier = 0 + for group_idx, (_, kv_managed) in enumerate(kv_groups): + mgr = managers[group_idx] + bom = self._assign_kv_cache_views(kv_managed, mgr) + if group_idx == primary_idx: + block_offset_multiplier = bom + + # 6. Update cache information using primary pool + primary_mgr = managers[primary_idx] + num_blocks = getattr( + primary_mgr, "blocks_in_primary_pool", primary_mgr.get_max_resource_count() + ) self.info.update_cache_information( - num_blocks=self._kv_cache_manager.blocks_in_primary_pool, + num_blocks=num_blocks, block_offset_multiplier=block_offset_multiplier, ) - # 7. Allocate remaining unmanaged resources + # 7. Allocate remaining unmanaged resources (non-KV only; all KV layers are managed) self._allocate_unmanaged_resources() # 8. Patch shutdown @@ -782,12 +1149,16 @@ def resize_kv_cache_manager(self, mem_exclude: int = 0) -> None: This implements the two-phase approach: after running a forward pass during estimation to allocate intermediate memory, call this method to recreate the cache manager. - The new manager will compute optimal capacity based on current free GPU memory. + + For multi-pool (dual head_dim): SWA pools have fixed size (max_batch_size × window). + Only the full-attention pool benefits from resize. All pools are shutdown and + recreated so that the full-attention pool gets the remaining memory after SWA pools + and forward-pass intermediates are accounted for. """ if not self.needs_resize(): return - # Calculate bytes-per-token for paged (resizable) resources + # Calculate bytes-per-token for ALL paged resources (across all groups) paged_bytes_per_token = sum( h.bytes_per_token for h in self._resource_lookup.values() if h.is_paged ) @@ -806,14 +1177,14 @@ def resize_kv_cache_manager(self, mem_exclude: int = 0) -> None: _, free_mem, *_ = get_mem_info(empty_cache=True) # Compute available memory for paged caches - # Reserve space for non-paged caches and mem_exclude, then apply free_gpu_memory_fraction free_gpu_memory_fraction = self._kv_cache_config_original.free_gpu_memory_fraction mem_for_paged_optimal = ( free_mem - non_paged_bytes_total - mem_exclude ) * free_gpu_memory_fraction max_tokens_optimal = int(mem_for_paged_optimal // paged_bytes_per_token) - # Create new cache manager with optimal capacity + # Recreate all pools. _compute_group_token_budget handles the split: + # SWA groups get fixed tokens, full-attention gets the rest. cache_stats = self._create_kv_cache_manager(max_tokens=max_tokens_optimal) max_tokens_final = cache_stats["max_tokens"] @@ -832,8 +1203,17 @@ def resize_kv_cache_manager(self, mem_exclude: int = 0) -> None: ) @property - def kv_cache_manager(self) -> Optional[KVCacheManager]: - """Return the KVCacheManager managing paged resources, or None if not initialized.""" + def window_plan(self) -> Optional[WindowPlan]: + """Return the logical window plan, or None if not set.""" + return self._window_plan + + @window_plan.setter + def window_plan(self, plan: WindowPlan) -> None: + self._window_plan = plan + + @property + def kv_cache_manager(self) -> Optional[Union[KVCacheManager, MultiPoolKVCacheManager]]: + """Return the KVCacheManager (or multi-pool wrapper), or None if not initialized.""" assert self._kv_cache_manager is not None, "KVCacheManager not initialized." return self._kv_cache_manager diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 3f4d5122584..9bdd8d931d4 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -35,7 +35,7 @@ from ...shim.interface import CachedSequenceInterface from ...utils._graph import add_graph_input from ...utils.cuda_mem_tracker import get_mem_info -from ...utils.node_utils import get_op_schema, is_op +from ...utils.node_utils import extract_op_args, get_op_schema, is_op from ..interface import ( BaseTransform, SharedConfig, @@ -188,6 +188,69 @@ def _apply( # Register host-side prepare_metadata function for attention descriptor. self._process_metadata_host(cm) + # Extract per-layer sliding_window values before nodes are erased. + # This bridges the model's attention config to the KV cache manager's + # max_attention_window for window-aware block allocation. + per_layer_sliding_windows = [] + for attn_node in source_attn_nodes: + (sw,) = extract_op_args(attn_node, "sliding_window") + per_layer_sliding_windows.append(sw) + + # Build the effective per-layer window list and create a WindowPlan. + # This is the single source of truth for VSWA: both graph wiring and + # runtime metadata emission derive from it. User-provided + # max_attention_window takes priority; otherwise auto-detect from the + # model's sliding_window annotations. + from ...shim.interface import WindowPlan + + has_any_sliding_window = any( + isinstance(sw, int) and sw > 0 for sw in per_layer_sliding_windows + ) + if cm.kv_cache_config.max_attention_window is not None: + max_attention_window = list(cm.kv_cache_config.max_attention_window) + elif has_any_sliding_window: + max_attention_window = [ + sw if isinstance(sw, int) and sw > 0 else cm.info.max_seq_len + for sw in per_layer_sliding_windows + ] + cm.update_kv_cache_config(max_attention_window=max_attention_window) + else: + max_attention_window = None + + # Build and store the WindowPlan on the interface. + window_plan: Optional[WindowPlan] = None + if max_attention_window is not None: + window_plan = WindowPlan(max_attention_window, cm.info.max_seq_len) + cm.window_plan = window_plan + + # VSWA graph wiring: create per-group cache_loc/cu_num_pages placeholders + # when the WindowPlan has multiple groups. The executor populates + # per-group metadata at runtime (driven by cm.window_plan). + is_vswa = window_plan is not None and window_plan.is_vswa + vswa_group_nodes: dict[int, dict[str, "Node"]] = {} + + if is_vswa: + cm.info.register_window_groups(window_plan.unique_windows) + # Create graph placeholders for groups 1..N-1 + vswa_swappable_bases = { + "cache_loc", + "cu_num_pages", + "last_page_len", + "kv_page_offset", + } + host_suffix = "_host" + std_arg_names = self.attn_descriptor.get_standard_metadata_args() + for group_idx in range(1, window_plan.num_groups): + vswa_group_nodes[group_idx] = {} + for arg_name in std_arg_names: + base = arg_name.removesuffix(host_suffix) + if base in vswa_swappable_bases: + is_host = arg_name.endswith(host_suffix) + group_arg = f"{base}_g{group_idx}{host_suffix if is_host else ''}" + vswa_group_nodes[group_idx][arg_name] = self._add_or_retrieve_input( + gm, cm, group_arg + ) + # replace fused attention node with attention node that has kv cache num_cached_attn_replacements = 0 cache_nodes_by_layer_idx = {} @@ -235,7 +298,20 @@ def _apply( attn_descriptor.prepare_node_for_cache_insertion(gm, attn_node) # retrieve constants for attention_op - constants = attn_descriptor.get_constants(attn_node) + constants = attn_descriptor.get_constants(attn_node, cm.kv_cache_config) + + # For VSWA, swap cache_loc/cu_num_pages/kv_page_offset nodes to the + # layer's window group (from WindowPlan) so each layer reads/writes + # through its own windowed page table. + layer_meta_nodes_std = meta_nodes_std + if is_vswa: + group_idx = window_plan.per_layer_group_idx[idx] + if group_idx > 0: + std_arg_names = self.attn_descriptor.get_standard_metadata_args() + layer_meta_nodes_std = list(meta_nodes_std) + for arg_pos, arg_name in enumerate(std_arg_names): + if arg_name in vswa_group_nodes.get(group_idx, {}): + layer_meta_nodes_std[arg_pos] = vswa_group_nodes[group_idx][arg_name] # insert cached attention replacement op self._insert_cached_attn_node( @@ -243,7 +319,7 @@ def _apply( attn_node, attn_descriptor.get_cached_attention_op(), qkv, - meta_nodes_std, + layer_meta_nodes_std, meta_nodes_extra, cache_in_nodes, constants, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 971ff8c5402..25abe9d0994 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -404,17 +404,11 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], ) else: if self.is_vswa: - # VSWA case: use C++ implementation for variable window sizes - if model_config is None: - raise ValueError( - "model_config is required for VSWA (Variable Sliding Window Attention)" - ) assert isinstance( kv_cache_config, KvCacheConfig ), "calculate_max_num_blocks_for_vswa only accepts KvCacheConfig" blocks_per_window = self.calculate_max_num_blocks_for_vswa( kv_cache_config=kv_cache_config, - model_config=model_config, extra_cost_memory=0, ) if mapping.world_size > 1: @@ -1012,6 +1006,10 @@ def get_cache_indices(self, assert len(result) == 1 return result[0] + def get_num_front_blocks_removed(self, request_id: int) -> int: + """Get the number of front blocks evicted by SWA for a sequence.""" + return self.impl.get_num_front_blocks_removed(request_id) + def unpin_blocks_by_id(self, kv_cache_block_id: int): self.impl.unpin_blocks_by_id(kv_cache_block_id) @@ -1235,11 +1233,11 @@ def adjust_window_sizes_for_vswa( window_size_to_layers: Dict[int, List[int]], max_attention_window_vec: List[int], kv_cache_config: KvCacheConfig, - model_config: ModelConfigCpp, pool_memory_bytes: int, kv_factor: int, dtype: DataType, is_cross_attention: bool = False, + model_config: Optional[ModelConfigCpp] = None, ) -> Tuple[Dict[int, List[int]], List[int]]: assert is_cross_attention is False, 'Cross attention is not supported' @@ -1249,7 +1247,7 @@ def adjust_window_sizes_for_vswa( def calculate_cache_size_per_token(layers: Set[int]) -> int: # Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize total_kv_heads = sum(self.num_kv_heads_per_layer[i] for i in layers) - return total_kv_heads * kv_factor * model_config.head_size + return total_kv_heads * kv_factor * self.head_dim # Calculate the required memory bytes per sequence. required_mem_bytes_per_seq = 0 @@ -1351,7 +1349,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: def calculate_max_num_blocks_for_vswa( self, kv_cache_config: KvCacheConfig, - model_config: ModelConfigCpp, + model_config: Optional[ModelConfigCpp] = None, extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]: """ Currently, this function is added to support *ONLY* VSWA. @@ -1376,8 +1374,6 @@ def calculate_max_num_blocks_for_vswa( # VSWA on Torch backend has not supported the cross attention. is_cross_attention = False - # check model config - assert model_config.layer_types is not None, "layer_types have to be set correctly for VSWA" # Construct WorldConfig from self.mapping world_config_cpp = WorldConfig( @@ -1406,7 +1402,6 @@ def calculate_max_num_blocks_for_vswa( window_size_to_layers, max_attention_window_vec = self.adjust_window_sizes_for_vswa( window_size_to_layers=window_size_to_layers, max_attention_window_vec=self.max_attention_window_vec, - model_config=model_config, kv_cache_config=kv_cache_config, pool_memory_bytes=self._primary_pool_memory_bytes, kv_factor=self.kv_factor, @@ -1418,7 +1413,7 @@ def calculate_max_num_blocks_for_vswa( def calculate_cache_size_per_token(layers: Set[int]) -> int: # Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize total_kv_heads = sum(self.num_kv_heads_per_layer[i] for i in layers) - return total_kv_heads * self.kv_factor * model_config.head_size + return total_kv_heads * self.kv_factor * self.head_dim logger.info( f"Primary pool memory bytes: {self._primary_pool_memory_bytes}") diff --git a/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py b/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py index 6a5f8d1500c..80b369415ef 100644 --- a/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py +++ b/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py @@ -701,6 +701,19 @@ def test_sequence_info_update_cache_information_resizes(): assert seq_info._input_buffer.get_capacity("cache_loc") >= expected_capacity +def test_sequence_info_update_cache_information_preserves_max_blocks(): + """Verify max_blocks_per_seq is always based on max_seq_len (not window).""" + seq_info = SequenceInfo( + max_seq_len=128, + max_batch_size=4, + tokens_per_block=32, + ) + + original_max_blocks = seq_info.max_blocks_per_seq # ceil(128 / 32) = 4 + seq_info.update_cache_information(num_blocks=100) + assert seq_info.max_blocks_per_seq == original_max_blocks + + def test_sequence_info_last_page_len_uses_tokens_per_block(): """Verify nest_sequences calculates last_page_len using tokens_per_block.""" seq_info = SequenceInfo( @@ -1063,3 +1076,206 @@ def dummy_host_prepare(batch_info_host: torch.Tensor, cu_num_pages_host: torch.T assert "batch_info_host" in seq_info._active_host_prep_args assert "cu_num_pages_host" in seq_info._active_host_prep_args + + +# ============================================================================= +# Dual-Pool (Multi-Group) KV Cache Tests +# ============================================================================= + + +def test_identify_managed_kv_groups_single_group(): + """Single head_dim produces one group — same behavior as before.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=KvCacheConfig( + tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0 + ), + ) + interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_1", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + + groups = interface._identify_managed_kv_groups() + + assert len(groups) == 1 + ref, managed = groups[0] + assert ref.head_dim == 64 + assert len(managed) == 2 + + +def test_identify_managed_kv_groups_dual_group(): + """Different head_dims produce two groups.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=KvCacheConfig( + tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0 + ), + ) + # Group A: head_dim=64 + interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_1", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + # Group B: head_dim=128 + interface.add_resource("kv_2", KVPagedResourceHandler(4, 128, dtype=torch.float16)) + + groups = interface._identify_managed_kv_groups() + + assert len(groups) == 2 + assert groups[0][0].head_dim == 64 + assert len(groups[0][1]) == 2 + assert groups[1][0].head_dim == 128 + assert len(groups[1][1]) == 1 + + +def test_identify_managed_kv_groups_no_unmanaged(): + """All KV layers belong to a group — none are unmanaged.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=KvCacheConfig( + tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0 + ), + ) + interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16)) + interface.add_resource("kv_2", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + + groups = interface._identify_managed_kv_groups() + + total_managed = sum(len(managed) for _, managed in groups) + assert total_managed == 3 # all layers managed, none left out + + +def test_dual_pool_creates_multi_pool_manager(): + """Two head_dim groups create a MultiPoolKVCacheManager.""" + from tensorrt_llm._torch.auto_deploy.shim.interface import MultiPoolKVCacheManager + + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=KvCacheConfig( + tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0 + ), + ) + # Two groups with different head_dims + interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_1", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_2", KVPagedResourceHandler(4, 128, dtype=torch.float16)) + + interface.initialize_resources() + + assert isinstance(interface.kv_cache_manager, MultiPoolKVCacheManager) + assert interface.kv_cache_manager.num_pools == 2 + + +def test_single_group_creates_plain_kv_cache_manager(): + """One head_dim group creates a plain KVCacheManager (no wrapper).""" + from tensorrt_llm._torch.auto_deploy.shim.interface import MultiPoolKVCacheManager + + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=KvCacheConfig( + tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0 + ), + ) + interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_1", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + + interface.initialize_resources() + + assert isinstance(interface.kv_cache_manager, KVCacheManager) + assert not isinstance(interface.kv_cache_manager, MultiPoolKVCacheManager) + + +def test_dual_pool_cache_views_correct_shape(): + """Each group's cache views have the correct head_dim.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=KvCacheConfig( + tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0 + ), + ) + interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16)) + + interface.initialize_resources() + + # Group 0 (head_dim=64): cache shape [..., 8, 32, 64] + kv_0 = interface._caches[list(interface._caches.keys())[0]] + assert kv_0.shape[-1] == 64 + assert kv_0.shape[-3] == 8 # num_kv_heads + + # Group 1 (head_dim=128): cache shape [..., 4, 32, 128] + kv_1 = interface._caches[list(interface._caches.keys())[1]] + assert kv_1.shape[-1] == 128 + assert kv_1.shape[-3] == 4 # num_kv_heads + + +def test_multi_pool_manager_lifecycle(): + """MultiPoolKVCacheManager delegates lifecycle to all pools.""" + from tensorrt_llm._torch.auto_deploy.shim.interface import MultiPoolKVCacheManager + + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=KvCacheConfig( + tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0 + ), + ) + interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16)) + + interface.initialize_resources() + mgr = interface.kv_cache_manager + + assert isinstance(mgr, MultiPoolKVCacheManager) + # Each pool accessible + pool_0 = mgr.get_pool(0) + pool_1 = mgr.get_pool(1) + assert isinstance(pool_0, KVCacheManager) + assert isinstance(pool_1, KVCacheManager) + # impl is from primary pool (largest window) + assert mgr.impl is not None + # get_num_free_blocks returns min across pools + assert mgr.get_num_free_blocks() <= pool_0.get_num_free_blocks() + assert mgr.get_num_free_blocks() <= pool_1.get_num_free_blocks() + + +def test_max_attention_window_scoped_per_group(): + """max_attention_window is filtered to each group's managed layers.""" + interface = CachedSequenceInterface( + max_seq_len=256, + max_batch_size=4, + device="cuda", + kv_cache_config=KvCacheConfig( + tokens_per_block=32, + max_tokens=2048, + free_gpu_memory_fraction=0.0, + max_attention_window=[64, 64, 256], # 3 layers: 2 SWA + 1 full + ), + ) + # Group A: head_dim=64 (layers 0, 1 → windows [64, 64]) + interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("kv_1", KVPagedResourceHandler(8, 64, dtype=torch.float16)) + # Group B: head_dim=128 (layer 2 → window [256]) + interface.add_resource("kv_2", KVPagedResourceHandler(4, 128, dtype=torch.float16)) + + interface.initialize_resources() + mgr = interface.kv_cache_manager + + # Group 0 (SWA): max_attention_window_vec should be [64, 64] + pool_0 = mgr.get_pool(0) + assert max(pool_0.max_attention_window_vec) == 64 + + # Group 1 (full): max_attention_window_vec should be [256] + pool_1 = mgr.get_pool(1) + assert max(pool_1.max_attention_window_vec) == 256 diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py index 34a8abd0da7..24cb6f7517d 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py @@ -550,3 +550,574 @@ def test_insert_cached_attention_passes_kv_cache_config(): for name, handler in cm._resource_lookup.items(): if hasattr(handler, "dtype"): assert handler.dtype == torch.bfloat16 + + +# ============================================================================= +# Sliding Window KV Cache Integration Tests +# ============================================================================= + + +class SlidingWindowGQA(GQAWithSdpaAndEmbedding): + """GQA model with a configurable per-layer sliding_window for testing.""" + + def __init__( + self, + num_attention_heads: int, + hidden_size: int, + num_key_value_heads: int, + vocab_size: int = 1000, + sliding_window: Optional[int] = None, + ): + super().__init__(num_attention_heads, hidden_size, num_key_value_heads, vocab_size) + self._sliding_window = sliding_window + + @torch.no_grad() + def forward( + self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = self.embed_tokens(input_ids) + b, s, _ = x.shape + q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim) + k = self.k_proj(x).view(b, s, self.num_kv_heads, self.head_dim) + v = self.v_proj(x).view(b, s, self.num_kv_heads, self.head_dim) + attn_output = torch.ops.auto_deploy.torch_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + scale=None, + sinks=None, + sliding_window=self._sliding_window, + logit_cap=None, + layout="bsnd", + ) + return self.o_proj(attn_output.reshape(b, s, -1)) + + +class VSWAModel(nn.Module): + """Model with two attention layers: one sliding-window, one full-attention (VSWA).""" + + def __init__( + self, + num_attention_heads: int, + hidden_size: int, + num_key_value_heads: int, + vocab_size: int = 1000, + sliding_window: int = 32, + ): + super().__init__() + self.num_heads = num_attention_heads + self.num_kv_heads = num_key_value_heads + self.head_dim = hidden_size // num_attention_heads + self.embed_tokens = nn.Embedding(vocab_size, hidden_size) + self.q_proj_0 = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj_0 = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=False) + self.v_proj_0 = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=False) + self.o_proj_0 = nn.Linear(hidden_size, hidden_size, bias=False) + self.q_proj_1 = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj_1 = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=False) + self.v_proj_1 = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=False) + self.o_proj_1 = nn.Linear(hidden_size, hidden_size, bias=False) + self._sliding_window = sliding_window + + @torch.no_grad() + def forward( + self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = self.embed_tokens(input_ids) + b, s, _ = x.shape + + # Layer 0: sliding window attention + q0 = self.q_proj_0(x).view(b, s, self.num_heads, self.head_dim) + k0 = self.k_proj_0(x).view(b, s, self.num_kv_heads, self.head_dim) + v0 = self.v_proj_0(x).view(b, s, self.num_kv_heads, self.head_dim) + a0 = torch.ops.auto_deploy.torch_attention( + q0, + k0, + v0, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + scale=None, + sinks=None, + sliding_window=self._sliding_window, + logit_cap=None, + layout="bsnd", + ) + x = x + self.o_proj_0(a0.reshape(b, s, -1)) + + # Layer 1: full attention (no sliding window) + q1 = self.q_proj_1(x).view(b, s, self.num_heads, self.head_dim) + k1 = self.k_proj_1(x).view(b, s, self.num_kv_heads, self.head_dim) + v1 = self.v_proj_1(x).view(b, s, self.num_kv_heads, self.head_dim) + a1 = torch.ops.auto_deploy.torch_attention( + q1, + k1, + v1, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + scale=None, + sinks=None, + sliding_window=None, + logit_cap=None, + layout="bsnd", + ) + return x + self.o_proj_1(a1.reshape(b, s, -1)) + + +def _build_optimizer_with_backend(model, backend="triton_paged"): + """Helper to create an InferenceOptimizer for testing.""" + return InferenceOptimizer( + DummyFactory(model, cache_config_updates={}), + { + "build_model": { + "stage": "factory", + "run_per_gm": False, + "device": "cuda", + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "export_to_gm": { + "stage": "export", + "strict": False, + "run_per_gm": False, + "clone_state_dict": True, + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "cleanup_input_constraints": { + "stage": "post_export", + }, + "insert_cached_attention": { + "stage": "cache_init", + "backend": backend, + }, + }, + ) + + +@torch.inference_mode() +def test_insert_cached_attention_extracts_sliding_window(): + """Verify insert_cached_attention sets max_attention_window from graph sliding_window.""" + sliding_window = 32 + max_seq_len = 128 + batch_size = 4 + + kv_cache_config = KvCacheConfig( + tokens_per_block=max_seq_len, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + assert cm.kv_cache_config.max_attention_window is None + + model = SlidingWindowGQA( + num_attention_heads=8, + hidden_size=512, + num_key_value_heads=8, + sliding_window=sliding_window, + ).to(dtype=torch.float16, device="cuda") + + optimizer = _build_optimizer_with_backend(model) + optimizer(cm) + + assert cm.kv_cache_config.max_attention_window is not None + assert cm.kv_cache_config.max_attention_window == [sliding_window] + + +@torch.inference_mode() +def test_insert_cached_attention_no_sliding_window_leaves_config_unchanged(): + """Verify insert_cached_attention does not set max_attention_window for non-SWA models.""" + max_seq_len = 128 + batch_size = 4 + + kv_cache_config = KvCacheConfig( + tokens_per_block=max_seq_len, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + model = GQAWithSdpaAndEmbedding( + num_attention_heads=8, + hidden_size=512, + num_key_value_heads=8, + ).to(dtype=torch.float16, device="cuda") + + optimizer = _build_optimizer_with_backend(model) + optimizer(cm) + + assert cm.kv_cache_config.max_attention_window is None + + +@torch.inference_mode() +def test_insert_cached_attention_respects_user_override(): + """Verify insert_cached_attention does not overwrite user-set max_attention_window.""" + max_seq_len = 128 + batch_size = 4 + user_window = [64] + + kv_cache_config = KvCacheConfig( + tokens_per_block=max_seq_len, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + max_attention_window=user_window, + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + model = SlidingWindowGQA( + num_attention_heads=8, + hidden_size=512, + num_key_value_heads=8, + sliding_window=32, + ).to(dtype=torch.float16, device="cuda") + + optimizer = _build_optimizer_with_backend(model) + optimizer(cm) + + # User-provided value must be preserved + assert cm.kv_cache_config.max_attention_window == user_window + + +@torch.inference_mode() +def test_insert_cached_attention_vswa_preserves_per_layer_windows(): + """Verify VSWA model preserves per-layer window sizes for proportional allocation.""" + sliding_window = 32 + max_seq_len = 128 + batch_size = 4 + + kv_cache_config = KvCacheConfig( + tokens_per_block=32, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + model = VSWAModel( + num_attention_heads=8, + hidden_size=512, + num_key_value_heads=8, + sliding_window=sliding_window, + ).to(dtype=torch.float16, device="cuda") + + optimizer = _build_optimizer_with_backend(model) + optimizer(cm) + + # VSWA preserves per-layer windows: [32, 128] (not collapsed) + assert cm.kv_cache_config.max_attention_window is not None + assert len(cm.kv_cache_config.max_attention_window) == 2 + assert cm.kv_cache_config.max_attention_window == [sliding_window, max_seq_len] + + # Window groups should be registered on SequenceInfo + assert cm.info.num_window_groups == 2 + assert cm.info.window_groups == [sliding_window, max_seq_len] + assert cm.info.window_group_map == {sliding_window: 0, max_seq_len: 1} + + +@torch.inference_mode() +def test_kv_cache_manager_initialized_with_sliding_window(): + """Verify KVCacheManager receives max_attention_window_vec from SWA model. + + Runs the full insert_cached_attention + initialize_cache pipeline. + """ + import math + + sliding_window = 32 + max_seq_len = 128 + batch_size = 4 + tokens_per_block = 16 + + kv_cache_config = KvCacheConfig( + tokens_per_block=tokens_per_block, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + model = SlidingWindowGQA( + num_attention_heads=8, + hidden_size=512, + num_key_value_heads=8, + sliding_window=sliding_window, + ).to(dtype=torch.float16, device="cuda") + + # Run insert_cached_attention + initialize_cache + optimizer = InferenceOptimizer( + DummyFactory(model, cache_config_updates={}), + { + "build_model": { + "stage": "factory", + "run_per_gm": False, + "device": "cuda", + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "export_to_gm": { + "stage": "export", + "strict": False, + "run_per_gm": False, + "clone_state_dict": True, + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "cleanup_input_constraints": { + "stage": "post_export", + }, + "insert_cached_attention": { + "stage": "cache_init", + "backend": "triton_paged", + }, + "initialize_cache": { + "stage": "cache_init", + "run_per_gm": False, + }, + }, + ) + optimizer(cm) + + # KVCacheManager should exist and carry the window vector + mgr = cm.kv_cache_manager + assert mgr.max_attention_window_vec == [sliding_window] + + # max_blocks_per_seq should be tightened to the sliding window + expected_max_blocks = math.ceil(sliding_window / tokens_per_block) + assert cm.info.max_blocks_per_seq == expected_max_blocks + + +@torch.inference_mode() +def test_kv_cache_manager_vswa_proportional_allocation(): + """Verify VSWA models get proportional pool allocation in KVCacheManager. + + KVCacheManager detects VSWA from the per-layer max_attention_window vector + and allocates separate block pools per window size. + """ + import math + + sliding_window = 32 + max_seq_len = 128 + batch_size = 4 + tokens_per_block = 16 + + kv_cache_config = KvCacheConfig( + tokens_per_block=tokens_per_block, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + model = VSWAModel( + num_attention_heads=8, + hidden_size=512, + num_key_value_heads=8, + sliding_window=sliding_window, + ).to(dtype=torch.float16, device="cuda") + + optimizer = InferenceOptimizer( + DummyFactory(model, cache_config_updates={}), + { + "build_model": { + "stage": "factory", + "run_per_gm": False, + "device": "cuda", + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "export_to_gm": { + "stage": "export", + "strict": False, + "run_per_gm": False, + "clone_state_dict": True, + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "cleanup_input_constraints": { + "stage": "post_export", + }, + "insert_cached_attention": { + "stage": "cache_init", + "backend": "triton_paged", + }, + "initialize_cache": { + "stage": "cache_init", + "run_per_gm": False, + }, + }, + ) + optimizer(cm) + + # KVCacheManager should detect VSWA with per-layer windows + mgr = cm.kv_cache_manager + assert mgr.is_vswa is True + assert mgr.max_attention_window_vec == [sliding_window, max_seq_len] + + # max_blocks_per_seq reflects the larger window (full-attention layer) + expected_max_blocks = math.ceil(max_seq_len / tokens_per_block) + assert cm.info.max_blocks_per_seq == expected_max_blocks + + # Per-group cache tensors should be registered on SequenceInfo + assert cm.info.num_window_groups == 2 + assert "cache_loc_g1" in cm.info.available_args + assert "cu_num_pages_g1" in cm.info.available_args + + +# ============================================================================= +# VSWA SequenceInfo and Graph Wiring Tests +# ============================================================================= + + +@torch.inference_mode() +def test_sequence_info_register_window_groups(): + """Verify register_window_groups creates per-group tensors in InputBuffer.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo + + max_seq_len = 128 + batch_size = 4 + tokens_per_block = 16 + + seq_info = SequenceInfo( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + tokens_per_block=tokens_per_block, + ) + + # Before registration: no window groups + assert seq_info.num_window_groups == 0 + assert seq_info.window_groups == [] + + # Register two groups: [32, 128] + seq_info.register_window_groups([32, 128]) + + assert seq_info.num_window_groups == 2 + assert seq_info.window_groups == [32, 128] + assert seq_info.window_group_map == {32: 0, 128: 1} + + # Group 0 reuses existing tensors (cache_loc, cu_num_pages, etc.) + # Group 1 gets new tensors + assert "cache_loc_g1" in seq_info.available_args + assert "cu_num_pages_g1" in seq_info.available_args + assert "cu_num_pages_g1_host" in seq_info.available_args + assert "last_page_len_g1" in seq_info.available_args + assert "last_page_len_g1_host" in seq_info.available_args + assert "extra_page_per_seq_g1" in seq_info.available_args + + # Group 0 names should NOT appear with _g0 suffix + assert "cache_loc_g0" not in seq_info.available_args + + +@torch.inference_mode() +def test_vswa_graph_has_per_group_placeholders(): + """Verify VSWA model graph contains per-group cache_loc/cu_num_pages placeholders.""" + sliding_window = 32 + max_seq_len = 128 + batch_size = 4 + + kv_cache_config = KvCacheConfig( + tokens_per_block=32, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + model = VSWAModel( + num_attention_heads=8, + hidden_size=512, + num_key_value_heads=8, + sliding_window=sliding_window, + ).to(dtype=torch.float16, device="cuda") + + optimizer = _build_optimizer_with_backend(model) + gm = optimizer(cm) + + # Check that per-group graph placeholders exist + placeholder_names = [n.target for n in gm.graph.nodes if n.op == "placeholder"] + assert "cache_loc" in placeholder_names + assert "cu_num_pages" in placeholder_names + # Group 1 should have its own placeholders + assert "cache_loc_g1" in placeholder_names + assert "cu_num_pages_g1" in placeholder_names + + +# ============================================================================= +# Phase 3: sink_token_length wiring through get_constants +# ============================================================================= + + +def test_trtllm_get_constants_returns_sink_token_length(): + """Verify TrtllmAttention.get_constants reads sink_token_length from cache_config.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.trtllm_attention import ( + TrtllmAttention, + ) + from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm + from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + + # Create a simple model and export to get a real source attention node + model = ( + SlidingWindowGQA( + num_attention_heads=8, + hidden_size=512, + num_key_value_heads=8, + sliding_window=32, + ) + .eval() + .to(dtype=torch.float16, device="cpu") + ) + + input_ids = torch.randint(0, 1000, (1, 4)) + gm = torch_export_to_gm(model, (input_ids,)) + + source_op = TrtllmAttention.get_source_attention_op() + source_nodes = [n for n in gm.graph.nodes if is_op(n, source_op)] + assert len(source_nodes) == 1 + + # Without cache_config: sink_token_length defaults to 0 + constants_no_config = TrtllmAttention.get_constants(source_nodes[0]) + assert constants_no_config[-1] == 0 + + # With cache_config containing sink_token_length=4 + kv_cache_config = KvCacheConfig(sink_token_length=4) + constants_with_config = TrtllmAttention.get_constants(source_nodes[0], kv_cache_config) + assert constants_with_config[-1] == 4 + + # With cache_config with sink_token_length=None: defaults to 0 + kv_cache_config_none = KvCacheConfig(sink_token_length=None) + constants_none = TrtllmAttention.get_constants(source_nodes[0], kv_cache_config_none) + assert constants_none[-1] == 0