Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/models/supported-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ The following is a table of supported models for the PyTorch backend:
| `Exaone4ForCausalLM` | EXAONE 4.0 | `LGAI-EXAONE/EXAONE-4.0-32B` |
| `ExaoneMoEForCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B` |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` |
| `Gemma3nForConditionalGeneration` [^8]| Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it` |
| `Gemma4ForConditionalGeneration` [^7]| Gemma 4 | `google/gemma-4-26B-A4B-it` |
| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6, GLM-4.7 | `THUDM/GLM-4-100B-A10B` |
| `Glm4MoeLiteForCausalLM` [^6] | GLM-4.7-Flash | `zai-org/GLM-4.7-Flash` |
| `GlmMoeDsaForCausalLM` | GLM-5 | `zai-org/GLM-5` |
Expand Down Expand Up @@ -60,6 +62,8 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
[^4]: Overlap scheduler isn't supported when using EAGLE-3(Two Model Engine) for GPT-OSS.
[^5]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml).
[^6]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml).
[^7]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma4_moe.yaml).
[^8]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml).


# Multimodal Feature Support Matrix (PyTorch Backend)
Expand Down
299 changes: 299 additions & 0 deletions examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deploying Gemma 4 MoE with TensorRT-LLM (AutoDeploy)\n",
"\n",
"This notebook walks you through serving **Gemma 4** (26B total, 4B activated MoE) with TensorRT-LLM using the **AutoDeploy** backend—same pattern as the Mistral Small 4 and GLM-4.7-Flash cookbooks in this folder.\n",
"\n",
"[TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/) is NVIDIA's open-source library for accelerating LLM inference on NVIDIA GPUs. AutoDeploy uses Hugging Face `transformers` modeling code and TensorRT-LLM graph transforms. See the [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).\n",
"\n",
"**Model resources:**\n",
"- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n",
"- Instruction-tuned MoE: [`google/gemma-4-26B-A4B-it`](https://huggingface.co/google/gemma-4-26B-A4B-it)\n",
"- Base MoE (no chat template): [`google/gemma-4-26B-A4B`](https://huggingface.co/google/gemma-4-26B-A4B)\n",
"\n",
"**Bundled AutoDeploy YAML (this branch):**\n",
"- **Instruction:** `examples/auto_deploy/model_registry/configs/gemma4_moe.yaml` — text-only export path; `attn_backend: triton_paged` (head_dim 512 / paged KV, CUDA-graph friendly).\n",
"- **Base:** `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` — same stack for the base checkpoint.\n",
"\n",
"`trtllm-serve` takes **one** YAML path via `--extra_llm_api_options` (or `--config`). The bundled MoE YAMLs omit `world_size`; add it (or copy the YAML and edit) so it matches your GPU count when you use tensor parallel or multi-GPU loading.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisites and environment\n",
"\n",
"Run TensorRT-LLM in a GPU container, for example:\n",
"\n",
"```shell\n",
"docker run --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all -p 8000:8000 nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc1\n",
"```\n",
"\n",
"Use a TensorRT-LLM checkout that includes Gemma 4 AutoDeploy support (model card, tokenizer, and any required bridges should match your branch).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If pip is not available\n",
"!python -m ensurepip --default-pip"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install torch openai"
]
Comment on lines +56 to +57
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't reinstall torch inside the TRT-LLM container.

The release image already ships a Torch build that matches its CUDA/TensorRT stack. %pip install torch here can replace it with an incompatible wheel and break the rest of the notebook; only openai should be installed on top, or this should be guarded by an import check.

Suggested change
-%pip install torch openai
+%pip install openai
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"%pip install torch openai"
]
"%pip install openai"
]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb` around lines 56
- 57, The notebook cell that runs "%pip install torch openai" should not
reinstall torch inside the TRT-LLM container because the release image already
provides a CUDA/TensorRT-compatible Torch; remove "torch" from the install
command or replace the cell with a guarded install that only pip-installs openai
(or first tries import torch and only installs if missing/mismatched). Update
the cell content that currently contains "%pip install torch openai" accordingly
so only "openai" is installed (or add an import-check guard) to avoid replacing
the bundled Torch wheel.

},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify GPU\n",
"\n",
"Confirm CUDA and visible devices before starting the server.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python: 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0]\n",
"CUDA available: True\n",
"Num GPUs: 8\n",
"GPU[0]: NVIDIA H100 80GB HBM3\n",
"GPU[1]: NVIDIA H100 80GB HBM3\n",
"GPU[2]: NVIDIA H100 80GB HBM3\n",
"GPU[3]: NVIDIA H100 80GB HBM3\n",
"GPU[4]: NVIDIA H100 80GB HBM3\n",
"GPU[5]: NVIDIA H100 80GB HBM3\n",
"GPU[6]: NVIDIA H100 80GB HBM3\n",
"GPU[7]: NVIDIA H100 80GB HBM3\n"
]
}
],
"source": [
"import sys\n",
"\n",
"import torch\n",
"\n",
"print(f\"Python: {sys.version}\")\n",
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
"print(f\"Num GPUs: {torch.cuda.device_count()}\")\n",
"\n",
"if torch.cuda.is_available():\n",
" for i in range(torch.cuda.device_count()):\n",
" print(f\"GPU[{i}]: {torch.cuda.get_device_name(i)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI-compatible server\n",
"\n",
"From a shell **inside** the container, at the TensorRT-LLM repo root, start `trtllm-serve` with AutoDeploy.\n",
"\n",
"Use the Gemma 4 MoE YAML under `examples/auto_deploy/model_registry/configs/` (see the introduction). Add `world_size` to that YAML if your serve command needs an explicit tensor-parallel device count.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the model\n",
"\n",
"**Instruction-tuned (`gemma4_moe.yaml`):**\n",
"\n",
"```shell\n",
"trtllm-serve \"google/gemma-4-26B-A4B-it\" \\\n",
" --host 0.0.0.0 \\\n",
" --port 8000 \\\n",
" --backend _autodeploy \\\n",
" --trust_remote_code \\\n",
" --extra_llm_api_options examples/auto_deploy/model_registry/configs/gemma4_moe.yaml\n",
"```\n",
"\n",
"**Base checkpoint:** use model id `google/gemma-4-26B-A4B` and `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` instead.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When the server finishes loading weights, it is ready for requests.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use the API\n",
"\n",
"Send chat completions with the OpenAI Python client pointed at the local server.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"\n",
"BASE_URL = \"http://0.0.0.0:8000/v1\"\n",
"API_KEY = \"null\"\n",
"MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n",
"\n",
"client = OpenAI(base_url=BASE_URL, api_key=API_KEY)"
Comment on lines +160 to +164
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use a routable client URL here.

0.0.0.0 is a bind address, not a client destination. The server can listen on 0.0.0.0, but the OpenAI client should connect to 127.0.0.1 or localhost.

Suggested change
-BASE_URL = "http://0.0.0.0:8000/v1"
+BASE_URL = "http://127.0.0.1:8000/v1"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"BASE_URL = \"http://0.0.0.0:8000/v1\"\n",
"API_KEY = \"null\"\n",
"MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n",
"\n",
"client = OpenAI(base_url=BASE_URL, api_key=API_KEY)"
"BASE_URL = \"http://127.0.0.1:8000/v1\"\n",
"API_KEY = \"null\"\n",
"MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n",
"\n",
"client = OpenAI(base_url=BASE_URL, api_key=API_KEY)"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb` around lines
160 - 164, The BASE_URL used to construct the OpenAI client is currently set to
the bind address "http://0.0.0.0:8000/v1"; update the BASE_URL constant to a
routable client address (e.g., "http://127.0.0.1:8000/v1" or
"http://localhost:8000/v1") so the OpenAI(...) client connects correctly; change
the string assigned to BASE_URL (referenced where client =
OpenAI(base_url=BASE_URL, api_key=API_KEY)) and keep the rest of the
instantiation unchanged.

]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat completion example\n",
"==================================================\n",
"Response:\n",
"To find 15% of 85, you can use a few different methods. Here are two easy ways to think about it:\n",
"\n",
"### Method 1: The Breakdown Method (Mental Math)\n",
"This is often the easiest way to calculate percentages in your head by breaking the percentage into manageable parts (10% and 5%).\n",
"\n",
"1. **Find 10% of 85:**\n",
" To find 10%, simply move the decimal point one place to the left.\n",
" $85 \\div 10 = 8.5$\n",
"2. **Find 5% of 85:**\n",
" Since 5% is half of 10%, just divide your previous answer by 2.\n",
" $8.5 \\div 2 = 4.25$\n",
"3. **Add them together:**\n",
" $10\\% + 5\\% = 15\\%$\n",
" $8.5 + 4.25 = 12.75$\n",
"\n",
"***\n",
"\n",
"### Method 2: The Multiplication Method (Calculator/Paper)\n",
"To find a percentage, you can convert the percentage into a decimal and multiply it by the total number.\n",
"\n",
"1. **Convert 15% to a decimal:**\n",
" $15\\% = \\frac{15}{100} = 0.15$\n",
"2. **Multiply by 85:**\n",
" $85 \\times 0.15 = 12.75$\n",
"\n",
"**Final Answer:**\n",
"15% of 85 is **12.75**.\n"
]
}
],
"source": [
"# Basic chat completion\n",
"print(\"Chat completion example\")\n",
"print(\"=\" * 50)\n",
"\n",
"response = client.chat.completions.create(\n",
" model=MODEL_ID,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"What is 15% of 85? Show your reasoning.\"},\n",
" ],\n",
" temperature=1.0,\n",
" top_p=0.95,\n",
" max_tokens=512,\n",
")\n",
"\n",
"print(\"Response:\")\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Streaming response:\n",
"==================================================\n",
"The first 5 prime numbers are **2, 3, 5, 7, and 11**."
]
}
],
"source": [
"# Streaming chat completion\n",
"print(\"Streaming response:\")\n",
"print(\"=\" * 50)\n",
"\n",
"stream = client.chat.completions.create(\n",
" model=MODEL_ID,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"What are the first 5 prime numbers?\"},\n",
" ],\n",
" temperature=0.7,\n",
" max_tokens=1024,\n",
" stream=True,\n",
")\n",
"\n",
"for chunk in stream:\n",
" if chunk.choices[0].delta.content:\n",
" print(chunk.choices[0].delta.content, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Additional resources\n",
"\n",
"- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n",
"- [TensorRT-LLM documentation](https://nvidia.github.io/TensorRT-LLM/)\n",
"- [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html)\n",
"- [`gemma4_moe.yaml`](../model_registry/configs/gemma4_moe.yaml), [`gemma4_moe_base.yaml`](../model_registry/configs/gemma4_moe_base.yaml), [`models.yaml`](../model_registry/models.yaml)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
13 changes: 13 additions & 0 deletions examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

runtime: trtllm
compile_backend: torch-cudagraph
model_factory: AutoModelForCausalLM
max_seq_len: 512
max_batch_size: 8
world_size: 1

# Gemma 3n uses shared-KV decode semantics in the tail layers. FlashInfer
# supports the read-only shared-KV cache path and alternating sliding windows.
attn_backend: flashinfer
28 changes: 28 additions & 0 deletions examples/auto_deploy/model_registry/configs/gemma4_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Gemma 4 MoE (26B total, 4B activated) — text-only AD export path.
# Uses triton paged attention backend: supports head_dim=512 (global_head_dim),
# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode.
model_factory: Gemma4ForConditionalGeneration
tokenizer: google/gemma-4-26B-A4B-it
attn_backend: triton_paged
compile_backend: torch-cudagraph
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
max_num_tokens: 8192
max_batch_size: 512
max_seq_len: 8192
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
transforms:
compile_model:
piecewise_enabled: true
mlir_elementwise_fusion:
enabled: true
gather_logits_before_lm_head:
enabled: true
fuse_gemms:
enabled: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Gemma 4 MoE base (26B total, 4B activated) — text-only AD export path.
# Uses triton paged attention backend: supports head_dim=512 (global_head_dim),
# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode.
model_factory: Gemma4ForConditionalGeneration
tokenizer: google/gemma-4-26B-A4B
attn_backend: triton_paged
compile_backend: torch-cudagraph
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
max_num_tokens: 8192
max_batch_size: 512
max_seq_len: 8192
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
transforms:
compile_model:
piecewise_enabled: true
Comment on lines +7 to +22
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Enable gather_logits_before_lm_head in the base Gemma4 config too.

This config uses the same Gemma4ForConditionalGeneration export path and piecewise_enabled: true setup as examples/auto_deploy/model_registry/configs/gemma4_moe.yaml, but it never opts into the transform that the new Gemma4 softcapping test is guarding. Without it, the base model can still materialize [num_tokens, vocab_size] before gather and hit the same piecewise CUDA-graph memory regression.

🔧 Suggested config change
 transforms:
   compile_model:
     piecewise_enabled: true
+  gather_logits_before_lm_head:
+    enabled: true
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model_factory: Gemma4ForConditionalGeneration
tokenizer: google/gemma-4-26B-A4B
attn_backend: triton_paged
compile_backend: torch-cudagraph
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
max_num_tokens: 8192
max_batch_size: 512
max_seq_len: 8192
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
transforms:
compile_model:
piecewise_enabled: true
model_factory: Gemma4ForConditionalGeneration
tokenizer: google/gemma-4-26B-A4B
attn_backend: triton_paged
compile_backend: torch-cudagraph
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
max_num_tokens: 8192
max_batch_size: 512
max_seq_len: 8192
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
transforms:
compile_model:
piecewise_enabled: true
gather_logits_before_lm_head:
enabled: true
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` around
lines 7 - 22, This Gemma4 base config is missing the
gather_logits_before_lm_head transform and can materialize [num_tokens,
vocab_size] before the LM head; update the YAML for the
Gemma4ForConditionalGeneration export to enable the same transform as in
gemma4_moe.yaml by adding gather_logits_before_lm_head under transforms
(alongside compile_model.piecewise_enabled) so the model uses the
gather-before-lm-head transform during export and avoids the piecewise
CUDA-graph memory regression.

5 changes: 5 additions & 0 deletions examples/auto_deploy/model_registry/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ models:
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'multimodal.yaml']
- name: google/gemma-3n-E4B-it
yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml']
# --- Gemma 4 (2026) - MoE with K=V attention ---
- name: google/gemma-4-26B-A4B
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma4_moe_base.yaml']
- name: google/gemma-4-26B-A4B-it
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma4_moe.yaml']
# --- JetBrains Mellum (Apr 2025) - code specialist ---
- name: JetBrains/Mellum-4b-sft-all
yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml']
Expand Down
Loading
Loading