-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[None][feat] AutoDeploy: Moved to #12861 - Gemma4 vision #12810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ac43945
202ba44
1846485
127748f
5842edf
273c79b
1927ae3
d86afe6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,299 @@ | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cells": [ | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "# Deploying Gemma 4 MoE with TensorRT-LLM (AutoDeploy)\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "This notebook walks you through serving **Gemma 4** (26B total, 4B activated MoE) with TensorRT-LLM using the **AutoDeploy** backend—same pattern as the Mistral Small 4 and GLM-4.7-Flash cookbooks in this folder.\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "[TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/) is NVIDIA's open-source library for accelerating LLM inference on NVIDIA GPUs. AutoDeploy uses Hugging Face `transformers` modeling code and TensorRT-LLM graph transforms. See the [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "**Model resources:**\n", | ||||||||||||||||||||||
| "- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n", | ||||||||||||||||||||||
| "- Instruction-tuned MoE: [`google/gemma-4-26B-A4B-it`](https://huggingface.co/google/gemma-4-26B-A4B-it)\n", | ||||||||||||||||||||||
| "- Base MoE (no chat template): [`google/gemma-4-26B-A4B`](https://huggingface.co/google/gemma-4-26B-A4B)\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "**Bundled AutoDeploy YAML (this branch):**\n", | ||||||||||||||||||||||
| "- **Instruction:** `examples/auto_deploy/model_registry/configs/gemma4_moe.yaml` — text-only export path; `attn_backend: triton_paged` (head_dim 512 / paged KV, CUDA-graph friendly).\n", | ||||||||||||||||||||||
| "- **Base:** `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` — same stack for the base checkpoint.\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "`trtllm-serve` takes **one** YAML path via `--extra_llm_api_options` (or `--config`). The bundled MoE YAMLs omit `world_size`; add it (or copy the YAML and edit) so it matches your GPU count when you use tensor parallel or multi-GPU loading.\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "## Prerequisites and environment\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "Run TensorRT-LLM in a GPU container, for example:\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "```shell\n", | ||||||||||||||||||||||
| "docker run --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all -p 8000:8000 nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc1\n", | ||||||||||||||||||||||
| "```\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "Use a TensorRT-LLM checkout that includes Gemma 4 AutoDeploy support (model card, tokenizer, and any required bridges should match your branch).\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "code", | ||||||||||||||||||||||
| "execution_count": null, | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "outputs": [], | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "# If pip is not available\n", | ||||||||||||||||||||||
| "!python -m ensurepip --default-pip" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "code", | ||||||||||||||||||||||
| "execution_count": null, | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "outputs": [], | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "%pip install torch openai" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "## Verify GPU\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "Confirm CUDA and visible devices before starting the server.\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "code", | ||||||||||||||||||||||
| "execution_count": 1, | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "outputs": [ | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "name": "stdout", | ||||||||||||||||||||||
| "output_type": "stream", | ||||||||||||||||||||||
| "text": [ | ||||||||||||||||||||||
| "Python: 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0]\n", | ||||||||||||||||||||||
| "CUDA available: True\n", | ||||||||||||||||||||||
| "Num GPUs: 8\n", | ||||||||||||||||||||||
| "GPU[0]: NVIDIA H100 80GB HBM3\n", | ||||||||||||||||||||||
| "GPU[1]: NVIDIA H100 80GB HBM3\n", | ||||||||||||||||||||||
| "GPU[2]: NVIDIA H100 80GB HBM3\n", | ||||||||||||||||||||||
| "GPU[3]: NVIDIA H100 80GB HBM3\n", | ||||||||||||||||||||||
| "GPU[4]: NVIDIA H100 80GB HBM3\n", | ||||||||||||||||||||||
| "GPU[5]: NVIDIA H100 80GB HBM3\n", | ||||||||||||||||||||||
| "GPU[6]: NVIDIA H100 80GB HBM3\n", | ||||||||||||||||||||||
| "GPU[7]: NVIDIA H100 80GB HBM3\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| ], | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "import sys\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "import torch\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "print(f\"Python: {sys.version}\")\n", | ||||||||||||||||||||||
| "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", | ||||||||||||||||||||||
| "print(f\"Num GPUs: {torch.cuda.device_count()}\")\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "if torch.cuda.is_available():\n", | ||||||||||||||||||||||
| " for i in range(torch.cuda.device_count()):\n", | ||||||||||||||||||||||
| " print(f\"GPU[{i}]: {torch.cuda.get_device_name(i)}\")" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "## OpenAI-compatible server\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "From a shell **inside** the container, at the TensorRT-LLM repo root, start `trtllm-serve` with AutoDeploy.\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "Use the Gemma 4 MoE YAML under `examples/auto_deploy/model_registry/configs/` (see the introduction). Add `world_size` to that YAML if your serve command needs an explicit tensor-parallel device count.\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "### Load the model\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "**Instruction-tuned (`gemma4_moe.yaml`):**\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "```shell\n", | ||||||||||||||||||||||
| "trtllm-serve \"google/gemma-4-26B-A4B-it\" \\\n", | ||||||||||||||||||||||
| " --host 0.0.0.0 \\\n", | ||||||||||||||||||||||
| " --port 8000 \\\n", | ||||||||||||||||||||||
| " --backend _autodeploy \\\n", | ||||||||||||||||||||||
| " --trust_remote_code \\\n", | ||||||||||||||||||||||
| " --extra_llm_api_options examples/auto_deploy/model_registry/configs/gemma4_moe.yaml\n", | ||||||||||||||||||||||
| "```\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "**Base checkpoint:** use model id `google/gemma-4-26B-A4B` and `examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` instead.\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "When the server finishes loading weights, it is ready for requests.\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "## Use the API\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "Send chat completions with the OpenAI Python client pointed at the local server.\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "code", | ||||||||||||||||||||||
| "execution_count": 2, | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "outputs": [], | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "from openai import OpenAI\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "BASE_URL = \"http://0.0.0.0:8000/v1\"\n", | ||||||||||||||||||||||
| "API_KEY = \"null\"\n", | ||||||||||||||||||||||
| "MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "client = OpenAI(base_url=BASE_URL, api_key=API_KEY)" | ||||||||||||||||||||||
|
Comment on lines
+160
to
+164
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use a routable client URL here.
Suggested change-BASE_URL = "http://0.0.0.0:8000/v1"
+BASE_URL = "http://127.0.0.1:8000/v1"📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "code", | ||||||||||||||||||||||
| "execution_count": 3, | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "outputs": [ | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "name": "stdout", | ||||||||||||||||||||||
| "output_type": "stream", | ||||||||||||||||||||||
| "text": [ | ||||||||||||||||||||||
| "Chat completion example\n", | ||||||||||||||||||||||
| "==================================================\n", | ||||||||||||||||||||||
| "Response:\n", | ||||||||||||||||||||||
| "To find 15% of 85, you can use a few different methods. Here are two easy ways to think about it:\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "### Method 1: The Breakdown Method (Mental Math)\n", | ||||||||||||||||||||||
| "This is often the easiest way to calculate percentages in your head by breaking the percentage into manageable parts (10% and 5%).\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "1. **Find 10% of 85:**\n", | ||||||||||||||||||||||
| " To find 10%, simply move the decimal point one place to the left.\n", | ||||||||||||||||||||||
| " $85 \\div 10 = 8.5$\n", | ||||||||||||||||||||||
| "2. **Find 5% of 85:**\n", | ||||||||||||||||||||||
| " Since 5% is half of 10%, just divide your previous answer by 2.\n", | ||||||||||||||||||||||
| " $8.5 \\div 2 = 4.25$\n", | ||||||||||||||||||||||
| "3. **Add them together:**\n", | ||||||||||||||||||||||
| " $10\\% + 5\\% = 15\\%$\n", | ||||||||||||||||||||||
| " $8.5 + 4.25 = 12.75$\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "***\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "### Method 2: The Multiplication Method (Calculator/Paper)\n", | ||||||||||||||||||||||
| "To find a percentage, you can convert the percentage into a decimal and multiply it by the total number.\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "1. **Convert 15% to a decimal:**\n", | ||||||||||||||||||||||
| " $15\\% = \\frac{15}{100} = 0.15$\n", | ||||||||||||||||||||||
| "2. **Multiply by 85:**\n", | ||||||||||||||||||||||
| " $85 \\times 0.15 = 12.75$\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "**Final Answer:**\n", | ||||||||||||||||||||||
| "15% of 85 is **12.75**.\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| ], | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "# Basic chat completion\n", | ||||||||||||||||||||||
| "print(\"Chat completion example\")\n", | ||||||||||||||||||||||
| "print(\"=\" * 50)\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "response = client.chat.completions.create(\n", | ||||||||||||||||||||||
| " model=MODEL_ID,\n", | ||||||||||||||||||||||
| " messages=[\n", | ||||||||||||||||||||||
| " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", | ||||||||||||||||||||||
| " {\"role\": \"user\", \"content\": \"What is 15% of 85? Show your reasoning.\"},\n", | ||||||||||||||||||||||
| " ],\n", | ||||||||||||||||||||||
| " temperature=1.0,\n", | ||||||||||||||||||||||
| " top_p=0.95,\n", | ||||||||||||||||||||||
| " max_tokens=512,\n", | ||||||||||||||||||||||
| ")\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "print(\"Response:\")\n", | ||||||||||||||||||||||
| "print(response.choices[0].message.content)" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "code", | ||||||||||||||||||||||
| "execution_count": 4, | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "outputs": [ | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "name": "stdout", | ||||||||||||||||||||||
| "output_type": "stream", | ||||||||||||||||||||||
| "text": [ | ||||||||||||||||||||||
| "Streaming response:\n", | ||||||||||||||||||||||
| "==================================================\n", | ||||||||||||||||||||||
| "The first 5 prime numbers are **2, 3, 5, 7, and 11**." | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| ], | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "# Streaming chat completion\n", | ||||||||||||||||||||||
| "print(\"Streaming response:\")\n", | ||||||||||||||||||||||
| "print(\"=\" * 50)\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "stream = client.chat.completions.create(\n", | ||||||||||||||||||||||
| " model=MODEL_ID,\n", | ||||||||||||||||||||||
| " messages=[\n", | ||||||||||||||||||||||
| " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", | ||||||||||||||||||||||
| " {\"role\": \"user\", \"content\": \"What are the first 5 prime numbers?\"},\n", | ||||||||||||||||||||||
| " ],\n", | ||||||||||||||||||||||
| " temperature=0.7,\n", | ||||||||||||||||||||||
| " max_tokens=1024,\n", | ||||||||||||||||||||||
| " stream=True,\n", | ||||||||||||||||||||||
| ")\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "for chunk in stream:\n", | ||||||||||||||||||||||
| " if chunk.choices[0].delta.content:\n", | ||||||||||||||||||||||
| " print(chunk.choices[0].delta.content, end=\"\", flush=True)" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "## Additional resources\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "- [Gemma 4 collection (Hugging Face)](https://huggingface.co/collections/google/gemma-4)\n", | ||||||||||||||||||||||
| "- [TensorRT-LLM documentation](https://nvidia.github.io/TensorRT-LLM/)\n", | ||||||||||||||||||||||
| "- [AutoDeploy guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html)\n", | ||||||||||||||||||||||
| "- [`gemma4_moe.yaml`](../model_registry/configs/gemma4_moe.yaml), [`gemma4_moe_base.yaml`](../model_registry/configs/gemma4_moe_base.yaml), [`models.yaml`](../model_registry/models.yaml)\n" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| ], | ||||||||||||||||||||||
| "metadata": { | ||||||||||||||||||||||
| "kernelspec": { | ||||||||||||||||||||||
| "display_name": ".venv", | ||||||||||||||||||||||
| "language": "python", | ||||||||||||||||||||||
| "name": "python3" | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| "language_info": { | ||||||||||||||||||||||
| "codemirror_mode": { | ||||||||||||||||||||||
| "name": "ipython", | ||||||||||||||||||||||
| "version": 3 | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| "file_extension": ".py", | ||||||||||||||||||||||
| "mimetype": "text/x-python", | ||||||||||||||||||||||
| "name": "python", | ||||||||||||||||||||||
| "nbconvert_exporter": "python", | ||||||||||||||||||||||
| "pygments_lexer": "ipython3", | ||||||||||||||||||||||
| "version": "3.12.3" | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| "nbformat": 4, | ||||||||||||||||||||||
| "nbformat_minor": 4 | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| runtime: trtllm | ||
| compile_backend: torch-cudagraph | ||
| model_factory: AutoModelForCausalLM | ||
| max_seq_len: 512 | ||
| max_batch_size: 8 | ||
| world_size: 1 | ||
|
|
||
| # Gemma 3n uses shared-KV decode semantics in the tail layers. FlashInfer | ||
| # supports the read-only shared-KV cache path and alternating sliding windows. | ||
| attn_backend: flashinfer |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # Gemma 4 MoE (26B total, 4B activated) — text-only AD export path. | ||
| # Uses triton paged attention backend: supports head_dim=512 (global_head_dim), | ||
| # paged KV cache, CUDA-graph-compatible, FlashDecoding for decode. | ||
| model_factory: Gemma4ForConditionalGeneration | ||
| tokenizer: google/gemma-4-26B-A4B-it | ||
| attn_backend: triton_paged | ||
| compile_backend: torch-cudagraph | ||
| cuda_graph_config: | ||
| batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] | ||
| max_num_tokens: 8192 | ||
| max_batch_size: 512 | ||
| max_seq_len: 8192 | ||
| enable_chunked_prefill: true | ||
| kv_cache_config: | ||
| enable_block_reuse: false | ||
| free_gpu_memory_fraction: 0.8 | ||
| transforms: | ||
| compile_model: | ||
| piecewise_enabled: true | ||
| mlir_elementwise_fusion: | ||
| enabled: true | ||
| gather_logits_before_lm_head: | ||
| enabled: true | ||
| fuse_gemms: | ||
| enabled: true |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,22 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Gemma 4 MoE base (26B total, 4B activated) — text-only AD export path. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Uses triton paged attention backend: supports head_dim=512 (global_head_dim), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # paged KV cache, CUDA-graph-compatible, FlashDecoding for decode. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_factory: Gemma4ForConditionalGeneration | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer: google/gemma-4-26B-A4B | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attn_backend: triton_paged | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compile_backend: torch-cudagraph | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cuda_graph_config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_num_tokens: 8192 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_batch_size: 512 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_seq_len: 8192 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enable_chunked_prefill: true | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kv_cache_config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enable_block_reuse: false | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| free_gpu_memory_fraction: 0.8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| transforms: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compile_model: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| piecewise_enabled: true | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+7
to
+22
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enable This config uses the same 🔧 Suggested config change transforms:
compile_model:
piecewise_enabled: true
+ gather_logits_before_lm_head:
+ enabled: true📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't reinstall
torchinside the TRT-LLM container.The release image already ships a Torch build that matches its CUDA/TensorRT stack.
%pip install torchhere can replace it with an incompatible wheel and break the rest of the notebook; onlyopenaishould be installed on top, or this should be guarded by an import check.Suggested change
📝 Committable suggestion
🤖 Prompt for AI Agents