diff --git a/docs/source/models/supported-models.md b/docs/source/models/supported-models.md index 0492533a798..adb2f53ef72 100644 --- a/docs/source/models/supported-models.md +++ b/docs/source/models/supported-models.md @@ -13,6 +13,8 @@ The following is a table of supported models for the PyTorch backend: | `Exaone4ForCausalLM` | EXAONE 4.0 | `LGAI-EXAONE/EXAONE-4.0-32B` | | `ExaoneMoEForCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B` | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` | +| `Gemma3nForConditionalGeneration` [^8]| Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it` | +| `Gemma4ForConditionalGeneration` [^7]| Gemma 4 | `google/gemma-4-26B-A4B-it` | | `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6, GLM-4.7 | `THUDM/GLM-4-100B-A10B` | | `Glm4MoeLiteForCausalLM` [^6] | GLM-4.7-Flash | `zai-org/GLM-4.7-Flash` | | `GlmMoeDsaForCausalLM` | GLM-5 | `zai-org/GLM-5` | @@ -60,6 +62,8 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl [^4]: Overlap scheduler isn't supported when using EAGLE-3(Two Model Engine) for GPT-OSS. [^5]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml). [^6]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml). +[^7]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma4_moe.yaml). +[^8]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml). # Multimodal Feature Support Matrix (PyTorch Backend) diff --git a/examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb b/examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb new file mode 100644 index 00000000000..4286cf65a1b --- /dev/null +++ b/examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Deploying Gemma 4 MoE with TensorRT-LLM (AutoDeploy)\n", + "\n", + "This notebook walks you through serving **Gemma 4** (26B total, 4B activated MoE) with TensorRT-LLM using the **AutoDeploy** backend—same pattern as the Mistral Small 4 and GLM-4.7-Flash cookbooks in this folder.\n", + "\n", + "[TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/) is NVIDIA's open-source library for accelerating LLM inference on NVIDIA GPUs. AutoDeploy uses Hugging Face `transformers` modeling code and TensorRT-LLM graph transforms. See the [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).\n", + "\n", + "**Model resources:**\n", + "- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n", + "- Instruction-tuned MoE: [`google/gemma-4-26B-A4B-it`](https://huggingface.co/google/gemma-4-26B-A4B-it)\n", + "- Base MoE (no chat template): [`google/gemma-4-26B-A4B`](https://huggingface.co/google/gemma-4-26B-A4B)\n", + "\n", + "**Bundled AutoDeploy YAML (this branch):**\n", + "- **Instruction:** `examples/auto_deploy/model_registry/configs/gemma4_moe.yaml` — text-only export path; `attn_backend: triton_paged` (head_dim 512 / paged KV, CUDA-graph friendly).\n", + "- **Base:** `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` — same stack for the base checkpoint.\n", + "\n", + "`trtllm-serve` takes **one** YAML path via `--extra_llm_api_options` (or `--config`). The bundled MoE YAMLs omit `world_size`; add it (or copy the YAML and edit) so it matches your GPU count when you use tensor parallel or multi-GPU loading.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites and environment\n", + "\n", + "Run TensorRT-LLM in a GPU container, for example:\n", + "\n", + "```shell\n", + "docker run --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all -p 8000:8000 nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc1\n", + "```\n", + "\n", + "Use a TensorRT-LLM checkout that includes Gemma 4 AutoDeploy support (model card, tokenizer, and any required bridges should match your branch).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If pip is not available\n", + "!python -m ensurepip --default-pip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torch openai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Verify GPU\n", + "\n", + "Confirm CUDA and visible devices before starting the server.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python: 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0]\n", + "CUDA available: True\n", + "Num GPUs: 8\n", + "GPU[0]: NVIDIA H100 80GB HBM3\n", + "GPU[1]: NVIDIA H100 80GB HBM3\n", + "GPU[2]: NVIDIA H100 80GB HBM3\n", + "GPU[3]: NVIDIA H100 80GB HBM3\n", + "GPU[4]: NVIDIA H100 80GB HBM3\n", + "GPU[5]: NVIDIA H100 80GB HBM3\n", + "GPU[6]: NVIDIA H100 80GB HBM3\n", + "GPU[7]: NVIDIA H100 80GB HBM3\n" + ] + } + ], + "source": [ + "import sys\n", + "\n", + "import torch\n", + "\n", + "print(f\"Python: {sys.version}\")\n", + "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", + "print(f\"Num GPUs: {torch.cuda.device_count()}\")\n", + "\n", + "if torch.cuda.is_available():\n", + " for i in range(torch.cuda.device_count()):\n", + " print(f\"GPU[{i}]: {torch.cuda.get_device_name(i)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## OpenAI-compatible server\n", + "\n", + "From a shell **inside** the container, at the TensorRT-LLM repo root, start `trtllm-serve` with AutoDeploy.\n", + "\n", + "Use the Gemma 4 MoE YAML under `examples/auto_deploy/model_registry/configs/` (see the introduction). Add `world_size` to that YAML if your serve command needs an explicit tensor-parallel device count.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the model\n", + "\n", + "**Instruction-tuned (`gemma4_moe.yaml`):**\n", + "\n", + "```shell\n", + "trtllm-serve \"google/gemma-4-26B-A4B-it\" \\\n", + " --host 0.0.0.0 \\\n", + " --port 8000 \\\n", + " --backend _autodeploy \\\n", + " --trust_remote_code \\\n", + " --extra_llm_api_options examples/auto_deploy/model_registry/configs/gemma4_moe.yaml\n", + "```\n", + "\n", + "**Base checkpoint:** use model id `google/gemma-4-26B-A4B` and `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` instead.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When the server finishes loading weights, it is ready for requests.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the API\n", + "\n", + "Send chat completions with the OpenAI Python client pointed at the local server.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "BASE_URL = \"http://0.0.0.0:8000/v1\"\n", + "API_KEY = \"null\"\n", + "MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n", + "\n", + "client = OpenAI(base_url=BASE_URL, api_key=API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Chat completion example\n", + "==================================================\n", + "Response:\n", + "To find 15% of 85, you can use a few different methods. Here are two easy ways to think about it:\n", + "\n", + "### Method 1: The Breakdown Method (Mental Math)\n", + "This is often the easiest way to calculate percentages in your head by breaking the percentage into manageable parts (10% and 5%).\n", + "\n", + "1. **Find 10% of 85:**\n", + " To find 10%, simply move the decimal point one place to the left.\n", + " $85 \\div 10 = 8.5$\n", + "2. **Find 5% of 85:**\n", + " Since 5% is half of 10%, just divide your previous answer by 2.\n", + " $8.5 \\div 2 = 4.25$\n", + "3. **Add them together:**\n", + " $10\\% + 5\\% = 15\\%$\n", + " $8.5 + 4.25 = 12.75$\n", + "\n", + "***\n", + "\n", + "### Method 2: The Multiplication Method (Calculator/Paper)\n", + "To find a percentage, you can convert the percentage into a decimal and multiply it by the total number.\n", + "\n", + "1. **Convert 15% to a decimal:**\n", + " $15\\% = \\frac{15}{100} = 0.15$\n", + "2. **Multiply by 85:**\n", + " $85 \\times 0.15 = 12.75$\n", + "\n", + "**Final Answer:**\n", + "15% of 85 is **12.75**.\n" + ] + } + ], + "source": [ + "# Basic chat completion\n", + "print(\"Chat completion example\")\n", + "print(\"=\" * 50)\n", + "\n", + "response = client.chat.completions.create(\n", + " model=MODEL_ID,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"What is 15% of 85? Show your reasoning.\"},\n", + " ],\n", + " temperature=1.0,\n", + " top_p=0.95,\n", + " max_tokens=512,\n", + ")\n", + "\n", + "print(\"Response:\")\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Streaming response:\n", + "==================================================\n", + "The first 5 prime numbers are **2, 3, 5, 7, and 11**." + ] + } + ], + "source": [ + "# Streaming chat completion\n", + "print(\"Streaming response:\")\n", + "print(\"=\" * 50)\n", + "\n", + "stream = client.chat.completions.create(\n", + " model=MODEL_ID,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"What are the first 5 prime numbers?\"},\n", + " ],\n", + " temperature=0.7,\n", + " max_tokens=1024,\n", + " stream=True,\n", + ")\n", + "\n", + "for chunk in stream:\n", + " if chunk.choices[0].delta.content:\n", + " print(chunk.choices[0].delta.content, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Additional resources\n", + "\n", + "- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n", + "- [TensorRT-LLM documentation](https://nvidia.github.io/TensorRT-LLM/)\n", + "- [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html)\n", + "- [`gemma4_moe.yaml`](../model_registry/configs/gemma4_moe.yaml), [`gemma4_moe_base.yaml`](../model_registry/configs/gemma4_moe_base.yaml), [`models.yaml`](../model_registry/models.yaml)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml b/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml new file mode 100644 index 00000000000..2e5e1f8d5cb --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +runtime: trtllm +compile_backend: torch-cudagraph +model_factory: AutoModelForCausalLM +max_seq_len: 512 +max_batch_size: 8 +world_size: 1 + +# Gemma 3n uses shared-KV decode semantics in the tail layers. FlashInfer +# supports the read-only shared-KV cache path and alternating sliding windows. +attn_backend: flashinfer diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml new file mode 100644 index 00000000000..d31ba340bc9 --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe.yaml @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Gemma 4 MoE (26B total, 4B activated) — text-only AD export path. +# Uses triton paged attention backend: supports head_dim=512 (global_head_dim), +# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode. +model_factory: Gemma4ForConditionalGeneration +tokenizer: google/gemma-4-26B-A4B-it +attn_backend: triton_paged +compile_backend: torch-cudagraph +cuda_graph_config: + batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +max_num_tokens: 8192 +max_batch_size: 512 +max_seq_len: 8192 +enable_chunked_prefill: true +kv_cache_config: + enable_block_reuse: false + free_gpu_memory_fraction: 0.8 +transforms: + compile_model: + piecewise_enabled: true + mlir_elementwise_fusion: + enabled: true + gather_logits_before_lm_head: + enabled: true + fuse_gemms: + enabled: true diff --git a/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml b/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml new file mode 100644 index 00000000000..9e469676559 --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Gemma 4 MoE base (26B total, 4B activated) — text-only AD export path. +# Uses triton paged attention backend: supports head_dim=512 (global_head_dim), +# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode. +model_factory: Gemma4ForConditionalGeneration +tokenizer: google/gemma-4-26B-A4B +attn_backend: triton_paged +compile_backend: torch-cudagraph +cuda_graph_config: + batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +max_num_tokens: 8192 +max_batch_size: 512 +max_seq_len: 8192 +enable_chunked_prefill: true +kv_cache_config: + enable_block_reuse: false + free_gpu_memory_fraction: 0.8 +transforms: + compile_model: + piecewise_enabled: true diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index 6e65ebee22d..ef614e39b8b 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -308,6 +308,11 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'multimodal.yaml'] - name: google/gemma-3n-E4B-it yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] +# --- Gemma 4 (2026) - MoE with K=V attention --- +- name: google/gemma-4-26B-A4B + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma4_moe_base.yaml'] +- name: google/gemma-4-26B-A4B-it + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma4_moe.yaml'] # --- JetBrains Mellum (Apr 2025) - code specialist --- - name: JetBrains/Mellum-4b-sft-all yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] diff --git a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py index 85c87afe278..28ff9620978 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py @@ -517,7 +517,14 @@ def forward(self, *args, num_tokens: Optional[int] = None, **kwargs) -> Any: """Forward pass: static segments replay graphs, dynamic segments run eagerly.""" if self.split_gm is not None: ADPiecewiseRunner.set_current_num_tokens(num_tokens) - return self.split_gm(*args, **kwargs) + result = self.split_gm(*args, **kwargs) + # Ensure all CUDA graph internal streams have completed before the + # caller (DualModeCapturedGraph) proceeds. Some captured kernels + # (e.g. trtllm fused_moe) may use non-default streams inside the + # graph; without this sync the next eager op can race with those + # internal streams, causing sporadic illegal-memory-access crashes. + torch.cuda.synchronize() + return result return self.original_model(*args, **kwargs) @@ -558,7 +565,7 @@ def __init__( def _is_decode_only(self, **kwargs) -> bool: """Check if the current batch is decode-only using batch_info_host. - batch_info_host = [num_prefill, num_prefill_tokens, num_decode] + batch_info_host is the serialized BatchInfo tensor. Decode-only means num_prefill == 0. """ batch_info = kwargs.get(self.batch_info_kwarg_name) diff --git a/tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py b/tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py index e4784a38165..e5d87e78e2c 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py @@ -27,6 +27,7 @@ _CACHED_ATTENTION_OPS = [ "auto_deploy::flashinfer_attention_mha_with_cache", "auto_deploy::triton_attention_flattened_mha_with_cache", + "auto_deploy::triton_paged_mha_with_cache", "auto_deploy::torch_cached_attention_with_cache", "auto_deploy::trtllm_attention_mha_with_cache", # MLA attention variants @@ -57,6 +58,7 @@ _METADATA_PREP_OPS = [ "auto_deploy::flashinfer_attention_prepare_metadata", "auto_deploy::flashinfer_mla_prepare_metadata", + "auto_deploy::triton_paged_prepare_metadata", "auto_deploy::mamba_ssm_prepare_metadata", ] diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 3e6f80b2f52..168646926ad 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -45,6 +45,9 @@ transforms: match_attention_layout: stage: pattern_matcher attn_layout: bsnd + inject_custom_attention_mask: + stage: pattern_matcher + backend: torch_attention match_rope_pattern: stage: pattern_matcher match_rope_layout: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py index c935116f2dd..89f40a99ea2 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py @@ -55,6 +55,7 @@ class PlanParams: sm_scale: Optional[float] = None causal: bool = True + window_left: int = -1 def __hash__(self): """Convert all fields to a string representation and concatenate them.""" @@ -153,6 +154,7 @@ def plan_generate_only( q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + window_left=plan_params.window_left, ) def plan_prefill( @@ -186,6 +188,7 @@ def plan_prefill( q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + window_left=plan_params.window_left, seq_lens=kv_lens_arr_host, ) self.plan_params_prefill = plan_params @@ -218,6 +221,7 @@ def _plan_decode( q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + window_left=plan_params.window_left, ) # we want to plan during warm-up of cuda graph capture to ensure we have the plan cached @@ -251,6 +255,13 @@ def _plan_decode( _GlobalFlashInferPlanner = _FlashInferPlanner() +def _to_flashinfer_window_left(sliding_window: Optional[int]) -> int: + """Convert AD sliding-window size to FlashInfer's inclusive window_left contract.""" + if sliding_window is None or sliding_window <= 0: + return -1 + return sliding_window - 1 + + @torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=()) def prepare_flashinfer_metadata( position_ids: torch.Tensor, @@ -342,11 +353,15 @@ def flashinfer_mha_with_cache( kv_cache: torch.Tensor, # CONSTANTS scale: Optional[float], + sliding_window: Optional[int], k_scale: float, v_scale: float, + read_cache_only: bool = False, # OPTIONAL PRE-ALLOCATED OUTPUT out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + _GlobalFlashInferPlanner.reset(q.device) + # kv_cache shape: [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim] (HND layout) head_dim = kv_cache.shape[-1] page_size = kv_cache.shape[3] # tokens_per_block @@ -365,25 +380,27 @@ def flashinfer_mha_with_cache( n_heads = q.shape[1] n_kv_heads = k.shape[1] + window_left = _to_flashinfer_window_left(sliding_window) # Assuming k_scale = v_scale = 1.0 k_scale, v_scale = 1.0, 1.0 - # k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v - if kv_cache.dtype == torch.float8_e4m3fn: - k = k.to(torch.float8_e4m3fn) - v = v.to(torch.float8_e4m3fn) - - flashinfer.page.append_paged_kv_cache( - append_key=k[:num_total_tokens], - append_value=v[:num_total_tokens], - batch_indices=flashinfer_batch_indices[:num_total_tokens], - positions=flashinfer_positions[:num_total_tokens], - paged_kv_cache=kv_cache, - kv_indices=cache_loc, - kv_indptr=cu_num_pages[: num_seq + 1], - kv_last_page_len=last_page_len[:num_seq], - kv_layout=_GlobalFlashInferPlanner.kv_layout, - ) + if not read_cache_only: + # k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v + if kv_cache.dtype == torch.float8_e4m3fn: + k = k.to(torch.float8_e4m3fn) + v = v.to(torch.float8_e4m3fn) + + flashinfer.page.append_paged_kv_cache( + append_key=k[:num_total_tokens], + append_value=v[:num_total_tokens], + batch_indices=flashinfer_batch_indices[:num_total_tokens], + positions=flashinfer_positions[:num_total_tokens], + paged_kv_cache=kv_cache, + kv_indices=cache_loc, + kv_indptr=cu_num_pages[: num_seq + 1], + kv_last_page_len=last_page_len[:num_seq], + kv_layout=_GlobalFlashInferPlanner.kv_layout, + ) bs = b * s if out is not None: @@ -403,6 +420,7 @@ def flashinfer_mha_with_cache( q_dtype=q_prefill.dtype, kv_dtype=kv_cache.dtype, sm_scale=scale, + window_left=window_left, ) wrapper_prefill = _GlobalFlashInferPlanner.plan_prefill( @@ -435,6 +453,7 @@ def flashinfer_mha_with_cache( q_dtype=q_decode.dtype, kv_dtype=kv_cache.dtype, sm_scale=scale, + window_left=window_left, ) wrapper_decode = _GlobalFlashInferPlanner.plan_decode( @@ -485,8 +504,10 @@ def flashinfer_mha_with_cache_fake( kv_cache: torch.Tensor, # CONSTANTS scale: Optional[float], + sliding_window: Optional[int], k_scale: float, v_scale: float, + read_cache_only: bool = False, # OPTIONAL PRE-ALLOCATED OUTPUT out: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -520,6 +541,10 @@ def get_source_attention_op(cls) -> OpOverloadPacket: def get_cached_attention_op(cls) -> MHACallable: return torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + @classmethod + def supports_shared_kv(cls) -> bool: + return True + @classmethod def get_standard_metadata_args(cls) -> List[str]: return [ @@ -565,14 +590,7 @@ def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCalla @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: # Sanity check: layout == "bsnd" - # Prefer kwargs; fall back to the final positional arg if it's a string. - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " @@ -589,11 +607,7 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: f"{source_attn_node=}: {attn_mask=}, {dropout_p=}, {is_causal=}" ) - # Get scale from args or kwargs - if len(source_attn_node.args) > 6: - scale = source_attn_node.args[6] - else: - scale = source_attn_node.kwargs.get("scale", None) + scale = extract_op_args(source_attn_node, "scale")[0] if not (isinstance(scale, float) or scale is None): ad_logger.warning(f"Provided {scale=}, is not a float. Using default scale instead.") @@ -601,6 +615,8 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: return [ scale, # softmax scale + extract_op_args(source_attn_node, "sliding_window")[0], # sliding window parameter 1.0, # k_scale 1.0, # v_scale + cls.get_shared_kv_source_layer_idx(source_attn_node) is not None, # read_cache_only ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py index 8d0d819300e..77b5cfc9820 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py @@ -119,6 +119,8 @@ def torch_attention( sliding_window: Optional[int] = None, logit_cap: Optional[float] = None, layout: str = "bnsd", # "bnsd" or "bsnd" + layer_idx: Optional[int] = None, + shared_kv_source_layer_idx: Optional[int] = None, ) -> torch.Tensor: """ SDPA attention (with optional GQA) that supports two memory layouts via `layout`: @@ -129,6 +131,8 @@ def torch_attention( Returns a tensor in the SAME layout as inputs specified by `layout`. """ + # `layer_idx` and `shared_kv_source_layer_idx` are graph metadata used by the KV-cache + # transform; the eager attention kernel itself does not need them. if layout not in ("bnsd", "bsnd"): raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}") @@ -239,5 +243,7 @@ def torch_attention_fake( sliding_window=None, logit_cap=None, layout: str = "bnsd", + layer_idx: Optional[int] = None, + shared_kv_source_layer_idx: Optional[int] = None, ): return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py index ae816f753d4..babb37f2875 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py @@ -16,7 +16,7 @@ """Torch backend attention using pure PyTorch reference implementations.""" import math -from typing import List, Optional +from typing import Dict, List, Optional import torch from torch._ops import OpOverloadPacket @@ -70,6 +70,24 @@ def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[floa return attn_scores +def _write_generate_kv_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + slot_idx: torch.Tensor, + input_pos: torch.Tensor, +): + """Write single-token decode K/V into the cache.""" + b, s = k.shape[:2] + assert s == 1, f"Expected sequence length 1 for generate phase, got {s}" + for i in range(b): + cache_idx = slot_idx[i].item() + pos = input_pos[i].item() + k_cache[cache_idx, pos] = k[i, 0] # Remove sequence dim + v_cache[cache_idx, pos] = v[i, 0] # Remove sequence dim + + def _torch_generate_mha( q: torch.Tensor, k: torch.Tensor, @@ -89,12 +107,7 @@ def _torch_generate_mha( assert s == 1, f"Expected sequence length 1 for generate phase, got {s}" n_kv_heads = k.shape[2] # k has shape (b, 1, n_kv_heads, head_dim) - # Update KV cache for single token - for i in range(b): - cache_idx = slot_idx[i].item() - pos = input_pos[i].item() - k_cache[cache_idx, pos] = k[i, 0] # Remove sequence dim - v_cache[cache_idx, pos] = v[i, 0] # Remove sequence dim + _write_generate_kv_cache(k, v, k_cache, v_cache, slot_idx, input_pos) # Compute attention for each sequence using manual computation for i in range(b): @@ -156,6 +169,60 @@ def _torch_generate_mha( out[i] = attn_out.squeeze(1) # [n_heads, v_head_dim] +def _torch_generate_mha_readonly( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + slot_idx: torch.Tensor, + input_pos: torch.Tensor, + scale: float, + out: torch.Tensor, + logit_cap: Optional[float] = None, + sliding_window_size: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, +): + """Generate-only attention using an existing KV cache without writing current-layer K/V.""" + b, s, n_heads, head_dim = q.shape + assert s == 1, f"Expected sequence length 1 for generate phase, got {s}" + n_kv_heads = k_cache.shape[2] + + for i in range(b): + cache_idx = slot_idx[i].item() + pos = input_pos[i].item() + q_i = q[i, 0] + + if sliding_window_size is not None and sliding_window_size > 0: + start_pos = max(0, pos - sliding_window_size + 1) + k_i = k_cache[cache_idx, start_pos : pos + 1] + v_i = v_cache[cache_idx, start_pos : pos + 1] + else: + k_i = k_cache[cache_idx, : pos + 1] + v_i = v_cache[cache_idx, : pos + 1] + + q_i = q_i.unsqueeze(1) + k_i = k_i.transpose(0, 1) + v_i = v_i.transpose(0, 1) + + if n_heads != n_kv_heads: + n_rep = n_heads // n_kv_heads + k_i = repeat_kv(k_i.unsqueeze(0), n_rep)[0] + v_i = repeat_kv(v_i.unsqueeze(0), n_rep)[0] + + attn_scores = torch.matmul(q_i, k_i.transpose(-2, -1)) * scale + attn_scores = _apply_logit_softcapping(attn_scores, logit_cap) + + if sinks is not None: + sinks = sinks.reshape(-1, 1, 1) + attn_weights = torch.cat([attn_scores, sinks], dim=-1) + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_out = torch.matmul(attn_weights[..., : -sinks.size(-1)], v_i) + else: + attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) + attn_out = torch.matmul(attn_weights, v_i) + + out[i] = attn_out.squeeze(1) + + def _torch_context_mha( q: torch.Tensor, k: torch.Tensor, @@ -168,12 +235,12 @@ def _torch_context_mha( seq_start: torch.Tensor, scale: float, out: torch.Tensor, + custom_attn_mask: Optional[torch.Tensor] = None, logit_cap: Optional[float] = None, sliding_window_size: Optional[int] = None, sinks: Optional[torch.Tensor] = None, ) -> None: """Context attention (multiple tokens, potentially multiple sequences) using existing torch functions.""" - # Update KV cache first using existing function _update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start) # Compute attention for each sequence @@ -240,14 +307,23 @@ def _torch_context_mha( ) # [seq_len_i, kv_seq_len] # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size - sliding_window_mask = pos_diff >= sliding_window_size + sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window_size) # Combine causal and sliding window masks combined_mask = causal_mask | sliding_window_mask else: combined_mask = causal_mask - attn_scores.masked_fill_(combined_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + if custom_attn_mask is not None: + custom_mask = ~custom_attn_mask[idx, :, :seq_len_i, :kv_seq_len] + if sliding_window_size is not None and sliding_window_size > 0: + combined_mask = sliding_window_mask.unsqueeze(0) | custom_mask + else: + combined_mask = custom_mask + else: + combined_mask = combined_mask.unsqueeze(0) + + attn_scores.masked_fill_(combined_mask.unsqueeze(0), float("-inf")) # Apply logit softcapping if enabled attn_scores = _apply_logit_softcapping(attn_scores, logit_cap) @@ -285,9 +361,94 @@ def _torch_context_mha( out.copy_(torch.cat(attn_outputs, dim=0)) -@torch.library.custom_op( - "auto_deploy::torch_cached_attention_with_cache", mutates_args=("k_cache", "v_cache") -) +def _torch_context_mha_readonly( + q: torch.Tensor, + input_pos: torch.Tensor, + slot_idx: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + seq_len: torch.Tensor, + seq_start: torch.Tensor, + scale: float, + out: torch.Tensor, + custom_attn_mask: Optional[torch.Tensor] = None, + logit_cap: Optional[float] = None, + sliding_window_size: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, +) -> None: + """Context attention using an existing KV cache without writing current-layer K/V.""" + attn_outputs = [] + for idx in range(seq_len.shape[0]): + seq_len_i = seq_len[idx].item() + input_pos_i = input_pos[idx].item() + slot_idx_i = slot_idx[idx].item() + seq_start_i = seq_start[idx].item() + + if seq_len_i == 0: + continue + + q_seq = q[seq_start_i : seq_start_i + seq_len_i] + kv_seq_len = input_pos_i + seq_len_i + k_seq = k_cache[slot_idx_i, :kv_seq_len] + v_seq = v_cache[slot_idx_i, :kv_seq_len] + + n_heads = q_seq.shape[1] + n_kv_heads = k_seq.shape[1] + + q_seq_t = q_seq.transpose(0, 1).unsqueeze(0) + k_seq_t = k_seq.transpose(0, 1).unsqueeze(0) + v_seq_t = v_seq.transpose(0, 1).unsqueeze(0) + + if n_heads != n_kv_heads: + n_rep = n_heads // n_kv_heads + k_seq_t = repeat_kv(k_seq_t, n_rep) + v_seq_t = repeat_kv(v_seq_t, n_rep) + + attn_scores = torch.matmul(q_seq_t, k_seq_t.transpose(-2, -1)) * scale + + causal_mask = torch.triu( + torch.ones(seq_len_i, kv_seq_len, device=q.device, dtype=torch.bool), + diagonal=1 + input_pos_i, + ) + combined_mask = causal_mask + + if sliding_window_size is not None and sliding_window_size > 0: + query_positions = torch.arange(input_pos_i, input_pos_i + seq_len_i, device=q.device) + key_positions = torch.arange(kv_seq_len, device=q.device) + pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) + sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window_size) + combined_mask = combined_mask | sliding_window_mask + + if custom_attn_mask is not None: + custom_mask = ~custom_attn_mask[idx, :, :seq_len_i, :kv_seq_len] + combined_mask = combined_mask.unsqueeze(0) | custom_mask + else: + combined_mask = combined_mask.unsqueeze(0) + + attn_scores.masked_fill_(combined_mask.unsqueeze(0), float("-inf")) + + attn_scores = _apply_logit_softcapping(attn_scores, logit_cap) + + if sinks is not None: + new_sinks = sinks.reshape(1, -1, 1, 1).expand(1, n_heads, seq_len_i, 1) + attn_weights = torch.cat([attn_scores, new_sinks], dim=-1) + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_out = torch.matmul(attn_weights[..., : -new_sinks.size(-1)], v_seq_t) + else: + attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) + attn_out = torch.matmul(attn_weights, v_seq_t) + + attn_outputs.append(attn_out[0].transpose(0, 1)) + + if len(attn_outputs) == 0: + out.zero_() + elif len(attn_outputs) == 1: + out.copy_(attn_outputs[0]) + else: + out.copy_(torch.cat(attn_outputs, dim=0)) + + +@torch.library.custom_op("auto_deploy::torch_cached_attention_with_cache", mutates_args=()) def torch_backend_mha_with_cache( # Q, K, V q: torch.Tensor, @@ -306,11 +467,15 @@ def torch_backend_mha_with_cache( v_cache: torch.Tensor, # BUFFERS # - # CONSTANTS - scale: Optional[float], + # CONSTANTS must come before dynamic tensor inputs. The KV-cache transform + # appends constants positionally and forwards dynamic inputs as kwargs. + scale: Optional[float] = None, sinks: Optional[torch.Tensor] = None, sliding_window_size: Optional[int] = None, logit_cap: Optional[float] = None, + read_cache_only: bool = False, + # DYNAMIC INPUTS + custom_attn_mask: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Torch backend MHA with cache that takes q, k, v in BSND layout.""" @@ -350,12 +515,15 @@ def torch_backend_mha_with_cache( y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous() # Compute attention + if not read_cache_only: + if s == 1: + _write_generate_kv_cache(k, v, k_cache, v_cache, slot_idx, input_pos) + else: + _update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start) + if s == 1: - # Generate-only phase - _torch_generate_mha( + _torch_generate_mha_readonly( q, - k, - v, k_cache, v_cache, slot_idx, @@ -367,11 +535,8 @@ def torch_backend_mha_with_cache( sinks, ) else: - # Context phase - _torch_context_mha( + _torch_context_mha_readonly( q, - k, - v, input_pos, slot_idx, k_cache, @@ -380,6 +545,7 @@ def torch_backend_mha_with_cache( seq_start, scale, y, + custom_attn_mask, logit_cap, sliding_window_size, sinks, @@ -419,13 +585,15 @@ def torch_backend_mha_with_cache_fake( # CACHES k_cache: torch.Tensor, v_cache: torch.Tensor, + custom_attn_mask: Optional[torch.Tensor] = None, # BUFFERS # # CONSTANTS - scale: Optional[float], + scale: Optional[float] = None, sinks: Optional[torch.Tensor] = None, sliding_window_size: Optional[int] = None, logit_cap: Optional[float] = None, + read_cache_only: bool = False, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out is not None: @@ -453,6 +621,10 @@ def get_source_attention_op(cls) -> OpOverloadPacket: def get_cached_attention_op(cls) -> MHACallable: return torch.ops.auto_deploy.torch_cached_attention_with_cache.default + @classmethod + def supports_shared_kv(cls) -> bool: + return True + @classmethod def get_standard_metadata_args(cls) -> List[str]: return ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"] @@ -481,17 +653,14 @@ def get_cache_initializers( ), } + @classmethod + def get_dynamic_inputs(cls, source_attn_node: Node) -> Dict[str, Optional[Node]]: + return {"custom_attn_mask": extract_op_args(source_attn_node, "attn_mask")[0]} + @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: # Sanity check: layout == "bsnd" - # Prefer kwargs; fall back to the final positional arg if it's a string. - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " @@ -502,10 +671,9 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: attn_mask, dropout_p, is_causal = extract_op_args( source_attn_node, "attn_mask", "dropout_p", "is_causal" ) - if attn_mask is not None or dropout_p != 0.0 or not is_causal: + if dropout_p != 0.0: ad_logger.debug( - "Unsupported attention arguments for " - f"{source_attn_node=}: {attn_mask=}, {dropout_p=}, {is_causal=}" + f"Unsupported attention arguments for {source_attn_node=}: {dropout_p=}" ) # Get scale from args or kwargs @@ -529,4 +697,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: sinks, # sinks parameter sliding_window, # sliding window parameter logit_cap, # logit cap parameter + cls.get_shared_kv_source_layer_idx(source_attn_node) is not None, # read_cache_only ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py index 0670a5b9d77..992e219c65b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py @@ -388,14 +388,7 @@ def get_cache_initializers( @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: # Sanity check: layout == "bsnd" - # Prefer kwargs; fall back to the final positional arg if it's a string. - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py index b33e70405c2..d0c2441eb12 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py @@ -21,7 +21,7 @@ """ import math -from typing import List, Literal, Optional, Tuple +from typing import Dict, List, Literal, Optional, Tuple import flashinfer import torch @@ -209,7 +209,7 @@ def _get_num_splits(max_seq_len: int, batch_size: int, n_kv_heads: int, page_siz triton.Config({}, num_warps=8, num_stages=2), triton.Config({}, num_warps=8, num_stages=3), ], - key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED"], + key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED", "SLIDING_WINDOW"], ) @triton.jit def _flash_decode_stage1_kernel( @@ -249,6 +249,7 @@ def _flash_decode_stage1_kernel( HEAD_RATIO: tl.constexpr, HEAD_RATIO_PADDED: tl.constexpr, NUM_SPLITS: tl.constexpr, + SLIDING_WINDOW: tl.constexpr = 0, ): """ Key optimizations: @@ -266,9 +267,20 @@ def _flash_decode_stage1_kernel( num_pages = kv_page_end - kv_page_start last_page_len = tl.load(kv_last_page_len_ptr + batch_id) - # Compute this split's page range (page-aligned splits) - pages_per_split = (num_pages + NUM_SPLITS - 1) // NUM_SPLITS - page_split_start = split_id * pages_per_split + # Sliding window: restrict attention to pages within the window. + # Compute the total sequence length and the first valid KV position. + seq_len = (num_pages - 1) * PAGE_SIZE + last_page_len + if SLIDING_WINDOW > 0: + first_valid_pos = tl.maximum(0, seq_len - SLIDING_WINDOW) + first_window_page = first_valid_pos // PAGE_SIZE + else: + first_valid_pos = 0 + first_window_page = 0 + + # Only split over pages within the window. + window_pages = num_pages - first_window_page + pages_per_split = (window_pages + NUM_SPLITS - 1) // NUM_SPLITS + page_split_start = first_window_page + split_id * pages_per_split page_split_end = tl.minimum(page_split_start + pages_per_split, num_pages) dhead_offsets = tl.arange(0, HEAD_DIM) @@ -346,7 +358,13 @@ def _flash_decode_stage1_kernel( # [HEAD_RATIO_PADDED, HEAD_DIM] @ [HEAD_DIM, PAGE_SIZE] -> [HEAD_RATIO_PADDED, PAGE_SIZE] attn = tl.dot(q_all, tl.trans(k)) * SM_SCALE - attn = tl.where(page_mask[None, :], attn, float("-inf")) + + if SLIDING_WINDOW > 0: + global_pos = page_idx * PAGE_SIZE + page_offsets + window_mask = global_pos >= first_valid_pos + attn = tl.where(page_mask[None, :] & window_mask[None, :], attn, float("-inf")) + else: + attn = tl.where(page_mask[None, :], attn, float("-inf")) # Online softmax update (vectorized over HEAD_RATIO_PADDED) m_ij = tl.max(attn, axis=1) # [HEAD_RATIO_PADDED] @@ -454,6 +472,7 @@ def triton_paged_decode( kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor, sm_scale: float, + sliding_window: Optional[int] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Optimized paged decode with GQA batching + FlashDecoding + page-aligned iteration. @@ -465,6 +484,7 @@ def triton_paged_decode( kv_indptr: Cumulative page counts [batch_size + 1] kv_last_page_len: Valid tokens in last page [batch_size] sm_scale: Softmax scale factor + sliding_window: If set, only attend to the last sliding_window tokens out: Optional output tensor [batch_size, n_heads, head_dim] Returns: @@ -477,13 +497,15 @@ def triton_paged_decode( max_pages = kv_indices.shape[0] max_seq_len = max_pages * page_size + sw = sliding_window if isinstance(sliding_window, int) and sliding_window > 0 else 0 output = out if out is not None else torch.empty_like(q) if batch_size == 0: return output - num_splits = _get_num_splits(max_seq_len, batch_size, n_kv_heads, page_size) + effective_seq_len = min(max_seq_len, sw) if sw > 0 else max_seq_len + num_splits = _get_num_splits(effective_seq_len, batch_size, n_kv_heads, page_size) # Allocate intermediate buffers for split-K partial_o = torch.empty( @@ -536,6 +558,7 @@ def triton_paged_decode( HEAD_RATIO=head_ratio, HEAD_RATIO_PADDED=head_ratio_padded, NUM_SPLITS=num_splits, + SLIDING_WINDOW=sw, ) # Stage 2: Combine partial results @@ -606,6 +629,7 @@ def _paged_context_kernel( N_KV_HEADS: tl.constexpr, HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr, + SLIDING_WINDOW: tl.constexpr = 0, ): """Context/prefill attention with paged KV cache, causal skip, and page-aligned iteration. @@ -669,6 +693,12 @@ def _paged_context_kernel( # Number of full pages (all tokens in these pages are attended by all Q tokens) num_full_pages = first_q_kv_pos // PAGE_SIZE + if SLIDING_WINDOW > 0: + first_valid_pos = tl.maximum(0, first_q_kv_pos - SLIDING_WINDOW + 1) + first_window_page = first_valid_pos // PAGE_SIZE + else: + first_window_page = 0 + # Check if this is a full Q block (no q_mask needed) is_full_q_block = (q_block_start + Q_BLOCK) <= q_len @@ -677,39 +707,64 @@ def _paged_context_kernel( kv_head_offset = kv_head_id * cache_stride_head local_kv = page_offsets[:, None] * cache_stride_token + dhead_offsets[None, :] - for page_idx in range(num_full_pages): + for page_idx in range(first_window_page, num_full_pages): physical_page = tl.load(kv_indices_ptr + kv_page_start + page_idx) # Use int64 to avoid overflow when physical_page * stride > 2^31 page_base = physical_page.to(tl.int64) * cache_stride_block + kv_head_offset - k_block_ptr = tl.make_block_ptr( - base=kv_cache_ptr + page_base, - shape=(PAGE_SIZE, HEAD_DIM), - strides=(cache_stride_token, 1), - offsets=(0, 0), - block_shape=(PAGE_SIZE, HEAD_DIM), - order=(1, 0), - ) - v_block_ptr = tl.make_block_ptr( - base=kv_cache_ptr + page_base + cache_stride_kv, - shape=(PAGE_SIZE, HEAD_DIM), - strides=(cache_stride_token, 1), - offsets=(0, 0), - block_shape=(PAGE_SIZE, HEAD_DIM), - order=(1, 0), - ) - k = tl.load(k_block_ptr) - v = tl.load(v_block_ptr) - qk = tl.dot(q, tl.trans(k)) * SM_SCALE + if SLIDING_WINDOW > 0: + k = tl.load( + kv_cache_ptr + page_base + local_kv, + mask=tl.full([PAGE_SIZE, HEAD_DIM], 1, tl.int1), + other=0.0, + ) + v = tl.load( + kv_cache_ptr + page_base + local_kv + cache_stride_kv, + mask=tl.full([PAGE_SIZE, HEAD_DIM], 1, tl.int1), + other=0.0, + ) + + qk = tl.dot(q, tl.trans(k)) * SM_SCALE + + kv_positions = page_idx * PAGE_SIZE + page_offsets[None, :] + q_kv_pos = q_offsets[:, None] + cache_len + sw_mask = (q_kv_pos - kv_positions) < SLIDING_WINDOW + full_mask_p1 = q_mask[:, None] & sw_mask + qk = tl.where(full_mask_p1, qk, float("-inf")) + else: + k_block_ptr = tl.make_block_ptr( + base=kv_cache_ptr + page_base, + shape=(PAGE_SIZE, HEAD_DIM), + strides=(cache_stride_token, 1), + offsets=(0, 0), + block_shape=(PAGE_SIZE, HEAD_DIM), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr( + base=kv_cache_ptr + page_base + cache_stride_kv, + shape=(PAGE_SIZE, HEAD_DIM), + strides=(cache_stride_token, 1), + offsets=(0, 0), + block_shape=(PAGE_SIZE, HEAD_DIM), + order=(1, 0), + ) + k = tl.load(k_block_ptr) + v = tl.load(v_block_ptr) - if not is_full_q_block: - qk = tl.where(q_mask[:, None], qk, float("-inf")) + qk = tl.dot(q, tl.trans(k)) * SM_SCALE + + if not is_full_q_block: + qk = tl.where(q_mask[:, None], qk, float("-inf")) m_ij = tl.max(qk, axis=1) m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - p = tl.exp(qk - m_i_new[:, None]) + if SLIDING_WINDOW > 0: + alpha = tl.where(m_i > float("-inf"), tl.exp(m_i - m_i_new), 0.0) + p = tl.where(m_i_new[:, None] > float("-inf"), tl.exp(qk - m_i_new[:, None]), 0.0) + else: + alpha = tl.exp(m_i - m_i_new) + p = tl.exp(qk - m_i_new[:, None]) acc = tl.dot(p.to(v.dtype), v, acc=acc * alpha[:, None]) l_i = l_i * alpha + tl.sum(p, axis=1) m_i = m_i_new @@ -740,13 +795,21 @@ def _paged_context_kernel( qk = tl.dot(q, tl.trans(k)) * SM_SCALE kv_positions = kv_base_pos + page_offsets[None, :] causal_mask = q_positions_2d >= kv_positions - full_mask = q_mask[:, None] & causal_mask & page_mask[None, :] + if SLIDING_WINDOW > 0: + sliding_mask = (q_positions_2d - kv_positions) < SLIDING_WINDOW + full_mask = q_mask[:, None] & causal_mask & sliding_mask & page_mask[None, :] + else: + full_mask = q_mask[:, None] & causal_mask & page_mask[None, :] qk = tl.where(full_mask, qk, float("-inf")) m_ij = tl.max(qk, axis=1) m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - p = tl.exp(qk - m_i_new[:, None]) + if SLIDING_WINDOW > 0: + alpha = tl.where(m_i > float("-inf"), tl.exp(m_i - m_i_new), 0.0) + p = tl.where(m_i_new[:, None] > float("-inf"), tl.exp(qk - m_i_new[:, None]), 0.0) + else: + alpha = tl.exp(m_i - m_i_new) + p = tl.exp(qk - m_i_new[:, None]) acc = tl.dot(p.to(v.dtype), v, acc=acc * alpha[:, None]) l_i = l_i * alpha + tl.sum(p, axis=1) m_i = m_i_new @@ -761,6 +824,153 @@ def _paged_context_kernel( tl.store(o_ptr + o_store_offsets, o, mask=q_load_mask) +@triton.autotune( + configs=[ + triton.Config({"Q_BLOCK": 64}, num_stages=2, num_warps=2), + triton.Config({"Q_BLOCK": 64}, num_stages=2, num_warps=4), + triton.Config({"Q_BLOCK": 64}, num_stages=4, num_warps=4), + triton.Config({"Q_BLOCK": 128}, num_stages=2, num_warps=4), + triton.Config({"Q_BLOCK": 128}, num_stages=2, num_warps=8), + triton.Config({"Q_BLOCK": 128}, num_stages=3, num_warps=8), + ], + key=["HEAD_DIM", "PAGE_SIZE"], +) +@triton.jit +def _paged_context_masked_kernel( + # Inputs + q_ptr, + kv_cache_ptr, + custom_mask_ptr, + # Metadata + qo_indptr_ptr, + kv_indptr_ptr, + kv_indices_ptr, + seq_len_with_cache_ptr, + # Output + o_ptr, + # Strides + q_stride_token: tl.constexpr, + q_stride_head: tl.constexpr, + o_stride_token: tl.constexpr, + o_stride_head: tl.constexpr, + cache_stride_block: tl.constexpr, + cache_stride_kv: tl.constexpr, + cache_stride_head: tl.constexpr, + cache_stride_token: tl.constexpr, + mask_stride_batch: tl.constexpr, + mask_stride_head: tl.constexpr, + mask_stride_q: tl.constexpr, + mask_stride_k: tl.constexpr, + # Autotuned + Q_BLOCK: tl.constexpr, + # Constants + SM_SCALE: tl.constexpr, + N_HEADS: tl.constexpr, + N_KV_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + SLIDING_WINDOW: tl.constexpr = 0, +): + """Context/prefill attention with paged KV cache and a backend-provided bool allow-mask.""" + batch_id = tl.program_id(axis=0) + head_id = tl.program_id(axis=1) + q_block_id = tl.program_id(axis=2) + + HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS + kv_head_id = head_id // HEAD_RATIO + + q_start = tl.load(qo_indptr_ptr + batch_id) + q_end = tl.load(qo_indptr_ptr + batch_id + 1) + q_len = q_end - q_start + + kv_page_start = tl.load(kv_indptr_ptr + batch_id) + kv_page_end = tl.load(kv_indptr_ptr + batch_id + 1) + num_kv_pages = kv_page_end - kv_page_start + total_kv_len = tl.load(seq_len_with_cache_ptr + batch_id) + + q_block_start = q_block_id * Q_BLOCK + q_offsets = q_block_start + tl.arange(0, Q_BLOCK) + q_mask = q_offsets < q_len + if tl.sum(q_mask.to(tl.int32)) == 0: + return + + dhead_offsets = tl.arange(0, HEAD_DIM) + q_load_offsets = ( + (q_start + q_offsets[:, None]) * q_stride_token + + head_id * q_stride_head + + dhead_offsets[None, :] + ) + q_load_mask = q_mask[:, None] + q = tl.load(q_ptr + q_load_offsets, mask=q_load_mask, other=0.0) + + acc = tl.zeros([Q_BLOCK, HEAD_DIM], dtype=tl.float32) + m_i = tl.zeros([Q_BLOCK], dtype=tl.float32) - float("inf") + l_i = tl.zeros([Q_BLOCK], dtype=tl.float32) + + page_offsets = tl.arange(0, PAGE_SIZE) + kv_head_offset = kv_head_id * cache_stride_head + local_kv = page_offsets[:, None] * cache_stride_token + dhead_offsets[None, :] + mask_batch_offset = batch_id * mask_stride_batch + # The mask is broadcast over heads (head dim == 1), so the head offset is always 0. + mask_head_offset = 0 * mask_stride_head + + for page_idx in range(num_kv_pages): + kv_base_pos = page_idx * PAGE_SIZE + physical_page = tl.load(kv_indices_ptr + kv_page_start + page_idx) + valid_tokens = tl.minimum(PAGE_SIZE, total_kv_len - kv_base_pos) + page_mask = page_offsets < valid_tokens + + page_base = physical_page.to(tl.int64) * cache_stride_block + kv_head_offset + page_mask_2d = page_mask[:, None] + k = tl.load(kv_cache_ptr + page_base + local_kv, mask=page_mask_2d, other=0.0) + v = tl.load( + kv_cache_ptr + page_base + local_kv + cache_stride_kv, + mask=page_mask_2d, + other=0.0, + ) + + qk = tl.dot(q, tl.trans(k)) * SM_SCALE + kv_positions = kv_base_pos + page_offsets[None, :] + mask_offsets = ( + mask_batch_offset + + mask_head_offset + + q_offsets[:, None] * mask_stride_q + + kv_positions * mask_stride_k + ) + valid_mask = q_mask[:, None] & page_mask[None, :] + custom_mask = tl.load(custom_mask_ptr + mask_offsets, mask=valid_mask, other=0) + if SLIDING_WINDOW > 0: + query_positions = q_offsets[:, None] + sliding_mask = (query_positions >= kv_positions) & ( + (query_positions - kv_positions) < SLIDING_WINDOW + ) + full_mask = valid_mask & sliding_mask & (custom_mask != 0) + else: + full_mask = valid_mask & (custom_mask != 0) + qk = tl.where(full_mask, qk, float("-inf")) + + m_ij = tl.max(qk, axis=1) + m_i_new = tl.maximum(m_i, m_ij) + if SLIDING_WINDOW > 0: + alpha = tl.where(m_i > float("-inf"), tl.exp(m_i - m_i_new), 0.0) + p = tl.where(m_i_new[:, None] > float("-inf"), tl.exp(qk - m_i_new[:, None]), 0.0) + else: + alpha = tl.exp(m_i - m_i_new) + p = tl.exp(qk - m_i_new[:, None]) + acc = tl.dot(p.to(v.dtype), v, acc=acc * alpha[:, None]) + l_i = l_i * alpha + tl.sum(p, axis=1) + m_i = m_i_new + + l_i = tl.where(l_i == 0.0, 1.0, l_i) + o = acc / l_i[:, None] + o_store_offsets = ( + (q_start + q_offsets[:, None]) * o_stride_token + + head_id * o_stride_head + + dhead_offsets[None, :] + ) + tl.store(o_ptr + o_store_offsets, o, mask=q_load_mask) + + @triton.jit def _fast_gather_sdpa_kernel( kv_cache_ptr, @@ -781,17 +991,10 @@ def _fast_gather_sdpa_kernel( PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr, ): - """Gather scattered pages into separate K, V buffers in SDPA layout. - - Grid: (total_pages, N_KV_HEADS) - Each program copies one page for one KV head into contiguous K and V - outputs shaped [num_seq, n_kv_heads, max_kv_len, head_dim]. - No precomputed mapping needed — seq_id and local_page computed from global index. - """ + """Gather scattered pages into separate K, V buffers in SDPA layout.""" page_global_idx = tl.program_id(0) kv_head_id = tl.program_id(1) - # Compute seq_id and local_page from global page index seq_id = page_global_idx // MAX_PAGES local_page = page_global_idx % MAX_PAGES @@ -800,14 +1003,12 @@ def _fast_gather_sdpa_kernel( token_offsets = tl.arange(0, PAGE_SIZE) head_offsets = tl.arange(0, HEAD_DIM) - # Source: kv_cache[physical_page, 0/1, kv_head_id, :, :] src_base = physical_page.to(tl.int64) * cache_stride_block + kv_head_id * cache_stride_head src_offsets = token_offsets[:, None] * cache_stride_token + head_offsets[None, :] k_data = tl.load(kv_cache_ptr + src_base + src_offsets) v_data = tl.load(kv_cache_ptr + src_base + cache_stride_kv + src_offsets) - # Destination: out_k/v[seq_id, kv_head_id, local_page*PAGE_SIZE + :, :] local_token_start = local_page * PAGE_SIZE dst_base = ( seq_id * out_stride_seq @@ -829,6 +1030,7 @@ def triton_paged_context( kv_last_page_len: torch.Tensor, seq_len_with_cache: torch.Tensor, sm_scale: float, + sliding_window: Optional[int] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Context/prefill attention with paged KV cache.""" @@ -851,10 +1053,15 @@ def triton_paged_context( q_lens = qo_indptr[1:] - qo_indptr[:-1] max_q_len = int(q_lens.max().item()) + sw = sliding_window if isinstance(sliding_window, int) and sliding_window > 0 else 0 + # Adaptive dispatch: gather + cuDNN SDPA for seq>=512 (outperforms paged kernel), # paged Triton kernel for shorter sequences where gather overhead dominates. # Compute max_pages from max_q_len without GPU sync # (assumes pure prefill where q_len == kv_len for each seq) + # Normalize sliding_window for kernel constexpr: None/non-positive → 0 + sw = sliding_window if isinstance(sliding_window, int) and sliding_window > 0 else 0 + max_pages = (max_q_len + page_size - 1) // page_size total_expected_pages = num_seq * max_pages use_sdpa = ( @@ -862,6 +1069,7 @@ def triton_paged_context( and num_seq <= 64 and max_pages > 0 and kv_indices.shape[0] == total_expected_pages # all seqs same page count + and sw == 0 # SDPA doesn't support sliding window natively ) if use_sdpa: @@ -936,11 +1144,90 @@ def grid_paged(meta): N_KV_HEADS=n_kv_heads, HEAD_DIM=head_dim, PAGE_SIZE=page_size, + SLIDING_WINDOW=sw, ) return output +def triton_paged_context_with_custom_mask( + q: torch.Tensor, + kv_cache: torch.Tensor, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + kv_indices: torch.Tensor, + seq_len_with_cache: torch.Tensor, + custom_attn_mask: torch.Tensor, + sm_scale: float, + sliding_window: Optional[int] = None, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Context/prefill attention with paged KV cache and a backend-provided bool allow-mask. + + The mask must be broadcastable over heads, i.e. shape ``[B, 1, S_q, S_k]``. + Per-head masks (``N > 1`` in the head dimension) are **not** supported. + """ + output = out if out is not None else torch.empty_like(q) + num_seq = qo_indptr.shape[0] - 1 + total_tokens, n_heads, head_dim = q.shape + _, _, n_kv_heads, page_size, _ = kv_cache.shape + + if num_seq == 0 or total_tokens == 0: + return output + + assert custom_attn_mask.shape[1] == 1, ( + f"Per-head masks are not supported; expected mask head dim == 1, " + f"got {custom_attn_mask.shape[1]}. Mask shape: {custom_attn_mask.shape}" + ) + + if num_seq == 1: + max_q_len = total_tokens + else: + q_lens = qo_indptr[1:] - qo_indptr[:-1] + max_q_len = int(q_lens.max().item()) + + if not custom_attn_mask.is_contiguous() or custom_attn_mask.dtype != torch.uint8: + custom_attn_mask = custom_attn_mask.contiguous().to(dtype=torch.uint8) + + sw = sliding_window if isinstance(sliding_window, int) and sliding_window > 0 else 0 + + def grid_masked(meta): + q_block = meta["Q_BLOCK"] + num_q_blocks = (max_q_len + q_block - 1) // q_block + return (num_seq, n_heads, num_q_blocks) + + _paged_context_masked_kernel[grid_masked]( + q, + kv_cache, + custom_attn_mask, + qo_indptr, + kv_indptr, + kv_indices, + seq_len_with_cache, + output, + q.stride(0), + q.stride(1), + output.stride(0), + output.stride(1), + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + kv_cache.stride(3), + custom_attn_mask.stride(0), + custom_attn_mask.stride(1), + custom_attn_mask.stride(2), + custom_attn_mask.stride(3), + SM_SCALE=sm_scale, + N_HEADS=n_heads, + N_KV_HEADS=n_kv_heads, + HEAD_DIM=head_dim, + PAGE_SIZE=page_size, + SLIDING_WINDOW=sw, + ) + + return output + + @torch.library.custom_op("auto_deploy::triton_paged_prepare_metadata", mutates_args=()) def prepare_triton_paged_metadata( position_ids: torch.Tensor, @@ -999,8 +1286,14 @@ def triton_paged_mha_with_cache( triton_positions: torch.Tensor, # CACHES - combined KV cache kv_cache: torch.Tensor, - # CONSTANTS - scale: Optional[float], + # CONSTANTS must come before dynamic tensor inputs. The KV-cache transform + # appends constants positionally and forwards dynamic inputs as kwargs. + scale: Optional[float] = None, + sliding_window: Optional[int] = None, + # DYNAMIC INPUTS + custom_attn_mask: Optional[torch.Tensor] = None, + # OPTIONAL PRE-ALLOCATED OUTPUT + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Triton paged attention with mixed batch support.""" head_dim = kv_cache.shape[-1] @@ -1031,23 +1324,41 @@ def triton_paged_mha_with_cache( cu_num_pages[: num_seq + 1], ) - y = torch.empty_like(q) + if out is not None: + y = out.view(-1, q.shape[1], head_dim) + else: + y = torch.empty_like(q) # Process prefill tokens if any if num_prefill > 0: cu_seqlen = cu_seqlen_host[: num_prefill + 1].to(q.device, non_blocking=True) seq_len_with_cache = seq_len_with_cache_host[:num_prefill].to(q.device, non_blocking=True) - triton_paged_context( - q[:num_prefill_tokens], - kv_cache, - cu_seqlen, - cu_num_pages[: num_prefill + 1], - cache_loc, - last_page_len[:num_prefill], - seq_len_with_cache, - sm_scale, - out=y[:num_prefill_tokens], - ) + if custom_attn_mask is None: + triton_paged_context( + q[:num_prefill_tokens], + kv_cache, + cu_seqlen, + cu_num_pages[: num_prefill + 1], + cache_loc, + last_page_len[:num_prefill], + seq_len_with_cache, + sm_scale, + sliding_window=sliding_window, + out=y[:num_prefill_tokens], + ) + else: + triton_paged_context_with_custom_mask( + q[:num_prefill_tokens], + kv_cache, + cu_seqlen, + cu_num_pages[: num_prefill + 1], + cache_loc, + seq_len_with_cache, + custom_attn_mask[:num_prefill], + sm_scale, + sliding_window=sliding_window, + out=y[:num_prefill_tokens], + ) # Process decode tokens if any if num_decode > 0: @@ -1058,9 +1369,20 @@ def triton_paged_mha_with_cache( cu_num_pages[num_prefill : num_seq + 1], last_page_len[num_prefill:num_seq], sm_scale, + sliding_window=sliding_window, out=y[num_prefill_tokens:num_total_tokens], ) + if out is not None: + # Zero stale data in padding region for CUDA graph replay stability + bs = b * s + if num_total_tokens < bs: + y[num_total_tokens:].zero_() + # Return a 0-element dummy to satisfy PyTorch's no-alias constraint. + # The caller (DynamicOpWrapper._coalesce_output) picks ``out`` over + # this dummy, so the pre-allocated buffer is used downstream. + return out.new_empty(0) + return y.view(q_shape_og) @@ -1080,8 +1402,13 @@ def triton_paged_mha_with_cache_fake( triton_batch_indices: torch.Tensor, triton_positions: torch.Tensor, kv_cache: torch.Tensor, - scale: Optional[float], + scale: Optional[float] = None, + sliding_window: Optional[int] = None, + custom_attn_mask: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if out is not None: + return out.new_empty(0) return torch.empty_like(q.contiguous()) @@ -1149,37 +1476,31 @@ def get_cache_initializers( ) } + @classmethod + def get_dynamic_inputs(cls, source_attn_node: Node) -> Dict[str, Optional[Node]]: + return {"custom_attn_mask": extract_op_args(source_attn_node, "attn_mask")[0]} + @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + layout, scale, attn_mask, dropout_p, is_causal = extract_op_args( + source_attn_node, "layout", "scale", "attn_mask", "dropout_p", "is_causal" + ) + if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " f"for node: {source_attn_node.format_node()}" ) - attn_mask, dropout_p, is_causal = extract_op_args( - source_attn_node, "attn_mask", "dropout_p", "is_causal" - ) - if attn_mask is not None or dropout_p != 0.0 or not is_causal: + if dropout_p != 0.0: ad_logger.debug( - "Unsupported attention arguments for " - f"{source_attn_node=}: {attn_mask=}, {dropout_p=}, {is_causal=}" + f"Unsupported attention arguments for {source_attn_node=}: {dropout_p=}" ) - if len(source_attn_node.args) > 6: - scale = source_attn_node.args[6] - else: - scale = source_attn_node.kwargs.get("scale", None) - if not (isinstance(scale, float) or scale is None): ad_logger.warning(f"Provided {scale=}, is not a float. Using default scale instead.") scale = None - return [scale] + sliding_window = extract_op_args(source_attn_node, "sliding_window")[0] + + return [scale, sliding_window] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py index d0c8a1dd8c0..f1c99267ed0 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py @@ -639,13 +639,7 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: from tensor shapes or SequenceInfo metadata at runtime. """ # Sanity check: layout == "bsnd" - layout = source_attn_node.kwargs.get("layout", None) - if ( - layout is None - and len(source_attn_node.args) > 0 - and isinstance(source_attn_node.args[-1], str) - ): - layout = source_attn_node.args[-1] + layout = extract_op_args(source_attn_node, "layout")[0] if layout != "bsnd": raise RuntimeError( f"Expected torch_attention layout='bsnd' but got {layout!r} " diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 748ce06ba22..9df8b090cbc 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -26,7 +26,19 @@ import math from abc import ABC, abstractmethod -from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union +from typing import ( + Callable, + Dict, + List, + Literal, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Type, + Union, +) import numpy as np import torch @@ -38,6 +50,7 @@ from ...._utils import nvtx_range, prefer_pinned, str_dtype_to_torch from ..utils.logger import ad_logger +from ..utils.node_utils import extract_op_args, get_op_schema Constant = Union[int, float, str, None] @@ -65,6 +78,14 @@ def _list_to_tensor(data: list, dtype: torch.dtype) -> torch.Tensor: return torch.tensor(data, dtype=dtype) +def _extract_optional_op_arg(node: Node, arg_name: str): + """Return an op argument if it exists in the schema, otherwise ``None``.""" + schema_arg_names = {arg.name for arg in get_op_schema(node.target).arguments} + if arg_name not in schema_arg_names: + return None + return extract_op_args(node, arg_name)[0] + + class PrepareMetadataHostCallable(Protocol): def __call__(self, **sequence_info_args: torch.Tensor) -> None: ... @@ -678,6 +699,14 @@ def __init__( # EXTRA TENSOR FIELDS ###################################################################### self._extra_args: Dict[str, Optional[torch.Tensor]] = {} + # Default factories for extra args: callables that accept this SequenceInfo instance + # and return a default value. These are called whenever a key is absent from + # ``extra_args`` in ``nest_sequences`` so that initialization-time forward passes + # (resize_kv_cache, cuda-graph warmup) always receive valid inputs. + # Model-specific transforms (e.g. attention mask providers) can register factories here. + self._default_extra_arg_factories: Dict[ + str, Callable[["SequenceInfo"], Optional[torch.Tensor]] + ] = {} ############################################################################################ # HOST PREPARE FOR ATTENTION FORWARD ####################################################### @@ -861,6 +890,21 @@ def activate_arg(self, arg_name: str) -> bool: return True return False + def register_default_extra_arg( + self, + name: str, + factory: Callable[["SequenceInfo"], Optional[torch.Tensor]], + ) -> None: + """Register a callable default factory for an extra argument. + + ``factory`` receives this ``SequenceInfo`` instance and returns the + default value for ``name``. It is invoked at the start of every + ``nest_sequences`` call so that initialization-time forward passes + (e.g. ``resize_kv_cache``, CUDA-graph warmup) always receive a valid + tensor for ``name`` even when no per-request data is provided. + """ + self._default_extra_arg_factories[name] = factory + def to(self, *args, **kwargs) -> None: # Move the InputBuffer (which recreates views automatically) self._input_buffer.to(*args, **kwargs) @@ -1118,6 +1162,10 @@ def nest_sequences( ### UPDATE EXTRA INPUTS #################################################################### self._extra_args = {} + # Seed with defaults first (callable factories receive ``self`` so they can + # access current shape info via unflatten()), then let per-request values override. + for key, factory in self._default_extra_arg_factories.items(): + self._store_extra_arg(key, factory(self)) for key, value in extra_args.items(): self._store_extra_arg(key, value) @@ -1842,6 +1890,7 @@ def attention_op( *meta_extra,# metadata about the sequences as returned by the prepare_metadata op *caches, # contains layer-specific caches per provided cache initializers *constants, # basic arguments (int, float, str, None) added as CONSTANTS in the graph + **dynamic, # optional dynamic tensor kwargs forwarded from the source attention node ) -> torch.Tensor: ... ``` @@ -1854,6 +1903,11 @@ def attention_op( """ raise NotImplementedError + @classmethod + def supports_shared_kv(cls) -> bool: + """Whether this backend supports shared-KV cache aliasing.""" + return False + @classmethod @abstractmethod def get_standard_metadata_args(cls) -> List[str]: @@ -1916,11 +1970,34 @@ def get_cache_initializers( def get_constants(cls, source_attn_node: Node) -> List[Constant]: """Provide a list of constant arguments to be passed to the attention op. - The constant arguments are passed to the attention op as additional arguments after the - caches. The constants are expected to be of type int, float, str, or None. + The constant arguments are passed to the attention op as positional arguments after the + caches. Dynamic inputs from ``get_dynamic_inputs`` are passed separately as kwargs. + Cached attention op signatures should keep these constant parameters before any dynamic + tensor inputs so the transform's mixed calling convention binds correctly. + The constants are expected to be of type int, float, str, or None. """ return [] + @classmethod + def get_layer_idx(cls, source_attn_node: Node) -> Optional[int]: + """Return the logical layer index associated with a source attention node, if any.""" + return _extract_optional_op_arg(source_attn_node, "layer_idx") + + @classmethod + def get_shared_kv_source_layer_idx(cls, source_attn_node: Node) -> Optional[int]: + """Return the KV source layer for a shared-KV attention node, if any.""" + return _extract_optional_op_arg(source_attn_node, "shared_kv_source_layer_idx") + + @classmethod + def get_dynamic_inputs(cls, source_attn_node: Node) -> Dict[str, Optional[Node]]: + """Provide backend-owned dynamic tensor inputs forwarded to the cached attention op. + + Returns a mapping from keyword argument name to the corresponding FX node + (or ``None``). These are passed as **kwargs** to the cached attention op, + so custom op signatures should place them after trailing constant parameters. + """ + return {} + @staticmethod def resolve_cache_dtype(dtype_config: str, fallback_dtype: torch.dtype) -> torch.dtype: """Resolve cache dtype from KvCacheConfig dtype string to torch.dtype. diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index e7ab3a61a3e..d231de35136 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, List import torch @@ -165,10 +166,15 @@ def _template_moe_alltoall( def _resolve_torch_fn(act_fn: ActivationType) -> Callable[[torch.Tensor], torch.Tensor]: """ Returns an elementwise activation callable matching the given activation function. - Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2 + Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2, ActivationType.Gelu """ - assert act_fn in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2], ( - f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu' or 'relu2'." + assert act_fn in [ + ActivationType.Silu, + ActivationType.Swiglu, + ActivationType.Relu2, + ActivationType.Gelu, + ], ( + f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu', 'relu2', or 'gelu'." ) torch_fn = None if act_fn == ActivationType.Silu or act_fn == ActivationType.Swiglu: @@ -179,6 +185,8 @@ def relu2(x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) torch_fn = relu2 + elif act_fn == ActivationType.Gelu: + torch_fn = partial(F.gelu, approximate="tanh") return torch_fn diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 39d1359dece..0060cb2e78a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -401,12 +401,15 @@ def trtllm_moe_fused( activation_type = ActivationType.Swiglu if is_gated_mlp: - # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) + # Gated MLP accepts either SiLU/SwiGLU or GELU/GEGLU style gating. if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: activation_type = ActivationType.Swiglu + elif act_fn in [ActivationType.Gelu, ActivationType.Geglu]: + activation_type = ActivationType.Geglu else: raise ValueError( - f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." + f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. " + "Use 'silu' or 'gelu'." ) else: # For non-gated MLP with ReLU^2 @@ -466,14 +469,24 @@ def trtllm_moe_fused_fake( def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None: - assert (is_gated_mlp and act_fn in [ActivationType.Silu, ActivationType.Swiglu]) or ( - not is_gated_mlp and act_fn in [ActivationType.Relu2, ActivationType.Silu] - ), ( + assert ( + is_gated_mlp + and act_fn + in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Gelu, ActivationType.Geglu] + ) or (not is_gated_mlp and act_fn in [ActivationType.Relu2, ActivationType.Silu]), ( f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'. " - f"Supported combinations: gated mlp with silu or mlp with relu2 or silu." + f"Supported combinations: gated mlp with silu or gelu, or mlp with relu2 or silu." ) +def _normalize_trtllm_act_fn(act_fn: int) -> int: + if act_fn == ActivationType.Silu: + return ActivationType.Swiglu + if act_fn == ActivationType.Gelu: + return ActivationType.Geglu + return act_fn + + @torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused", mutates_args=()) def trtllm_quant_fp8_moe_fused( x: torch.Tensor, @@ -521,7 +534,7 @@ def trtllm_quant_fp8_moe_fused( """ _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) - act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn + act_fn = _normalize_trtllm_act_fn(act_fn) # Store original shape and flatten to 2D x_shape = x.shape @@ -663,7 +676,7 @@ def trtllm_quant_nvfp4_moe_fused( assert fc2_weight_blockscale_fp8.ndim == 3, "fc2_weight_blockscale_fp8 must be 3D" _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) - act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn + act_fn = _normalize_trtllm_act_fn(act_fn) # quant_scales is described by this code: # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 @@ -798,7 +811,7 @@ def trtllm_quant_finegrained_fp8_moe_fused( Output tensor of shape (B, H) or (B, S, H) """ _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) - act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn + act_fn = _normalize_trtllm_act_fn(act_fn) x_shape = x.shape x2d = x.view(-1, x_shape[-1]) @@ -934,6 +947,11 @@ def trtllm_nvfp4_trtllm_gen_moe_fused( apply_routing_on_input: bool = False, ) -> torch.Tensor: _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) + if act_fn in (ActivationType.Gelu, ActivationType.Geglu): + raise ValueError( + f"NVFP4 TRTLLM-Gen MoE does not support activation " + f"'{ActivationType(act_fn).name}'. Only Silu/Swiglu and Relu2 are supported." + ) x_shape = x.shape x2d = x.view(-1, x_shape[-1]) diff --git a/tensorrt_llm/_torch/auto_deploy/export/export.py b/tensorrt_llm/_torch/auto_deploy/export/export.py index 79313a28bb9..730576528b6 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/export.py +++ b/tensorrt_llm/_torch/auto_deploy/export/export.py @@ -14,7 +14,7 @@ from ..utils._graph import canonicalize_graph, lift_to_meta, load_buffers_and_params, tree_to from ..utils.logger import ad_logger -from ..utils.node_utils import is_op +from ..utils.node_utils import get_op_schema, is_op from .interface import apply_export_patches if TYPE_CHECKING: @@ -276,7 +276,7 @@ def _expand_moe_experts_in_graph( # Collect indices of List[Tensor] arguments from the op schema – these # are the per-expert weight / scale lists. op = node.target - schema = op._schema if hasattr(op, "_schema") else next(iter(op._schemas.values())) + schema = get_op_schema(op) _tensor_list_types = ("Tensor[]", "List[Tensor]") list_arg_indices = [ i diff --git a/tensorrt_llm/_torch/auto_deploy/llm.py b/tensorrt_llm/_torch/auto_deploy/llm.py index 905efe1135f..b4d6e7bb5f4 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm.py +++ b/tensorrt_llm/_torch/auto_deploy/llm.py @@ -1,4 +1,5 @@ import types +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch @@ -151,6 +152,11 @@ def _build_model(self): # _autodeploy backend. super()._build_model() + # AutoDeploy resolves HF checkpoints through the factory prefetch path rather than the + # shared model loader. Preserve that resolved snapshot path for downstream utilities such + # as multimodal lm-eval, which read config.json from ``_hf_model_dir``. + self._hf_model_dir = Path(self.factory.model) + # now correct input processor assert isinstance(self.input_processor, DefaultInputProcessor) assert self.tokenizer is None or isinstance(self.tokenizer, TransformersTokenizer) diff --git a/tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py b/tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py index 96ef220ef58..b8790c1b018 100644 --- a/tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py +++ b/tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py @@ -82,12 +82,12 @@ def _cleanup_temp_files(): "ad.rsqrt": lambda a: f"(1.0 / tl.sqrt({a}))", "ad.sqrt": lambda a: f"tl.sqrt({a})", "ad.silu": lambda a: f"({a} * tl.sigmoid({a}))", - "ad.gelu": lambda a: f"({a} * 0.5 * (1.0 + tl.math.erf({a} * 0.7071067811865476)))", + "ad.gelu": lambda a: f"({a} * 0.5 * (1.0 + tl.extra.cuda.libdevice.erf({a} * 0.7071067811865476)))", "ad.relu": lambda a: f"tl.maximum({a}, 0)", - "ad.tanh": lambda a: f"tl.math.tanh({a})", + "ad.tanh": lambda a: f"tl.extra.cuda.libdevice.tanh({a})", "ad.sigmoid": lambda a: f"tl.sigmoid({a})", - "ad.exp": lambda a: f"tl.math.exp({a})", - "ad.softplus": lambda a: f"tl.math.log(1.0 + tl.math.exp({a}))", + "ad.exp": lambda a: f"tl.extra.cuda.libdevice.exp({a})", + "ad.softplus": lambda a: f"tl.extra.cuda.libdevice.log(1.0 + tl.extra.cuda.libdevice.exp({a}))", "ad.reduce_sum": lambda a: f"tl.sum({a}, 0)", "ad.reduce_mean": lambda a, ncols: f"(tl.sum({a}, 0) * (1.0 / {ncols}))", "ad.splat": None, # handled specially — just inline the scalar value @@ -253,7 +253,21 @@ def generate_kernel_from_subgraph(subgraph) -> Callable: # last dim < N_COLS, e.g. a gating scalar of shape (-1, 1) in a subgraph # whose row width is 2048). Both categories need a load pattern that # avoids reading past the end of the actual data. + # Scalar-like inputs (rank-0 OR broadcast with last-dim 1, e.g. shape [1]) + # need a single-element load; Triton broadcasts the scalar automatically. broadcast_flags = [_is_broadcast_input(inp, max_rank) for inp in subgraph.inputs] + scalar_flags = [] + for i, inp in enumerate(subgraph.inputs): + rank = _get_tensor_rank(inp) + if rank == 0: + scalar_flags.append(True) + elif broadcast_flags[i] and isinstance(inp.type, TensorType): + shape = inp.type.get_shape() + # Broadcast input whose last dim is 1 (e.g. layer_scalar shape [1]) + # must be loaded as a single element, not a vector. + scalar_flags.append(not shape or shape[-1] == 1) + else: + scalar_flags.append(False) narrow_flags = [] for inp in subgraph.inputs: if isinstance(inp.type, TensorType): @@ -273,7 +287,10 @@ def generate_kernel_from_subgraph(subgraph) -> Callable: # Broadcast (1D) inputs (e.g. weights) are offset by group only: # ptr + pid_group * N_COLS + offs for i, inp in enumerate(subgraph.inputs): - if broadcast_flags[i]: + if scalar_flags[i]: + # Rank-0 (scalar) tensor: load single element, Triton broadcasts automatically. + body_lines.append(f" v{i} = tl.load(in{i}_ptr).to(tl.float32)") + elif broadcast_flags[i]: if grouped_mode: body_lines.append( f" v{i} = tl.load(in{i}_ptr + group_off + offs, mask=mask).to(tl.float32)" @@ -320,7 +337,9 @@ def generate_kernel_from_subgraph(subgraph) -> Callable: else: exp_val = float(str(exp_attr)) result_name = f"t{temp_counter}" - body_lines.append(f" {result_name} = tl.math.pow({base_name}, {exp_val})") + body_lines.append( + f" {result_name} = tl.extra.cuda.libdevice.pow({base_name}, {exp_val})" + ) temp_counter += 1 for r in op.results: val_names[id(r)] = result_name @@ -530,6 +549,20 @@ def generate_kernel_from_subgraph(subgraph) -> Callable: "import torch\n\n" + kernel_src + "\n" + launcher_src ) + import logging as _logging + import os as _os + + _logging.getLogger("mlir_codegen").info("Generated kernel %s:\n%s", sg_hash, full_src) + + # Optional: dump kernel source to a directory for offline inspection. + # Controlled by the AD_DUMP_KERNELS_DIR environment variable. + _kernel_dump_dir = _os.environ.get("AD_DUMP_KERNELS_DIR") + if _kernel_dump_dir: + _dump_path = _os.path.join(_kernel_dump_dir, f"triton_gen_{sg_hash}.py") + _os.makedirs(_kernel_dump_dir, exist_ok=True) + with open(_dump_path, "w") as _f: + _f.write(full_src) + import importlib.util import tempfile diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py index 77b5b8c91f1..35a485ba2d9 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py @@ -1,4 +1,6 @@ from .modeling_deepseek import DeepSeekV3ForCausalLM +from .modeling_gemma3n import Gemma3nForCausalLM, Gemma3nForConditionalGeneration +from .modeling_gemma4 import Gemma4ForCausalLM, Gemma4ForConditionalGeneration from .modeling_glm4_moe_lite import Glm4MoeLiteForCausalLM from .modeling_kimi_k2 import KimiK2ForCausalLM, KimiK25ForConditionalGeneration from .modeling_mistral3 import Mistral3ForConditionalGenerationAD, Mistral4ForCausalLM @@ -8,6 +10,10 @@ __all__ = ( "DeepSeekV3ForCausalLM", + "Gemma3nForCausalLM", + "Gemma3nForConditionalGeneration", + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", "Glm4MoeLiteForCausalLM", "KimiK2ForCausalLM", "KimiK25ForConditionalGeneration", diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py new file mode 100644 index 00000000000..b2ab5965d22 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py @@ -0,0 +1,852 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Slimmed down Gemma 3n text implementation for AutoDeploy export. + +This implementation follows the Hugging Face Gemma 3n text stack closely while +keeping only the prefill path needed by AutoDeploy. The outer +``Gemma3nForConditionalGeneration`` wrapper preserves the HF text checkpoint +layout (``model.language_model.*`` + ``lm_head``) and drops unsupported +vision/audio tower weights at load time. The forward path intentionally +supports only text-only export. +""" + +import copy +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.models.gemma3n.configuration_gemma3n import ( + Gemma3nAudioConfig, + Gemma3nConfig, + Gemma3nTextConfig, + Gemma3nVisionConfig, +) +from transformers.utils import ModelOutput + +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory + + +def _build_rope_cache( + config: Gemma3nTextConfig, +) -> Tuple[torch.Tensor, torch.Tensor, float]: + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) + else: + rope_type = "default" + + inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, device=None) + positions = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype) + freqs = torch.outer(positions, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin(), attention_scaling + + +class Gemma3nRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_buffer("weight", torch.ones(dim), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.torch_rmsnorm(x, self.weight, self.eps) + + +class Gemma3nTextScaledWordEmbedding(nn.Embedding): + def __init__( + self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float + ): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale.to(dtype=self.weight.dtype) + + +class Gemma3nTextLaurelBlock(nn.Module): + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + self.linear_left = nn.Linear(config.hidden_size, config.laurel_rank, bias=False) + self.linear_right = nn.Linear(config.laurel_rank, config.hidden_size, bias=False) + self.post_laurel_norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + laurel_hidden_states = self.linear_left(hidden_states) + laurel_hidden_states = self.linear_right(laurel_hidden_states) + laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states) + return hidden_states + laurel_hidden_states + + +class Gemma3nTextMLP(nn.Module): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size[layer_idx] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + self.activation_sparsity = config.activation_sparsity_pattern[layer_idx] + if self.activation_sparsity > 0.0: + normal_dist = torch.distributions.normal.Normal(0, 1) + std_multiplier = normal_dist.icdf( + torch.tensor(self.activation_sparsity, dtype=torch.float32) + ) + self.register_buffer( + "activation_sparsity_std_multiplier", std_multiplier, persistent=False + ) + else: + self.register_buffer( + "activation_sparsity_std_multiplier", + torch.tensor(0.0, dtype=torch.float32), + persistent=False, + ) + + def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: + std_multiplier = self.activation_sparsity_std_multiplier.to( + device=inputs.device, dtype=inputs.dtype + ) + inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) + inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) + cutoff = inputs_mean + inputs_std * std_multiplier + return torch.relu(inputs - cutoff) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_proj = self.gate_proj(hidden_states) + if self.activation_sparsity > 0.0: + gate_proj = self._gaussian_topk(gate_proj) + activations = self.act_fn(gate_proj) + up_proj = self.up_proj(hidden_states) + return self.down_proj(activations * up_proj) + + +class Gemma3nTextAltUp(nn.Module): + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + self.config = config + self.correct_output_scale = nn.Parameter(torch.zeros(config.hidden_size)) + self.correction_coefs = nn.Linear( + config.altup_num_inputs, config.altup_num_inputs, bias=False + ) + self.prediction_coefs = nn.Linear( + config.altup_num_inputs, config.altup_num_inputs**2, bias=False + ) + self.modality_router = nn.Linear(config.hidden_size, config.altup_num_inputs, bias=False) + self.router_norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.register_buffer( + "router_input_scale", torch.tensor(config.hidden_size**-1.0), persistent=False + ) + + def compute_router_modalities(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_inputs = self.router_norm(hidden_states) * self.router_input_scale + routed = self.modality_router(router_inputs) + return torch.tanh(routed.float()).type_as(hidden_states) + + def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: + modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx]) + all_coefs = self.prediction_coefs(modalities).reshape( + *modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs + ) + all_coefs = all_coefs.permute(0, 1, 3, 2) + predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs) + predictions = predictions.permute(3, 0, 1, 2) + return (predictions + hidden_states).contiguous().type_as(hidden_states) + + def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor: + modalities = self.compute_router_modalities(activated) + innovation = activated - predictions[self.config.altup_active_idx] + innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) + all_coefs = self.correction_coefs(modalities) + 1.0 + all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1) + corrected = torch.mul(innovation, all_coefs) + return (corrected + predictions).contiguous().type_as(activated) + + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: + return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as( + corrected + ) + + +class Gemma3nTextRotaryEmbedding(nn.Module): + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + cos, sin, attention_scaling = _build_rope_cache(config) + self.register_buffer("_ad_cos_cached", cos * attention_scaling, persistent=False) + self.register_buffer("_ad_sin_cached", sin * attention_scaling, persistent=False) + + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + del position_ids + cos = self._ad_cos_cached.to(dtype=x.dtype, device=x.device) + sin = self._ad_sin_cached.to(dtype=x.dtype, device=x.device) + return cos, sin + + +def _slice_rope_cache( + position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + cos, sin = position_embeddings + return cos[position_ids], sin[position_ids] + + +class Gemma3nMultimodalEmbedder(nn.Module): + def __init__( + self, + multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig, + text_config: Gemma3nTextConfig, + ): + super().__init__() + self.multimodal_hidden_size = multimodal_config.hidden_size + self.eps = multimodal_config.rms_norm_eps + self.vocab_offset = multimodal_config.vocab_offset + self.vocab_size = multimodal_config.vocab_size + self.text_hidden_size = text_config.hidden_size + + self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size) + self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) + self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) + self.embedding_projection = nn.Linear( + self.multimodal_hidden_size, self.text_hidden_size, bias=False + ) + self.embedding_post_projection_norm = Gemma3nRMSNorm( + self.text_hidden_size, eps=self.eps, with_scale=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if inputs_embeds is not None: + embeddings = self.soft_embedding_norm(inputs_embeds) + else: + embeddings = self.embedding(input_ids - self.vocab_offset) + embeddings = self.hard_embedding_norm(embeddings) + embeddings = self.embedding_projection(embeddings) + return self.embedding_post_projection_norm(embeddings) + + +class Gemma3nTextAttention(nn.Module): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_sliding else None + first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx]) + ) + else: + self.kv_shared_layer_index = None + + self.q_proj = nn.Linear( + config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + + self.q_norm = Gemma3nRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3nRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma3nRMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + query_states = self.q_proj(hidden_states).view( + batch_size, seq_len, self.num_heads, self.head_dim + ) + key_states = self.k_proj(hidden_states).view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ) + value_states = self.v_proj(hidden_states).view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + value_states = self.v_norm(value_states) + + cos, sin = position_embeddings + query_states, key_states = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin( + query_states, + key_states, + cos, + sin, + 2, + ) + + attn_output = torch.ops.auto_deploy.torch_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + 1.0, + None, + self.sliding_window, + None, + "bsnd", + self.layer_idx, + self.kv_shared_layer_index if self.is_kv_shared_layer else None, + ) + attn_output = attn_output.reshape(batch_size, seq_len, -1) + return self.o_proj(attn_output) + + +class Gemma3nTextDecoderLayer(nn.Module): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Gemma3nTextAttention(config, layer_idx) + self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx) + self.input_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma3nRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.act_fn = ACT2FN[config.hidden_activation] + + self.altup = Gemma3nTextAltUp(config) + self.laurel = Gemma3nTextLaurelBlock(config) + self.per_layer_input_gate = nn.Linear( + config.hidden_size, config.hidden_size_per_layer_input, bias=False + ) + self.per_layer_projection = nn.Linear( + config.hidden_size_per_layer_input, config.hidden_size, bias=False + ) + self.post_per_layer_input_norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: Tuple[torch.Tensor, torch.Tensor], + position_embeddings_local: Tuple[torch.Tensor, torch.Tensor], + per_layer_input: torch.Tensor, + ) -> torch.Tensor: + predictions = self.altup.predict(hidden_states) + active_idx = getattr(self.altup, "active_idx", self.altup.config.altup_active_idx) + active_prediction = predictions[active_idx] + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + attn = self.self_attn(active_prediction_normed, position_embeddings) + attn = self.post_attention_layernorm(attn) + + attn_gated = active_prediction + attn + attn_laurel = (attn_gated + laurel_output) / math.sqrt(2.0) + + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + corrected_predictions = self.altup.correct(predictions, attn_laurel + attn_ffw_norm) + + first_prediction = corrected_predictions[active_idx].clone() + if self.altup.config.altup_correct_scale: + first_prediction = self.altup.scale_corrected_output(first_prediction) + + first_prediction = self.per_layer_input_gate(first_prediction) + first_prediction = self.act_fn(first_prediction) + first_prediction = torch.multiply(first_prediction, per_layer_input) + first_prediction = self.per_layer_projection(first_prediction) + first_prediction = self.post_per_layer_input_norm(first_prediction) + for idx in range(corrected_predictions.shape[0]): + if idx != active_idx: + corrected_predictions[idx] += first_prediction + return corrected_predictions + + +@dataclass +class Gemma3nTextOutput(ModelOutput): + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class Gemma3nCausalLMOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +@dataclass +class Gemma3nConditionalOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +class Gemma3nTextPreTrainedModel(PreTrainedModel): + config_class = Gemma3nTextConfig + base_model_prefix = "model" + _no_split_modules = ["Gemma3nTextDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma3nTextAltUp): + module.correct_output_scale.data.zero_() + + +class Gemma3nPreTrainedModel(PreTrainedModel): + config_class = Gemma3nConfig + base_model_prefix = "model" + _no_split_modules = ["Gemma3nTextDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma3nTextAltUp): + module.correct_output_scale.data.zero_() + + +class Gemma3nTextModel(Gemma3nTextPreTrainedModel): + def __init__(self, config: Gemma3nTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.embed_tokens = Gemma3nTextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=config.hidden_size**0.5, + ) + self.layers = nn.ModuleList( + [ + Gemma3nTextDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Gemma3nTextRotaryEmbedding(config) + + local_config = copy.deepcopy(config) + local_config.rope_theta = local_config.rope_local_base_freq + local_config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3nTextRotaryEmbedding(local_config) + + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=config.hidden_size_per_layer_input**0.5, + ) + self.per_layer_model_projection = nn.Linear( + config.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + ) + self.per_layer_projection_norm = Gemma3nRMSNorm( + config.hidden_size_per_layer_input, eps=config.rms_norm_eps + ) + self.altup_projections = nn.ModuleList( + [ + nn.Linear(config.hidden_size, config.hidden_size, bias=False) + for _ in range(1, config.altup_num_inputs) + ] + ) + self.altup_unembed_projections = nn.ModuleList( + [ + nn.Linear(config.hidden_size, config.hidden_size, bias=False) + for _ in range(1, config.altup_num_inputs) + ] + ) + self.register_buffer( + "per_layer_projection_scale", torch.tensor(config.hidden_size**-0.5), persistent=False + ) + self.register_buffer( + "per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False + ) + self.register_buffer("_ad_eps", torch.tensor(1e-5), persistent=False) + self._register_load_state_dict_pre_hook(self._slice_reduced_layer_weights) + self.post_init() + + def _slice_reduced_layer_weights( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + del local_metadata, strict, missing_keys, unexpected_keys, error_msgs + keys_to_params = { + prefix + "embed_tokens_per_layer.weight": self.embed_tokens_per_layer.weight, + prefix + "per_layer_model_projection.weight": self.per_layer_model_projection.weight, + } + for state_key, target_param in keys_to_params.items(): + if state_key not in state_dict: + continue + checkpoint_weight = state_dict[state_key] + if checkpoint_weight.ndim != 2: + continue + if ( + checkpoint_weight.shape[0] == target_param.shape[0] + and checkpoint_weight.shape[1] > target_param.shape[1] + ): + state_dict[state_key] = checkpoint_weight[:, : target_param.shape[1]] + elif ( + checkpoint_weight.shape[0] > target_param.shape[0] + and checkpoint_weight.shape[1] == target_param.shape[1] + ): + state_dict[state_key] = checkpoint_weight[: target_param.shape[0]] + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.embed_tokens_per_layer(input_ids).reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor], + ) -> torch.Tensor: + per_layer_projection = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale.to( + dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + if per_layer_projection.shape != per_layer_inputs.shape: + per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] + + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to( + dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma3nTextOutput: + del kwargs + assert position_ids is not None, "position_ids must be provided" + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + per_layer_inputs = self.get_per_layer_inputs(input_ids) + + assert inputs_embeds is not None + per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs) + position_embeddings_global = _slice_rope_cache( + self.rotary_emb(inputs_embeds, position_ids), position_ids + ) + position_embeddings_local = _slice_rope_cache( + self.rotary_emb_local(inputs_embeds, position_ids), position_ids + ) + + target_magnitude = torch.mean(inputs_embeds**2, dim=-1, keepdim=True) ** 0.5 + hidden_states = [inputs_embeds] + for projection in self.altup_projections: + current_hidden_state = projection(inputs_embeds).to(dtype=inputs_embeds.dtype) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt( + torch.maximum( + new_magnitude, + self._ad_eps.to(device=inputs_embeds.device, dtype=new_magnitude.dtype), + ) + ) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude + hidden_states.append(current_hidden_state) + hidden_states = torch.stack(hidden_states, dim=0) + + for decoder_layer in self.layers: + layer_per_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] + hidden_states = decoder_layer( + hidden_states, + position_embeddings_global, + position_embeddings_local, + layer_per_input, + ) + + target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5 + reduced_hidden_states = [hidden_states[0]] + for i, projection in enumerate(self.altup_unembed_projections, start=1): + current_hidden_state = projection(hidden_states[i]).to(dtype=inputs_embeds.dtype) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt( + torch.maximum( + new_magnitude, + self._ad_eps.to(device=inputs_embeds.device, dtype=new_magnitude.dtype), + ) + ) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude + reduced_hidden_states.append(current_hidden_state) + + hidden_states = torch.mean(torch.stack(reduced_hidden_states), dim=0) + hidden_states = self.norm(hidden_states) + return Gemma3nTextOutput(last_hidden_state=hidden_states) + + +class Gemma3nForCausalLM(Gemma3nTextPreTrainedModel, GenerationMixin): + config_class = Gemma3nTextConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Gemma3nTextConfig, **kwargs): + del kwargs + super().__init__(config) + self.model = Gemma3nTextModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma3nCausalLMOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + logits = self.lm_head(outputs.last_hidden_state) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + return Gemma3nCausalLMOutput(logits=logits) + + +class Gemma3nModel(Gemma3nPreTrainedModel): + def __init__(self, config: Gemma3nConfig): + super().__init__(config) + self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input + self.vision_tower = nn.Module() + self.language_model = Gemma3nTextModel(config.text_config) + self.audio_tower = nn.Module() + self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config) + self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) + self._register_load_state_dict_pre_hook(self._drop_unsupported_multimodal_tower_weights) + self.post_init() + + @staticmethod + def _drop_unsupported_multimodal_tower_weights( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + del local_metadata, strict, missing_keys, unexpected_keys, error_msgs + unsupported_prefixes = ( + prefix + "vision_tower.", + prefix + "audio_tower.", + ) + for key in list(state_dict): + if key.startswith(unsupported_prefixes): + state_dict.pop(key) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma3nTextOutput: + del kwargs + del input_features_mask + assert position_ids is not None, "position_ids must be provided" + if pixel_values is not None or input_features is not None: + raise NotImplementedError( + "Gemma3n multimodal inputs are not supported by the current AutoDeploy export path. " + "Use text-only prompts for this onboarding." + ) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + per_layer_inputs = None + if input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, input_ids < self.vocab_size_per_layer_input + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) + + return self.language_model( + input_ids=None, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + ) + + +class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin): + config_class = Gemma3nConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Gemma3nConfig, **kwargs): + del kwargs + super().__init__(config) + self.model = Gemma3nModel(config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=False + ) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def get_decoder(self): + return self.model.get_decoder() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma3nConditionalOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + input_features=input_features, + input_features_mask=input_features_mask, + **kwargs, + ) + logits = self.lm_head(outputs.last_hidden_state) + if self.config.text_config.final_logit_softcapping is not None: + logits = logits / self.config.text_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.text_config.final_logit_softcapping + return Gemma3nConditionalOutput(logits=logits) + + +AutoModelForCausalLMFactory.register_custom_model_cls("Gemma3nTextConfig", Gemma3nForCausalLM) +AutoModelForCausalLMFactory.register_custom_model_cls( + "Gemma3nConfig", Gemma3nForConditionalGeneration +) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py new file mode 100644 index 00000000000..f74123db066 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py @@ -0,0 +1,2566 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Slimmed down Gemma 4 text implementation for AutoDeploy export. + +This implementation follows the HuggingFace Gemma 4 text stack closely while +keeping only the prefill path needed by AutoDeploy. The outer +``Gemma4ForConditionalGeneration`` wrapper preserves the HF checkpoint layout +(``model.language_model.*``). The text model is exported, while the outer +wrapper remains eager so it can run Gemma4's multimodal vision merge path +before delegating to the exported language model. + +Key architectural features of Gemma 4 vs standard transformers: +- K=V attention on full-attention layers (v_proj is absent; k_proj output is + reused as value) +- Different head dimensions for full vs sliding attention (global_head_dim vs + head_dim) +- Proportional RoPE with partial_rotary_factor on full-attention layers +- Dense MLP running in parallel with Mixture-of-Experts (MoE) in every layer +- Per-layer scalar multiplier +- Final logit softcapping +""" + +import json +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from tokenizers import Tokenizer +from torch import nn +from torch.fx import GraphModule +from transformers import AutoConfig, PretrainedConfig, PreTrainedTokenizerFast +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, cached_file + +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry +from tensorrt_llm._torch.auto_deploy.models.hf import ( + AutoModelForCausalLMFactory, + AutoModelForImageTextToTextFactory, +) +from tensorrt_llm._torch.utils import ActivationType +from tensorrt_llm.inputs.content_format import ContentFormat +from tensorrt_llm.inputs.registry import ( + MULTIMODAL_PLACEHOLDER_REGISTRY, + MultimodalPlaceholderMetadata, +) + +# --------------------------------------------------------------------------- +# Bundled config classes — enables loading on transformers <5.3 where +# Gemma4 is not natively registered. +# --------------------------------------------------------------------------- + + +class Gemma4TextConfig(PretrainedConfig): + """Minimal Gemma4 text config for AutoDeploy.""" + + model_type = "gemma4_text" + + def __init__( + self, + vocab_size: int = 262_144, + hidden_size: int = 2816, + intermediate_size: int = 2112, + num_hidden_layers: int = 30, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 256, + global_head_dim: int = 512, + num_global_key_value_heads: int = 2, + hidden_activation: str = "gelu_pytorch_tanh", + max_position_embeddings: int = 131_072, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + attention_k_eq_v: bool = True, + sliding_window: int = 1024, + layer_types: Optional[list] = None, + rope_parameters: Optional[dict] = None, + final_logit_softcapping: Optional[float] = 30.0, + hidden_size_per_layer_input: int = 0, + num_kv_shared_layers: int = 0, + use_double_wide_mlp: bool = False, + use_bidirectional_attention: Optional[str] = "vision", + enable_moe_block: bool = True, + num_experts: Optional[int] = 128, + top_k_experts: Optional[int] = 8, + expert_intermediate_size: Optional[int] = 704, + stream_and_decode_in_f32: bool = True, + vocab_size_per_layer_input: int = 262_144, + routed_layer_pattern: Optional[list] = None, + pad_token_id: Optional[int] = 0, + eos_token_id=1, + bos_token_id: Optional[int] = 2, + tie_word_embeddings: bool = True, + initializer_range: float = 0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.global_head_dim = global_head_dim + self.num_global_key_value_heads = num_global_key_value_heads + self.hidden_activation = hidden_activation + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.attention_k_eq_v = attention_k_eq_v + self.sliding_window = sliding_window + self.layer_types = layer_types or (["sliding_attention"] * num_hidden_layers) + self.rope_parameters = rope_parameters or { + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1_000_000.0, + "partial_rotary_factor": 0.25, + }, + "sliding_attention": {"rope_type": "default", "rope_theta": 10_000.0}, + } + self.final_logit_softcapping = final_logit_softcapping + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.num_kv_shared_layers = num_kv_shared_layers + self.use_double_wide_mlp = use_double_wide_mlp + self.use_bidirectional_attention = use_bidirectional_attention + self.enable_moe_block = enable_moe_block + self.num_experts = num_experts + self.top_k_experts = top_k_experts + self.expert_intermediate_size = expert_intermediate_size + self.stream_and_decode_in_f32 = stream_and_decode_in_f32 + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.routed_layer_pattern = routed_layer_pattern + self.initializer_range = initializer_range + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + bos_token_id=bos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Gemma4VisionConfig(PretrainedConfig): + """Gemma4 vision config.""" + + model_type = "gemma4_vision" + + def __init__( + self, + hidden_size: int = 768, + intermediate_size: int = 3072, + num_hidden_layers: int = 16, + num_attention_heads: int = 12, + num_key_value_heads: int = 12, + head_dim: int = 64, + hidden_activation: str = "gelu_pytorch_tanh", + rms_norm_eps: float = 1e-6, + max_position_embeddings: int = 131_072, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rope_parameters: Optional[dict] = None, + pooling_kernel_size: int = 3, + patch_size: int = 16, + position_embedding_size: int = 10 * 1024, + standardize: bool = False, + initializer_range: float = 0.02, + **kwargs, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_activation = hidden_activation + self.rms_norm_eps = rms_norm_eps + self.max_position_embeddings = max_position_embeddings + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_parameters = rope_parameters or {"rope_type": "default", "rope_theta": 100.0} + self.pooling_kernel_size = pooling_kernel_size + self.patch_size = patch_size + self.position_embedding_size = position_embedding_size + self.standardize = standardize + self.initializer_range = initializer_range + super().__init__(**kwargs) + + +class Gemma4Config(PretrainedConfig): + """Top-level Gemma4 multimodal config.""" + + model_type = "gemma4" + + def __init__( + self, + text_config=None, + vision_config=None, + audio_config=None, + initializer_range: float = 0.02, + boi_token_id: int = 255_999, + eoi_token_id: int = 258_882, + image_token_id: int = 258_880, + video_token_id: int = 258_884, + audio_token_id: int = 258_881, + tie_word_embeddings: bool = True, + **kwargs, + ): + self.initializer_range = initializer_range + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.audio_token_id = audio_token_id + if text_config is None: + self.text_config = Gemma4TextConfig() + elif isinstance(text_config, dict): + self.text_config = Gemma4TextConfig(**text_config) + else: + self.text_config = text_config + + if vision_config is None: + self.vision_config = Gemma4VisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = Gemma4VisionConfig(**vision_config) + else: + self.vision_config = vision_config + + self.audio_config = audio_config + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +AutoConfig.register("gemma4", Gemma4Config, exist_ok=True) +AutoConfig.register("gemma4_text", Gemma4TextConfig, exist_ok=True) +AutoConfig.register("gemma4_vision", Gemma4VisionConfig, exist_ok=True) + +# --------------------------------------------------------------------------- +# RoPE cache builder +# --------------------------------------------------------------------------- + + +def _build_rope_cache( + config: Gemma4TextConfig, + layer_type: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pre-compute cos/sin RoPE cache for the given layer type.""" + rope_params = config.rope_parameters[layer_type] + rope_type = rope_params.get("rope_type", "default") + base = rope_params["rope_theta"] + factor = rope_params.get("factor", 1.0) + attention_scaling = 1.0 + + if rope_type == "default": + dim = config.head_dim + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + elif rope_type == "proportional": + # Proportional RoPE: only partial_rotary_factor of head dims are rotated, + # remaining dims get zero inv_freq → cos=1, sin=0 (no rotation). + head_dim = config.global_head_dim + rope_proportion = rope_params.get("partial_rotary_factor", 1.0) + rope_angles = int(rope_proportion * head_dim // 2) + inv_freq_rotated = 1.0 / ( + base ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.float) / head_dim) + ) + nope_angles = head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq = torch.cat( + (inv_freq_rotated, torch.zeros(nope_angles, dtype=torch.float32)), + dim=0, + ) + else: + inv_freq = inv_freq_rotated + inv_freq = inv_freq / factor + else: + # Fallback to HF ROPE_INIT_FUNCTIONS for other types (e.g. yarn, longrope) + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + inv_freq, attention_scaling = rope_init_fn(config, device=None, layer_type=layer_type) + + positions = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype) + freqs = torch.outer(positions, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos() * attention_scaling, emb.sin() * attention_scaling + + +# --------------------------------------------------------------------------- +# Basic building blocks +# --------------------------------------------------------------------------- + + +class Gemma4RMSNorm(nn.Module): + """RMSNorm matching HF Gemma4 (transformers >= 5.5). + + The checkpoint stores effective weights directly — no ``+1.0`` offset. + Uses the ``torch_rmsnorm`` canonical op for AD transform compatibility. + """ + + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_buffer("weight", torch.ones(dim), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.torch_rmsnorm(x, self.weight, self.eps) + + +class Gemma4TextScaledWordEmbedding(nn.Embedding): + def __init__( + self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float + ): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale.to(dtype=self.weight.dtype) + + +class Gemma4ClippableLinear(nn.Module): + """Wrapper matching the upstream Gemma4 ``*.linear.weight`` checkpoint layout.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = False): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.linear(hidden_states) + + +class Gemma4RotaryEmbedding(nn.Module): + """Pre-computed RoPE cache for a single layer type (global or local).""" + + def __init__(self, config: Gemma4TextConfig, layer_type: str): + super().__init__() + ( + cos, + sin, + ) = _build_rope_cache(config, layer_type) + self.register_buffer("_ad_cos_cached", cos, persistent=False) + self.register_buffer("_ad_sin_cached", sin, persistent=False) + + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos = self._ad_cos_cached[position_ids].to(dtype=x.dtype, device=x.device) + sin = self._ad_sin_cached[position_ids].to(dtype=x.dtype, device=x.device) + return cos, sin + + +# --------------------------------------------------------------------------- +# Vision tower +# --------------------------------------------------------------------------- + + +def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch_size, num_key_value_heads, seq_len, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch_size, num_key_value_heads, n_rep, seq_len, head_dim + ) + return hidden_states.reshape(batch_size, num_key_value_heads * n_rep, seq_len, head_dim) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> torch.Tensor: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (x * cos) + (_rotate_half(x) * sin) + + +def _apply_multidimensional_rope( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + unsqueeze_dim: int = 2, +) -> torch.Tensor: + ndim = position_ids.shape[-1] + num_channels = x.shape[-1] + num_rotated_channels_per_dim = 2 * (num_channels // (2 * ndim)) + if num_rotated_channels_per_dim <= 0: + raise ValueError( + f"Invalid Gemma4 vision RoPE configuration: num_channels={num_channels}, ndim={ndim}" + ) + + split_sizes = [num_rotated_channels_per_dim] * ndim + x_parts = torch.split(x, split_sizes, dim=-1) + cos_parts = torch.split(cos, split_sizes, dim=-1) + sin_parts = torch.split(sin, split_sizes, dim=-1) + outputs = [ + _apply_rotary_pos_emb( + x=x_parts[idx], + cos=cos_parts[idx], + sin=sin_parts[idx], + unsqueeze_dim=unsqueeze_dim, + ) + for idx in range(ndim) + ] + return torch.cat(outputs, dim=-1) + + +class Gemma4VisionPatchEmbedder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.patch_size = config.patch_size + self.position_embedding_size = config.position_embedding_size + self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False) + self.position_embedding_table = nn.Parameter( + torch.ones(2, self.position_embedding_size, self.hidden_size) + ) + + def _position_embeddings( + self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor + ) -> torch.Tensor: + clamped_positions = pixel_position_ids.clamp(min=0) + position_embeddings = ( + self.position_embedding_table[0][clamped_positions[..., 0]] + + self.position_embedding_table[1][clamped_positions[..., 1]] + ) + return torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) + + def forward( + self, + pixel_values: torch.Tensor, + pixel_position_ids: torch.Tensor, + padding_positions: torch.Tensor, + ) -> torch.Tensor: + pixel_values = 2 * (pixel_values - 0.5) + hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype)) + position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions) + return hidden_states + position_embeddings + + +class Gemma4VisionPooler(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.root_hidden_size = self.hidden_size**0.5 + + def _avg_pool_by_positions( + self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + input_seq_len = hidden_states.shape[1] + kernel_size = int((input_seq_len // length) ** 0.5) + if kernel_size**2 * length != input_seq_len: + raise ValueError( + f"Cannot pool {hidden_states.shape} to {length}: incompatible kernel size" + ) + + clamped_positions = pixel_position_ids.clamp(min=0) + max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kernel_indices = torch.div(clamped_positions, kernel_size, rounding_mode="floor") + kernel_indices = kernel_indices[..., 0] + (max_x // kernel_size) * kernel_indices[..., 1] + weights = F.one_hot(kernel_indices.long(), length).float() / (kernel_size**2) + output = weights.transpose(1, 2) @ hidden_states.float() + mask = torch.logical_not((weights == 0).all(dim=1)) + return output.to(hidden_states.dtype), mask + + def forward( + self, + hidden_states: torch.Tensor, + pixel_position_ids: torch.Tensor, + padding_positions: torch.Tensor, + output_length: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if output_length is None: + output_length = hidden_states.shape[1] + if output_length > hidden_states.shape[1]: + raise ValueError("Gemma4 vision pooler cannot increase the number of tokens") + + hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0) + if hidden_states.shape[1] != output_length: + hidden_states, padding_positions = self._avg_pool_by_positions( + hidden_states, pixel_position_ids, output_length + ) + hidden_states *= self.root_hidden_size + return hidden_states, padding_positions + + +class Gemma4VisionMLP(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.gate_proj = Gemma4ClippableLinear( + config.hidden_size, config.intermediate_size, bias=False + ) + self.up_proj = Gemma4ClippableLinear( + config.hidden_size, config.intermediate_size, bias=False + ) + self.down_proj = Gemma4ClippableLinear( + config.intermediate_size, config.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) + + +class Gemma4VisionRotaryEmbedding(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + rope_theta = config.rope_parameters["rope_theta"] + spatial_dim = config.head_dim // 2 + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, spatial_dim, 2, dtype=torch.float32) / spatial_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward( + self, hidden_states: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + all_cos = [] + all_sin = [] + for dim_idx in range(2): + dim_position_ids = position_ids[:, None, :, dim_idx].float().to(hidden_states.device) + freqs = (inv_freq_expanded.to(hidden_states.device) @ dim_position_ids).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + all_cos.append(emb.cos()) + all_sin.append(emb.sin()) + cos = torch.cat(all_cos, dim=-1).to(dtype=hidden_states.dtype, device=hidden_states.device) + sin = torch.cat(all_sin, dim=-1).to(dtype=hidden_states.dtype, device=hidden_states.device) + return cos, sin + + +class Gemma4VisionAttention(nn.Module): + def __init__(self, config: Gemma4VisionConfig, layer_idx: int): + super().__init__() + del layer_idx + self.head_dim = config.head_dim + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.q_proj = Gemma4ClippableLinear( + config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = Gemma4ClippableLinear( + config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = Gemma4ClippableLinear( + config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = Gemma4ClippableLinear( + config.hidden_size, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + del position_ids + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cos, sin = position_embeddings + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)) + query_states = _apply_multidimensional_rope( + query_states, cos, sin, torch.zeros_like(cos[..., :2]) + ) + query_states = query_states.transpose(1, 2) + + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)) + key_states = _apply_multidimensional_rope( + key_states, cos, sin, torch.zeros_like(cos[..., :2]) + ) + key_states = key_states.transpose(1, 2) + + value_states = self.v_norm(self.v_proj(hidden_states).view(hidden_shape)) + value_states = value_states.transpose(1, 2) + + key_states = _repeat_kv(key_states, self.num_key_value_groups) + value_states = _repeat_kv(value_states, self.num_key_value_groups) + + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + scale=1.0, + ) + attn_output = attn_output.transpose(1, 2).contiguous().reshape(*input_shape, -1) + return self.o_proj(attn_output), None + + +class Gemma4VisionEncoderLayer(nn.Module): + def __init__(self, config: Gemma4VisionConfig, layer_idx: int): + super().__init__() + self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx) + self.mlp = Gemma4VisionMLP(config) + self.input_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: torch.LongTensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return residual + hidden_states + + +class Gemma4VisionEncoder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.rotary_emb = Gemma4VisionRotaryEmbedding(config) + self.layers = nn.ModuleList( + [ + Gemma4VisionEncoderLayer(config=config, layer_idx=i) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + pixel_position_ids: torch.LongTensor, + ) -> ModelOutput: + # Full bidirectional attention over valid patches only. + valid = attention_mask.to(torch.bool) + attention_mask_4d = valid[:, None, :, None] & valid[:, None, None, :] + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids) + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask_4d, + position_embeddings=position_embeddings, + position_ids=pixel_position_ids, + ) + return ModelOutput(last_hidden_state=hidden_states) + + +class Gemma4VisionModel(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.config = config + self.patch_embedder = Gemma4VisionPatchEmbedder(config) + self.encoder = Gemma4VisionEncoder(config) + self.pooler = Gemma4VisionPooler(config) + + if self.config.standardize: + self.register_buffer("std_bias", torch.empty(self.config.hidden_size)) + self.register_buffer("std_scale", torch.empty(self.config.hidden_size)) + + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_position_ids: torch.LongTensor, + ) -> ModelOutput: + pooling_kernel_size = self.config.pooling_kernel_size + output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size) + padding_positions = (pixel_position_ids == -1).all(dim=-1) + inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions) + output = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, + pixel_position_ids=pixel_position_ids, + ) + hidden_states, pooler_mask = self.pooler( + hidden_states=output.last_hidden_state, + pixel_position_ids=pixel_position_ids, + padding_positions=padding_positions, + output_length=output_length, + ) + hidden_states = hidden_states[pooler_mask] + if self.config.standardize: + hidden_states = (hidden_states - self.std_bias) * self.std_scale + return ModelOutput(last_hidden_state=hidden_states) + + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- + + +class Gemma4TextMLP(nn.Module): + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# MoE Router + Experts +# --------------------------------------------------------------------------- + + +class Gemma4Router(nn.Module): + """Gemma4-style MoE router: RMSNorm(no-scale) -> per-dim scale -> linear -> softmax -> topk.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(config.hidden_size)) + self.register_buffer("root_size", torch.tensor(config.hidden_size**-0.5), persistent=False) + self.eps = config.rms_norm_eps + self.top_k = config.top_k_experts + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # RMSNorm without learnable scale + normed = hidden_states.float() + normed = normed * torch.rsqrt(normed.pow(2).mean(-1, keepdim=True) + self.eps) + normed = normed.type_as(hidden_states) + # Apply scalar and per-dim scaling + normed = normed * self.root_size.to(hidden_states.dtype) + normed = normed * self.scale.to(hidden_states.dtype) + # Route + expert_scores = self.proj(normed) + probs = F.softmax(expert_scores, dim=-1) + top_k_weights, top_k_index = torch.topk(probs, k=self.top_k, dim=-1) + top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) + return top_k_weights, top_k_index + + +class Gemma4Expert(nn.Module): + """Single MoE expert: gated MLP (gate_proj, up_proj, down_proj).""" + + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + +class Gemma4MoEBlock(nn.Module): + """Mixture-of-Experts block with fused checkpoint weight conversion. + + Checkpoint stores fused parameters: + - gate_up_proj: [num_experts, 2*intermediate, hidden] + - down_proj: [num_experts, hidden, intermediate] + - per_expert_scale: [num_experts] + + We unfuse these into per-expert nn.Linear modules at load time so that + torch_moe can consume them as weight lists. + """ + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.num_experts = config.num_experts + self.intermediate_size = config.expert_intermediate_size + self.experts = nn.ModuleList( + [ + Gemma4Expert(config.hidden_size, config.expert_intermediate_size) + for _ in range(config.num_experts) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.auto_deploy.torch_moe( + hidden_states, + top_k_index, + top_k_weights, + w1_weight=[e.gate_proj.weight for e in self.experts], + w2_weight=[e.down_proj.weight for e in self.experts], + w3_weight=[e.up_proj.weight for e in self.experts], + is_gated_mlp=True, + act_fn=int(ActivationType.Gelu), + ) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class Gemma4TextAttention(nn.Module): + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_sliding else None + + # Full-attention layers may use different head dim and K=V + self.use_k_eq_v = config.attention_k_eq_v and not self.is_sliding + if not self.is_sliding and config.global_head_dim: + self.head_dim = config.global_head_dim + else: + self.head_dim = config.head_dim + + num_kv_heads = ( + config.num_global_key_value_heads if self.use_k_eq_v else config.num_key_value_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = num_kv_heads + + self.q_proj = nn.Linear( + config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = ( + None + if self.use_k_eq_v + else nn.Linear( + config.hidden_size, num_kv_heads * self.head_dim, bias=config.attention_bias + ) + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + self.q_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + # v_norm has no learnable scale + self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim) + k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + if self.v_proj is not None: + v = self.v_proj(hidden_states).view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ) + else: + v = k # K=V: reuse key as value + + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + + cos, sin = position_embeddings + q, k = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin(q, k, cos, sin, 2) + + attn_output = torch.ops.auto_deploy.torch_attention( + q, + k, + v, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + 1.0, # scale (QK norms handle scaling) + None, # sinks + self.sliding_window, + None, # logit_cap + "bsnd", + self.layer_idx, + ) + return self.o_proj(attn_output.reshape(batch_size, seq_len, -1)) + + +# --------------------------------------------------------------------------- +# Decoder layer +# --------------------------------------------------------------------------- + + +class Gemma4TextDecoderLayer(nn.Module): + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.num_experts = config.num_experts + self.expert_intermediate_size = config.expert_intermediate_size + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Gemma4TextAttention(config, layer_idx) + self.mlp = Gemma4TextMLP(config) + self.input_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.register_buffer("layer_scalar", torch.ones(1)) + + self.enable_moe_block = config.enable_moe_block + if self.enable_moe_block: + self.router = Gemma4Router(config) + self.moe = Gemma4MoEBlock(config) + self._register_load_state_dict_pre_hook(self._unfuse_moe_weights) + self.post_feedforward_layernorm_1 = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def _unfuse_moe_weights(self, state_dict, prefix, *_args, **_kwargs): + """Convert layer-level fused Gemma4 MoE checkpoint weights to per-expert weights.""" + candidates = [ + ( + prefix + "experts.gate_up_proj", + prefix + "experts.down_proj", + prefix + "router.per_expert_scale", + ), + ( + prefix + "moe.gate_up_proj", + prefix + "moe.down_proj", + prefix + "moe.per_expert_scale", + ), + ] + + gate_up_key = down_key = scale_key = None + for gate_up_candidate, down_candidate, scale_candidate in candidates: + if ( + gate_up_candidate in state_dict + and down_candidate in state_dict + and scale_candidate in state_dict + ): + gate_up_key = gate_up_candidate + down_key = down_candidate + scale_key = scale_candidate + break + + if gate_up_key is None or down_key is None or scale_key is None: + return + + gate_up = state_dict.pop(gate_up_key) # [E, 2*I, H] + down = state_dict.pop(down_key) # [E, H, I] + scale = state_dict.pop(scale_key) # [E] + + inter = self.expert_intermediate_size + for e in range(self.num_experts): + state_dict[f"{prefix}moe.experts.{e}.gate_proj.weight"] = gate_up[e, :inter, :] + state_dict[f"{prefix}moe.experts.{e}.up_proj.weight"] = gate_up[e, inter:, :] + state_dict[f"{prefix}moe.experts.{e}.down_proj.weight"] = down[e] * scale[e] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + # Self-attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_embeddings) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Feed-forward (dense MLP ± MoE) + residual = hidden_states + + if self.enable_moe_block: + # Dense MLP path + hs_dense = self.pre_feedforward_layernorm(hidden_states) + hs_dense = self.mlp(hs_dense) + hs_dense = self.post_feedforward_layernorm_1(hs_dense) + + # MoE path + hs_flat = hidden_states.reshape(-1, hidden_states.shape[-1]) + top_k_weights, top_k_index = self.router(hs_flat) + hs_moe = self.pre_feedforward_layernorm_2(hs_flat) + hs_moe = self.moe(hs_moe, top_k_index, top_k_weights) + hs_moe = hs_moe.reshape(hidden_states.shape) + hs_moe = self.post_feedforward_layernorm_2(hs_moe) + + hidden_states = hs_dense + hs_moe + else: + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = hidden_states * self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# Text model +# --------------------------------------------------------------------------- + + +@dataclass +class Gemma4TextOutput(ModelOutput): + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class Gemma4CausalLMOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +@dataclass +class Gemma4ConditionalOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +class Gemma4TextPreTrainedModel(PreTrainedModel): + config_class = Gemma4TextConfig + base_model_prefix = "model" + _no_split_modules = ["Gemma4TextDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Gemma4TextModel(Gemma4TextPreTrainedModel): + def __init__(self, config: Gemma4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.embed_tokens = Gemma4TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=config.hidden_size**0.5, + ) + self.layers = nn.ModuleList( + [Gemma4TextDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Separate RoPE caches for global (full) and local (sliding) attention + self.rotary_emb_global = Gemma4RotaryEmbedding(config, "full_attention") + self.rotary_emb_local = Gemma4RotaryEmbedding(config, "sliding_attention") + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma4TextOutput: + del kwargs + assert position_ids is not None, "position_ids must be provided" + + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + + pos_emb_global = self.rotary_emb_global(inputs_embeds, position_ids) + pos_emb_local = self.rotary_emb_local(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + for decoder_layer in self.layers: + if decoder_layer.attention_type == "sliding_attention": + pos_emb = pos_emb_local + else: + pos_emb = pos_emb_global + hidden_states = decoder_layer(hidden_states, pos_emb) + + hidden_states = self.norm(hidden_states) + return Gemma4TextOutput(last_hidden_state=hidden_states) + + +# --------------------------------------------------------------------------- +# CausalLM wrapper (text config) +# --------------------------------------------------------------------------- + + +class Gemma4ForCausalLM(Gemma4TextPreTrainedModel, GenerationMixin): + config_class = Gemma4TextConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Gemma4TextConfig, **kwargs): + del kwargs + super().__init__(config) + self.padding_idx = config.pad_token_id + self.embed_tokens = Gemma4TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=config.hidden_size**0.5, + ) + self.layers = nn.ModuleList( + [Gemma4TextDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Keep the text backbone modules directly on the CausalLM wrapper so + # checkpoint keys match the HF layout: `language_model.layers.*` + # instead of `language_model.model.layers.*`. + self.rotary_emb_global = Gemma4RotaryEmbedding(config, "full_attention") + self.rotary_emb_local = Gemma4RotaryEmbedding(config, "sliding_attention") + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.register_load_state_dict_post_hook(self._retie_lm_head_weight) + self.post_init() + if getattr(config, "tie_word_embeddings", True): + self.lm_head.weight = self.embed_tokens.weight + + @staticmethod + def _retie_lm_head_weight(module, incompatible_keys): + del incompatible_keys + if not hasattr(module, "config") or not hasattr(module, "lm_head"): + return + if getattr(module.config, "tie_word_embeddings", True): + module.lm_head.weight = module.embed_tokens.weight + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def get_decoder(self): + return self + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma4CausalLMOutput: + del kwargs + assert position_ids is not None, "position_ids must be provided" + + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + + assert inputs_embeds is not None + pos_emb_global = self.rotary_emb_global(inputs_embeds, position_ids) + pos_emb_local = self.rotary_emb_local(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + for decoder_layer in self.layers: + if decoder_layer.attention_type == "sliding_attention": + pos_emb = pos_emb_local + else: + pos_emb = pos_emb_global + hidden_states = decoder_layer(hidden_states, pos_emb) + + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + return Gemma4CausalLMOutput(logits=logits) + + +# --------------------------------------------------------------------------- +# Multimodal embedder stub (for weight loading) +# --------------------------------------------------------------------------- + + +class Gemma4MultimodalEmbedder(nn.Module): + """Projects multimodal hidden states into language-model space.""" + + def __init__(self, vision_config: Gemma4VisionConfig, text_config: Gemma4TextConfig): + super().__init__() + self.eps = vision_config.rms_norm_eps + self.embedding_projection = nn.Linear( + vision_config.hidden_size, text_config.hidden_size, bias=False + ) + self.embedding_pre_projection_norm = Gemma4RMSNorm( + vision_config.hidden_size, eps=self.eps, with_scale=False + ) + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + hidden_states = self.embedding_pre_projection_norm(inputs_embeds) + return self.embedding_projection(hidden_states) + + +# --------------------------------------------------------------------------- +# ConditionalGeneration wrapper (multimodal config, text-only forward) +# --------------------------------------------------------------------------- + + +class Gemma4PreTrainedModel(PreTrainedModel): + config_class = Gemma4Config + base_model_prefix = "model" + _no_split_modules = ["Gemma4TextDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Gemma4Model(Gemma4PreTrainedModel): + def __init__(self, config: Gemma4Config): + super().__init__(config) + self.language_model = Gemma4ForCausalLM(config.text_config) + self.vision_tower = ( + Gemma4VisionModel(config.vision_config) if config.vision_config is not None else None + ) + self.embed_vision = Gemma4MultimodalEmbedder(config.vision_config, config.text_config) + self._register_load_state_dict_pre_hook(self._remap_and_drop_weights) + self.post_init() + + @staticmethod + def _remap_and_drop_weights(state_dict, prefix, *_args, **_kwargs): + unsupported_prefixes = ( + prefix + "audio_tower.", + prefix + "embed_audio.", + ) + for key in list(state_dict): + if key.startswith(unsupported_prefixes): + state_dict.pop(key) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_position_ids: Optional[torch.LongTensor] = None, + ) -> ModelOutput: + if self.vision_tower is None: + raise ValueError("Gemma4 vision_tower is not initialized") + vision_outputs = self.vision_tower( + pixel_values=pixel_values, + pixel_position_ids=image_position_ids, + ) + last_hidden_state = vision_outputs.last_hidden_state + return ModelOutput( + last_hidden_state=last_hidden_state, + pooler_output=self.embed_vision(inputs_embeds=last_hidden_state), + ) + + def get_placeholder_mask( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.BoolTensor: + if input_ids is not None: + return input_ids == self.config.image_token_id + if inputs_embeds is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") + image_embedding = self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + return (inputs_embeds == image_embedding).all(-1) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_decoder(self): + return self.language_model + + def _split_image_features_by_item( + self, + image_features: torch.Tensor, + image_position_ids: Optional[torch.LongTensor], + ) -> List[torch.Tensor]: + if image_position_ids is None: + raise ValueError("image_position_ids is required to split Gemma4 image features") + + pooling_kernel_size = self.config.vision_config.pooling_kernel_size + image_feature_slices: List[torch.Tensor] = [] + feature_offset = 0 + + for item_position_ids in image_position_ids: + valid_positions = item_position_ids[(item_position_ids != -1).all(dim=-1)] + if valid_positions.numel() == 0: + num_feature_tokens = 0 + else: + max_x = int(valid_positions[:, 0].max().item()) + 1 + max_y = int(valid_positions[:, 1].max().item()) + 1 + pooled_width = (max_x + pooling_kernel_size - 1) // pooling_kernel_size + pooled_height = (max_y + pooling_kernel_size - 1) // pooling_kernel_size + num_feature_tokens = pooled_width * pooled_height + + image_feature_slices.append( + image_features[feature_offset : feature_offset + num_feature_tokens] + ) + feature_offset += num_feature_tokens + + if feature_offset != image_features.shape[0]: + raise ValueError( + "Gemma4 image feature splitting mismatch: " + f"consumed={feature_offset}, actual={image_features.shape[0]}" + ) + + return image_feature_slices + + def _select_request_chunk_image_features( + self, + req_input_pos: int, + req_seq_len: int, + req_mm_positions: Sequence[int], + req_mm_lengths: Sequence[int], + req_special_offsets: Sequence[int], + req_image_feature_slices: Sequence[torch.Tensor], + ) -> torch.Tensor: + chunk_end = req_input_pos + req_seq_len + mm_cumulative_offset = 0 + chunks: List[torch.Tensor] = [] + hidden_size = self.config.text_config.hidden_size + special_offsets_set = set(int(x) for x in req_special_offsets) + + for item_embeds, mm_start, mm_len in zip( + req_image_feature_slices, req_mm_positions, req_mm_lengths + ): + item_mm_offset = mm_cumulative_offset + item_mm_len = int(mm_len) + item_abs_start = int(mm_start) + item_abs_end = item_abs_start + item_mm_len + overlap_start = max(req_input_pos, item_abs_start) + overlap_end = min(chunk_end, item_abs_end) + + local_to_feature_idx: List[Optional[int]] = [] + feature_idx = 0 + for rel in range(item_mm_len): + if item_mm_offset + rel in special_offsets_set: + local_to_feature_idx.append(None) + else: + local_to_feature_idx.append(feature_idx) + feature_idx += 1 + + if feature_idx != item_embeds.shape[0]: + raise ValueError( + "Gemma4 multimodal embedding length mismatch for image item: " + f"expected={feature_idx}, actual={item_embeds.shape[0]}, " + f"mm_len={item_mm_len}, item_start={item_abs_start}, " + f"special_offsets={sorted(special_offsets_set)}" + ) + + if overlap_start < overlap_end: + selected_indices = [ + local_to_feature_idx[rel] + for rel in range(overlap_start - item_abs_start, overlap_end - item_abs_start) + if local_to_feature_idx[rel] is not None + ] + if selected_indices: + chunks.append(item_embeds[selected_indices]) + + mm_cumulative_offset += item_mm_len + + if chunks: + return torch.cat(chunks, dim=0) + + device = req_image_feature_slices[0].device + dtype = req_image_feature_slices[0].dtype + return torch.empty(0, hidden_size, device=device, dtype=dtype) + + def _build_chunked_image_features( + self, + image_features: torch.Tensor, + image_position_ids: Optional[torch.LongTensor], + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + input_pos: torch.Tensor, + mm_item_cu_seqlen: torch.Tensor, + mm_token_positions: torch.Tensor, + mm_token_lengths: torch.Tensor, + mm_special_offsets_cu_seqlen: Optional[torch.Tensor], + mm_special_offsets: Optional[torch.Tensor], + ) -> torch.Tensor: + num_prefill_seqs = int(batch_info_host[0].item()) + seq_len = cu_seqlen[1:] - cu_seqlen[:-1] + image_feature_slices = self._split_image_features_by_item( + image_features, image_position_ids + ) + img_idx = 0 + chunks: List[torch.Tensor] = [] + + for req_idx in range(num_prefill_seqs): + item_start = int(mm_item_cu_seqlen[req_idx].item()) + item_end = int(mm_item_cu_seqlen[req_idx + 1].item()) + req_mm_positions = mm_token_positions[item_start:item_end].tolist() + req_mm_lengths = mm_token_lengths[item_start:item_end].tolist() + req_num_images = item_end - item_start + req_image_feature_slices = image_feature_slices[img_idx : img_idx + req_num_images] + img_idx += req_num_images + + req_special_offsets: List[int] = [] + if mm_special_offsets_cu_seqlen is not None and mm_special_offsets is not None: + special_start = int(mm_special_offsets_cu_seqlen[req_idx].item()) + special_end = int(mm_special_offsets_cu_seqlen[req_idx + 1].item()) + req_special_offsets = mm_special_offsets[special_start:special_end].tolist() + + req_chunk_features = self._select_request_chunk_image_features( + req_input_pos=int(input_pos[req_idx].item()), + req_seq_len=int(seq_len[req_idx].item()), + req_mm_positions=req_mm_positions, + req_mm_lengths=req_mm_lengths, + req_special_offsets=req_special_offsets, + req_image_feature_slices=req_image_feature_slices, + ) + chunks.append(req_chunk_features) + + if chunks: + return torch.cat(chunks, dim=0) + + return image_features[:0] + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_position_ids: Optional[torch.LongTensor] = None, + batch_info_host: Optional[torch.Tensor] = None, + cu_seqlen: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + mm_item_cu_seqlen: Optional[torch.Tensor] = None, + mm_token_positions: Optional[torch.Tensor] = None, + mm_token_lengths: Optional[torch.Tensor] = None, + mm_special_offsets_cu_seqlen: Optional[torch.Tensor] = None, + mm_special_offsets: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma4CausalLMOutput: + assert position_ids is not None, "position_ids must be provided" + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + image_mask = self.get_placeholder_mask(input_ids=input_ids) + llm_input_ids = input_ids.clone() + llm_input_ids = torch.where( + image_mask, + torch.full_like(llm_input_ids, self.config.text_config.pad_token_id), + llm_input_ids, + ) + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + else: + image_mask = self.get_placeholder_mask(inputs_embeds=inputs_embeds) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + image_position_ids=image_position_ids, + ).pooler_output + has_chunk_mm_layout = ( + batch_info_host is not None + and cu_seqlen is not None + and input_pos is not None + and mm_item_cu_seqlen is not None + and mm_token_positions is not None + and mm_token_lengths is not None + and mm_item_cu_seqlen.numel() > 0 + and int(mm_item_cu_seqlen[-1].item()) > 0 + and mm_token_positions.numel() > 0 + and mm_token_lengths.numel() > 0 + ) + if has_chunk_mm_layout: + image_features = self._build_chunked_image_features( + image_features=image_features, + image_position_ids=image_position_ids, + batch_info_host=batch_info_host, + cu_seqlen=cu_seqlen, + input_pos=input_pos, + mm_item_cu_seqlen=mm_item_cu_seqlen, + mm_token_positions=mm_token_positions, + mm_token_lengths=mm_token_lengths, + mm_special_offsets_cu_seqlen=mm_special_offsets_cu_seqlen, + mm_special_offsets=mm_special_offsets, + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) + # _dbg.debug( + # "[Gemma4Model EMBED INJECTION] image_mask sum=%d | image_mask shape=%s | " + # "image_features shape=%s | image_features mean=%.6f std=%.6f | " + # "inputs_embeds mean=%.6f std=%.6f (before scatter) | " + # "pixel_values shape=%s mean=%.6f", + # image_mask.sum().item(), + # tuple(image_mask.shape), + # tuple(image_features.shape), + # image_features.mean().item(), + # image_features.std().item(), + # inputs_embeds.mean().item(), + # inputs_embeds.std().item(), + # tuple(pixel_values.shape), + # pixel_values.mean().item(), + # ) + if inputs_embeds[expanded_image_mask].numel() != image_features.numel(): + placeholder_token_count = int(image_mask.sum().item()) + feature_token_count = int(image_features.shape[0]) + raise ValueError( + "Image features and image placeholder tokens do not match: " + f"placeholder_tokens={placeholder_token_count}, " + f"feature_tokens={feature_token_count}, " + f"pixel_values_shape={tuple(pixel_values.shape)}, " + f"image_position_ids_shape=" + f"{tuple(image_position_ids.shape) if image_position_ids is not None else None}, " + f"input_pos={input_pos.tolist() if input_pos is not None else None}, " + f"cu_seqlen={cu_seqlen.tolist() if cu_seqlen is not None else None}" + ) + inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_features) + # _dbg.debug( + # "[Gemma4Model EMBED INJECTION] inputs_embeds mean=%.6f std=%.6f (after scatter)", + # inputs_embeds.mean().item(), + # inputs_embeds.std().item(), + # ) + + language_model_kwargs = dict(kwargs) + if batch_info_host is not None: + language_model_kwargs["batch_info_host"] = batch_info_host + if cu_seqlen is not None: + language_model_kwargs["cu_seqlen"] = cu_seqlen + if input_pos is not None: + language_model_kwargs["input_pos"] = input_pos + + return Gemma4ForConditionalGeneration._call_language_model( + self.language_model, + input_ids=None, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **language_model_kwargs, + ) + + +class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin): + config_class = Gemma4Config + _tied_weights_keys = ["model.language_model.lm_head.weight"] + + def __init__(self, config: Gemma4Config, **kwargs): + del kwargs + super().__init__(config) + self.model = Gemma4Model(config) + self._register_load_state_dict_pre_hook(self._remap_lm_head_weight) + self.post_init() + + @staticmethod + def _remap_lm_head_weight(state_dict, prefix, *_args, **_kwargs): + """Remap lm_head into language_model so the export info exports it.""" + old_key = prefix + "lm_head.weight" + new_key = prefix + "model.language_model.lm_head.weight" + if old_key in state_dict and new_key not in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.model.language_model.lm_head + + def set_output_embeddings(self, value): + self.model.language_model.lm_head = value + + def get_decoder(self): + return self.model.get_decoder() + + @staticmethod + def _call_language_model( + language_model: nn.Module, + input_ids: Optional[torch.LongTensor], + position_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor], + **kwargs, + ): + """Call eager modules and exported FX graphs using their expected input structure.""" + if not isinstance(language_model, GraphModule): + model_kwargs = dict(kwargs) + model_kwargs["position_ids"] = position_ids + if inputs_embeds is not None: + model_kwargs["inputs_embeds"] = inputs_embeds + else: + model_kwargs["input_ids"] = input_ids + return language_model(**model_kwargs) + + available_args = { + "input_ids": input_ids, + "position_ids": position_ids, + "inputs_embeds": inputs_embeds, + **kwargs, + } + placeholder_names = [ + node.target for node in language_model.graph.nodes if node.op == "placeholder" + ] + in_spec = getattr(language_model, "_in_spec", None) + if in_spec is not None and in_spec.type is tuple and in_spec.num_children == 2: + pos_spec = in_spec.child(0) + kw_spec = in_spec.child(1) + num_positional = pos_spec.num_children if pos_spec.type is tuple else 0 + positional_names = placeholder_names[:num_positional] + keyword_names = list(kw_spec.context) if kw_spec.type is dict else [] + + positional_args = [available_args.get(name) for name in positional_names] + keyword_args = {name: available_args.get(name) for name in keyword_names} + return language_model(*positional_args, **keyword_args) + + positional_args = [available_args.get(name) for name in placeholder_names] + return language_model(*positional_args) + + @staticmethod + def _blob_ids_from_spans( + kv_len: int, + mm_positions: torch.Tensor, + mm_lengths: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + """Build per-position blob IDs for a single sequence from span metadata. + + Spans use absolute request-local coordinates, so this works correctly + for any chunk window during chunked prefill. + + Returns a 1D ``[kv_len]`` tensor where text positions are 0 and media + positions have blob IDs 1, 2, ... + """ + blob_ids = torch.zeros(kv_len, dtype=torch.int64, device=device) + for i in range(mm_positions.shape[0]): + start = int(mm_positions[i].item()) + length = int(mm_lengths[i].item()) + end = min(start + length, kv_len) + if start < kv_len: + blob_ids[start:end] = i + 1 + return blob_ids + + @staticmethod + def _build_attention_mask( + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + input_pos: torch.Tensor, + mm_positions: torch.Tensor, + mm_lengths: torch.Tensor, + mm_cu_seqlen: torch.Tensor, + ) -> torch.Tensor: + """Build per-sequence attention masks from span metadata + batch geometry. + + Returns a ``[num_prefill, 1, max_q, max_kv]`` bool mask that is causal + for text tokens and bidirectional within contiguous media blobs. + """ + num_prefill = int(batch_info_host[0].item()) + device = mm_positions.device + + masks = [] + max_q = 0 + max_kv = 0 + + for i in range(num_prefill): + q_start = int(input_pos[i].item()) + q_len = int(cu_seqlen[i + 1].item()) - int(cu_seqlen[i].item()) + kv_len = q_start + q_len + + span_start = int(mm_cu_seqlen[i].item()) + span_end = int(mm_cu_seqlen[i + 1].item()) + seq_positions = mm_positions[span_start:span_end] + seq_lengths = mm_lengths[span_start:span_end] + + blob_ids = Gemma4ForConditionalGeneration._blob_ids_from_spans( + kv_len, seq_positions, seq_lengths, device + ) + + q_blob = blob_ids[q_start : q_start + q_len].unsqueeze(1) # [Q, 1] + kv_blob = blob_ids.unsqueeze(0) # [1, KV] + bidirectional = (q_blob == kv_blob) & (q_blob != 0) # [Q, KV] + + q_pos = torch.arange(q_start, q_start + q_len, device=device).unsqueeze(1) + kv_pos = torch.arange(kv_len, device=device).unsqueeze(0) + causal = kv_pos <= q_pos # [Q, KV] + + mask = (causal | bidirectional).unsqueeze(0) # [1, Q, KV] + masks.append(mask) + max_q = max(max_q, q_len) + max_kv = max(max_kv, kv_len) + + padded = [] + for mask in masks: + _, q, kv = mask.shape + padded.append(F.pad(mask, (0, max_kv - kv, 0, max_q - q), value=False)) + return torch.stack(padded, dim=0) # [num_prefill, 1, max_q, max_kv] + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Gemma4ConditionalOutput: + # Build attention mask from span metadata (mm_token_positions/lengths) + # provided by _store_prefill_multimodal_metadata in the AD executor. + # Pass None during decode / text-only / warmup so the attention backend + # uses its fast causal kernel instead of the per-sequence fallback. + kwargs.pop("token_type_ids", None) + + batch_info_host = kwargs.pop("batch_info_host", None) + mm_positions = kwargs.pop("mm_token_positions", None) + mm_lengths = kwargs.pop("mm_token_lengths", None) + mm_cu_seqlen = kwargs.pop("mm_item_cu_seqlen", None) + mm_special_offsets_cu_seqlen = kwargs.pop("mm_special_offsets_cu_seqlen", None) + mm_special_offsets = kwargs.pop("mm_special_offsets", None) + + for key in ( + "mm_item_types", + "mm_chunk_flat_start", + "mm_chunk_count", + ): + kwargs.pop(key, None) + + has_media = ( + mm_positions is not None and mm_positions.numel() > 0 and batch_info_host is not None + ) + + cu_seqlen = kwargs.pop("cu_seqlen", None) + if cu_seqlen is None: + cu_seqlen = kwargs.pop("cu_seqlen_host", None) + input_pos = kwargs.pop("input_pos", None) + seq_len_with_cache = kwargs.pop("seq_len_with_cache", None) + if seq_len_with_cache is None: + seq_len_with_cache = kwargs.pop("seq_len_with_cache_host", None) + seq_len = kwargs.pop("seq_len", None) + + if has_media: + if input_pos is None: + if seq_len is None and cu_seqlen is not None: + seq_len = cu_seqlen[1:] - cu_seqlen[:-1] + if seq_len_with_cache is not None and seq_len is not None: + input_pos = seq_len_with_cache.to(seq_len.device) - seq_len + _built_mask = self._build_attention_mask( + batch_info_host, + cu_seqlen, + input_pos, + mm_positions, + mm_lengths, + mm_cu_seqlen, + ) + kwargs["custom_attn_mask"] = _built_mask + else: + kwargs["custom_attn_mask"] = None + + model_kwargs = dict(kwargs) + if batch_info_host is not None: + model_kwargs["batch_info_host"] = batch_info_host + if cu_seqlen is not None: + model_kwargs["cu_seqlen"] = cu_seqlen + if input_pos is not None: + model_kwargs["input_pos"] = input_pos + if seq_len_with_cache is not None: + model_kwargs["seq_len_with_cache"] = seq_len_with_cache + if mm_cu_seqlen is not None: + model_kwargs["mm_item_cu_seqlen"] = mm_cu_seqlen + if mm_positions is not None: + model_kwargs["mm_token_positions"] = mm_positions + if mm_lengths is not None: + model_kwargs["mm_token_lengths"] = mm_lengths + if mm_special_offsets_cu_seqlen is not None: + model_kwargs["mm_special_offsets_cu_seqlen"] = mm_special_offsets_cu_seqlen + if mm_special_offsets is not None: + model_kwargs["mm_special_offsets"] = mm_special_offsets + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + return Gemma4ConditionalOutput(logits=outputs.logits) + + +# --------------------------------------------------------------------------- +# Wrapper tokenizer for Gemma 4 +# +# The upstream HF checkpoint ships ``extra_special_tokens`` as a *list* in +# tokenizer_config.json, which is incompatible with transformers <5.3. +# This thin wrapper loads the tokenizer assets directly, bypassing the +# problematic codepath. +# --------------------------------------------------------------------------- + +_TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +_PROCESSOR_CONFIG_FILE = "processor_config.json" +_CHAT_TEMPLATE_FILE = "chat_template.jinja" +_TOKENIZER_FILE = "tokenizer.json" +_SUPPORTED_GEMMA4_SOFT_TOKENS = (70, 140, 280, 560, 1120) + + +def get_aspect_ratio_preserving_size( + height: int, + width: int, + patch_size: int, + max_patches: int, + pooling_kernel_size: int, +) -> Tuple[int, int]: + """Resize within the Gemma4 patch budget while preserving aspect ratio.""" + total_px = height * width + target_px = max_patches * (patch_size**2) + factor = (target_px / total_px) ** 0.5 + ideal_height = factor * height + ideal_width = factor * width + side_multiple = pooling_kernel_size * patch_size + + target_height = int(ideal_height // side_multiple) * side_multiple + target_width = int(ideal_width // side_multiple) * side_multiple + + if target_height == 0 and target_width == 0: + raise ValueError( + "Attempting to resize to a 0 x 0 image. " + f"Resized height should be divisible by `pooling_kernel_size * patch_size`={side_multiple}." + ) + + max_side_length = (max_patches // pooling_kernel_size**2) * side_multiple + if target_height == 0: + target_height = side_multiple + target_width = min((width // height) * side_multiple, max_side_length) + elif target_width == 0: + target_width = side_multiple + target_height = min((height // width) * side_multiple, max_side_length) + + if target_height * target_width > target_px: + raise ValueError( + f"Resizing [{height}x{width}] to [{target_height}x{target_width}] " + f"exceeds {max_patches} patches with patch_size {patch_size}." + ) + + return target_height, target_width + + +class ADGemma4Tokenizer(PreTrainedTokenizerFast): + """Wrapper that loads the upstream Gemma 4 tokenizer on current transformers.""" + + vocab_files_names = {"tokenizer_file": _TOKENIZER_FILE} + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = None + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + *inputs, + **kwargs, + ) -> "ADGemma4Tokenizer": + del inputs + for k in ("_from_auto", "_commit_hash", "trust_remote_code"): + kwargs.pop(k, None) + + config_path = cached_file(pretrained_model_name_or_path, _TOKENIZER_CONFIG_FILE, **kwargs) + assert config_path is not None + config = json.loads(Path(config_path).read_text()) + + tokenizer_file = cached_file(pretrained_model_name_or_path, _TOKENIZER_FILE, **kwargs) + assert tokenizer_file is not None + + # ``extra_special_tokens`` is a list in the upstream config; map it to + # the standard ``additional_special_tokens`` field. + extra = config.get("extra_special_tokens", []) + if isinstance(extra, list): + additional = extra + else: + additional = list(extra.keys()) if isinstance(extra, dict) else [] + + tokenizer = cls( + tokenizer_object=Tokenizer.from_file(tokenizer_file), + name_or_path=str(pretrained_model_name_or_path), + bos_token=config.get("bos_token"), + eos_token=config.get("eos_token"), + unk_token=config.get("unk_token"), + pad_token=config.get("pad_token"), + additional_special_tokens=additional, + clean_up_tokenization_spaces=config.get("clean_up_tokenization_spaces", False), + model_max_length=config.get("model_max_length"), + padding_side=config.get("padding_side", "left"), + truncation_side=config.get("truncation_side", "left"), + ) + + tokenizer.image_token = config.get("image_token", "<|image|>") + tokenizer.boi_token = config.get("boi_token", "<|image>") + tokenizer.eoi_token = config.get("eoi_token", "") + tokenizer.image_token_id = tokenizer.convert_tokens_to_ids(tokenizer.image_token) + tokenizer.boi_token_id = tokenizer.convert_tokens_to_ids(tokenizer.boi_token) + tokenizer.eoi_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eoi_token) + + template_path = cached_file( + pretrained_model_name_or_path, + _CHAT_TEMPLATE_FILE, + _raise_exceptions_for_missing_entries=False, + **kwargs, + ) + if template_path is not None: + tokenizer.chat_template = Path(template_path).read_text() + + return tokenizer + + +class ADGemma4ImageProcessor: + """Minimal Gemma4 image processor compatible with the local transformers version.""" + + def __init__( + self, + *, + patch_size: int = 16, + max_soft_tokens: int = 280, + pooling_kernel_size: int = 3, + do_convert_rgb: bool = True, + do_resize: bool = True, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = False, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + resample: int = Image.BICUBIC, + ) -> None: + if max_soft_tokens not in _SUPPORTED_GEMMA4_SOFT_TOKENS: + raise ValueError( + f"`max_soft_tokens` must be one of {_SUPPORTED_GEMMA4_SOFT_TOKENS}, got {max_soft_tokens}." + ) + self.patch_size = patch_size + self.max_soft_tokens = max_soft_tokens + self.pooling_kernel_size = pooling_kernel_size + self.do_convert_rgb = do_convert_rgb + self.do_resize = do_resize + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean or [0.0, 0.0, 0.0] + self.image_std = image_std or [1.0, 1.0, 1.0] + self.resample = resample + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + **kwargs, + ) -> "ADGemma4ImageProcessor": + for key in ("_from_auto", "_commit_hash", "trust_remote_code"): + kwargs.pop(key, None) + + config_path = cached_file(pretrained_model_name_or_path, _PROCESSOR_CONFIG_FILE, **kwargs) + assert config_path is not None + processor_config = json.loads(Path(config_path).read_text()) + image_config = processor_config.get("image_processor", {}) + allowed_keys = { + "patch_size", + "max_soft_tokens", + "pooling_kernel_size", + "do_convert_rgb", + "do_resize", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "resample", + } + filtered_config = {key: value for key, value in image_config.items() if key in allowed_keys} + return cls(**filtered_config) + + @staticmethod + def fetch_images(images): + return images + + @staticmethod + def _to_tensor(image, do_convert_rgb: bool) -> torch.Tensor: + if isinstance(image, (str, Path)): + image = Image.open(image) + + if isinstance(image, Image.Image): + if do_convert_rgb: + image = image.convert("RGB") + array = np.array(image, copy=True) + tensor = torch.from_numpy(array) + if tensor.ndim == 2: + tensor = tensor.unsqueeze(-1) + return tensor.permute(2, 0, 1).contiguous().to(torch.float32) + + if torch.is_tensor(image): + tensor = image.detach().cpu() + if tensor.ndim != 3: + raise ValueError(f"Expected a 3D image tensor, got shape {tuple(tensor.shape)}") + if tensor.shape[0] in (1, 3): + return tensor.to(torch.float32) + if tensor.shape[-1] in (1, 3): + return tensor.permute(2, 0, 1).contiguous().to(torch.float32) + raise ValueError(f"Unsupported tensor image shape {tuple(tensor.shape)}") + + array = np.asarray(image) + if array.ndim == 2: + array = array[..., None] + if array.ndim != 3: + raise ValueError(f"Unsupported image with shape {array.shape}") + tensor = torch.from_numpy(array) + if tensor.shape[0] not in (1, 3): + tensor = tensor.permute(2, 0, 1) + return tensor.contiguous().to(torch.float32) + + @staticmethod + def _convert_image_to_patches(image: torch.Tensor, patch_size: int) -> torch.Tensor: + channels, image_height, image_width = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched = image.reshape( + channels, + num_patches_height, + patch_size, + num_patches_width, + patch_size, + ) + patched = patched.permute(1, 3, 2, 4, 0) + return patched.reshape(num_patches_height * num_patches_width, -1) + + @staticmethod + def _pad_along_first_dim( + image: torch.Tensor, positions: torch.Tensor, target_length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + current_length = image.shape[0] + padding_length = target_length - current_length + if padding_length <= 0: + return image, positions + image_padding = torch.zeros( + (padding_length, image.shape[1]), dtype=image.dtype, device=image.device + ) + pos_padding = torch.full( + (padding_length, 2), -1, dtype=positions.dtype, device=positions.device + ) + return torch.cat([image, image_padding], dim=0), torch.cat([positions, pos_padding], dim=0) + + def _aspect_ratio_preserving_resize(self, image: torch.Tensor) -> torch.Tensor: + height, width = image.shape[-2], image.shape[-1] + max_patches = self.max_soft_tokens * self.pooling_kernel_size**2 + target_height, target_width = get_aspect_ratio_preserving_size( + height=height, + width=width, + patch_size=self.patch_size, + max_patches=max_patches, + pooling_kernel_size=self.pooling_kernel_size, + ) + if target_height == height and target_width == width: + return image + return F.interpolate( + image.unsqueeze(0), + size=(target_height, target_width), + mode="bicubic", + align_corners=False, + antialias=True, + ).squeeze(0) + + def __call__( + self, + images, + *, + do_convert_rgb: Optional[bool] = None, + do_resize: Optional[bool] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + return_tensors: Optional[str] = None, + **_kwargs, + ) -> dict[str, Any]: + del return_tensors + do_convert_rgb = self.do_convert_rgb if do_convert_rgb is None else do_convert_rgb + do_resize = self.do_resize if do_resize is None else do_resize + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else image_mean + image_std = self.image_std if image_std is None else image_std + + if not isinstance(images, list): + images = [images] + + pixel_values = [] + position_ids = [] + num_soft_tokens_per_image = [] + mean = torch.tensor(image_mean, dtype=torch.float32).view(-1, 1, 1) + std = torch.tensor(image_std, dtype=torch.float32).view(-1, 1, 1) + target_patches = self.max_soft_tokens * self.pooling_kernel_size**2 + + for image in images: + tensor = self._to_tensor(image, do_convert_rgb=do_convert_rgb) + if do_resize: + tensor = self._aspect_ratio_preserving_resize(tensor) + if do_rescale: + tensor = tensor * rescale_factor + if do_normalize: + tensor = (tensor - mean) / std + + patches = self._convert_image_to_patches(tensor, self.patch_size) + num_soft_tokens_per_image.append(patches.shape[0] // self.pooling_kernel_size**2) + + patch_height = tensor.shape[-2] // self.patch_size + patch_width = tensor.shape[-1] // self.patch_size + grid_y, grid_x = torch.meshgrid( + torch.arange(patch_height, dtype=torch.int64), + torch.arange(patch_width, dtype=torch.int64), + indexing="ij", + ) + positions = torch.stack([grid_x, grid_y], dim=-1).reshape(patches.shape[0], 2) + + pixel_values.append(patches) + position_ids.append(positions) + pixel_values_padded = [] + position_ids_padded = [] + for patches, positions in zip(pixel_values, position_ids): + padded_patches, padded_positions = self._pad_along_first_dim( + patches, positions, target_patches + ) + pixel_values_padded.append(padded_patches) + position_ids_padded.append(padded_positions) + + return { + "pixel_values": torch.stack(pixel_values_padded, dim=0), + "image_position_ids": torch.stack(position_ids_padded, dim=0), + "num_soft_tokens_per_image": num_soft_tokens_per_image, + } + + +class ADGemma4Processor: + """Minimal Gemma4 multimodal processor for image-text requests.""" + + def __init__( + self, + *, + tokenizer: ADGemma4Tokenizer, + image_processor: ADGemma4ImageProcessor, + image_seq_length: int = 280, + ) -> None: + self.tokenizer = tokenizer + self.image_processor = image_processor + self.image_seq_length = image_seq_length + self.image_token = tokenizer.image_token + self.boi_token = tokenizer.boi_token + self.eoi_token = tokenizer.eoi_token + self.image_token_id = tokenizer.image_token_id + self.boi_token_id = tokenizer.boi_token_id + self.eoi_token_id = tokenizer.eoi_token_id + self.chat_template = getattr(tokenizer, "chat_template", None) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + **kwargs, + ) -> "ADGemma4Processor": + tokenizer = ADGemma4Tokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + image_processor = ADGemma4ImageProcessor.from_pretrained( + pretrained_model_name_or_path, **kwargs + ) + config_path = cached_file(pretrained_model_name_or_path, _PROCESSOR_CONFIG_FILE, **kwargs) + assert config_path is not None + processor_config = json.loads(Path(config_path).read_text()) + return cls( + tokenizer=tokenizer, + image_processor=image_processor, + image_seq_length=processor_config.get( + "image_seq_length", image_processor.max_soft_tokens + ), + ) + + @staticmethod + def _ensure_text_list(text) -> List[str]: + if text is None: + return [] + if isinstance(text, str): + return [text] + return list(text) + + @staticmethod + def _normalize_batched_images(images) -> List[List[Any]]: + if images is None: + return [] + if not isinstance(images, list): + return [[images]] + if not images: + return [] + if isinstance(images[0], list): + return [list(batch) for batch in images] + return [list(images)] + + def _expand_image_placeholders( + self, text: List[str], batched_images: List[List[Any]], image_inputs: dict[str, Any] + ) -> List[str]: + num_soft_tokens = image_inputs.pop("num_soft_tokens_per_image") + if not text: + text = [" ".join([self.image_token] * len(images)) for images in batched_images] + if len(text) != len(batched_images): + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." + ) + + replacements = [ + f"{self.boi_token}{self.image_token * num_tokens}{self.eoi_token}" + for num_tokens in num_soft_tokens + ] + replacements_iter = iter(replacements) + pattern = re.escape(self.image_token) + return [re.sub(pattern, lambda _match: next(replacements_iter), prompt) for prompt in text] + + def _build_token_type_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + token_type_ids = torch.zeros_like(input_ids) + next_blob_id = 1 + for batch_idx in range(input_ids.shape[0]): + in_blob = False + current_blob_id = 0 + for token_idx, token in enumerate(input_ids[batch_idx].tolist()): + if token == self.boi_token_id: + in_blob = True + current_blob_id = next_blob_id + next_blob_id += 1 + token_type_ids[batch_idx, token_idx] = current_blob_id + elif token == self.eoi_token_id and in_blob: + token_type_ids[batch_idx, token_idx] = current_blob_id + in_blob = False + current_blob_id = 0 + elif in_blob: + token_type_ids[batch_idx, token_idx] = current_blob_id + return token_type_ids + + @staticmethod + def _render_messages(messages) -> Tuple[List[str], List[Any]]: + """Extract text + images from a single conversation. + + Returns ``(rendered_text, images)`` where ``rendered_text`` has an + ``<|image|>`` placeholder for each image. + """ + parts: List[str] = [] + images: List[Any] = [] + for message in messages: + content = message.get("content", "") + if isinstance(content, str): + parts.append(content) + continue + for item in content: + item_type = item.get("type") + if item_type == "text": + parts.append(item.get("text", "")) + elif item_type == "image": + parts.append("<|image|>") + images.append(item.get("image")) + return " ".join(part for part in parts if part), images + + def apply_chat_template( + self, + messages, + *, + tokenize: bool = False, + return_dict: bool = False, + return_tensors: Optional[str] = None, + add_generation_prompt: bool = True, + **kwargs, + ): + is_batched = messages and isinstance(messages[0], list) + batched_messages = messages if is_batched else [messages] + + rendered_texts: List[str] = [] + batched_images: List[List[Any]] = [] + has_chat_template = bool(self.chat_template) + + for conversation in batched_messages: + if has_chat_template: + # Use the Jinja chat template for proper turn formatting. + # Strip image content items so the template only sees text; + # we insert image placeholders ourselves afterwards. + text_only_conv = [] + conv_images: List[Any] = [] + for message in conversation: + content = message.get("content", "") + if isinstance(content, str): + text_only_conv.append(message) + continue + text_parts: List[str] = [] + for item in content: + if item.get("type") == "text": + text_parts.append(item.get("text", "")) + elif item.get("type") == "image": + text_parts.append(self.image_token) + conv_images.append(item.get("image")) + text_only_conv.append({**message, "content": " ".join(text_parts)}) + rendered_texts.append( + self.tokenizer.apply_chat_template( + text_only_conv, + chat_template=self.chat_template, + add_generation_prompt=add_generation_prompt, + tokenize=False, + ) + ) + batched_images.append(conv_images) + else: + # No chat template — render messages directly. + text, conv_images = self._render_messages(conversation) + rendered_texts.append(text) + batched_images.append(conv_images) + + if not tokenize: + return rendered_texts if is_batched else rendered_texts[0] + + result = self( + text=rendered_texts, + images=batched_images, + return_dict=True, + return_tensors=return_tensors, + **kwargs, + ) + if return_dict: + return result + return result["input_ids"] + + def __call__( + self, + *, + images=None, + text=None, + return_dict: bool = True, + return_tensors: Optional[str] = None, + return_attention_mask: bool = False, + **kwargs, + ): + del return_dict + batched_images = self._normalize_batched_images(images) + flat_images = [image for batch in batched_images for image in batch] + text_list = self._ensure_text_list(text) + + image_inputs = {} + if flat_images: + image_inputs = self.image_processor( + flat_images, + return_tensors=return_tensors, + **kwargs, + ) + text_list = self._expand_image_placeholders(text_list, batched_images, image_inputs) + + tokenizer_kwargs = dict(kwargs) + tokenizer_kwargs.pop("do_rescale", None) + tokenizer_kwargs.pop("do_convert_rgb", None) + tokenizer_kwargs.pop("rescale_factor", None) + tokenizer_kwargs.pop("do_resize", None) + tokenizer_kwargs.pop("do_normalize", None) + tokenizer_kwargs["return_tensors"] = return_tensors + tokenizer_kwargs["return_attention_mask"] = return_attention_mask + text_inputs = self.tokenizer(text=text_list, **tokenizer_kwargs) + text_inputs["token_type_ids"] = self._build_token_type_ids(text_inputs["input_ids"]) + return {**text_inputs, **image_inputs} + + +class Gemma4ADInputProcessor: + """Input processor that ensures ``multimodal_input`` is set. + + For multimodal requests, ``multimodal_input`` is computed with image token + positions and lengths so the AD executor can stage span metadata + (``mm_token_positions``, ``mm_token_lengths``, ``mm_item_cu_seqlen``) for + the eager wrapper to build per-sequence attention masks. + """ + + def __init__(self, base, image_token_id: int, boi_token_id: int, eoi_token_id: int): + self.base = base + self.image_token_id = image_token_id + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id + + def __getattr__(self, name): + return getattr(self.base, name) + + def get_num_tokens_per_image(self, *, image: Image.Image, **kwargs) -> int: + processor = getattr(self, "processor", None) + if processor is None: + raise AttributeError( + "Gemma4ADInputProcessor requires a processor for image token sizing." + ) + + image_inputs = processor.image_processor([image], **kwargs) + num_soft_tokens = int(image_inputs["num_soft_tokens_per_image"][0].item()) + return num_soft_tokens + 2 # include BOI + EOI + + def get_vocab_size(self) -> Optional[int]: + tokenizer = getattr(self, "tokenizer", None) + if tokenizer is not None and hasattr(tokenizer, "vocab_size"): + return int(tokenizer.vocab_size) + wrapped_tokenizer = getattr(tokenizer, "tokenizer", None) + if wrapped_tokenizer is not None and hasattr(wrapped_tokenizer, "vocab_size"): + return int(wrapped_tokenizer.vocab_size) + processor = getattr(self, "processor", None) + processor_tokenizer = getattr(processor, "tokenizer", None) + if processor_tokenizer is not None and hasattr(processor_tokenizer, "vocab_size"): + return int(processor_tokenizer.vocab_size) + return None + + def get_mm_token_ids(self) -> torch.Tensor: + return torch.tensor([self.image_token_id], dtype=torch.int32) + + def get_mm_special_token_ids(self) -> torch.Tensor: + return torch.tensor(sorted({self.boi_token_id, self.eoi_token_id}), dtype=torch.int32) + + def _find_image_spans(self, token_ids: List[int]) -> Tuple[List[int], List[int]]: + """Find start positions and lengths of each image blob (boi…eoi) span.""" + positions: List[int] = [] + lengths: List[int] = [] + i = 0 + while i < len(token_ids): + if token_ids[i] == self.boi_token_id: + start = i + # Scan to the matching eoi token + j = i + 1 + while j < len(token_ids) and token_ids[j] != self.eoi_token_id: + j += 1 + end = j + 1 if j < len(token_ids) else j # include eoi + positions.append(start) + lengths.append(end - start) + i = end + else: + i += 1 + return positions, lengths + + def __call__(self, inputs, sampling_params): + token_ids, extra = self.base(inputs, sampling_params) + if extra is None: + extra = {} + + # Remove token_type_ids if the base processor added it — mask is now + # built from span metadata in the eager wrapper. + mm_data = extra.get("multimodal_data") + if mm_data is not None: + mm_data.pop("token_type_ids", None) + + # Compute multimodal_input so the executor knows where image spans are. + if "multimodal_input" not in extra: + positions, lengths = self._find_image_spans(token_ids) + if positions: + from tensorrt_llm.inputs.multimodal import MultimodalInput + + # Dummy hashes — KV-cache reuse for images is not yet supported. + dummy_hashes = [[0] * 8 for _ in positions] + extra["multimodal_input"] = MultimodalInput.from_components( + mm_hashes=dummy_hashes, + mm_positions=positions, + mm_lengths=lengths, + ) + multimodal_data = extra.get("multimodal_data", {}) + special_offsets: List[int] = [] + mm_offset = 0 + for length in lengths: + special_offsets.extend([mm_offset, mm_offset + length - 1]) + mm_offset += length + multimodal_data["layout_metadata"] = { + "special_token_offsets": torch.tensor(special_offsets, dtype=torch.int32), + "item_types": torch.zeros(len(positions), dtype=torch.int32), + } + extra["multimodal_data"] = multimodal_data + + return token_ids, extra + + +@ModelFactoryRegistry.register("Gemma4ForConditionalGeneration") +class Gemma4ForConditionalGenerationFactory(AutoModelForImageTextToTextFactory): + """Factory for Gemma 4 VLM with custom attention mask support.""" + + def init_tokenizer(self) -> Optional[Any]: + if self.tokenizer is None: + return None + return ADGemma4Tokenizer.from_pretrained(self.tokenizer) + + def init_processor(self) -> Optional[Any]: + """Return the local Gemma4 multimodal processor.""" + if self.tokenizer is None: + return None + return ADGemma4Processor.from_pretrained(self.tokenizer) + + def init_input_processor(self, base): + processor = self.init_processor() + image_token_id = getattr(processor, "image_token_id", 258_880) + boi_token_id = getattr(processor, "boi_token_id", 255_999) + eoi_token_id = getattr(processor, "eoi_token_id", 258_882) + return Gemma4ADInputProcessor( + base, + image_token_id=image_token_id, + boi_token_id=boi_token_id, + eoi_token_id=eoi_token_id, + ) + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + +AutoModelForCausalLMFactory.register_custom_model_cls("Gemma4TextConfig", Gemma4ForCausalLM) +Gemma4ForConditionalGenerationFactory.register_custom_model_cls( + "Gemma4Config", Gemma4ForConditionalGeneration +) +MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata( + "gemma4", + MultimodalPlaceholderMetadata( + placeholder_map={"image": "<|image|>"}, + content_format=ContentFormat.STRING, + ), +) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/__init__.py b/tensorrt_llm/_torch/auto_deploy/transform/__init__.py index 79658227043..7c655e2a49c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/__init__.py @@ -1,4 +1,7 @@ """AutoDeploy's modular graph transform + inference optimizer pipeline.""" -from . import library # ensure all transforms are registered +from . import ( + attention_mask_providers, # noqa: F401 + library, # ensure all transforms are registered +) from .interface import * diff --git a/tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py b/tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py new file mode 100644 index 00000000000..a07e501f555 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transform-time registry for backend-native attention mask providers.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Tuple + +from torch.fx import GraphModule, Node + +from ..models.factory import ModelFactory +from ..shim.interface import CachedSequenceInterface +from ..utils._graph import _NO_VAL, add_graph_input +from .interface import SharedConfig + +AttentionMaskProviderFn = Callable[["AttentionMaskProviderContext", Node], Optional[Node]] + + +def infer_model_type(factory: Optional[ModelFactory]) -> Optional[str]: + """Best-effort inference of the source model type for provider lookup.""" + if factory is None: + return None + + model_type = getattr(factory, "model_type", None) + if isinstance(model_type, str): + return model_type + + get_model_type = getattr(factory, "get_model_type", None) + if callable(get_model_type): + inferred = get_model_type() + if isinstance(inferred, str): + return inferred + + get_model_config = getattr(factory, "_get_model_config", None) + if callable(get_model_config): + try: + model_config, _unused_kwargs = get_model_config() + except Exception: + return None + inferred = getattr(model_config, "model_type", None) + if isinstance(inferred, str): + return inferred + + return None + + +@dataclass +class AttentionMaskProviderContext: + """Context object shared across provider invocations within one transform pass.""" + + gm: GraphModule + cm: Optional[CachedSequenceInterface] + factory: Optional[ModelFactory] + shared_config: SharedConfig + model_type: str + backend: str + cache: Dict[str, Node] = field(default_factory=dict) + + def add_or_retrieve_input( + self, + name: str, + *, + activate_arg: bool = True, + val: Any = _NO_VAL, + ) -> Node: + """Add or retrieve a graph placeholder for a provider input.""" + input_nodes = self.gm.graph.find_nodes(op="placeholder", target=name) + if len(input_nodes) == 1: + return input_nodes[0] + if len(input_nodes) > 1: + raise ValueError(f"Expected exactly one input node for {name=}, got {input_nodes=}") + + if activate_arg: + if self.cm is None: + raise ValueError( + f"Cannot activate managed arg {name!r} without CachedSequenceInterface." + ) + self.cm.info.activate_arg(name) + + return add_graph_input(self.gm, name=name, val=val) + + def register_default_extra_arg( + self, + name: str, + factory: Callable, + ) -> None: + """Register a callable default factory for ``name`` in the SequenceInfo. + + ``factory`` receives the ``SequenceInfo`` and returns the default value + for ``name``. It is called at the start of every ``nest_sequences`` + so that initialization-time forward passes (e.g. ``resize_kv_cache``) + always receive a valid tensor for ``name`` even when no per-request data + is available. + + Has no effect when ``cm`` is ``None`` (e.g. during DemoLLM export). + """ + if self.cm is None: + return + self.cm.info.register_default_extra_arg(name, factory) + + def get_or_create_cached_node(self, key: str, builder: Callable[[], Node]) -> Node: + """Memoize provider-created nodes so shared masks are built once per forward.""" + if key not in self.cache: + self.cache[key] = builder() + return self.cache[key] + + +class AttentionMaskProviderRegistry: + """Registry for backend-native attention mask providers.""" + + _registry: Dict[Tuple[str, str], AttentionMaskProviderFn] = {} + + @classmethod + def register( + cls, model_type: str, backend: str + ) -> Callable[[AttentionMaskProviderFn], AttentionMaskProviderFn]: + """Register a provider for a specific ``(model_type, backend)`` pair.""" + + def decorator(provider: AttentionMaskProviderFn) -> AttentionMaskProviderFn: + cls._registry[(model_type, backend)] = provider + return provider + + return decorator + + @classmethod + def get( + cls, model_type: Optional[str], backend: Optional[str] + ) -> Optional[AttentionMaskProviderFn]: + """Return the provider registered for ``(model_type, backend)``.""" + if model_type is None or backend is None: + return None + return cls._registry.get((model_type, backend)) + + @classmethod + def has(cls, model_type: str, backend: str) -> bool: + """Return whether a provider exists for ``(model_type, backend)``.""" + return (model_type, backend) in cls._registry diff --git a/tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py b/tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py new file mode 100644 index 00000000000..deb5547fe34 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-specific attention mask providers. + +Providers registered here add an optional ``custom_attn_mask`` graph input and +wire it to each ``torch_attention`` node. The actual mask tensor is computed +**outside** the graph (e.g. in the VLM wrapper ``forward()``) and passed in at +runtime. During warmup, text-only, and decode steps the wrapper passes +``None``, so the attention backend uses its fast causal kernel. +""" + +from __future__ import annotations + +from .attention_mask_provider import AttentionMaskProviderRegistry + + +def _add_custom_attn_mask_input(ctx, source_attn_node): + """Add ``custom_attn_mask`` as an optional graph input (default ``None``). + + No mask computation nodes are inserted into the graph. The mask is + computed outside the graph by the model wrapper and supplied at runtime. + """ + return ctx.add_or_retrieve_input( + "custom_attn_mask", + activate_arg=False, + val=None, + ) + + +@AttentionMaskProviderRegistry.register("gemma4", "triton_paged") +def _gemma4_triton_paged_mask_provider(ctx, source_attn_node): + return _add_custom_attn_mask_input(ctx, source_attn_node) + + +@AttentionMaskProviderRegistry.register("gemma4", "torch_attention") +def _gemma4_torch_attention_mask_provider(ctx, source_attn_node): + return _add_custom_attn_mask_input(ctx, source_attn_node) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py b/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py index 5653ec7a481..bb30e1a2dac 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py @@ -62,8 +62,21 @@ def _apply( node_to_gather = lm_head_node.all_input_nodes[0] self._log_info(f"Found LM head node: {lm_head_node.name}") else: - node_to_gather = lm_head_node - self._log_info("lm_head node is not linear, using it as the node to gather") + # Walk backward through elementwise/unary ops (e.g. softcapping: div, tanh, mul) + # to find the actual lm_head linear node. + current = lm_head_node + while current is not None and not is_linear_op(current): + inputs = current.all_input_nodes + current = inputs[0] if len(inputs) >= 1 else None + + if current is not None and is_linear_op(current): + node_to_gather = current.all_input_nodes[0] + self._log_info( + f"Found LM head linear through post-processing chain: {current.name}" + ) + else: + node_to_gather = lm_head_node + self._log_info("lm_head node is not linear, using it as the node to gather") # Add logits_gather_mask as input in the graph and the sequence info interface logits_gather_indices_node = self._add_or_retrieve_input(gm, cm, "token_gather_indices") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.py b/tensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.py new file mode 100644 index 00000000000..803f03a3a0d --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inject backend-native custom attention masks into ``torch_attention`` nodes.""" + +from typing import Optional, Tuple, Type + +import torch +from pydantic import Field +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import extract_op_args, is_op +from ..attention_mask_provider import ( + AttentionMaskProviderContext, + AttentionMaskProviderRegistry, + infer_model_type, +) +from ..interface import ( + BaseTransform, + SharedConfig, + Stages, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +class InjectCustomAttentionMaskConfig(TransformConfig): + """Configuration for injecting backend-native custom attention masks.""" + + stage: Stages = Field(default=Stages.PATTERN_MATCHER) + backend: str = Field( + default="torch_attention", + description="Backend key used to resolve the attention mask provider.", + ) + model_type: Optional[str] = Field( + default=None, + description="Optional explicit model_type override used for provider lookup.", + ) + override_existing_mask: bool = Field( + default=False, + description="Whether to override an attention node that already has an attn_mask input.", + ) + + +@TransformRegistry.register("inject_custom_attention_mask") +class InjectCustomAttentionMask(BaseTransform): + """Inject backend-native masks into ``torch_attention`` calls.""" + + config: InjectCustomAttentionMaskConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return InjectCustomAttentionMaskConfig + + @staticmethod + def _get_attn_mask_arg(node: Node): + return extract_op_args(node, "attn_mask")[0] + + @staticmethod + def _set_attn_mask_arg(node: Node, attn_mask: Node) -> None: + from ...utils.node_utils import _get_op_schema + + schema = _get_op_schema(node) + pos = {a.name: i for i, a in enumerate(schema.arguments)} + idx = pos["attn_mask"] + if len(node.args) > idx: + node.update_arg(idx, attn_mask) + else: + kwargs = dict(node.kwargs) + kwargs["attn_mask"] = attn_mask + node.kwargs = kwargs + + def _apply( + self, + gm: GraphModule, + cm: Optional[CachedSequenceInterface], + factory: Optional[ModelFactory], + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + attn_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention)] + if not attn_nodes: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + model_type = self.config.model_type or infer_model_type(factory) + provider = AttentionMaskProviderRegistry.get(model_type, self.config.backend) + if provider is None: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + assert model_type is not None + ctx = AttentionMaskProviderContext( + gm=gm, + cm=cm, + factory=factory, + shared_config=shared_config, + model_type=model_type, + backend=self.config.backend, + ) + + num_matches = 0 + for attn_node in attn_nodes: + if ( + not self.config.override_existing_mask + and self._get_attn_mask_arg(attn_node) is not None + ): + continue + + with gm.graph.inserting_before(attn_node): + attn_mask = provider(ctx, attn_node) + + if attn_mask is None: + continue + + self._set_attn_mask_arg(attn_node, attn_mask) + num_matches += 1 + + if num_matches == 0: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + return gm, TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False + ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 4bef54528b6..4c669d9492a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -18,7 +18,7 @@ import inspect import operator -from typing import List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type import torch import torch.nn as nn @@ -35,7 +35,7 @@ from ...shim.interface import CachedSequenceInterface from ...utils._graph import add_graph_input from ...utils.cuda_mem_tracker import get_mem_info -from ...utils.node_utils import is_op +from ...utils.node_utils import get_op_schema, is_op from ..interface import ( BaseTransform, SharedConfig, @@ -108,7 +108,7 @@ def _process_metadata_extra( # check what inputs the extra metadata op expects inputs_for_prep_meta = [ self._add_or_retrieve_input(gm, cm, arg.name) - for arg in prep_meta_op._schema.arguments + for arg in get_op_schema(prep_meta_op).arguments if arg.name in cm.info.available_args ] @@ -137,16 +137,18 @@ def _insert_cached_attn_node( self, gm: GraphModule, attn_node: Node, + cached_attn_op, qkv_nodes: List[Node], meta_nodes_std: List[Node], meta_nodes_extra: List[Node], cache_nodes: List[Node], + dynamic_kwargs: Dict[str, Optional[Node]], constants: List[Constant], ): """Insert a cached attention node into the graph.""" with gm.graph.inserting_before(attn_node): cached_attn_node = gm.graph.call_function( - self.attn_descriptor.get_cached_attention_op(), + cached_attn_op, args=( *qkv_nodes, *meta_nodes_std, @@ -154,6 +156,7 @@ def _insert_cached_attn_node( *cache_nodes, *constants, ), + kwargs=dynamic_kwargs, ) attn_node.replace_all_uses_with(cached_attn_node) gm.graph.erase_node(attn_node) @@ -168,10 +171,8 @@ def _apply( """Replace uncached source attention node with corresponding cached attn node.""" attn_descriptor = self.attn_descriptor - # Get all attention nodes and their info objects - source_op = attn_descriptor.get_source_attention_op() - # look for relevant source attention nodes + source_op = attn_descriptor.get_source_attention_op() source_attn_nodes = [n for n in gm.graph.nodes if is_op(n, source_op)] if not source_attn_nodes: @@ -191,31 +192,64 @@ def _apply( # replace fused attention node with attention node that has kv cache num_cached_attn_replacements = 0 - for attn_node in source_attn_nodes: + cache_nodes_by_layer_idx = {} + for idx, attn_node in enumerate(source_attn_nodes): # pick out GEMMs qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()] - # setup + store cache resource handlers and caches as input nodes - resources_dict = attn_descriptor.get_cache_initializers(attn_node, cm.kv_cache_config) - cache_in_nodes = [ - self._process_cache_node(gm, cm.add_resource(k, resource_handler)) - for k, resource_handler in resources_dict.items() - ] + layer_idx = attn_descriptor.get_layer_idx(attn_node) + shared_kv_source_layer_idx = attn_descriptor.get_shared_kv_source_layer_idx(attn_node) + + if shared_kv_source_layer_idx is not None: + if not attn_descriptor.supports_shared_kv(): + raise RuntimeError( + f"Backend '{self.config.backend}' does not support shared-KV attention." + ) + if layer_idx is None: + raise RuntimeError( + "Shared-KV attention node is missing layer_idx metadata required for " + "cache aliasing." + ) + if shared_kv_source_layer_idx == layer_idx: + raise RuntimeError(f"Layer {layer_idx} cannot share its own KV cache.") + if shared_kv_source_layer_idx not in cache_nodes_by_layer_idx: + raise RuntimeError( + f"Missing shared-KV source layer {shared_kv_source_layer_idx}." + ) + cache_in_nodes = cache_nodes_by_layer_idx[shared_kv_source_layer_idx] + else: + # setup + store cache initializers and caches as input nodes + if layer_idx is not None and layer_idx in cache_nodes_by_layer_idx: + raise RuntimeError( + f"Duplicate KV cache owner detected for layer {layer_idx}. " + "Each non-shared attention layer must own exactly one cache." + ) + cache_in_nodes = [] + for k, resource_handler in attn_descriptor.get_cache_initializers( + attn_node, cm.kv_cache_config + ).items(): + resource_name = cm.add_resource(k, resource_handler) + cache_in_nodes.append(self._process_cache_node(gm, resource_name)) + if layer_idx is not None: + cache_nodes_by_layer_idx[layer_idx] = cache_in_nodes # allow backend-specific prep before constants are extracted attn_descriptor.prepare_node_for_cache_insertion(gm, attn_node) # retrieve constants for attention_op + dynamic_inputs = attn_descriptor.get_dynamic_inputs(attn_node) constants = attn_descriptor.get_constants(attn_node) # insert cached attention replacement op self._insert_cached_attn_node( gm, attn_node, + attn_descriptor.get_cached_attention_op(), qkv, meta_nodes_std, meta_nodes_extra, cache_in_nodes, + dynamic_inputs, constants, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index 8eeacbd6685..54b5650738e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -159,7 +159,10 @@ def cached_attn( elif attention_layout != "bnsd": raise ValueError(f"Unsupported attention layout: {attention_layout}") - attn_output = attn_descriptor.get_cached_attention_op()( + cached_attn_op = module._node_ref.meta.get( + "cached_attn_op", attn_descriptor.get_cached_attention_op() + ) + attn_output = cached_attn_op( query, key, value, @@ -238,6 +241,7 @@ def _insert_cached_attn_node( self, gm: GraphModule, attn_node: Node, + cached_attn_op, qkv_nodes: List[Node], meta_nodes_std: List[Node], meta_nodes_extra: List[Node], @@ -246,6 +250,7 @@ def _insert_cached_attn_node( ): """Here we now need to actually do the correct mapping of the cached attn nodes.""" # store reference to metadata, caches, and constants for this attn node + attn_node.meta["cached_attn_op"] = cached_attn_op attn_node.meta["metadata_cache_keys"] = (*meta_nodes_std, *meta_nodes_extra, *cache_nodes) attn_node.meta["constants"] = constants diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index d54c86d0c37..6ba653cec9b 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -18,7 +18,7 @@ from torch.utils._pytree import _LEAF_SPEC, TreeSpec from .logger import ad_logger -from .node_utils import get_weight_tensor, is_op +from .node_utils import get_op_schema, get_weight_tensor, is_op # --------------------------------------------------------------------------- # Dynamic custom-op derivation helpers @@ -72,7 +72,7 @@ def create_derived_custom_op( the same *base_op* and *suffix* return the cached op. """ base_overload = base_op.default if hasattr(base_op, "default") else base_op - schema = base_overload._schema + schema = get_op_schema(base_overload) # e.g. "auto_deploy::trtllm_moe_fused" qualified_name = schema.name diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 71f4bab4e52..a64800782db 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -1049,18 +1049,22 @@ def extract_output_tuple(node: Node, count: int = 2): return results -def _get_op_schema(node: Node): - """Return the op schema for a call_function node.""" - if node.op != "call_function": - raise ValueError(f"_get_op_schema only supports call_function nodes, got {node.op}") - op = node.target +def get_op_schema(op) -> torch.FunctionSchema: + """Return the schema for an op or op overload packet.""" if hasattr(op, "_schemas"): return next(iter(op._schemas.values())) - elif hasattr(op, "_schema"): + if hasattr(op, "_schema"): return op._schema raise RuntimeError(f"No schema found on op {op}") +def _get_op_schema(node: Node): + """Return the op schema for a call_function node.""" + if node.op != "call_function": + raise ValueError(f"_get_op_schema only supports call_function nodes, got {node.op}") + return get_op_schema(node.target) + + def extract_op_args(node: Node, *arg_names): """ Given a call_function node for torch custom op, diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 222f7052ed6..8f9189a713b 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -982,6 +982,44 @@ def test_nvfp4(self, ep_size, attention_dp): task.evaluate(llm) +class TestGemma4MoE(LlmapiAccuracyTestHarness): + """Bench-run coverage for Gemma4 MoE via AutoDeploy.""" + + MODEL_NAME = "google/gemma-4-26B-A4B-it" + EXTRA_EVALUATOR_KWARGS = { + "apply_chat_template": True, + } + + def get_default_sampling_params(self): + return SamplingParams(end_id=None, + pad_id=None, + n=1, + use_beam_search=False) + + @pytest.mark.skip_less_device_memory(80000) + def test_bf16(self): + yaml_paths, registry_world_size = _get_registry_yaml_extra( + self.MODEL_NAME) + if get_device_count() < registry_world_size: + pytest.skip("Not enough devices for world size, skipping test") + + self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_NAME, + tokenizer=self.MODEL_NAME, + world_size=registry_world_size, + yaml_extra=yaml_paths) as llm: + #task = MMLU(self.MODEL_NAME) + #task.evaluate(llm, + # sampling_params=sampling_params, + # extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + #task = GSM8K(self.MODEL_NAME) + #task.evaluate(llm, + # extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + task = MMMU(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + + class TestModelRegistryAccuracy(LlmapiAccuracyTestHarness): """Accuracy tests for models from the AutoDeploy model registry. diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py new file mode 100644 index 00000000000..847c84dcf68 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py @@ -0,0 +1,474 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import copy +from typing import Tuple + +import pytest +import torch +from torch.export import Dim + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_backend_attention import ( + TorchBackendAttention, +) +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_gemma3n import ( + Gemma3nAudioConfig, + Gemma3nConditionalOutput, + Gemma3nConfig, + Gemma3nForCausalLM, + Gemma3nForConditionalGeneration, + Gemma3nTextAttention, + Gemma3nTextConfig, + Gemma3nTextDecoderLayer, + Gemma3nTextMLP, + Gemma3nVisionConfig, +) +from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device + + +def assert_rmse_close( + actual: torch.Tensor, + expected: torch.Tensor, + rmse_ratio_tol: float, + msg: str = "", +) -> None: + diff = actual.float() - expected.float() + rmse_diff = torch.sqrt(torch.mean(diff**2)) + rmse_ref = torch.sqrt(torch.mean(expected.float() ** 2)) + ratio = (rmse_diff / rmse_ref).item() + assert ratio < rmse_ratio_tol, ( + f"{msg}RMSE ratio {ratio:.6f} exceeds tolerance {rmse_ratio_tol}. " + f"(rmse_diff={rmse_diff.item():.6f}, rmse_ref={rmse_ref.item():.6f})" + ) + + +def _get_hf_classes(): + try: + from transformers.models.gemma3n.modeling_gemma3n import ( + Gemma3nForCausalLM as HFGemma3nForCausalLM, + ) + from transformers.models.gemma3n.modeling_gemma3n import ( + Gemma3nTextAttention as HFGemma3nTextAttention, + ) + from transformers.models.gemma3n.modeling_gemma3n import ( + Gemma3nTextDecoderLayer as HFGemma3nTextDecoderLayer, + ) + from transformers.models.gemma3n.modeling_gemma3n import Gemma3nTextMLP as HFGemma3nTextMLP + except ImportError: + return None + return HFGemma3nForCausalLM, HFGemma3nTextAttention, HFGemma3nTextDecoderLayer, HFGemma3nTextMLP + + +HF_CLASSES = _get_hf_classes() + + +def _device_and_dtype() -> Tuple[str, torch.dtype]: + if torch.cuda.is_available(): + return "cuda", torch.bfloat16 + return "cpu", torch.float32 + + +def _small_text_config() -> Gemma3nTextConfig: + config = Gemma3nTextConfig( + vocab_size=256, + vocab_size_per_layer_input=256, + hidden_size=64, + hidden_size_per_layer_input=8, + intermediate_size=[128, 128, 128], + num_hidden_layers=3, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=64, + rms_norm_eps=1e-6, + rope_theta=10000.0, + rope_local_base_freq=1000.0, + attention_bias=False, + attention_dropout=0.0, + sliding_window=16, + layer_types=["sliding_attention", "sliding_attention", "full_attention"], + final_logit_softcapping=30.0, + altup_active_idx=0, + altup_correct_scale=True, + altup_num_inputs=3, + num_kv_shared_layers=0, + laurel_rank=8, + activation_sparsity_pattern=[0.5, 0.0, 0.0], + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + ) + config._attn_implementation = "eager" + return config + + +def _small_full_config() -> Gemma3nConfig: + return Gemma3nConfig( + text_config=_small_text_config(), + vision_config=Gemma3nVisionConfig( + hidden_size=32, + vocab_size=8, + vocab_offset=256, + rms_norm_eps=1e-6, + ), + audio_config=Gemma3nAudioConfig( + vocab_size=8, + vocab_offset=264, + hidden_size=32, + rms_norm_eps=1e-6, + conf_num_attention_heads=4, + conf_num_hidden_layers=2, + sscp_conv_channel_size=(16, 8), + ), + ) + + +def _extended_text_config(num_hidden_layers: int) -> Gemma3nTextConfig: + config = copy.deepcopy(_small_text_config()) + config.num_hidden_layers = num_hidden_layers + config.intermediate_size = [128] * num_hidden_layers + config.layer_types = ["sliding_attention"] * (num_hidden_layers - 1) + ["full_attention"] + config.activation_sparsity_pattern = [0.0] * num_hidden_layers + return config + + +def _shared_kv_text_config() -> Gemma3nTextConfig: + config = Gemma3nTextConfig( + vocab_size=256, + vocab_size_per_layer_input=256, + hidden_size=64, + hidden_size_per_layer_input=8, + intermediate_size=[128] * 6, + num_hidden_layers=6, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=64, + rms_norm_eps=1e-6, + rope_theta=10000.0, + rope_local_base_freq=1000.0, + attention_bias=False, + attention_dropout=0.0, + sliding_window=16, + layer_types=[ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ], + final_logit_softcapping=30.0, + altup_active_idx=0, + altup_correct_scale=True, + altup_num_inputs=3, + num_kv_shared_layers=2, + laurel_rank=8, + activation_sparsity_pattern=[0.0] * 6, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + ) + config._attn_implementation = "eager" + return config + + +def _position_ids(batch_size: int, seq_len: int, device: str) -> torch.Tensor: + return torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + + +def _load_equivalent_modules(custom_module: torch.nn.Module, hf_module: torch.nn.Module) -> None: + missing, unexpected = custom_module.load_state_dict(hf_module.state_dict(), strict=False) + assert not missing, f"Missing keys when loading HF weights into custom module: {missing}" + assert not unexpected, ( + f"Unexpected keys when loading HF weights into custom module: {unexpected}" + ) + + +@pytest.fixture(autouse=True) +def _set_seed(): + torch.manual_seed(42) + + +def test_hf_reference_available(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + hf_model_cls, hf_attention_cls, hf_layer_cls, hf_mlp_cls = HF_CLASSES + assert hf_model_cls.__name__ == "Gemma3nForCausalLM" + assert hf_attention_cls.__name__ == "Gemma3nTextAttention" + assert hf_layer_cls.__name__ == "Gemma3nTextDecoderLayer" + assert hf_mlp_cls.__name__ == "Gemma3nTextMLP" + + +@torch.no_grad() +def test_gemma3n_mlp_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + _, _, _, hf_mlp_cls = HF_CLASSES + device, dtype = _device_and_dtype() + config = _small_text_config() + custom_mlp = Gemma3nTextMLP(config, layer_idx=0).to(device=device, dtype=dtype) + hf_mlp = hf_mlp_cls(config, layer_idx=0).to(device=device, dtype=dtype) + _load_equivalent_modules(custom_mlp, hf_mlp) + + hidden_states = torch.randn(2, 6, config.hidden_size, device=device, dtype=dtype) + custom_out = custom_mlp(hidden_states) + hf_out = hf_mlp(hidden_states) + torch.testing.assert_close(custom_out.float(), hf_out.float(), rtol=1e-3, atol=1e-3) + + +@torch.no_grad() +def test_gemma3n_attention_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + _, hf_attention_cls, _, _ = HF_CLASSES + device, dtype = _device_and_dtype() + config = _small_text_config() + custom_attn = Gemma3nTextAttention(config, layer_idx=2).to(device=device, dtype=dtype) + hf_attn = hf_attention_cls(config, layer_idx=2).to(device=device, dtype=dtype) + _load_equivalent_modules(custom_attn, hf_attn) + + hidden_states = torch.randn(2, 6, config.hidden_size, device=device, dtype=dtype) + position_ids = _position_ids(2, 6, device) + custom_rope = Gemma3nForCausalLM(config).model.rotary_emb.to(device=device) + full_cos, full_sin = custom_rope(hidden_states, position_ids) + position_embeddings = (full_cos[position_ids], full_sin[position_ids]) + + custom_out = custom_attn(hidden_states, position_embeddings) + hf_out = hf_attn(hidden_states, position_embeddings, attention_mask=None)[0] + assert_rmse_close(custom_out[:, -1:], hf_out[:, -1:], rmse_ratio_tol=0.10, msg="Attention: ") + + +@torch.no_grad() +def test_gemma3n_decoder_layer_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + _, _, hf_layer_cls, _ = HF_CLASSES + device, dtype = _device_and_dtype() + config = _small_text_config() + custom_layer = Gemma3nTextDecoderLayer(config, layer_idx=2).to(device=device, dtype=dtype) + hf_layer = hf_layer_cls(config, layer_idx=2).to(device=device, dtype=dtype) + _load_equivalent_modules(custom_layer, hf_layer) + + batch_size, seq_len = 2, 1 + hidden_states = torch.randn( + config.altup_num_inputs, batch_size, seq_len, config.hidden_size, device=device, dtype=dtype + ) + per_layer_input = torch.randn( + batch_size, seq_len, config.hidden_size_per_layer_input, device=device, dtype=dtype + ) + position_ids = _position_ids(batch_size, seq_len, device) + rope_model = Gemma3nForCausalLM(config).model.to(device=device) + global_cos, global_sin = rope_model.rotary_emb(hidden_states[0], position_ids) + local_cos, local_sin = rope_model.rotary_emb_local(hidden_states[0], position_ids) + position_embeddings_global = (global_cos[position_ids], global_sin[position_ids]) + position_embeddings_local = (local_cos[position_ids], local_sin[position_ids]) + + custom_out = custom_layer( + hidden_states, + position_embeddings_global, + position_embeddings_local, + per_layer_input, + ) + hf_out = hf_layer( + hidden_states, + position_embeddings_global, + position_embeddings_local, + per_layer_input, + attention_mask=None, + position_ids=position_ids, + )[0] + assert_rmse_close(custom_out, hf_out, rmse_ratio_tol=0.05, msg="Decoder layer: ") + + +@torch.no_grad() +def test_gemma3n_full_model_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + hf_model_cls, _, _, _ = HF_CLASSES + device, dtype = "cpu", torch.float32 + config = _small_text_config() + custom_model = Gemma3nForCausalLM(config).to(device=device, dtype=dtype) + hf_model = hf_model_cls(config).to(device=device, dtype=dtype) + _load_equivalent_modules(custom_model, hf_model) + custom_model.eval() + hf_model.eval() + + input_ids = torch.randint(0, config.vocab_size, (2, 6), device=device) + position_ids = _position_ids(2, 6, device) + custom_out = custom_model(input_ids=input_ids, position_ids=position_ids) + hf_out = hf_model(input_ids=input_ids, position_ids=position_ids) + assert_rmse_close(custom_out.logits, hf_out.logits, rmse_ratio_tol=0.05, msg="Full model: ") + + +@torch.no_grad() +def test_gemma3n_conditional_wrapper_equivalence(): + if HF_CLASSES is None: + pytest.skip("transformers gemma3n reference classes are unavailable") + + hf_model_cls, _, _, _ = HF_CLASSES + device, dtype = "cpu", torch.float32 + config = _small_full_config() + wrapper = Gemma3nForConditionalGeneration(config).to(device=device, dtype=dtype) + hf_model = hf_model_cls(config.text_config).to(device=device, dtype=dtype) + _load_equivalent_modules(wrapper.model.language_model, hf_model.model) + _load_equivalent_modules(wrapper.lm_head, hf_model.lm_head) + wrapper.eval() + hf_model.eval() + + input_ids = torch.randint( + 0, config.text_config.vocab_size_per_layer_input, (2, 6), device=device + ) + position_ids = _position_ids(2, 6, device) + wrapper_out = wrapper(input_ids=input_ids, position_ids=position_ids) + hf_out = hf_model(input_ids=input_ids, position_ids=position_ids) + assert isinstance(wrapper_out, Gemma3nConditionalOutput) + assert_rmse_close(wrapper_out.logits, hf_out.logits, rmse_ratio_tol=0.05, msg="Wrapper: ") + + +def test_gemma3n_conditional_wrapper_load_hook_drops_unsupported_tower_weights(): + config = _small_full_config() + wrapper = Gemma3nForConditionalGeneration(config) + state_dict = wrapper.state_dict() + state_dict["model.vision_tower.fake.weight"] = torch.randn(2, 2) + state_dict["model.audio_tower.fake.weight"] = torch.randn(2, 2) + + missing, unexpected = wrapper.load_state_dict(state_dict, strict=True) + + assert missing == [] + assert unexpected == [] + + +def test_gemma3n_conditional_wrapper_ignores_hf_init_kwargs(): + config = _small_full_config() + wrapper = Gemma3nForConditionalGeneration(config, use_cache=False) + assert isinstance(wrapper, Gemma3nForConditionalGeneration) + + +def test_gemma3n_reduced_layer_load_hook_slices_per_layer_weights(): + source_model = Gemma3nForCausalLM(_extended_text_config(5)) + target_model = Gemma3nForCausalLM(_small_text_config()) + + missing, unexpected = target_model.load_state_dict(source_model.state_dict(), strict=False) + + assert missing == [] + assert "model.layers.3.self_attn.q_proj.weight" in unexpected + + +def test_gemma3n_causal_lm_ties_lm_head_to_input_embeddings(): + model = Gemma3nForCausalLM(_small_text_config()) + assert model.lm_head.weight.data_ptr() == model.model.embed_tokens.weight.data_ptr() + + +def test_gemma3n_conditional_lm_ties_lm_head_to_input_embeddings(): + model = Gemma3nForConditionalGeneration(_small_full_config()) + assert ( + model.lm_head.weight.data_ptr() == model.model.language_model.embed_tokens.weight.data_ptr() + ) + + +def test_gemma3n_shared_kv_layer_metadata_matches_config(): + model = Gemma3nForCausalLM(_shared_kv_text_config()) + layer_expectations = [ + (False, None), + (False, None), + (False, None), + (False, None), + (True, 2), + (True, 3), + ] + + for layer, (is_shared, source_idx) in zip(model.model.layers, layer_expectations, strict=True): + assert layer.self_attn.is_kv_shared_layer is is_shared + assert layer.self_attn.kv_shared_layer_index == source_idx + + +def test_gemma3n_export_uses_shared_kv_attention_for_shared_layers(): + config = _shared_kv_text_config() + model = Gemma3nForCausalLM(config).eval() + input_ids = torch.randint(0, config.vocab_size, (1, 4)) + position_ids = _position_ids(1, 4, "cpu") + + gm = torch_export_to_gm( + model, + args=tuple(), + kwargs={"input_ids": input_ids, "position_ids": position_ids}, + ) + + attn_nodes = [node for node in gm.graph.nodes if node.op == "call_function"] + attn_nodes = [ + node for node in attn_nodes if node.target == torch.ops.auto_deploy.torch_attention.default + ] + regular_nodes = [ + node + for node in attn_nodes + if TorchBackendAttention.get_shared_kv_source_layer_idx(node) is None + ] + shared_nodes = [ + node + for node in attn_nodes + if TorchBackendAttention.get_shared_kv_source_layer_idx(node) is not None + ] + + assert len(attn_nodes) == config.num_hidden_layers + assert len(regular_nodes) == config.num_hidden_layers - config.num_kv_shared_layers + assert len(shared_nodes) == config.num_kv_shared_layers + assert [TorchBackendAttention.get_layer_idx(regular) for regular in regular_nodes] == [ + 0, + 1, + 2, + 3, + ] + assert [TorchBackendAttention.get_layer_idx(shared) for shared in shared_nodes] == [4, 5] + assert [ + TorchBackendAttention.get_shared_kv_source_layer_idx(shared) for shared in shared_nodes + ] == [2, 3] + + +def test_gemma3n_model_can_be_exported(): + if not torch.cuda.is_available(): + pytest.skip("Export test requires CUDA") + + device = "cuda" + dtype = torch.bfloat16 + config = _small_full_config() + model = Gemma3nForConditionalGeneration(config).to(device=device, dtype=dtype) + model.eval() + + input_ids = torch.randint(0, config.text_config.vocab_size, (2, 8), device=device) + position_ids = _position_ids(2, 8, device) + + gm = torch_export_to_gm( + model, + args=tuple(), + kwargs={"input_ids": input_ids, "position_ids": position_ids}, + dynamic_shapes=( + {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, + {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, + ), + ) + move_to_device(gm, device) + + with torch.inference_mode(): + eager_out = model(input_ids=input_ids, position_ids=position_ids) + export_out = gm(input_ids=input_ids, position_ids=position_ids) + + assert "logits" in export_out + assert_rmse_close(export_out["logits"], eager_out.logits, rmse_ratio_tol=0.05, msg="Export: ") + + input_ids_2 = torch.randint(0, config.text_config.vocab_size, (1, 5), device=device) + position_ids_2 = _position_ids(1, 5, device) + with torch.inference_mode(): + export_out_2 = gm(input_ids=input_ids_2, position_ids=position_ids_2) + eager_out_2 = model(input_ids=input_ids_2, position_ids=position_ids_2) + assert_rmse_close( + export_out_2["logits"], eager_out_2.logits, rmse_ratio_tol=0.05, msg="Export dynamic: " + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py new file mode 100644 index 00000000000..769c174eece --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py @@ -0,0 +1,1418 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Hierarchical equivalence tests for Gemma4 AutoDeploy custom model. + +Reference classes (_Ref*) are standalone PyTorch reimplementations of the +HuggingFace Gemma4 math — no transformers>=5.3 dependency required. +""" + +from types import SimpleNamespace +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F +from torch import nn +from torch.export import Dim +from transformers.activations import ACT2FN + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_gemma4 import ( + ADGemma4ImageProcessor, + Gemma4ADInputProcessor, + Gemma4Config, + Gemma4ForCausalLM, + Gemma4ForConditionalGeneration, + Gemma4MoEBlock, + Gemma4MultimodalEmbedder, + Gemma4RotaryEmbedding, + Gemma4Router, + Gemma4TextAttention, + Gemma4TextConfig, + Gemma4TextDecoderLayer, + Gemma4TextMLP, + Gemma4VisionAttention, + Gemma4VisionConfig, + Gemma4VisionEncoder, + Gemma4VisionEncoderLayer, + Gemma4VisionMLP, + Gemma4VisionModel, + Gemma4VisionPatchEmbedder, + Gemma4VisionPooler, + Gemma4VisionRotaryEmbedding, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def assert_rmse_close( + actual: torch.Tensor, + expected: torch.Tensor, + rmse_ratio_tol: float, + msg: str = "", +) -> None: + diff = actual.float() - expected.float() + rmse_diff = torch.sqrt(torch.mean(diff**2)) + rmse_ref = torch.sqrt(torch.mean(expected.float() ** 2)) + ratio = (rmse_diff / rmse_ref).item() + assert ratio < rmse_ratio_tol, ( + f"{msg}RMSE ratio {ratio:.6f} exceeds tolerance {rmse_ratio_tol}. " + f"(rmse_diff={rmse_diff.item():.6f}, rmse_ref={rmse_ref.item():.6f})" + ) + + +def _device_and_dtype() -> Tuple[str, torch.dtype]: + if torch.cuda.is_available(): + return "cuda", torch.bfloat16 + return "cpu", torch.float32 + + +def _small_text_config() -> Gemma4TextConfig: + config = Gemma4TextConfig( + vocab_size=256, + hidden_size=64, + intermediate_size=32, + num_hidden_layers=3, + num_attention_heads=4, + num_key_value_heads=2, + num_global_key_value_heads=1, + head_dim=16, + global_head_dim=32, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=64, + rms_norm_eps=1e-6, + attention_bias=False, + attention_dropout=0.0, + attention_k_eq_v=True, + sliding_window=16, + layer_types=["sliding_attention", "sliding_attention", "full_attention"], + enable_moe_block=True, + num_experts=4, + top_k_experts=2, + expert_intermediate_size=16, + final_logit_softcapping=30.0, + hidden_size_per_layer_input=0, + num_kv_shared_layers=0, + use_double_wide_mlp=False, + use_bidirectional_attention="vision", + rope_parameters={ + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1000000.0, + "partial_rotary_factor": 0.25, + }, + "sliding_attention": { + "rope_type": "default", + "rope_theta": 10000.0, + }, + }, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + ) + config._attn_implementation = "eager" + return config + + +def _position_ids(batch_size: int, seq_len: int, device: str) -> torch.Tensor: + return torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + + +@pytest.fixture(autouse=True) +def _set_seed(): + torch.manual_seed(42) + + +# --------------------------------------------------------------------------- +# Standalone HF-faithful reference implementations (pure PyTorch) +# These mirror the HuggingFace Gemma4 math exactly, using the same +# state_dict key names, so weights can be shared between AD and reference. +# --------------------------------------------------------------------------- + + +class _RefRMSNorm(nn.Module): + """HF Gemma4RMSNorm (transformers>=5.5): norm(x) * weight.""" + + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + self.eps, -0.5) + if self.weight is not None: + normed = normed * self.weight.float() + return normed.type_as(x) + + +def _ref_rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _ref_apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, udim: int = 2): + cos = cos.unsqueeze(udim) + sin = sin.unsqueeze(udim) + return (x * cos) + (_ref_rotate_half(x) * sin) + + +def _ref_repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + if n_rep == 1: + return x + b, n, s, d = x.shape + return x[:, :, None, :, :].expand(b, n, n_rep, s, d).reshape(b, n * n_rep, s, d) + + +class _RefMLP(nn.Module): + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class _RefAttention(nn.Module): + """HF Gemma4TextAttention reference (eager, no cache, no shared-kv).""" + + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.use_k_eq_v = config.attention_k_eq_v and not self.is_sliding + + self.head_dim = ( + config.global_head_dim + if (not self.is_sliding and config.global_head_dim) + else config.head_dim + ) + self.num_heads = config.num_attention_heads + num_kv_heads = ( + config.num_global_key_value_heads if self.use_k_eq_v else config.num_key_value_heads + ) + self.num_kv_heads = num_kv_heads + self.num_kv_groups = self.num_heads // num_kv_heads + self.scaling = 1.0 + + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, num_kv_heads * self.head_dim, bias=False) + self.v_proj = ( + None + if self.use_k_eq_v + else nn.Linear(config.hidden_size, num_kv_heads * self.head_dim, bias=False) + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + self.q_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, S, _ = hidden_states.shape + shape = (B, S, -1, self.head_dim) + cos, sin = position_embeddings + + q = self.q_proj(hidden_states).view(shape) + q = self.q_norm(q) + q = _ref_apply_rotary(q, cos, sin, udim=2) + q = q.transpose(1, 2) # -> [B, num_heads, S, head_dim] + + k = self.k_proj(hidden_states).view(shape) + v = self.v_proj(hidden_states).view(shape) if self.v_proj is not None else k + k = self.k_norm(k) + k = _ref_apply_rotary(k, cos, sin, udim=2) + k = k.transpose(1, 2) + v = self.v_norm(v) + v = v.transpose(1, 2) + + # Eager attention with GQA repeat + k = _ref_repeat_kv(k, self.num_kv_groups) + v = _ref_repeat_kv(v, self.num_kv_groups) + attn_w = torch.matmul(q, k.transpose(2, 3)) * self.scaling + if attention_mask is not None: + attn_w = attn_w + attention_mask + attn_w = F.softmax(attn_w, dim=-1, dtype=torch.float32).to(q.dtype) + out = torch.matmul(attn_w, v) + out = out.transpose(1, 2).contiguous().reshape(B, S, -1) + return self.o_proj(out) + + +class _RefRouter(nn.Module): + """HF Gemma4Router reference.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(config.hidden_size)) + self.register_buffer("root_size", torch.tensor(config.hidden_size**-0.5), persistent=False) + self.eps = config.rms_norm_eps + self.top_k = config.top_k_experts + + def forward(self, hidden_states: torch.Tensor): + normed = hidden_states.float() + normed = normed * torch.rsqrt(normed.pow(2).mean(-1, keepdim=True) + self.eps) + normed = normed.type_as(hidden_states) + normed = ( + normed * self.root_size.to(hidden_states.dtype) * self.scale.to(hidden_states.dtype) + ) + probs = F.softmax(self.proj(normed), dim=-1) + topk_w, topk_i = torch.topk(probs, k=self.top_k, dim=-1) + topk_w = topk_w / topk_w.sum(dim=-1, keepdim=True) + return topk_w, topk_i + + +class _RefMoEBlock(nn.Module): + """HF Gemma4MoEBlock reference with fused parameter layout.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.num_experts = config.num_experts + inter = config.expert_intermediate_size + hidden = config.hidden_size + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * inter, hidden)) + self.down_proj = nn.Parameter(torch.zeros(self.num_experts, hidden, inter)) + self.per_expert_scale = nn.Parameter(torch.ones(self.num_experts)) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, hidden_states, top_k_index, top_k_weights): + final = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = expert_mask.sum(dim=(-1, -2)).nonzero() + for eidx in expert_hit: + eidx = eidx[0] + top_k_pos, token_idx = torch.where(expert_mask[eidx]) + cur = hidden_states[token_idx] + gate, up = F.linear(cur, self.gate_up_proj[eidx]).chunk(2, dim=-1) + cur = self.act_fn(gate) * up + cur = F.linear(cur, self.down_proj[eidx]) + cur = cur * self.per_expert_scale[eidx] + cur = cur * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, cur.to(final.dtype)) + return final + + +class _RefDecoderLayer(nn.Module): + """HF Gemma4TextDecoderLayer reference (no cache/grad-ckpt).""" + + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.self_attn = _RefAttention(config, layer_idx) + self.mlp = _RefMLP(config) + self.input_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.register_buffer("layer_scalar", torch.ones(1)) + self.enable_moe_block = config.enable_moe_block + if self.enable_moe_block: + self.router = _RefRouter(config) + self.moe = _RefMoEBlock(config) + self.post_feedforward_layernorm_1 = _RefRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = _RefRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = _RefRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward(self, hidden_states, position_embeddings, attention_mask=None): + residual = hidden_states + h = self.input_layernorm(hidden_states) + h = self.self_attn(h, position_embeddings, attention_mask=attention_mask) + h = self.post_attention_layernorm(h) + hidden_states = residual + h + + residual = hidden_states + if self.enable_moe_block: + h1 = self.pre_feedforward_layernorm(hidden_states) + h1 = self.mlp(h1) + h1 = self.post_feedforward_layernorm_1(h1) + h_flat = hidden_states.reshape(-1, hidden_states.shape[-1]) + topk_w, topk_i = self.router(h_flat) + h2 = self.pre_feedforward_layernorm_2(h_flat) + h2 = self.moe(h2, topk_i, topk_w) + h2 = h2.reshape(hidden_states.shape) + h2 = self.post_feedforward_layernorm_2(h2) + hidden_states = h1 + h2 + else: + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + hidden_states = hidden_states * self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# Weight-transfer helpers +# --------------------------------------------------------------------------- + + +def _build_ref_rope(config: Gemma4TextConfig, layer_type: str, device, dtype): + """Build reference cos/sin matching AD's Gemma4RotaryEmbedding.""" + rope = Gemma4RotaryEmbedding(config, layer_type).to(device) + return rope + + +def _load_ref_into_ad(ad_module: nn.Module, ref_module: nn.Module): + """Load reference state_dict into AD module (hooks handle weight conversion).""" + missing, unexpected = ad_module.load_state_dict(ref_module.state_dict(), strict=False) + # v_norm buffer (non-persistent) won't be in state_dict, that's expected + allowed_missing = {"v_norm.weight"} + real_missing = {k for k in missing if not any(k.endswith(s) for s in allowed_missing)} + assert not real_missing, f"Unexpected missing keys: {real_missing}" + assert not unexpected, f"Unexpected keys: {unexpected}" + + +# --------------------------------------------------------------------------- +# Tests — Block equivalence +# --------------------------------------------------------------------------- + + +def test_mlp_equivalence(): + """MLP block: identical math, should match exactly.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + + ref = _RefMLP(config).to(device=device, dtype=dtype).eval() + ad = Gemma4TextMLP(config).to(device=device, dtype=dtype).eval() + ad.load_state_dict(ref.state_dict()) + + x = torch.randn(2, 8, config.hidden_size, device=device, dtype=dtype) + with torch.no_grad(): + torch.testing.assert_close(ad(x), ref(x), rtol=1e-3, atol=1e-3) + + +def test_attention_sliding_equivalence(): + """Sliding attention (standard GQA) matches reference.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + layer_idx = 0 # sliding + + ref = _RefAttention(config, layer_idx).to(device=device, dtype=dtype).eval() + ad = Gemma4TextAttention(config, layer_idx).to(device=device, dtype=dtype).eval() + _load_ref_into_ad(ad, ref) + + B, S = 2, 8 + x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + pos_ids = _position_ids(B, S, device) + rope = _build_ref_rope(config, "sliding_attention", device, dtype) + cos, sin = rope(x, pos_ids) + + # Build causal mask for reference (AD uses is_causal=True internally) + causal_mask = torch.triu( + torch.full((S, S), float("-inf"), device=device, dtype=dtype), diagonal=1 + ) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + with torch.no_grad(): + ad_out = ad(x, (cos, sin)) + ref_out = ref(x, (cos, sin), attention_mask=causal_mask) + assert_rmse_close(ad_out, ref_out, rmse_ratio_tol=0.10, msg="Sliding attention: ") + + +def test_attention_full_k_eq_v_equivalence(): + """Full attention with K=V and different head_dim matches reference.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + layer_idx = 2 # full_attention + + ref = _RefAttention(config, layer_idx).to(device=device, dtype=dtype).eval() + ad = Gemma4TextAttention(config, layer_idx).to(device=device, dtype=dtype).eval() + _load_ref_into_ad(ad, ref) + + B, S = 2, 8 + x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + pos_ids = _position_ids(B, S, device) + rope = _build_ref_rope(config, "full_attention", device, dtype) + cos, sin = rope(x, pos_ids) + + causal_mask = torch.triu( + torch.full((S, S), float("-inf"), device=device, dtype=dtype), diagonal=1 + ) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + with torch.no_grad(): + ad_out = ad(x, (cos, sin)) + ref_out = ref(x, (cos, sin), attention_mask=causal_mask) + assert_rmse_close(ad_out, ref_out, rmse_ratio_tol=0.10, msg="Full K=V attention: ") + + +def test_moe_block_equivalence(): + """MoE block (router + experts) matches reference with fused weight conversion.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + + ref_router = _RefRouter(config).to(device=device, dtype=dtype).eval() + ref_moe = _RefMoEBlock(config).to(device=device, dtype=dtype).eval() + # Initialize MoE fused params with random values (default is zeros → all-zero output) + nn.init.normal_(ref_moe.gate_up_proj, std=0.02) + nn.init.normal_(ref_moe.down_proj, std=0.02) + nn.init.uniform_(ref_moe.per_expert_scale, 0.5, 1.5) + + ad_router = Gemma4Router(config).to(device=device, dtype=dtype).eval() + ad_moe = Gemma4MoEBlock(config).to(device=device, dtype=dtype).eval() + + # Load router weights (same structure) + ad_router.load_state_dict(ref_router.state_dict()) + # Manually unfuse ref MoE fused weights into per-expert format + # (The unfusing hook is on the decoder layer, not the MoE block) + ref_sd = ref_moe.state_dict() + gate_up = ref_sd["gate_up_proj"] # [E, 2*I, H] + down = ref_sd["down_proj"] # [E, H, I] + scale = ref_sd["per_expert_scale"] # [E] + inter = config.expert_intermediate_size + ad_sd = {} + for e in range(config.num_experts): + ad_sd[f"experts.{e}.gate_proj.weight"] = gate_up[e, :inter, :] + ad_sd[f"experts.{e}.up_proj.weight"] = gate_up[e, inter:, :] + ad_sd[f"experts.{e}.down_proj.weight"] = down[e] * scale[e] + ad_moe.load_state_dict(ad_sd) + + T = 16 # num tokens (flattened B*S) + x = torch.randn(T, config.hidden_size, device=device, dtype=dtype) + + with torch.no_grad(): + ref_w, ref_i = ref_router(x) + ad_w, ad_i = ad_router(x) + # Router outputs should match exactly (same math, no custom ops) + torch.testing.assert_close(ad_w, ref_w, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ad_i, ref_i) + + ref_out = ref_moe(x, ref_i, ref_w) + ad_out = ad_moe(x, ad_i, ad_w) + + assert_rmse_close(ad_out, ref_out, rmse_ratio_tol=0.02, msg="MoE block: ") + + +# --------------------------------------------------------------------------- +# Tests — Layer equivalence +# --------------------------------------------------------------------------- + + +def test_decoder_layer_equivalence(): + """Decoder layer (sliding + full) matches reference.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + + for layer_idx in [0, 2]: + layer_type = config.layer_types[layer_idx] + ref = _RefDecoderLayer(config, layer_idx).to(device=device, dtype=dtype).eval() + ad = Gemma4TextDecoderLayer(config, layer_idx).to(device=device, dtype=dtype).eval() + _load_ref_into_ad(ad, ref) + + B, S = 2, 8 + x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + pos_ids = _position_ids(B, S, device) + rope = _build_ref_rope(config, layer_type, device, dtype) + cos, sin = rope(x, pos_ids) + + causal_mask = ( + torch.triu(torch.full((S, S), float("-inf"), device=device, dtype=dtype), diagonal=1) + .unsqueeze(0) + .unsqueeze(0) + ) + + with torch.no_grad(): + ad_out = ad(x, (cos, sin)) + ref_out = ref(x, (cos, sin), attention_mask=causal_mask) + assert_rmse_close( + ad_out, ref_out, rmse_ratio_tol=0.05, msg=f"Layer {layer_idx} ({layer_type}): " + ) + + +# --------------------------------------------------------------------------- +# Tests — Full model equivalence +# --------------------------------------------------------------------------- + + +class _RefForCausalLM(nn.Module): + """Standalone reference CausalLM for full-model equivalence testing.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.embed_scale = config.hidden_size**0.5 + self.layers = nn.ModuleList( + [_RefDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Tie weights like AD model (tie_word_embeddings=True) + if config.tie_word_embeddings: + self.lm_head.weight = self.embed_tokens.weight + + def forward(self, input_ids, position_ids): + hidden_states = self.embed_tokens(input_ids) * self.embed_scale + for i, layer in enumerate(self.layers): + layer_type = self.config.layer_types[i] + rope = _build_ref_rope( + self.config, layer_type, hidden_states.device, hidden_states.dtype + ) + cos, sin = rope(hidden_states, position_ids) + causal_mask = ( + torch.triu( + torch.full( + (hidden_states.shape[1], hidden_states.shape[1]), + float("-inf"), + device=hidden_states.device, + dtype=hidden_states.dtype, + ), + diagonal=1, + ) + .unsqueeze(0) + .unsqueeze(0) + ) + hidden_states = layer(hidden_states, (cos, sin), attention_mask=causal_mask) + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + return logits + + +def _transfer_ref_to_ad_full_model(ad_model: Gemma4ForCausalLM, ref_model: _RefForCausalLM) -> None: + """Transfer weights from reference full model into AD ForCausalLM. + + The ref uses fused MoE weights (gate_up_proj, down_proj, per_expert_scale) + while the AD model uses per-expert weights. The AD decoder layer's + _unfuse_moe_weights pre-hook handles this conversion automatically when + the fused keys are present in the state_dict passed to load_state_dict. + """ + ref_sd = ref_model.state_dict() + # AD ForCausalLM has flat keys (layers.0..., embed_tokens..., lm_head...) + # matching the ref layout — no prefix remapping needed. + missing, unexpected = ad_model.load_state_dict(ref_sd, strict=False) + # v_norm buffers are non-persistent, expected missing + real_missing = {m for m in missing if "v_norm" not in m} + assert not real_missing, f"Missing keys: {real_missing}" + assert not unexpected, f"Unexpected keys: {unexpected}" + + +def test_full_model_equivalence(): + """Full CausalLM logits match standalone reference with shared weights.""" + device, dtype = _device_and_dtype() + config = _small_text_config() + + ref = _RefForCausalLM(config).to(device=device, dtype=dtype).eval() + ad = Gemma4ForCausalLM(config).to(device=device, dtype=dtype).eval() + _transfer_ref_to_ad_full_model(ad, ref) + + B, S = 2, 8 + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + pos_ids = _position_ids(B, S, device) + + with torch.no_grad(): + ref_logits = ref(input_ids, pos_ids) + ad_out = ad(input_ids=input_ids, position_ids=pos_ids) + + assert ad_out.logits.shape == (B, S, config.vocab_size) + assert torch.isfinite(ad_out.logits).all() + assert_rmse_close(ad_out.logits, ref_logits, rmse_ratio_tol=0.05, msg="Full model: ") + + +def test_conditional_generation_wrapper(): + """ConditionalGeneration wrapper loads and forwards correctly.""" + device, dtype = _device_and_dtype() + config = Gemma4Config( + text_config=_small_text_config(), + vision_config=Gemma4VisionConfig(hidden_size=32), + ) + model = Gemma4ForConditionalGeneration(config).to(device=device, dtype=dtype).eval() + + B, S = 2, 8 + input_ids = torch.randint(0, config.text_config.vocab_size, (B, S), device=device) + pos_ids = _position_ids(B, S, device) + + with torch.no_grad(): + out = model(input_ids=input_ids, position_ids=pos_ids) + assert out.logits is not None + assert out.logits.shape == (B, S, config.text_config.vocab_size) + assert torch.isfinite(out.logits).all() + + +# --------------------------------------------------------------------------- +# Tests — Export +# --------------------------------------------------------------------------- + + +def test_export(): + """Model can be exported with torch.export and produces correct output.""" + device = "cpu" + dtype = torch.float32 + config = _small_text_config() + config.enable_moe_block = False # MoE expert dispatch uses data-dependent ops + + model = Gemma4ForCausalLM(config).to(device=device, dtype=dtype).eval() + + B, S = 2, 8 + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + pos_ids = _position_ids(B, S, device) + + batch_dim = Dim("batch", min=1, max=4) + seq_dim = Dim("seq", min=1, max=64) + dynamic_shapes = { + "input_ids": {0: batch_dim, 1: seq_dim}, + "position_ids": {0: batch_dim, 1: seq_dim}, + } + + gm = torch_export_to_gm( + model, + args=(input_ids,), + kwargs={"position_ids": pos_ids}, + dynamic_shapes=dynamic_shapes, + ) + + with torch.no_grad(): + pre_export_out = model(input_ids=input_ids, position_ids=pos_ids) + exported_out = gm(input_ids, position_ids=pos_ids) + + logits = ( + exported_out[0] + if isinstance(exported_out, tuple) + else getattr(exported_out, "logits", exported_out) + ) + assert torch.isfinite(logits).all(), "Export produced non-finite values" + # Exported graph should produce identical output to the original model + torch.testing.assert_close(logits, pre_export_out.logits, rtol=1e-3, atol=1e-3) + + # Test different shape + B2, S2 = 1, 4 + ids2 = torch.randint(0, config.vocab_size, (B2, S2), device=device) + pos2 = _position_ids(B2, S2, device) + with torch.no_grad(): + out2 = gm(ids2, position_ids=pos2) + logits2 = out2[0] if isinstance(out2, tuple) else getattr(out2, "logits", out2) + assert logits2.shape == (B2, S2, config.vocab_size) + assert torch.isfinite(logits2).all() + + +# --------------------------------------------------------------------------- +# Vision tower — helpers and small config +# --------------------------------------------------------------------------- + + +def _small_vision_config() -> Gemma4VisionConfig: + return Gemma4VisionConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=16, + hidden_activation="gelu_pytorch_tanh", + rms_norm_eps=1e-6, + max_position_embeddings=256, + attention_bias=False, + attention_dropout=0.0, + rope_parameters={"rope_type": "default", "rope_theta": 100.0}, + pooling_kernel_size=3, + patch_size=4, + position_embedding_size=64, + standardize=False, + ) + + +def _make_vision_inputs( + config: Gemma4VisionConfig, batch_size: int, num_patches: int, device: str, dtype: torch.dtype +): + """Create synthetic pixel_values and pixel_position_ids for the vision tower. + + num_patches should be divisible by pooling_kernel_size^2. + """ + patch_dim = 3 * config.patch_size**2 + pixel_values = torch.rand(batch_size, num_patches, patch_dim, device=device, dtype=dtype) + # 2D position ids: (batch, num_patches, 2) with (x, y) grid coordinates + grid_side = int(num_patches**0.5) + pos_x = torch.arange(grid_side, device=device).unsqueeze(1).expand(grid_side, grid_side) + pos_y = torch.arange(grid_side, device=device).unsqueeze(0).expand(grid_side, grid_side) + positions = torch.stack([pos_x.flatten(), pos_y.flatten()], dim=-1) # [num_patches, 2] + pixel_position_ids = positions.unsqueeze(0).expand(batch_size, -1, -1).long() + return pixel_values, pixel_position_ids + + +# --------------------------------------------------------------------------- +# Vision tower — standalone HF-faithful reference implementations +# --------------------------------------------------------------------------- + + +class _RefVisionMLP(nn.Module): + """HF Gemma4VisionMLP reference — plain nn.Linear (no ClippableLinear wrapper).""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class _RefVisionRotaryEmbedding(nn.Module): + """HF Gemma4VisionRotaryEmbedding reference.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + rope_theta = config.rope_parameters["rope_theta"] + spatial_dim = config.head_dim // 2 + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, spatial_dim, 2, dtype=torch.float32) / spatial_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward( + self, hidden_states: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + all_cos = [] + all_sin = [] + for dim_idx in range(2): + dim_pos = position_ids[:, None, :, dim_idx].float().to(hidden_states.device) + freqs = (inv_freq_expanded.to(hidden_states.device) @ dim_pos).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + all_cos.append(emb.cos()) + all_sin.append(emb.sin()) + cos = torch.cat(all_cos, dim=-1).to(dtype=hidden_states.dtype, device=hidden_states.device) + sin = torch.cat(all_sin, dim=-1).to(dtype=hidden_states.dtype, device=hidden_states.device) + return cos, sin + + +def _ref_vision_apply_multidimensional_rope( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + unsqueeze_dim: int = 2, +) -> torch.Tensor: + """HF apply_multidimensional_rope reference.""" + ndim = position_ids.shape[-1] + num_channels = x.shape[-1] + num_rotated_per_dim = 2 * (num_channels // (2 * ndim)) + split_sizes = [num_rotated_per_dim] * ndim + x_parts = torch.split(x, split_sizes, dim=-1) + cos_parts = torch.split(cos, split_sizes, dim=-1) + sin_parts = torch.split(sin, split_sizes, dim=-1) + outputs = [] + for idx in range(ndim): + c = cos_parts[idx].unsqueeze(unsqueeze_dim) + s = sin_parts[idx].unsqueeze(unsqueeze_dim) + outputs.append((x_parts[idx] * c) + (_ref_rotate_half(x_parts[idx]) * s)) + return torch.cat(outputs, dim=-1) + + +class _RefVisionAttention(nn.Module): + """HF Gemma4VisionAttention reference (eager, no cache).""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.head_dim = config.head_dim + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.q_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = _RefRMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + B, S, _ = hidden_states.shape + shape = (B, S, -1, self.head_dim) + cos, sin = position_embeddings + + q = self.q_norm(self.q_proj(hidden_states).view(shape)) + q = _ref_vision_apply_multidimensional_rope(q, cos, sin, torch.zeros_like(cos[..., :2])) + q = q.transpose(1, 2) + + k = self.k_norm(self.k_proj(hidden_states).view(shape)) + k = _ref_vision_apply_multidimensional_rope(k, cos, sin, torch.zeros_like(cos[..., :2])) + k = k.transpose(1, 2) + + v = self.v_norm(self.v_proj(hidden_states).view(shape)) + v = v.transpose(1, 2) + + # GQA repeat + k = _ref_repeat_kv(k, self.num_kv_groups) + v = _ref_repeat_kv(v, self.num_kv_groups) + + attn_w = torch.matmul(q, k.transpose(2, 3)) # scaling=1.0 for vision + if attention_mask is not None: + invalid = torch.finfo(attn_w.dtype).min + attn_w = attn_w.masked_fill(attention_mask.logical_not(), invalid) + attn_w = F.softmax(attn_w, dim=-1, dtype=torch.float32).to(q.dtype) + out = torch.matmul(attn_w, v) + out = out.transpose(1, 2).contiguous().reshape(B, S, -1) + return self.o_proj(out), attn_w + + +class _RefVisionEncoderLayer(nn.Module): + """HF Gemma4VisionEncoderLayer reference.""" + + def __init__(self, config: Gemma4VisionConfig, layer_idx: int): + super().__init__() + del layer_idx + self.self_attn = _RefVisionAttention(config) + self.mlp = _RefVisionMLP(config) + self.input_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = _RefRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return residual + hidden_states + + +class _RefVisionEncoder(nn.Module): + """HF Gemma4VisionEncoder reference.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.rotary_emb = _RefVisionRotaryEmbedding(config) + self.layers = nn.ModuleList( + [_RefVisionEncoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + pixel_position_ids: torch.LongTensor, + ): + valid = attention_mask.to(torch.bool) + attention_mask_4d = valid[:, None, :, None] & valid[:, None, None, :] + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids) + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask_4d, + position_embeddings=position_embeddings, + position_ids=pixel_position_ids, + ) + return hidden_states + + +class _RefVisionPatchEmbedder(nn.Module): + """HF Gemma4VisionPatchEmbedder reference.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.patch_size = config.patch_size + self.position_embedding_size = config.position_embedding_size + self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False) + self.position_embedding_table = nn.Parameter( + torch.ones(2, self.position_embedding_size, self.hidden_size) + ) + + def _position_embeddings( + self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor + ) -> torch.Tensor: + clamped = pixel_position_ids.clamp(min=0) + one_hot = F.one_hot(clamped, num_classes=self.position_embedding_size) + one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) + pos_emb = one_hot @ self.position_embedding_table + pos_emb = pos_emb.sum(dim=1) + return torch.where(padding_positions.unsqueeze(-1), 0.0, pos_emb) + + def forward(self, pixel_values, pixel_position_ids, padding_positions): + pixel_values = 2 * (pixel_values - 0.5) + hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype)) + pos_emb = self._position_embeddings(pixel_position_ids, padding_positions) + return hidden_states + pos_emb + + +class _RefVisionPooler(nn.Module): + """HF Gemma4VisionPooler reference.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.root_hidden_size = self.hidden_size**0.5 + + def _avg_pool_by_positions(self, hidden_states, pixel_position_ids, length): + input_seq_len = hidden_states.shape[1] + kernel_size = int((input_seq_len // length) ** 0.5) + clamped = pixel_position_ids.clamp(min=0) + max_x = clamped[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kernel_indices = torch.div(clamped, kernel_size, rounding_mode="floor") + kernel_indices = kernel_indices[..., 0] + (max_x // kernel_size) * kernel_indices[..., 1] + weights = F.one_hot(kernel_indices.long(), length).float() / (kernel_size**2) + output = weights.transpose(1, 2) @ hidden_states.float() + mask = torch.logical_not((weights == 0).all(dim=1)) + return output.to(hidden_states.dtype), mask + + def forward(self, hidden_states, pixel_position_ids, padding_positions, output_length=None): + if output_length is None: + output_length = hidden_states.shape[1] + hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0) + if hidden_states.shape[1] != output_length: + hidden_states, padding_positions = self._avg_pool_by_positions( + hidden_states, pixel_position_ids, output_length + ) + hidden_states *= self.root_hidden_size + return hidden_states, padding_positions + + +class _RefVisionModel(nn.Module): + """HF Gemma4VisionModel reference (full pipeline).""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.config = config + self.patch_embedder = _RefVisionPatchEmbedder(config) + self.encoder = _RefVisionEncoder(config) + self.pooler = _RefVisionPooler(config) + + def forward(self, pixel_values, pixel_position_ids): + pooling_kernel_size = self.config.pooling_kernel_size + output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size) + padding_positions = (pixel_position_ids == -1).all(dim=-1) + inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions) + hidden_states = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, + pixel_position_ids=pixel_position_ids, + ) + hidden_states, pooler_mask = self.pooler( + hidden_states=hidden_states, + pixel_position_ids=pixel_position_ids, + padding_positions=padding_positions, + output_length=output_length, + ) + return hidden_states[pooler_mask] + + +class _RefMultimodalEmbedder(nn.Module): + """HF Gemma4MultimodalEmbedder reference.""" + + def __init__(self, vision_config: Gemma4VisionConfig, text_config: Gemma4TextConfig): + super().__init__() + self.eps = vision_config.rms_norm_eps + self.embedding_projection = nn.Linear( + vision_config.hidden_size, text_config.hidden_size, bias=False + ) + self.embedding_pre_projection_norm = _RefRMSNorm( + vision_config.hidden_size, eps=self.eps, with_scale=False + ) + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + return self.embedding_projection(self.embedding_pre_projection_norm(inputs_embeds)) + + +# --------------------------------------------------------------------------- +# Vision tower — weight transfer helpers +# --------------------------------------------------------------------------- + + +def _transfer_vision_mlp_weights(ad_mlp: Gemma4VisionMLP, ref_mlp: _RefVisionMLP): + """Transfer weights from ref MLP (plain nn.Linear) to AD MLP (ClippableLinear).""" + ad_mlp.gate_proj.linear.weight.data.copy_(ref_mlp.gate_proj.weight.data) + ad_mlp.up_proj.linear.weight.data.copy_(ref_mlp.up_proj.weight.data) + ad_mlp.down_proj.linear.weight.data.copy_(ref_mlp.down_proj.weight.data) + + +def _transfer_vision_attn_weights(ad_attn: Gemma4VisionAttention, ref_attn: _RefVisionAttention): + """Transfer weights from ref attention to AD attention (ClippableLinear + canonical norms).""" + ad_attn.q_proj.linear.weight.data.copy_(ref_attn.q_proj.weight.data) + ad_attn.k_proj.linear.weight.data.copy_(ref_attn.k_proj.weight.data) + ad_attn.v_proj.linear.weight.data.copy_(ref_attn.v_proj.weight.data) + ad_attn.o_proj.linear.weight.data.copy_(ref_attn.o_proj.weight.data) + ad_attn.q_norm.weight.data.copy_(ref_attn.q_norm.weight.data) + ad_attn.k_norm.weight.data.copy_(ref_attn.k_norm.weight.data) + # v_norm has no learnable scale (with_scale=False), but AD uses a buffer + # The ref also uses with_scale=False so no weight to copy + + +def _transfer_vision_encoder_layer_weights( + ad_layer: Gemma4VisionEncoderLayer, ref_layer: _RefVisionEncoderLayer +): + """Transfer all weights for a single encoder layer.""" + _transfer_vision_attn_weights(ad_layer.self_attn, ref_layer.self_attn) + _transfer_vision_mlp_weights(ad_layer.mlp, ref_layer.mlp) + for norm_name in [ + "input_layernorm", + "post_attention_layernorm", + "pre_feedforward_layernorm", + "post_feedforward_layernorm", + ]: + getattr(ad_layer, norm_name).weight.data.copy_(getattr(ref_layer, norm_name).weight.data) + + +def _transfer_vision_encoder_weights( + ad_encoder: Gemma4VisionEncoder, ref_encoder: _RefVisionEncoder +): + """Transfer encoder weights including RoPE and all layers.""" + ad_encoder.rotary_emb.inv_freq.data.copy_(ref_encoder.rotary_emb.inv_freq.data) + for ad_layer, ref_layer in zip(ad_encoder.layers, ref_encoder.layers): + _transfer_vision_encoder_layer_weights(ad_layer, ref_layer) + + +def _transfer_vision_patch_embedder_weights( + ad_pe: Gemma4VisionPatchEmbedder, ref_pe: _RefVisionPatchEmbedder +): + """Transfer patch embedder weights.""" + ad_pe.input_proj.weight.data.copy_(ref_pe.input_proj.weight.data) + ad_pe.position_embedding_table.data.copy_(ref_pe.position_embedding_table.data) + + +def _transfer_vision_model_weights(ad_model: Gemma4VisionModel, ref_model: _RefVisionModel): + """Transfer all vision model weights.""" + _transfer_vision_patch_embedder_weights(ad_model.patch_embedder, ref_model.patch_embedder) + _transfer_vision_encoder_weights(ad_model.encoder, ref_model.encoder) + # Pooler has no learnable parameters + + +def _transfer_multimodal_embedder_weights( + ad_emb: Gemma4MultimodalEmbedder, ref_emb: _RefMultimodalEmbedder +): + """Transfer multimodal embedder weights.""" + ad_emb.embedding_projection.weight.data.copy_(ref_emb.embedding_projection.weight.data) + # Pre-projection norm has with_scale=False on both sides, no weight to copy + + +# --------------------------------------------------------------------------- +# Tests — Vision tower block equivalence +# --------------------------------------------------------------------------- + + +def test_vision_mlp_equivalence(): + """Vision MLP: identical math (SwiGLU), should match exactly.""" + device, dtype = _device_and_dtype() + config = _small_vision_config() + + ref = _RefVisionMLP(config).to(device=device, dtype=dtype).eval() + ad = Gemma4VisionMLP(config).to(device=device, dtype=dtype).eval() + _transfer_vision_mlp_weights(ad, ref) + + x = torch.randn(2, 9, config.hidden_size, device=device, dtype=dtype) + with torch.no_grad(): + torch.testing.assert_close(ad(x), ref(x), rtol=1e-3, atol=1e-3) + + +def test_vision_rotary_embedding_equivalence(): + """Vision RoPE: multidimensional cos/sin should match reference exactly.""" + device, dtype = _device_and_dtype() + config = _small_vision_config() + + ref_rope = _RefVisionRotaryEmbedding(config).to(device=device) + ad_rope = Gemma4VisionRotaryEmbedding(config).to(device=device) + + B, S = 2, 9 + hidden = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + # 2D position ids + grid_side = 3 + pos_x = torch.arange(grid_side, device=device).unsqueeze(1).expand(grid_side, grid_side) + pos_y = torch.arange(grid_side, device=device).unsqueeze(0).expand(grid_side, grid_side) + pos_ids = torch.stack([pos_x.flatten(), pos_y.flatten()], dim=-1).unsqueeze(0).expand(B, -1, -1) + + with torch.no_grad(): + ref_cos, ref_sin = ref_rope(hidden, pos_ids) + ad_cos, ad_sin = ad_rope(hidden, pos_ids) + torch.testing.assert_close(ad_cos, ref_cos, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(ad_sin, ref_sin, rtol=1e-5, atol=1e-5) + + +def test_vision_patch_embedder_equivalence(): + """Patch embedder: linear projection + position embeddings.""" + device, dtype = _device_and_dtype() + config = _small_vision_config() + + ref = _RefVisionPatchEmbedder(config).to(device=device, dtype=dtype).eval() + ad = Gemma4VisionPatchEmbedder(config).to(device=device, dtype=dtype).eval() + _transfer_vision_patch_embedder_weights(ad, ref) + + B, num_patches = 2, 9 + pixel_values, pixel_position_ids = _make_vision_inputs(config, B, num_patches, device, dtype) + padding_positions = (pixel_position_ids == -1).all(dim=-1) + + with torch.no_grad(): + ref_out = ref(pixel_values, pixel_position_ids, padding_positions) + ad_out = ad(pixel_values, pixel_position_ids, padding_positions) + torch.testing.assert_close(ad_out, ref_out, rtol=1e-3, atol=1e-3) + + +def test_image_processor_pads_to_fixed_patch_budget(): + """Image processor should pad every request to the configured patch budget.""" + config = _small_vision_config() + processor = ADGemma4ImageProcessor( + patch_size=config.patch_size, + max_soft_tokens=280, + pooling_kernel_size=config.pooling_kernel_size, + do_resize=False, + do_rescale=False, + do_normalize=False, + ) + + image_small = torch.zeros(3, 12, 12, dtype=torch.float32) + image_large = torch.zeros(3, 12, 24, dtype=torch.float32) + + outputs = processor([image_small, image_large]) + + target_patches = 280 * config.pooling_kernel_size**2 + assert outputs["pixel_values"].shape == (2, target_patches, 3 * config.patch_size**2) + assert outputs["image_position_ids"].shape == (2, target_patches, 2) + assert outputs["num_soft_tokens_per_image"] == [1, 2] + assert torch.all(outputs["image_position_ids"][0, 9:] == -1) + assert torch.all(outputs["image_position_ids"][1, 18:] == -1) + assert torch.all(outputs["image_position_ids"][1, :18] >= 0) + + +def test_ad_input_processor_emits_layout_metadata_for_boi_eoi_spans(): + class _DummyBaseProcessor: + def __init__(self): + self.processor = SimpleNamespace( + image_processor=lambda images, **kwargs: { + "num_soft_tokens_per_image": torch.tensor([260], dtype=torch.int32) + } + ) + self.tokenizer = SimpleNamespace(vocab_size=1024) + + def __call__(self, inputs, sampling_params): + del inputs, sampling_params + return [7, 255999, 258880, 258880, 258882, 9], { + "multimodal_data": { + "token_type_ids": torch.tensor([0, 1, 1, 1, 1, 0], dtype=torch.int32) + } + } + + processor = Gemma4ADInputProcessor( + _DummyBaseProcessor(), + image_token_id=258880, + boi_token_id=255999, + eoi_token_id=258882, + ) + + token_ids, extra = processor(inputs={}, sampling_params=None) + + assert token_ids == [7, 255999, 258880, 258880, 258882, 9] + assert processor.get_num_tokens_per_image(image=torch.zeros(3, 8, 8)) == 262 + torch.testing.assert_close( + processor.get_mm_token_ids(), torch.tensor([258880], dtype=torch.int32) + ) + torch.testing.assert_close( + processor.get_mm_special_token_ids(), torch.tensor([255999, 258882], dtype=torch.int32) + ) + + multimodal_input = extra["multimodal_input"] + assert multimodal_input.multimodal_positions == [1] + assert multimodal_input.multimodal_lengths == [4] + + multimodal_data = extra["multimodal_data"] + assert "token_type_ids" not in multimodal_data + torch.testing.assert_close( + multimodal_data["layout_metadata"]["special_token_offsets"], + torch.tensor([0, 3], dtype=torch.int32), + ) + torch.testing.assert_close( + multimodal_data["layout_metadata"]["item_types"], + torch.tensor([0], dtype=torch.int32), + ) + + +def test_vision_pooler_equivalence(): + """Vision pooler: avg pooling + scaling.""" + device, dtype = _device_and_dtype() + config = _small_vision_config() + + ref = _RefVisionPooler(config).to(device=device, dtype=dtype).eval() + ad = Gemma4VisionPooler(config).to(device=device, dtype=dtype).eval() + # Pooler has no learnable params + + B, num_patches = 2, 9 # 9 patches → pool to 1 with kernel_size=3 + hidden = torch.randn(B, num_patches, config.hidden_size, device=device, dtype=dtype) + _, pixel_position_ids = _make_vision_inputs(config, B, num_patches, device, dtype) + padding_positions = torch.zeros(B, num_patches, device=device, dtype=torch.bool) + output_length = num_patches // (config.pooling_kernel_size**2) + + with torch.no_grad(): + ref_h, ref_mask = ref(hidden, pixel_position_ids, padding_positions, output_length) + ad_h, ad_mask = ad(hidden, pixel_position_ids, padding_positions, output_length) + torch.testing.assert_close(ad_h, ref_h, rtol=1e-3, atol=1e-3) + assert (ad_mask == ref_mask).all() + + +def test_vision_attention_equivalence(): + """Vision attention: bidirectional, multidimensional RoPE, scaling=1.0.""" + device, dtype = _device_and_dtype() + config = _small_vision_config() + + ref = _RefVisionAttention(config).to(device=device, dtype=dtype).eval() + ad = Gemma4VisionAttention(config, layer_idx=0).to(device=device, dtype=dtype).eval() + _transfer_vision_attn_weights(ad, ref) + + B, S = 2, 9 + x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + # 2D position ids for vision + grid_side = 3 + pos_x = torch.arange(grid_side, device=device).unsqueeze(1).expand(grid_side, grid_side) + pos_y = torch.arange(grid_side, device=device).unsqueeze(0).expand(grid_side, grid_side) + pos_ids = torch.stack([pos_x.flatten(), pos_y.flatten()], dim=-1).unsqueeze(0).expand(B, -1, -1) + + rope = Gemma4VisionRotaryEmbedding(config).to(device=device) + with torch.no_grad(): + cos, sin = rope(x, pos_ids) + + # Bidirectional mask (all True) + attn_mask = torch.ones(B, 1, S, S, device=device, dtype=torch.bool) + + with torch.no_grad(): + ad_out, _ = ad(x, (cos, sin), attention_mask=attn_mask, position_ids=pos_ids) + ref_out, _ = ref(x, (cos, sin), attention_mask=attn_mask, position_ids=pos_ids) + assert_rmse_close(ad_out, ref_out, rmse_ratio_tol=0.10, msg="Vision attention: ") + + +# --------------------------------------------------------------------------- +# Tests — Vision tower layer equivalence +# --------------------------------------------------------------------------- + + +def test_vision_encoder_layer_equivalence(): + """Vision encoder layer matches reference.""" + device, dtype = _device_and_dtype() + config = _small_vision_config() + + ref = _RefVisionEncoderLayer(config, layer_idx=0).to(device=device, dtype=dtype).eval() + ad = Gemma4VisionEncoderLayer(config, layer_idx=0).to(device=device, dtype=dtype).eval() + _transfer_vision_encoder_layer_weights(ad, ref) + + B, S = 2, 9 + x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype) + grid_side = 3 + pos_x = torch.arange(grid_side, device=device).unsqueeze(1).expand(grid_side, grid_side) + pos_y = torch.arange(grid_side, device=device).unsqueeze(0).expand(grid_side, grid_side) + pos_ids = torch.stack([pos_x.flatten(), pos_y.flatten()], dim=-1).unsqueeze(0).expand(B, -1, -1) + + rope = Gemma4VisionRotaryEmbedding(config).to(device=device) + with torch.no_grad(): + cos, sin = rope(x, pos_ids) + + attn_mask = torch.ones(B, 1, S, S, device=device, dtype=torch.bool) + + with torch.no_grad(): + ad_out = ad(x, (cos, sin), attention_mask=attn_mask, position_ids=pos_ids) + ref_out = ref(x, (cos, sin), attention_mask=attn_mask, position_ids=pos_ids) + assert_rmse_close(ad_out, ref_out, rmse_ratio_tol=0.05, msg="Vision encoder layer: ") + + +# --------------------------------------------------------------------------- +# Tests — Vision tower full model equivalence +# --------------------------------------------------------------------------- + + +def test_vision_encoder_equivalence(): + """Full vision encoder (all layers + RoPE) matches reference.""" + device, dtype = _device_and_dtype() + config = _small_vision_config() + + ref = _RefVisionEncoder(config).to(device=device, dtype=dtype).eval() + ad = Gemma4VisionEncoder(config).to(device=device, dtype=dtype).eval() + _transfer_vision_encoder_weights(ad, ref) + + B, num_patches = 2, 9 + x = torch.randn(B, num_patches, config.hidden_size, device=device, dtype=dtype) + grid_side = 3 + pos_x = torch.arange(grid_side, device=device).unsqueeze(1).expand(grid_side, grid_side) + pos_y = torch.arange(grid_side, device=device).unsqueeze(0).expand(grid_side, grid_side) + pos_ids = torch.stack([pos_x.flatten(), pos_y.flatten()], dim=-1).unsqueeze(0).expand(B, -1, -1) + attn_mask = torch.ones(B, num_patches, device=device, dtype=torch.bool) + + with torch.no_grad(): + ad_out = ad(inputs_embeds=x, attention_mask=attn_mask, pixel_position_ids=pos_ids) + ref_out = ref(inputs_embeds=x, attention_mask=attn_mask, pixel_position_ids=pos_ids) + # ad_out is ModelOutput, ref_out is tensor + ad_hidden = ad_out.last_hidden_state + assert_rmse_close(ad_hidden, ref_out, rmse_ratio_tol=0.05, msg="Vision encoder: ") + + +def test_vision_model_equivalence(): + """Full vision model (embedder + encoder + pooler) matches reference.""" + device, dtype = _device_and_dtype() + config = _small_vision_config() + + ref = _RefVisionModel(config).to(device=device, dtype=dtype).eval() + ad = Gemma4VisionModel(config).to(device=device, dtype=dtype).eval() + _transfer_vision_model_weights(ad, ref) + + B, num_patches = 2, 9 + pixel_values, pixel_position_ids = _make_vision_inputs(config, B, num_patches, device, dtype) + + with torch.no_grad(): + ad_out = ad(pixel_values=pixel_values, pixel_position_ids=pixel_position_ids) + ref_out = ref(pixel_values=pixel_values, pixel_position_ids=pixel_position_ids) + ad_hidden = ad_out.last_hidden_state + assert_rmse_close(ad_hidden, ref_out, rmse_ratio_tol=0.05, msg="Vision model: ") + + +def test_multimodal_embedder_equivalence(): + """Multimodal embedder (norm + projection) matches reference.""" + device, dtype = _device_and_dtype() + vision_config = _small_vision_config() + text_config = _small_text_config() + + ref = _RefMultimodalEmbedder(vision_config, text_config).to(device=device, dtype=dtype).eval() + ad = Gemma4MultimodalEmbedder(vision_config, text_config).to(device=device, dtype=dtype).eval() + _transfer_multimodal_embedder_weights(ad, ref) + + x = torch.randn(2, 4, vision_config.hidden_size, device=device, dtype=dtype) + with torch.no_grad(): + ad_out = ad(x) + ref_out = ref(x) + torch.testing.assert_close(ad_out, ref_out, rtol=1e-3, atol=1e-3) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py new file mode 100644 index 00000000000..4a65cc2de80 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py @@ -0,0 +1,660 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.compile.piecewise_utils import is_dynamic_cached_op +from tensorrt_llm._torch.auto_deploy.custom_ops.attention.flashinfer_attention import ( + FlashInferAttention, +) +from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_backend_attention import ( + TorchBackendAttention, +) +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface +from tensorrt_llm._torch.auto_deploy.transform.interface import SharedConfig, Stages +from tensorrt_llm._torch.auto_deploy.transform.library.kvcache import ( + InsertCachedAttentionConfig, + _InsertCachedOperator, +) + + +class _TinySharedKVModule(torch.nn.Module): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + qkv = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], 2, 4) + regular = torch.ops.auto_deploy.torch_attention( + qkv, + qkv, + qkv, + None, + 0.0, + True, + 1.0, + None, + None, + None, + "bsnd", + 0, + ) + shared = torch.ops.auto_deploy.torch_attention( + qkv, + qkv, + qkv, + None, + 0.0, + True, + 1.0, + None, + None, + None, + "bsnd", + 1, + 0, + ) + return regular + shared + + +class _DuplicateLayerOwnerSharedKVModule(torch.nn.Module): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + qkv = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], 2, 4) + first = torch.ops.auto_deploy.torch_attention( + qkv, qkv, qkv, None, 0.0, True, 1.0, None, None, None, "bsnd", 0 + ) + second = torch.ops.auto_deploy.torch_attention( + qkv, qkv, qkv, None, 0.0, True, 1.0, None, None, None, "bsnd", 0 + ) + return first + second + + +def _context_meta(seq_len: int): + batch_info_host = BatchInfo() + batch_info_host.update([1, seq_len, 0, 0, 0, 0]) + return ( + batch_info_host.serialize(), + torch.tensor([seq_len], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([0], dtype=torch.int64), + torch.tensor([0], dtype=torch.int32), + ) + + +def _decode_meta(input_pos: int): + batch_info_host = BatchInfo() + batch_info_host.update([0, 0, 0, 0, 1, 1]) + return ( + batch_info_host.serialize(), + torch.tensor([1], dtype=torch.int32), + torch.tensor([input_pos], dtype=torch.int32), + torch.tensor([0], dtype=torch.int64), + torch.tensor([0], dtype=torch.int32), + ) + + +def _manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sliding_window: int | None = None, +) -> torch.Tensor: + batch, seq_len_q, num_heads, _ = q.shape + _, seq_len_k, num_kv_heads, _ = k.shape + if num_heads != num_kv_heads: + repeat_factor = num_heads // num_kv_heads + k = k.repeat_interleave(repeat_factor, dim=2) + v = v.repeat_interleave(repeat_factor, dim=2) + + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + scores = torch.matmul(q_t, k_t.transpose(-2, -1)) + causal_mask = torch.triu( + torch.ones(seq_len_q, seq_len_k, dtype=torch.bool, device=scores.device), + diagonal=seq_len_k - seq_len_q + 1, + ) + scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + if sliding_window is not None: + query_positions = torch.arange(seq_len_k - seq_len_q, seq_len_k, device=scores.device) + key_positions = torch.arange(seq_len_k, device=scores.device) + pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) + sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) + scores = scores.masked_fill(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + weights = torch.softmax(scores, dim=-1) + return torch.matmul(weights, v_t).transpose(1, 2) + + +def _make_layer_inputs(offset: float, seq_len: int, decode: bool = False): + base_q = torch.tensor( + [[[1.0, 0.0], [0.0, 1.0]]] + if decode + else [ + [[1.0, 0.0], [0.0, 1.0]], + [[0.5, 0.5], [0.5, -0.5]], + [[0.25, 0.75], [0.75, 0.25]], + ], + dtype=torch.float32, + ) + base_k = torch.tensor( + [[[1.0, 0.0]]] if decode else [[[1.0, 0.0]], [[0.0, 1.0]], [[1.0, 1.0]]], + dtype=torch.float32, + ) + base_v = torch.tensor( + [[[10.0, 1.0]]] if decode else [[[10.0, 1.0]], [[2.0, 20.0]], [[30.0, 3.0]]], + dtype=torch.float32, + ) + q = (base_q + offset).unsqueeze(0) + k = (base_k + offset).unsqueeze(0) + v = (base_v + offset * 10.0).unsqueeze(0) + assert q.shape[1] == seq_len + return q, k, v + + +def test_shared_kv_transform_aliases_source_cache_placeholders(): + module = _TinySharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + + cm = CachedSequenceInterface( + max_seq_len=16, + max_batch_size=2, + max_num_tokens=16, + device="cpu", + ) + transform = _InsertCachedOperator( + InsertCachedAttentionConfig(stage=Stages.CACHE_INIT, backend="torch") + ) + gm, info = transform._apply(gm, cm, factory=None, shared_config=SharedConfig()) + + assert info.num_matches == 2 + + placeholder_names = [node.target for node in gm.graph.nodes if node.op == "placeholder"] + assert placeholder_names.count("r0_k_cache") == 1 + assert placeholder_names.count("r1_v_cache") == 1 + assert "r2_k_cache" not in placeholder_names + assert "r3_v_cache" not in placeholder_names + assert set(cm._resource_lookup).issubset(set(placeholder_names)) + + cached_nodes = [node for node in gm.graph.nodes if node.op == "call_function"] + regular_node = next( + node + for node in cached_nodes + if node.target == torch.ops.auto_deploy.torch_cached_attention_with_cache.default + and node.args[-1] is False + ) + shared_node = next( + node + for node in cached_nodes + if node.target == torch.ops.auto_deploy.torch_cached_attention_with_cache.default + and node.args[-1] is True + ) + + assert regular_node.args[8] is shared_node.args[8] + assert regular_node.args[9] is shared_node.args[9] + assert regular_node.target == torch.ops.auto_deploy.torch_cached_attention_with_cache.default + assert shared_node.target == torch.ops.auto_deploy.torch_cached_attention_with_cache.default + assert regular_node.args[-1] is False + assert shared_node.args[-1] is True + + +def test_shared_kv_cached_attention_reads_without_writing(): + q = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32) + dummy_k = torch.full((1, 1, 2, 2), 123.0, dtype=torch.float32) + dummy_v = torch.full((1, 1, 2, 2), -456.0, dtype=torch.float32) + + k_cache = torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]], [[0.5, 0.0], [0.0, 0.5]], [[0.25, 0.0], [0.0, 0.25]]]], + dtype=torch.float32, + ) + v_cache = torch.tensor( + [[[[10.0, 1.0], [2.0, 20.0]], [[30.0, 3.0], [4.0, 40.0]], [[50.0, 5.0], [6.0, 60.0]]]], + dtype=torch.float32, + ) + k_cache_before = k_cache.clone() + v_cache_before = v_cache.clone() + + batch_info_host = BatchInfo() + batch_info_host.update([0, 0, 0, 0, 1, 1]) + output = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q, + dummy_k, + dummy_v, + batch_info_host.serialize(), + torch.tensor([1], dtype=torch.int32), + torch.tensor([2], dtype=torch.int32), + torch.tensor([0], dtype=torch.int64), + torch.tensor([0], dtype=torch.int32), + k_cache, + v_cache, + 1.0, + None, + None, + None, + True, + ) + + assert torch.equal(k_cache, k_cache_before) + assert torch.equal(v_cache, v_cache_before) + + k_for_attn = k_cache_before[0, :3].transpose(0, 1) + v_for_attn = v_cache_before[0, :3].transpose(0, 1) + logits = torch.matmul(q[0, 0].unsqueeze(1), k_for_attn.transpose(-2, -1)) + weights = torch.softmax(logits, dim=-1) + expected = torch.matmul(weights, v_for_attn).squeeze(1).unsqueeze(0).unsqueeze(0) + torch.testing.assert_close(output, expected, rtol=1e-5, atol=1e-5) + + +def test_torch_backend_attention_metadata_for_shared_kv_node(): + module = _TinySharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + source_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" + and node.target == torch.ops.auto_deploy.torch_attention.default + ] + regular = next( + node + for node in source_nodes + if node.target == torch.ops.auto_deploy.torch_attention.default + ) + shared = next(node for node in source_nodes if TorchBackendAttention.get_layer_idx(node) == 1) + + assert TorchBackendAttention.get_layer_idx(regular) == 0 + assert TorchBackendAttention.get_layer_idx(shared) == 1 + assert TorchBackendAttention.get_shared_kv_source_layer_idx(regular) is None + assert TorchBackendAttention.get_shared_kv_source_layer_idx(shared) == 0 + + +def test_flashinfer_backend_attention_metadata_for_shared_kv_node(): + module = _TinySharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + source_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" + and node.target == torch.ops.auto_deploy.torch_attention.default + ] + regular = next( + node + for node in source_nodes + if node.target == torch.ops.auto_deploy.torch_attention.default + ) + shared = next(node for node in source_nodes if FlashInferAttention.get_layer_idx(node) == 1) + + assert FlashInferAttention.get_layer_idx(regular) == 0 + assert FlashInferAttention.get_layer_idx(shared) == 1 + assert FlashInferAttention.get_shared_kv_source_layer_idx(regular) is None + assert FlashInferAttention.get_shared_kv_source_layer_idx(shared) == 0 + assert FlashInferAttention.get_cached_attention_op() == ( + torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + ) + + +def test_shared_kv_transform_aliases_source_cache_placeholders_for_flashinfer(): + module = _TinySharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + + cm = CachedSequenceInterface( + max_seq_len=16, + max_batch_size=2, + max_num_tokens=16, + device="cpu", + ) + transform = _InsertCachedOperator( + InsertCachedAttentionConfig(stage=Stages.CACHE_INIT, backend="flashinfer") + ) + gm, info = transform._apply(gm, cm, factory=None, shared_config=SharedConfig()) + + assert info.num_matches == 2 + + placeholder_names = [node.target for node in gm.graph.nodes if node.op == "placeholder"] + assert placeholder_names.count("r0_kv_cache") == 1 + assert "r1_kv_cache" not in placeholder_names + assert set(cm._resource_lookup).issubset(set(placeholder_names)) + + cached_nodes = [node for node in gm.graph.nodes if node.op == "call_function"] + regular_node = next( + node + for node in cached_nodes + if node.target == torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + and node.args[-1] is False + ) + shared_node = next( + node + for node in cached_nodes + if node.target == torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + and node.args[-1] is True + ) + + assert regular_node.args[11] is shared_node.args[11] + assert regular_node.target == torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + assert shared_node.target == torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default + assert regular_node.args[-1] is False + assert shared_node.args[-1] is True + + +def test_flashinfer_cached_attention_is_dynamic_for_piecewise(): + shared_op_name = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default.name() + + class _FakeNode: + op = "call_function" + + def __init__(self, target): + self.target = target + + assert "flashinfer_attention_mha_with_cache" in shared_op_name + assert is_dynamic_cached_op( + _FakeNode(torch.ops.auto_deploy.flashinfer_attention_mha_with_cache.default) + ) + + +@torch.no_grad() +def test_torch_shared_kv_cached_attention_supports_out_buffer(): + q = torch.randn(1, 3, 2, 4) + k = torch.randn(1, 3, 1, 4) + v = torch.randn(1, 3, 1, 4) + batch_info_host = BatchInfo() + batch_info_host.update([1, 3, 0, 0, 1, 1]) + seq_len = torch.tensor([3], dtype=torch.int32) + input_pos = torch.tensor([0, 1, 2], dtype=torch.int32) + slot_idx = torch.tensor([0, 1, 2], dtype=torch.int32) + cu_seqlen = torch.tensor([0], dtype=torch.int32) + k_cache = torch.randn(1, 4, 1, 4) + v_cache = torch.randn(1, 4, 1, 4) + + expected = torch.ops.auto_deploy.torch_cached_attention_with_cache.default( + q, + k, + v, + batch_info_host.serialize(), + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + None, + read_cache_only=True, + ) + + out = torch.full_like(expected, float("nan")) + ret = torch.ops.auto_deploy.torch_cached_attention_with_cache.default( + q, + k, + v, + batch_info_host.serialize(), + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + None, + read_cache_only=True, + out=out, + ) + + assert ret.numel() == 0 + torch.testing.assert_close(out, expected) + + +def test_shared_kv_self_alias_raises(): + class _SelfAliasingSharedKVModule(torch.nn.Module): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + qkv = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], 2, 4) + return torch.ops.auto_deploy.torch_attention( + qkv, qkv, qkv, None, 0.0, True, 1.0, None, None, None, "bsnd", 1, 1 + ) + + module = _SelfAliasingSharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + + cm = CachedSequenceInterface( + max_seq_len=16, + max_batch_size=2, + max_num_tokens=16, + device="cpu", + ) + transform = _InsertCachedOperator( + InsertCachedAttentionConfig(stage=Stages.CACHE_INIT, backend="torch") + ) + + with pytest.raises(RuntimeError, match="cannot share its own KV cache"): + transform._apply(gm, cm, factory=None, shared_config=SharedConfig()) + + +def test_duplicate_cache_owner_layer_idx_raises(): + module = _DuplicateLayerOwnerSharedKVModule().eval() + gm = torch_export_to_gm(module, (torch.randn(1, 4, 8),)) + + cm = CachedSequenceInterface( + max_seq_len=16, + max_batch_size=2, + max_num_tokens=16, + device="cpu", + ) + transform = _InsertCachedOperator( + InsertCachedAttentionConfig(stage=Stages.CACHE_INIT, backend="torch") + ) + + with pytest.raises(RuntimeError, match="Duplicate KV cache owner"): + transform._apply(gm, cm, factory=None, shared_config=SharedConfig()) + + +@torch.no_grad() +def test_flashinfer_shared_kv_cached_attention_reads_aliased_cache_without_writing(): + if not torch.cuda.is_available(): + return + + device = torch.device("cuda") + head_dim = 64 + q = torch.zeros((1, 1, 1, head_dim), dtype=torch.float16, device=device) + q[0, 0, 0, 0] = 1.0 + dummy_k = torch.full((1, 1, 1, head_dim), 9.0, dtype=torch.float16, device=device) + dummy_v = torch.full((1, 1, 1, head_dim), 7.0, dtype=torch.float16, device=device) + + owner_k = torch.zeros((1, 3, 1, head_dim), dtype=torch.float16, device=device) + owner_k[0, 0, 0, 0] = 1.0 + owner_k[0, 1, 0, 1] = 1.0 + owner_k[0, 2, 0, 0] = 1.0 + owner_k[0, 2, 0, 1] = 1.0 + owner_v = torch.zeros((1, 3, 1, head_dim), dtype=torch.float16, device=device) + owner_v[0, 0, 0, 0] = 10.0 + owner_v[0, 1, 0, 1] = 20.0 + owner_v[0, 2, 0, 0] = 30.0 + owner_v[0, 2, 0, 1] = 3.0 + kv_cache = torch.zeros((1, 2, 1, 32, head_dim), dtype=torch.float16, device=device) + kv_cache[0, 0, 0, :3, :] = owner_k[0, :, 0, :] + kv_cache[0, 1, 0, :3, :] = owner_v[0, :, 0, :] + kv_cache_before = kv_cache.clone() + + batch_info_host = BatchInfo() + batch_info_host.update([0, 0, 0, 0, 1, 1]) + cu_seqlen_host = torch.tensor([0, 1], dtype=torch.int32, device="cpu") + cu_num_pages = torch.tensor([0, 1], dtype=torch.int32, device=device) + cu_num_pages_host = torch.tensor([0, 1], dtype=torch.int32, device="cpu") + cache_loc = torch.tensor([0], dtype=torch.int32, device=device) + last_page_len = torch.tensor([3], dtype=torch.int32, device=device) + last_page_len_host = torch.tensor([3], dtype=torch.int32, device="cpu") + seq_len_with_cache_host = torch.tensor([3], dtype=torch.int32, device="cpu") + batch_indices = torch.zeros(1, dtype=torch.int32, device=device) + positions = torch.zeros(1, dtype=torch.int32, device=device) + + output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( + q, + dummy_k, + dummy_v, + batch_info_host.serialize(), + cu_seqlen_host, + cu_num_pages, + cu_num_pages_host, + cache_loc, + last_page_len, + last_page_len_host, + seq_len_with_cache_host, + batch_indices, + positions, + kv_cache, + 1.0, + None, + 1.0, + 1.0, + True, + ) + + torch.testing.assert_close(kv_cache, kv_cache_before, rtol=0.0, atol=0.0) + + expected = _manual_attention(q.float(), owner_k.float(), owner_v.float()).to(output.dtype) + torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +def test_shared_kv_six_layer_stack_matches_reference_for_prefill_and_decode(): + layer_sources = {4: 2, 5: 3} + sliding_layers = {2, 4} + prefill_len = 3 + decode_pos = prefill_len + owner_caches = { + layer_idx: ( + torch.zeros(1, 8, 1, 2, dtype=torch.float32), + torch.zeros(1, 8, 1, 2, dtype=torch.float32), + ) + for layer_idx in range(4) + } + owner_history = {} + + for layer_idx in range(6): + q_prefill, k_prefill, v_prefill = _make_layer_inputs( + offset=float(layer_idx), seq_len=prefill_len + ) + batch_info_host, seq_len, input_pos, slot_idx, cu_seqlen = _context_meta(prefill_len) + sliding_window = 2 if layer_idx in sliding_layers else None + + if layer_idx in layer_sources: + source_idx = layer_sources[layer_idx] + k_cache, v_cache = owner_caches[source_idx] + output_prefill = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q_prefill, + k_prefill, + v_prefill, + batch_info_host, + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + 1.0, + None, + sliding_window, + None, + True, + ) + expected_prefill = _manual_attention( + q_prefill, + owner_history[source_idx]["k_prefill"], + owner_history[source_idx]["v_prefill"], + sliding_window=sliding_window, + ) + else: + k_cache, v_cache = owner_caches[layer_idx] + output_prefill = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q_prefill, + k_prefill, + v_prefill, + batch_info_host, + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + 1.0, + None, + sliding_window, + None, + ) + expected_prefill = _manual_attention( + q_prefill, + k_prefill, + v_prefill, + sliding_window=sliding_window, + ) + owner_history[layer_idx] = { + "k_prefill": k_prefill.clone(), + "v_prefill": v_prefill.clone(), + } + torch.testing.assert_close(k_cache[0, :prefill_len], k_prefill[0], rtol=0.0, atol=0.0) + torch.testing.assert_close(v_cache[0, :prefill_len], v_prefill[0], rtol=0.0, atol=0.0) + + torch.testing.assert_close(output_prefill, expected_prefill, rtol=1e-5, atol=1e-5) + + for layer_idx in range(6): + q_decode, k_decode, v_decode = _make_layer_inputs( + offset=100.0 + float(layer_idx), seq_len=1, decode=True + ) + batch_info_host, seq_len, input_pos, slot_idx, cu_seqlen = _decode_meta(decode_pos) + sliding_window = 2 if layer_idx in sliding_layers else None + + if layer_idx in layer_sources: + source_idx = layer_sources[layer_idx] + k_cache, v_cache = owner_caches[source_idx] + k_cache_before = k_cache.clone() + v_cache_before = v_cache.clone() + output_decode = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q_decode, + k_decode, + v_decode, + batch_info_host, + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + 1.0, + None, + sliding_window, + None, + True, + ) + torch.testing.assert_close(k_cache, k_cache_before, rtol=0.0, atol=0.0) + torch.testing.assert_close(v_cache, v_cache_before, rtol=0.0, atol=0.0) + expected_k = owner_history[source_idx]["k_full"] + expected_v = owner_history[source_idx]["v_full"] + else: + k_cache, v_cache = owner_caches[layer_idx] + output_decode = torch.ops.auto_deploy.torch_cached_attention_with_cache( + q_decode, + k_decode, + v_decode, + batch_info_host, + seq_len, + input_pos, + slot_idx, + cu_seqlen, + k_cache, + v_cache, + 1.0, + None, + sliding_window, + None, + ) + expected_k = torch.cat([owner_history[layer_idx]["k_prefill"], k_decode], dim=1) + expected_v = torch.cat([owner_history[layer_idx]["v_prefill"], v_decode], dim=1) + owner_history[layer_idx]["k_full"] = expected_k + owner_history[layer_idx]["v_full"] = expected_v + torch.testing.assert_close( + k_cache[0, : decode_pos + 1], expected_k[0], rtol=0.0, atol=0.0 + ) + torch.testing.assert_close( + v_cache[0, : decode_pos + 1], expected_v[0], rtol=0.0, atol=0.0 + ) + + expected_decode = _manual_attention( + q_decode, + expected_k, + expected_v, + sliding_window=sliding_window, + ) + torch.testing.assert_close(output_decode, expected_decode, rtol=1e-5, atol=1e-5) diff --git a/tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py b/tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py index f8b2fc74f0d..293b56eea7a 100644 --- a/tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py +++ b/tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py @@ -67,6 +67,7 @@ def basic_mha_with_cache(q, k, v, k_cache, v_cache, input_positions, scale=None) seq_start, k_cache, v_cache, + None, scale, ) @@ -108,6 +109,7 @@ def flattened_mha_with_cache( seq_start, k_cache, v_cache, + None, scale, ) @@ -162,6 +164,7 @@ def decode_with_prefilled_cache(q, k_ref, v_ref, k_cache, v_cache, prefill_lengt k_cache, v_cache, None, + None, ) # Return in flattened format to match flashinfer backend behavior [batch, seq=1, n_heads * head_dim] @@ -198,6 +201,7 @@ def mha_with_features( seq_start, k_cache, v_cache, + None, scale, None, # sinks sliding_window_size, diff --git a/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py b/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py index b1c1e604e0f..a6e51aa7794 100644 --- a/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py +++ b/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py @@ -18,6 +18,7 @@ _args_kwargs_flatten_spec, ) from tensorrt_llm._torch.auto_deploy.compile.piecewise_utils import submod_has_cuda_ops +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.shim.ad_executor import _round_up_to_closest from tensorrt_llm._torch.auto_deploy.transform.library.compile_model import ( @@ -285,13 +286,17 @@ def _make_dual_mode(self, piecewise_num_tokens=None): def test_is_decode_only_with_batch_info_host_zero(self): dual = self._make_dual_mode() # num_prefill=0 → decode-only - batch_info = torch.tensor([0, 0, 4]) # [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = BatchInfo() + batch_info_host.update([0, 0, 0, 0, 4, 4]) + batch_info = batch_info_host.serialize() assert dual._is_decode_only(batch_info_host=batch_info) is True def test_is_decode_only_with_batch_info_host_nonzero(self): dual = self._make_dual_mode() # num_prefill=2 → not decode-only - batch_info = torch.tensor([2, 100, 3]) + batch_info_host = BatchInfo() + batch_info_host.update([2, 100, 0, 0, 3, 3]) + batch_info = batch_info_host.serialize() assert dual._is_decode_only(batch_info_host=batch_info) is False def test_is_decode_only_fallback_heuristic_decode(self): diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py index d9b016e9498..8493bd8e201 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py @@ -132,6 +132,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -261,6 +262,7 @@ def test_flashinfer_attention_op_decode( kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -391,6 +393,7 @@ def test_flashinfer_attention_context_and_generate( kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -485,6 +488,7 @@ def test_flashinfer_attention_context_and_generate( kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -618,6 +622,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -776,6 +781,7 @@ def test_flashinfer_attention_with_fp8_cache( kv_cache, # CONSTANTS None, + None, K_SCALE, V_SCALE, ) @@ -883,6 +889,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) @@ -977,6 +984,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de kv_cache, # CONSTANTS None, + None, 1.0, 1.0, ) diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py index 599dc26cfc8..20bf69b5284 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py @@ -10,6 +10,136 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo +@torch.inference_mode() +def test_torch_backend_attention_custom_bool_mask_context(): + device = "cuda" + dtype = torch.float16 + batch_size, seq_len, num_heads, head_dim = 1, 5, 2, 8 + scale = 1.0 / math.sqrt(head_dim) + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k_cache = torch.zeros(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + token_type_ids = torch.tensor([[0, 1, 1, 2, 2]], device=device, dtype=torch.int64) + non_text = token_type_ids != 0 + prev = torch.cat( + [ + torch.zeros(batch_size, 1, device=device, dtype=token_type_ids.dtype), + token_type_ids[:, :-1], + ], + dim=1, + ) + blob_starts = non_text & (token_type_ids != prev) + blob_ids = torch.cumsum(blob_starts.to(torch.int64), dim=1) + token_blob_ids = torch.where(non_text, blob_ids, torch.zeros_like(blob_ids)) + media_mask = (token_blob_ids.unsqueeze(2) == token_blob_ids.unsqueeze(1)) & ( + token_blob_ids.unsqueeze(2) != 0 + ) + positions = torch.arange(seq_len, device=device) + attn_mask = (positions.unsqueeze(0) <= positions.unsqueeze(1)).unsqueeze(0) | media_mask + attn_mask = attn_mask.unsqueeze(1) + + batch_info = BatchInfo() + batch_info.update([batch_size, batch_size * seq_len, 0, 0, 0, 0]) + seq_len_tensor = torch.tensor([seq_len], device=device, dtype=torch.int32) + input_positions = torch.tensor([0], device=device, dtype=torch.int32) + slot_idx = torch.tensor([0], device=device, dtype=torch.int32) + seq_start = torch.tensor([0], device=device, dtype=torch.int32) + + expected = torch.ops.auto_deploy.torch_attention( + q, k, v, attn_mask=attn_mask, is_causal=False, scale=scale, layout="bsnd" + ) + actual = torch.ops.auto_deploy.torch_cached_attention_with_cache.default( + q, + k, + v, + batch_info.serialize(), + seq_len_tensor, + input_positions, + slot_idx, + seq_start, + k_cache, + v_cache, + attn_mask, + scale, + ) + + torch.testing.assert_close(actual, expected, atol=5e-2, rtol=5e-2) + + +@torch.inference_mode() +def test_torch_backend_attention_custom_bool_mask_with_sliding_window_context(): + device = "cuda" + dtype = torch.float16 + batch_size, seq_len, num_heads, head_dim = 1, 6, 2, 8 + scale = 1.0 / math.sqrt(head_dim) + sliding_window_size = 3 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k_cache = torch.zeros(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + token_type_ids = torch.tensor([[0, 1, 1, 1, 2, 2]], device=device, dtype=torch.int64) + non_text = token_type_ids != 0 + prev = torch.cat( + [ + torch.zeros(batch_size, 1, device=device, dtype=token_type_ids.dtype), + token_type_ids[:, :-1], + ], + dim=1, + ) + blob_starts = non_text & (token_type_ids != prev) + blob_ids = torch.cumsum(blob_starts.to(torch.int64), dim=1) + token_blob_ids = torch.where(non_text, blob_ids, torch.zeros_like(blob_ids)) + media_mask = (token_blob_ids.unsqueeze(2) == token_blob_ids.unsqueeze(1)) & ( + token_blob_ids.unsqueeze(2) != 0 + ) + positions = torch.arange(seq_len, device=device) + attn_mask = (positions.unsqueeze(0) <= positions.unsqueeze(1)).unsqueeze(0) | media_mask + attn_mask = attn_mask.unsqueeze(1) + + batch_info = BatchInfo() + batch_info.update([batch_size, batch_size * seq_len, 0, 0, 0, 0]) + seq_len_tensor = torch.tensor([seq_len], device=device, dtype=torch.int32) + input_positions = torch.tensor([0], device=device, dtype=torch.int32) + slot_idx = torch.tensor([0], device=device, dtype=torch.int32) + seq_start = torch.tensor([0], device=device, dtype=torch.int32) + + expected = torch.ops.auto_deploy.torch_attention( + q, + k, + v, + attn_mask=attn_mask, + is_causal=False, + scale=scale, + sliding_window=sliding_window_size, + layout="bsnd", + ) + actual = torch.ops.auto_deploy.torch_cached_attention_with_cache.default( + q, + k, + v, + batch_info.serialize(), + seq_len_tensor, + input_positions, + slot_idx, + seq_start, + k_cache, + v_cache, + attn_mask, + scale, + None, + sliding_window_size, + ) + + torch.testing.assert_close(actual, expected, atol=5e-2, rtol=5e-2) + + def numpy_attention_reference( q, k, @@ -301,7 +431,9 @@ def _run_attention( scale, sinks, sliding_window_size, - logit_cap, # Updated parameter order + logit_cap, + # DYNAMIC INPUTS + custom_attn_mask=None, ) def test_basic_functionality(self): @@ -321,6 +453,42 @@ def test_basic_functionality(self): # Verify output is not NaN or Inf assert torch.isfinite(output).all(), "Output contains NaN or Inf values" + def test_accepts_positional_constants_with_keyword_custom_mask(self): + """Regression test for transform-emitted call style.""" + batch_size, seq_len, n_heads, n_kv_heads, d_head, max_seq_len = 2, 4, 8, 4, 32, 128 + data = self._create_test_data(batch_size, seq_len, n_heads, n_kv_heads, d_head, max_seq_len) + custom_attn_mask = torch.ones( + batch_size, + 1, + seq_len, + seq_len, + dtype=torch.bool, + device=self.device, + ) + + output = torch.ops.auto_deploy.torch_cached_attention_with_cache( + data["q"], + data["k"], + data["v"], + data["batch_info_host"], + data["seq_len"], + data["input_pos"], + data["cache_loc"], + data["seq_start"], + data["k_cache"], + data["v_cache"], + None, + None, + None, + None, + False, + custom_attn_mask=custom_attn_mask, + ) + + expected_shape = data["q"].shape[:2] + (n_heads * d_head,) + assert output.shape == expected_shape + assert torch.isfinite(output).all(), "Output contains NaN or Inf values" + @pytest.mark.parametrize("logit_cap", [None, 5.0]) @pytest.mark.parametrize("sliding_window_size", [None, 3]) @pytest.mark.parametrize("sinks", [None, 1.0]) diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py index 5a14303e14a..6b2eb8aea25 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py @@ -23,6 +23,8 @@ import pytest import torch +import tensorrt_llm._torch.auto_deploy # noqa: F401 + # Skip all tests if CUDA is not available pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -264,6 +266,157 @@ def test_context_kernel_vs_pytorch_reference( torch.testing.assert_close(output.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + def test_context_with_custom_bool_mask_matches_torch_attention(self): + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_context_with_custom_mask, + update_paged_kv_cache, + ) + + batch_size, seq_len = 1, 5 + n_heads, n_kv_heads, head_dim = 8, 8, 64 + page_size = 16 + num_pages_per_seq = 1 + num_blocks = 4 + total_tokens = batch_size * seq_len + sm_scale = 1.0 / math.sqrt(head_dim) + + q = torch.randn(total_tokens, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + + qo_indptr = torch.tensor([0, seq_len], dtype=torch.int32, device="cuda") + kv_indptr = torch.tensor([0, num_pages_per_seq], dtype=torch.int32, device="cuda") + kv_indices = torch.tensor([0], dtype=torch.int32, device="cuda") + seq_len_with_cache = torch.tensor([seq_len], dtype=torch.int32, device="cuda") + + batch_indices = torch.zeros(total_tokens, dtype=torch.int32, device="cuda") + positions = torch.arange(seq_len, dtype=torch.int32, device="cuda") + + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache(k, v, batch_indices, positions, kv_cache, kv_indices, kv_indptr) + + token_type_ids = torch.tensor([[0, 1, 1, 2, 2]], device="cuda", dtype=torch.int64) + non_text = token_type_ids != 0 + prev = torch.cat( + [ + torch.zeros(batch_size, 1, device="cuda", dtype=token_type_ids.dtype), + token_type_ids[:, :-1], + ], + dim=1, + ) + blob_starts = non_text & (token_type_ids != prev) + blob_ids = torch.cumsum(blob_starts.to(torch.int64), dim=1) + token_blob_ids = torch.where(non_text, blob_ids, torch.zeros_like(blob_ids)) + media_mask = (token_blob_ids.unsqueeze(2) == token_blob_ids.unsqueeze(1)) & ( + token_blob_ids.unsqueeze(2) != 0 + ) + pos = torch.arange(seq_len, device="cuda") + custom_attn_mask = ( + (pos.unsqueeze(0) <= pos.unsqueeze(1)).unsqueeze(0) | media_mask + ).unsqueeze(1) + + output = triton_paged_context_with_custom_mask( + q, + kv_cache, + qo_indptr, + kv_indptr, + kv_indices, + seq_len_with_cache, + custom_attn_mask, + sm_scale, + ) + + output_ref = torch.ops.auto_deploy.torch_attention( + q.view(batch_size, seq_len, n_heads, head_dim), + k.view(batch_size, seq_len, n_kv_heads, head_dim), + v.view(batch_size, seq_len, n_kv_heads, head_dim), + attn_mask=custom_attn_mask, + is_causal=False, + scale=sm_scale, + layout="bsnd", + ).reshape(total_tokens, n_heads, head_dim) + + torch.testing.assert_close(output.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + + def test_context_with_custom_bool_mask_and_sliding_window_matches_torch_attention(self): + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_context_with_custom_mask, + update_paged_kv_cache, + ) + + batch_size, seq_len = 1, 18 + n_heads, n_kv_heads, head_dim = 8, 8, 64 + page_size = 16 + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = 4 + total_tokens = batch_size * seq_len + sm_scale = 1.0 / math.sqrt(head_dim) + sliding_window = 3 + + q = torch.randn(total_tokens, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + + qo_indptr = torch.tensor([0, seq_len], dtype=torch.int32, device="cuda") + kv_indptr = torch.tensor([0, num_pages_per_seq], dtype=torch.int32, device="cuda") + kv_indices = torch.arange(num_pages_per_seq, dtype=torch.int32, device="cuda") + seq_len_with_cache = torch.tensor([seq_len], dtype=torch.int32, device="cuda") + + batch_indices = torch.zeros(total_tokens, dtype=torch.int32, device="cuda") + positions = torch.arange(seq_len, dtype=torch.int32, device="cuda") + + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache(k, v, batch_indices, positions, kv_cache, kv_indices, kv_indptr) + + token_type_ids = torch.tensor( + [[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2]], + device="cuda", + dtype=torch.int64, + ) + non_text = token_type_ids != 0 + prev = torch.cat( + [ + torch.zeros(batch_size, 1, device="cuda", dtype=token_type_ids.dtype), + token_type_ids[:, :-1], + ], + dim=1, + ) + blob_starts = non_text & (token_type_ids != prev) + blob_ids = torch.cumsum(blob_starts.to(torch.int64), dim=1) + token_blob_ids = torch.where(non_text, blob_ids, torch.zeros_like(blob_ids)) + media_mask = (token_blob_ids.unsqueeze(2) == token_blob_ids.unsqueeze(1)) & ( + token_blob_ids.unsqueeze(2) != 0 + ) + pos = torch.arange(seq_len, device="cuda") + custom_attn_mask = ( + (pos.unsqueeze(0) <= pos.unsqueeze(1)).unsqueeze(0) | media_mask + ).unsqueeze(1) + + output = triton_paged_context_with_custom_mask( + q, + kv_cache, + qo_indptr, + kv_indptr, + kv_indices, + seq_len_with_cache, + custom_attn_mask, + sm_scale, + sliding_window=sliding_window, + ) + + output_ref = torch.ops.auto_deploy.torch_attention( + q.view(batch_size, seq_len, n_heads, head_dim), + k.view(batch_size, seq_len, n_kv_heads, head_dim), + v.view(batch_size, seq_len, n_kv_heads, head_dim), + attn_mask=custom_attn_mask, + is_causal=False, + scale=sm_scale, + sliding_window=sliding_window, + layout="bsnd", + ).reshape(total_tokens, n_heads, head_dim) + + torch.testing.assert_close(output.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + class TestCacheUpdate: """Tests for the KV cache update kernel.""" @@ -416,6 +569,7 @@ def test_batch_info_12_element_format(self): batch_indices, positions, kv_cache, + custom_attn_mask=None, scale=None, ) @@ -442,6 +596,73 @@ def test_batch_info_with_extend_requests(self): assert num_prefill_tokens == 96 # 32 + 64 assert num_decode == 3 + def test_accepts_positional_constants_with_keyword_custom_mask(self): + """Regression test for transform-emitted call style. + + The KV-cache transform passes scale/sliding_window positionally and + custom_attn_mask by keyword. This should bind correctly. + """ + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + prepare_triton_paged_metadata, + triton_paged_mha_with_cache, + ) + + n_heads, n_kv_heads, head_dim, page_size = 8, 2, 128, 16 + seq_len = 8 + num_pages = (seq_len + page_size - 1) // page_size + num_blocks = num_pages + 4 + + q = torch.randn(1, seq_len, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(1, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(1, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + + batch_info_host = self._make_batch_info( + num_prefill=1, num_prefill_tokens=seq_len, num_decode=0 + ) + cu_seqlen_host = torch.tensor([0, seq_len], dtype=torch.int32) + cu_num_pages = torch.tensor([0, num_pages], dtype=torch.int32, device="cuda") + cu_num_pages_host = cu_num_pages.cpu() + cache_loc = torch.arange(num_pages, dtype=torch.int32, device="cuda") + last_page_len = torch.tensor([seq_len], dtype=torch.int32, device="cuda") + last_page_len_host = last_page_len.cpu() + seq_len_with_cache_host = torch.tensor([seq_len], dtype=torch.int32) + kv_cache = torch.zeros( + num_blocks, 2, n_kv_heads, page_size, head_dim, dtype=torch.float16, device="cuda" + ) + + position_ids = torch.arange(seq_len, device="cuda") + batch_indices, positions = prepare_triton_paged_metadata( + position_ids, + batch_info_host, + cu_seqlen_host.to("cuda", non_blocking=True), + seq_len_with_cache_host.to("cuda", non_blocking=True), + ) + custom_attn_mask = torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device="cuda") + + output = triton_paged_mha_with_cache( + q, + k, + v, + batch_info_host, + cu_seqlen_host, + cu_num_pages, + cu_num_pages_host, + cache_loc, + last_page_len, + last_page_len_host, + seq_len_with_cache_host, + batch_indices, + positions, + kv_cache, + 1.0, + page_size, + custom_attn_mask=custom_attn_mask, + ) + + assert output.shape == q.shape + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + def test_prepare_metadata_with_12_element_batch_info(self): """Test prepare_triton_paged_metadata with 12-element batch_info_host.""" from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( @@ -464,6 +685,309 @@ def test_prepare_metadata_with_12_element_batch_info(self): assert (positions == torch.arange(7, device="cuda")).all() +class TestSlidingWindow: + """Tests for sliding window attention support in Triton paged kernels.""" + + @staticmethod + def _sliding_window_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sm_scale: float, + sliding_window: int, + ) -> torch.Tensor: + """Compute causal + sliding window attention with manual masking. + + Args: + q: [B, n_heads, S_q, head_dim] + k: [B, n_heads, S_k, head_dim] + v: [B, n_heads, S_k, head_dim] + + Returns: + [B, n_heads, S_q, head_dim] + """ + s_q = q.shape[2] + s_k = k.shape[2] + + attn = torch.matmul(q, k.transpose(-2, -1)) * sm_scale + + q_pos = torch.arange(s_k - s_q + s_q, device=q.device) # absolute positions + # For prefill: q_pos = [0..s_q-1], k_pos = [0..s_k-1] + q_pos = torch.arange(s_k - s_q, s_k, device=q.device) # [s_q] + k_pos = torch.arange(s_k, device=q.device) # [s_k] + + pos_diff = q_pos.unsqueeze(1) - k_pos.unsqueeze(0) # [s_q, s_k] + causal_mask = pos_diff < 0 + window_mask = pos_diff >= sliding_window + combined = causal_mask | window_mask + attn.masked_fill_(combined.unsqueeze(0).unsqueeze(0), float("-inf")) + + attn = torch.softmax(attn, dim=-1) + return torch.matmul(attn, v) + + @pytest.mark.parametrize("batch_size", [1, 4]) + @pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (32, 8)]) + @pytest.mark.parametrize("head_dim", [64, 128]) + @pytest.mark.parametrize("seq_len", [128, 256, 512]) + @pytest.mark.parametrize("sliding_window", [32, 64]) + def test_decode_sliding_window( + self, + batch_size: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + seq_len: int, + sliding_window: int, + ): + """Test decode with sliding window against reference (seq_len > window).""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_decode, + update_paged_kv_cache, + ) + + assert seq_len > sliding_window, "Test requires seq_len > sliding_window" + page_size = 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 5 + + q = torch.randn(batch_size, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + v = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + + k_flat = k.reshape(batch_size * seq_len, n_kv_heads, head_dim) + v_flat = v.reshape(batch_size * seq_len, n_kv_heads, head_dim) + + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache( + k_flat, v_flat, batch_indices, positions, kv_cache, kv_indices, kv_indptr + ) + + sm_scale = 1.0 / math.sqrt(head_dim) + + output_triton = triton_paged_decode( + q, + kv_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + sm_scale, + sliding_window=sliding_window, + ) + + # Reference: only attend to last `sliding_window` tokens + head_ratio = n_heads // n_kv_heads + k_ref = k[:, -sliding_window:, :, :].transpose(1, 2) + v_ref = v[:, -sliding_window:, :, :].transpose(1, 2) + if head_ratio > 1: + k_ref = k_ref.repeat_interleave(head_ratio, dim=1) + v_ref = v_ref.repeat_interleave(head_ratio, dim=1) + + q_ref = q.unsqueeze(2) # [B, n_heads, 1, head_dim] + output_ref = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, scale=sm_scale, is_causal=False + ).squeeze(2) + + torch.testing.assert_close(output_triton.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("batch_size", [1, 2]) + @pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (32, 8)]) + @pytest.mark.parametrize("head_dim", [64, 128]) + @pytest.mark.parametrize("seq_len", [128, 256]) + @pytest.mark.parametrize("sliding_window", [32, 64]) + def test_context_sliding_window( + self, + batch_size: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + seq_len: int, + sliding_window: int, + ): + """Test prefill with sliding window against manual reference (seq_len > window).""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_context, + update_paged_kv_cache, + ) + + assert seq_len > sliding_window, "Test requires seq_len > sliding_window" + page_size = 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 5 + total_tokens = batch_size * seq_len + + q = torch.randn(total_tokens, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + + qo_indptr = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device="cuda" + )[: batch_size + 1] + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + seq_len_with_cache = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache(k, v, batch_indices, positions, kv_cache, kv_indices, kv_indptr) + + sm_scale = 1.0 / math.sqrt(head_dim) + + output = triton_paged_context( + q, + kv_cache, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + seq_len_with_cache, + sm_scale, + sliding_window=sliding_window, + ) + + # Reference: manual causal + sliding window attention + head_ratio = n_heads // n_kv_heads + q_ref = q.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2) + k_ref = k.view(batch_size, seq_len, n_kv_heads, head_dim).transpose(1, 2) + v_ref = v.view(batch_size, seq_len, n_kv_heads, head_dim).transpose(1, 2) + if head_ratio > 1: + k_ref = k_ref.repeat_interleave(head_ratio, dim=1) + v_ref = v_ref.repeat_interleave(head_ratio, dim=1) + + output_ref = self._sliding_window_reference(q_ref, k_ref, v_ref, sm_scale, sliding_window) + output_ref = output_ref.transpose(1, 2).reshape(total_tokens, n_heads, head_dim) + + torch.testing.assert_close(output.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + + def test_no_sliding_window_unchanged(self): + """Verify that sliding_window=None produces the same output as before.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_decode, + update_paged_kv_cache, + ) + + batch_size, n_heads, n_kv_heads, head_dim = 2, 8, 8, 64 + seq_len, page_size = 128, 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 5 + + q = torch.randn(batch_size, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + v = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + + k_flat = k.reshape(batch_size * seq_len, n_kv_heads, head_dim) + v_flat = v.reshape(batch_size * seq_len, n_kv_heads, head_dim) + + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache( + k_flat, v_flat, batch_indices, positions, kv_cache, kv_indices, kv_indptr + ) + + sm_scale = 1.0 / math.sqrt(head_dim) + + out_none = triton_paged_decode( + q, + kv_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + sm_scale, + sliding_window=None, + ) + out_zero = triton_paged_decode( + q, + kv_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + sm_scale, + sliding_window=0, + ) + + torch.testing.assert_close(out_none, out_zero) + + class TestFlashInferComparison: """Tests comparing Triton implementation against FlashInfer.""" diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py index e7310944951..a3f40d08105 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py @@ -26,7 +26,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer -from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op class SimpleLMHeadModel(torch.nn.Module): @@ -45,6 +45,24 @@ def forward(self, hidden_states, logit_gather_ids=None, seq_len=None): return logits +class SoftcapLMHeadModel(torch.nn.Module): + """Model with LM head followed by softcapping (like Gemma4).""" + + def __init__(self, hidden_size: int = 128, vocab_size: int = 1000, softcap: float = 30.0): + super().__init__() + self.linear1 = torch.nn.Linear(hidden_size, hidden_size, device="cuda", dtype=torch.float16) + self.lm_head = torch.nn.Linear(hidden_size, vocab_size, device="cuda", dtype=torch.float16) + self.softcap = softcap + + def forward(self, hidden_states, logit_gather_ids=None, seq_len=None): + hidden_states = self.linear1(hidden_states) + logits = self.lm_head(hidden_states) + logits = logits / self.softcap + logits = torch.tanh(logits) + logits = logits * self.softcap + return logits + + class TestGatherTokensOp: """Test the custom op directly.""" @@ -348,3 +366,75 @@ def test_transform_skips_when_disabled(self): assert not self._check_gather_op_in_graph(gm_transformed), ( "Gather op should not be in graph" ) + + def test_transform_with_softcapping(self): + """Test that gather is placed BEFORE lm_head when softcapping follows it. + + Models like Gemma4 apply softcapping (div, tanh, mul) after the lm_head. + The transform must walk backward through these ops to find the actual + linear and insert gather before it, not after the softcapping chain. + Otherwise the lm_head still runs on all tokens (no compute reduction) + and piecewise CUDA graph capture OOMs on the [num_tokens, vocab_size] + intermediate. + """ + hidden_size = 128 + vocab_size = 1000 + batch_size = 4 + max_batch_size = 8 + model = SoftcapLMHeadModel(hidden_size, vocab_size).cuda() + + hidden_states = torch.randn(batch_size, 1, hidden_size, device="cuda", dtype=torch.float16) + logit_gather_ids = torch.zeros(max_batch_size, dtype=torch.long, device="cuda") + seq_len = torch.ones(batch_size, dtype=torch.long, device="cuda") + + gm = torch_export_to_gm( + model, + args=(hidden_states, logit_gather_ids, seq_len), + dynamic_shapes=None, + clone=True, + ) + + # Apply transform + cm = self._create_cached_sequence_interface(max_batch_size) + transform_config = { + "gather_logits_before_lm_head": { + "stage": "post_load_fusion", + "max_batch_size": max_batch_size, + } + } + optimizer = InferenceOptimizer(None, transform_config) + gm_transformed = optimizer(cm, gm) + + assert self._check_gather_op_in_graph(gm_transformed), "Gather op not found in graph" + + # Verify gather_tokens comes BEFORE the lm_head linear, not after softcapping. + # Walk the graph and record the order of gather_tokens vs aten.linear ops. + gather_idx = None + linear_indices = [] + for i, node in enumerate(gm_transformed.graph.nodes): + if is_op(node, torch.ops.auto_deploy.gather_tokens): + gather_idx = i + if is_linear_op(node): + linear_indices.append(i) + + assert gather_idx is not None, "gather_tokens not found" + # The lm_head linear is the last linear in the graph + lm_head_linear_idx = linear_indices[-1] + assert gather_idx < lm_head_linear_idx, ( + f"gather_tokens (idx={gather_idx}) should come before " + f"lm_head linear (idx={lm_head_linear_idx})" + ) + + # Verify forward pass correctness + token_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda") + batch_info = BatchInfo() + batch_info.update_tokens_gather_info(batch_size, False) + batch_info_host = batch_info.serialize() + output = gm_transformed( + hidden_states, + logit_gather_ids, + seq_len, + token_gather_indices=token_gather_indices, + batch_info_host=batch_info_host, + ) + assert output.shape == (batch_size, 1, vocab_size) diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py new file mode 100644 index 00000000000..c3b8ca44274 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for custom attention mask injection into torch_attention nodes.""" + +from types import SimpleNamespace + +import torch + +import tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_attention # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.attention_mask_provider import ( + AttentionMaskProviderRegistry, +) +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + + +class DualAttentionModel(torch.nn.Module): + """Minimal model with two torch_attention calls sharing one custom mask.""" + + def __init__(self, hidden_size: int = 8, num_heads: int = 2): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + self.q_proj_1 = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj_1 = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj_1 = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.q_proj_2 = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj_2 = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj_2 = torch.nn.Linear(hidden_size, hidden_size, bias=False) + + def _embed(self, input_ids: torch.Tensor) -> torch.Tensor: + return input_ids.unsqueeze(-1).expand(-1, -1, self.hidden_size).to(torch.float32) + + def _run_attention(self, x: torch.Tensor, attn_mask: torch.Tensor | None) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + + def _project(q_proj, k_proj, v_proj): + q = q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) + k = k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) + v = v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) + return torch.ops.auto_deploy.torch_attention( + q, k, v, attn_mask=attn_mask, is_causal=False, layout="bsnd" + ) + + return _project(self.q_proj_1, self.k_proj_1, self.v_proj_1) + _project( + self.q_proj_2, self.k_proj_2, self.v_proj_2 + ) + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + del position_ids + return self._run_attention(self._embed(input_ids), attn_mask=None) + + def forward_with_mask(self, input_ids: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + return self._run_attention(self._embed(input_ids), attn_mask=attn_mask) + + +class DummyFactory: + def _get_model_config(self): + return SimpleNamespace(model_type="unit_test_mask_model"), {} + + +class Gemma4Factory: + def _get_model_config(self): + return SimpleNamespace(model_type="gemma4"), {} + + +def _build_segment_mask(segment_ids: torch.Tensor) -> torch.Tensor: + same_segment = segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1) + return same_segment.unsqueeze(1) + + +def _build_token_type_mask(token_type_ids: torch.Tensor) -> torch.Tensor: + non_text = token_type_ids != 0 + prev = torch.cat( + [ + torch.zeros(token_type_ids.shape[0], 1, dtype=token_type_ids.dtype), + token_type_ids[:, :-1], + ], + dim=1, + ) + blob_starts = non_text & (token_type_ids != prev) + blob_ids = torch.cumsum(blob_starts.to(torch.int64), dim=1) + token_blob_ids = torch.where(non_text, blob_ids, torch.zeros_like(blob_ids)) + media_mask = (token_blob_ids.unsqueeze(2) == token_blob_ids.unsqueeze(1)) & ( + token_blob_ids.unsqueeze(2) != 0 + ) + positions = torch.arange(token_type_ids.shape[1]) + causal_mask = positions.unsqueeze(0) <= positions.unsqueeze(1) + return (causal_mask.unsqueeze(0) | media_mask).unsqueeze(1) + + +_provider_build_counts = {"mask": 0} + + +@AttentionMaskProviderRegistry.register("unit_test_mask_model", "torch_attention") +def _segment_mask_provider(ctx, source_attn_node): + del source_attn_node + + def _builder(): + _provider_build_counts["mask"] += 1 + segment_ids = ctx.add_or_retrieve_input( + "segment_ids", + activate_arg=False, + val=torch.zeros(2, 4, dtype=torch.int64), + ) + seg_q = ctx.gm.graph.call_function(torch.ops.aten.unsqueeze.default, args=(segment_ids, 2)) + seg_k = ctx.gm.graph.call_function(torch.ops.aten.unsqueeze.default, args=(segment_ids, 1)) + same_segment = ctx.gm.graph.call_function(torch.ops.aten.eq.Tensor, args=(seg_q, seg_k)) + return ctx.gm.graph.call_function(torch.ops.aten.unsqueeze.default, args=(same_segment, 1)) + + return ctx.get_or_create_cached_node("segment_ids_mask", _builder) + + +@torch.inference_mode() +def test_inject_custom_attention_mask(): + model = DualAttentionModel().eval() + input_ids = torch.tensor([[1, 2, 3, 4], [4, 3, 2, 1]], dtype=torch.int64) + position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).repeat(input_ids.shape[0], 1) + segment_ids = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 2]], dtype=torch.int64) + + gm = torch_export_to_gm(model, args=(input_ids, position_ids), clone=True) + + _provider_build_counts["mask"] = 0 + gm_transformed = InferenceOptimizer( + DummyFactory(), + { + "inject_custom_attention_mask": { + "stage": "pattern_matcher", + "backend": "torch_attention", + }, + }, + )(None, gm) + + attn_nodes = [ + node + for node in gm_transformed.graph.nodes + if is_op(node, torch.ops.auto_deploy.torch_attention) + ] + assert len(attn_nodes) == 2 + assert _provider_build_counts["mask"] == 1 + + mask_nodes = [ + node.args[3] if len(node.args) > 3 else node.kwargs["attn_mask"] for node in attn_nodes + ] + assert mask_nodes[0] is mask_nodes[1] + + placeholder_targets = { + node.target for node in gm_transformed.graph.nodes if node.op == "placeholder" + } + assert "segment_ids" in placeholder_targets + + expected_mask = _build_segment_mask(segment_ids) + expected = model.forward_with_mask(input_ids, expected_mask) + actual = gm_transformed(input_ids, position_ids, segment_ids=segment_ids) + + torch.testing.assert_close(actual, expected) + + +@torch.inference_mode() +def test_inject_gemma4_custom_attention_mask_for_torch_backend(): + model = DualAttentionModel().eval() + input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.int64) + position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).repeat(input_ids.shape[0], 1) + token_type_ids = torch.tensor([[0, 1, 1, 2, 2]], dtype=torch.int64) + + gm = torch_export_to_gm(model, args=(input_ids, position_ids), clone=True) + gm_transformed = InferenceOptimizer( + Gemma4Factory(), + { + "inject_custom_attention_mask": { + "stage": "pattern_matcher", + "backend": "torch_attention", + }, + }, + )(None, gm) + + attn_nodes = [ + node + for node in gm_transformed.graph.nodes + if is_op(node, torch.ops.auto_deploy.torch_attention) + ] + assert len(attn_nodes) == 2 + assert all( + (node.args[3] if len(node.args) > 3 else node.kwargs["attn_mask"]) is not None + for node in attn_nodes + ) + + # With outside-graph approach, the graph receives the finished mask, not token_type_ids. + custom_attn_mask = _build_token_type_mask(token_type_ids) + expected = model.forward_with_mask(input_ids, custom_attn_mask) + actual = gm_transformed(input_ids, position_ids, custom_attn_mask=custom_attn_mask) + torch.testing.assert_close(actual, expected) + + # Verify None mask produces standard causal output + actual_none = gm_transformed(input_ids, position_ids, custom_attn_mask=None) + expected_none = model(input_ids, position_ids) + torch.testing.assert_close(actual_none, expected_none) + + +@torch.inference_mode() +def test_inject_gemma4_custom_attention_mask_for_triton_paged_backend(): + model = DualAttentionModel().eval() + input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.int64) + position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).repeat(input_ids.shape[0], 1) + token_type_ids = torch.tensor([[0, 1, 1, 2, 2]], dtype=torch.int64) + + gm = torch_export_to_gm(model, args=(input_ids, position_ids), clone=True) + gm_transformed = InferenceOptimizer( + Gemma4Factory(), + { + "inject_custom_attention_mask": { + "stage": "pattern_matcher", + "backend": "triton_paged", + }, + }, + )(None, gm) + + attn_nodes = [ + node + for node in gm_transformed.graph.nodes + if is_op(node, torch.ops.auto_deploy.torch_attention) + ] + assert len(attn_nodes) == 2 + assert all( + (node.args[3] if len(node.args) > 3 else node.kwargs["attn_mask"]) is not None + for node in attn_nodes + ) + + # With outside-graph approach, the graph receives the finished mask, not token_type_ids. + custom_attn_mask = _build_token_type_mask(token_type_ids) + expected = model.forward_with_mask(input_ids, custom_attn_mask) + actual = gm_transformed(input_ids, position_ids, custom_attn_mask=custom_attn_mask) + torch.testing.assert_close(actual, expected)