Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,11 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("count_reusable_blocks", &BaseKVCacheManager::countReusableBlocks, nb::arg("unique_tokens"),
nb::arg("llm_request"), nb::arg("only_allocated") = false, nb::call_guard<nb::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard<nb::gil_scoped_release>())
.def(
"get_num_front_blocks_removed",
[](BaseKVCacheManager const& self, tb::LlmRequest::RequestIdType requestId)
{ return self.getSequence(requestId).getNumFrontBlocksRemoved(); },
nb::call_guard<nb::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
nb::call_guard<nb::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ The following is a table of supported models for the PyTorch backend:
| `Exaone4ForCausalLM` | EXAONE 4.0 | `LGAI-EXAONE/EXAONE-4.0-32B` |
| `ExaoneMoEForCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B` |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` |
| `Gemma3nForConditionalGeneration` [^8]| Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it` |
| `Gemma4ForConditionalGeneration` [^7]| Gemma 4 | `google/gemma-4-26B-A4B-it` |
| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6, GLM-4.7 | `THUDM/GLM-4-100B-A10B` |
| `Glm4MoeLiteForCausalLM` [^6] | GLM-4.7-Flash | `zai-org/GLM-4.7-Flash` |
| `GlmMoeDsaForCausalLM` | GLM-5 | `zai-org/GLM-5` |
Expand Down Expand Up @@ -60,6 +62,8 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
[^4]: Overlap scheduler isn't supported when using EAGLE-3(Two Model Engine) for GPT-OSS.
[^5]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml).
[^6]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml).
[^7]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma4_moe.yaml).
[^8]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml).


# Multimodal Feature Support Matrix (PyTorch Backend)
Expand Down
299 changes: 299 additions & 0 deletions examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deploying Gemma 4 MoE with TensorRT-LLM (AutoDeploy)\n",
"\n",
"This notebook walks you through serving **Gemma 4** (26B total, 4B activated MoE) with TensorRT-LLM using the **AutoDeploy** backend—same pattern as the Mistral Small 4 and GLM-4.7-Flash cookbooks in this folder.\n",
"\n",
"[TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/) is NVIDIA's open-source library for accelerating LLM inference on NVIDIA GPUs. AutoDeploy uses Hugging Face `transformers` modeling code and TensorRT-LLM graph transforms. See the [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).\n",
"\n",
"**Model resources:**\n",
"- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n",
"- Instruction-tuned MoE: [`google/gemma-4-26B-A4B-it`](https://huggingface.co/google/gemma-4-26B-A4B-it)\n",
"- Base MoE (no chat template): [`google/gemma-4-26B-A4B`](https://huggingface.co/google/gemma-4-26B-A4B)\n",
"\n",
"**Bundled AutoDeploy YAML (this branch):**\n",
"- **Instruction:** `examples/auto_deploy/model_registry/configs/gemma4_moe.yaml` — text-only export path; `attn_backend: triton_paged` (head_dim 512 / paged KV, CUDA-graph friendly).\n",
"- **Base:** `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` — same stack for the base checkpoint.\n",
"\n",
"`trtllm-serve` takes **one** YAML path via `--extra_llm_api_options` (or `--config`). The bundled MoE YAMLs omit `world_size`; add it (or copy the YAML and edit) so it matches your GPU count when you use tensor parallel or multi-GPU loading.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisites and environment\n",
"\n",
"Run TensorRT-LLM in a GPU container, for example:\n",
"\n",
"```shell\n",
"docker run --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all -p 8000:8000 nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc1\n",
"```\n",
"\n",
"Use a TensorRT-LLM checkout that includes Gemma 4 AutoDeploy support (model card, tokenizer, and any required bridges should match your branch).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If pip is not available\n",
"!python -m ensurepip --default-pip"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install torch openai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify GPU\n",
"\n",
"Confirm CUDA and visible devices before starting the server.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python: 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0]\n",
"CUDA available: True\n",
"Num GPUs: 8\n",
"GPU[0]: NVIDIA H100 80GB HBM3\n",
"GPU[1]: NVIDIA H100 80GB HBM3\n",
"GPU[2]: NVIDIA H100 80GB HBM3\n",
"GPU[3]: NVIDIA H100 80GB HBM3\n",
"GPU[4]: NVIDIA H100 80GB HBM3\n",
"GPU[5]: NVIDIA H100 80GB HBM3\n",
"GPU[6]: NVIDIA H100 80GB HBM3\n",
"GPU[7]: NVIDIA H100 80GB HBM3\n"
]
}
],
"source": [
"import sys\n",
"\n",
"import torch\n",
"\n",
"print(f\"Python: {sys.version}\")\n",
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
"print(f\"Num GPUs: {torch.cuda.device_count()}\")\n",
"\n",
"if torch.cuda.is_available():\n",
" for i in range(torch.cuda.device_count()):\n",
" print(f\"GPU[{i}]: {torch.cuda.get_device_name(i)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI-compatible server\n",
"\n",
"From a shell **inside** the container, at the TensorRT-LLM repo root, start `trtllm-serve` with AutoDeploy.\n",
"\n",
"Use the Gemma 4 MoE YAML under `examples/auto_deploy/model_registry/configs/` (see the introduction). Add `world_size` to that YAML if your serve command needs an explicit tensor-parallel device count.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the model\n",
"\n",
"**Instruction-tuned (`gemma4_moe.yaml`):**\n",
"\n",
"```shell\n",
"trtllm-serve \"google/gemma-4-26B-A4B-it\" \\\n",
" --host 0.0.0.0 \\\n",
" --port 8000 \\\n",
" --backend _autodeploy \\\n",
" --trust_remote_code \\\n",
" --extra_llm_api_options examples/auto_deploy/model_registry/configs/gemma4_moe.yaml\n",
"```\n",
"\n",
"**Base checkpoint:** use model id `google/gemma-4-26B-A4B` and `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` instead.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When the server finishes loading weights, it is ready for requests.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use the API\n",
"\n",
"Send chat completions with the OpenAI Python client pointed at the local server.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"\n",
"BASE_URL = \"http://0.0.0.0:8000/v1\"\n",
"API_KEY = \"null\"\n",
"MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n",
"\n",
"client = OpenAI(base_url=BASE_URL, api_key=API_KEY)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat completion example\n",
"==================================================\n",
"Response:\n",
"To find 15% of 85, you can use a few different methods. Here are two easy ways to think about it:\n",
"\n",
"### Method 1: The Breakdown Method (Mental Math)\n",
"This is often the easiest way to calculate percentages in your head by breaking the percentage into manageable parts (10% and 5%).\n",
"\n",
"1. **Find 10% of 85:**\n",
" To find 10%, simply move the decimal point one place to the left.\n",
" $85 \\div 10 = 8.5$\n",
"2. **Find 5% of 85:**\n",
" Since 5% is half of 10%, just divide your previous answer by 2.\n",
" $8.5 \\div 2 = 4.25$\n",
"3. **Add them together:**\n",
" $10\\% + 5\\% = 15\\%$\n",
" $8.5 + 4.25 = 12.75$\n",
"\n",
"***\n",
"\n",
"### Method 2: The Multiplication Method (Calculator/Paper)\n",
"To find a percentage, you can convert the percentage into a decimal and multiply it by the total number.\n",
"\n",
"1. **Convert 15% to a decimal:**\n",
" $15\\% = \\frac{15}{100} = 0.15$\n",
"2. **Multiply by 85:**\n",
" $85 \\times 0.15 = 12.75$\n",
"\n",
"**Final Answer:**\n",
"15% of 85 is **12.75**.\n"
]
}
],
"source": [
"# Basic chat completion\n",
"print(\"Chat completion example\")\n",
"print(\"=\" * 50)\n",
"\n",
"response = client.chat.completions.create(\n",
" model=MODEL_ID,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"What is 15% of 85? Show your reasoning.\"},\n",
" ],\n",
" temperature=1.0,\n",
" top_p=0.95,\n",
" max_tokens=512,\n",
")\n",
"\n",
"print(\"Response:\")\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Streaming response:\n",
"==================================================\n",
"The first 5 prime numbers are **2, 3, 5, 7, and 11**."
]
}
],
"source": [
"# Streaming chat completion\n",
"print(\"Streaming response:\")\n",
"print(\"=\" * 50)\n",
"\n",
"stream = client.chat.completions.create(\n",
" model=MODEL_ID,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"What are the first 5 prime numbers?\"},\n",
" ],\n",
" temperature=0.7,\n",
" max_tokens=1024,\n",
" stream=True,\n",
")\n",
"\n",
"for chunk in stream:\n",
" if chunk.choices[0].delta.content:\n",
" print(chunk.choices[0].delta.content, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Additional resources\n",
"\n",
"- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n",
"- [TensorRT-LLM documentation](https://nvidia.github.io/TensorRT-LLM/)\n",
"- [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html)\n",
"- [`gemma4_moe.yaml`](../model_registry/configs/gemma4_moe.yaml), [`gemma4_moe_base.yaml`](../model_registry/configs/gemma4_moe_base.yaml), [`models.yaml`](../model_registry/models.yaml)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
13 changes: 13 additions & 0 deletions examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

runtime: trtllm
compile_backend: torch-cudagraph
model_factory: AutoModelForCausalLM
max_seq_len: 512
max_batch_size: 8
world_size: 1

# Gemma 3n uses shared-KV decode semantics in the tail layers. FlashInfer
# supports the read-only shared-KV cache path and alternating sliding windows.
attn_backend: flashinfer
28 changes: 28 additions & 0 deletions examples/auto_deploy/model_registry/configs/gemma4_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Gemma 4 MoE (26B total, 4B activated) — text-only AD export path.
# Uses triton paged attention backend: supports head_dim=512 (global_head_dim),
# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode.
model_factory: Gemma4ForConditionalGeneration
tokenizer: google/gemma-4-26B-A4B-it
attn_backend: triton_paged
compile_backend: torch-cudagraph
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
max_num_tokens: 8192
max_batch_size: 512
max_seq_len: 8192
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
transforms:
compile_model:
piecewise_enabled: true
mlir_elementwise_fusion:
enabled: true
gather_logits_before_lm_head:
enabled: true
fuse_gemms:
enabled: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Gemma 4 MoE base (26B total, 4B activated) — text-only AD export path.
# Uses triton paged attention backend: supports head_dim=512 (global_head_dim),
# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode.
model_factory: Gemma4ForConditionalGeneration
tokenizer: google/gemma-4-26B-A4B
attn_backend: triton_paged
compile_backend: torch-cudagraph
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
max_num_tokens: 8192
max_batch_size: 512
max_seq_len: 8192
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
transforms:
compile_model:
piecewise_enabled: true
Loading
Loading