[Model] Add Gemma 4 E2B text-only support#3485
Conversation
Adds the text decoder for Gemma 4 E2B-it. Multimodal encoders are not wired up in this PR; the class is deliberately named `Gemma4ForCausalLM` (not `Gemma4ForConditionalGeneration`) so a future multimodal wrapper can be added without renaming. Architecture follows Gemma 3 with a hybrid sliding/full attention pattern (4 sliding : 1 full, cycling every 5 layers) and adds a per-layer input-gate mechanism that concatenates an auxiliary embedding stream with the main hidden stream. The full-attention layers use partial-rotary RoPE (`partial_rotary_factor=0.25`) with `head_dim` as the frequency base, which is routed through TVM's `rope_freq_gptj` via the companion PR's new `freq_dim_base` parameter. Sliding-attention layers use the standard GPT-J-style frequency. Final logits are softcapped at 30.0. New files --------- * `python/mlc_llm/model/gemma4/__init__.py` * `python/mlc_llm/model/gemma4/gemma4_loader.py` — HuggingFace weight-name → MLC param-name mapping. Folds the embedding-scale (`sqrt(hidden_size)`) into the tied embedding weight so the downstream `dequantize+take` fusion cannot drop the post-lookup multiply. Pads the 1-element per-layer scalar to 2 elements to meet WebGPU's 4-byte storage-buffer alignment. * `python/mlc_llm/model/gemma4/gemma4_model.py` — Relax-level spec for the 35-layer text decoder. Includes a `Gemma4SplitScaledEmbedding` that shards the per-layer embedding table along the embedding dimension so each shard fits under WebGPU's per-buffer budget. * `tests/python/model/test_gemma4.py` — four unit tests (registry presence, config round-trip, double-wide-MLP guard, `export_tvm` smoke). Mirrors `test_gemma3.py`. Runs in ~12 s with no GPU. Registration ------------ * `python/mlc_llm/model/model.py`: register `gemma4` and wrap `make_quantization_functions` in `_rewrite_quantize_names` to strip the `language_model.` prefix from the quantization map. Required because `export_tvm()` walks the top-level `Gemma4ForCausalLM` module and flattens those names, but `make_quantization_functions()` visits the decoder as a sub-module and keeps the prefix. A cleaner long-term fix is to teach `make_quantization_functions()` to understand the visitor prefix; that touches every multimodal-shaped model and is left as a follow-up. * `python/mlc_llm/model/model_preset.py`: adds `gemma4_e2b_it` preset with the E2B-it text-config layout used by the new tests. WebGPU compile-time infrastructure ---------------------------------- * `python/mlc_llm/support/auto_target.py`: add `max_num_threads: 128` to the `webgpu:generic` preset. With the default 256-thread workgroup, Dlight-generated decode kernels exceed Chrome/Dawn's 32 KB workgroup-storage budget. Runtime prerequisites (separate PR) ----------------------------------- Depends on the TVM changes on branch `prep/gemma4-e2b-support`: paged-KV hybrid sliding/full dispatch, `rope_freq_gptj` with a new `freq_dim_base` parameter, `max_shared_memory_per_block = 32768` on the `webgpu` target kind, ArrayDecodeStorage bf16-from-f32 tolerance, chunked weight loading, and a handful of `str(target.kind) -> target.kind.name` fixes made necessary by the `max_num_threads: 128` preset here. Those go as a separate PR against mlc-ai/relax.
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Gemma 4 model architecture, including the model implementation, weight loader, and configuration presets. The review feedback identifies several critical issues that would prevent the model from functioning correctly: a typo in the weight name transformation logic that would cause loading failures, the omission of the standard attention scaling factor (1/sqrt(head_dim)) which is necessary for numerical stability, and a missing scale normalization for per-layer inputs that were pre-scaled during weight conversion.
| if name.startswith("language_model."): | ||
| name = name[len("language_model.") :] | ||
| if name.startswith("model."): | ||
| return f"model.language_{name}" |
There was a problem hiding this comment.
The name_transform function appears to have a typo in the prefix construction. It currently produces names like model.language_model.layers... (since name already starts with model.), whereas standard HuggingFace multimodal checkpoints typically use the language_model.model.layers... prefix. This will likely cause weight loading failures.
| return f"model.language_{name}" | |
| return f"language_model.{name}" |
There was a problem hiding this comment.
I checked this against Gemma 4-family HF checkpoint keys. The text decoder weights are stored under model.language_model.*, not language_model.model.*.
The current transform intentionally strips the outer language_model. wrapper from the MLC-side multimodal namespace and maps nested model.* names back to HF model.language_model.*. Changing this to language_model.{name} would point at non-existent keys and break loading.
| self.shared_cache_source = self.text_config.layer_cache_source(layer_idx) | ||
| self.rope_theta = self.text_config.layer_rope_theta(layer_idx) | ||
| self.rope_scaling = self.text_config.layer_rope_scaling(layer_idx) | ||
| self.scaling = 1.0 |
There was a problem hiding this comment.
The attention scaling factor sm_scale is hardcoded to 1.0. For models with large head dimensions (256 or 512 in Gemma 4), omitting the 1/sqrt(head_dim) scaling factor will lead to extremely large attention scores, causing softmax saturation and incorrect model behavior. It should be set to self.local_head_dim**-0.5.
| self.scaling = 1.0 | |
| self.scaling = self.local_head_dim**-0.5 |
There was a problem hiding this comment.
This is matching the current HF Gemma 4 reference implementation. In transformers' Gemma4TextAttention, self.scaling is also set explicitly to 1.0 and passed into the attention interface.
So this is not missing the standard 1/sqrt(head_dim) scaling; it is preserving the model-specific Gemma 4 behavior rather than falling back to the generic default.
| projection = self.per_layer_projection_norm(projection) | ||
| if per_layer_inputs is None: | ||
| return projection | ||
| return (projection + per_layer_inputs) * self.per_layer_input_scale |
There was a problem hiding this comment.
In gemma4_loader.py, the embed_tokens_per_layer weights are pre-scaled by sqrt(hidden_size_per_layer_input) to support scale folding in quantized embedding fusion. However, this scale is not divided out here before adding per_layer_inputs to the unscaled projection. This will cause the per-layer inputs to be incorrectly weighted by a factor of sqrt(hidden_size_per_layer_input) (e.g., 16x for E2B).
| return (projection + per_layer_inputs) * self.per_layer_input_scale | |
| return (projection + per_layer_inputs / self.text_config.hidden_size_per_layer_input**0.5) * self.per_layer_input_scale |
There was a problem hiding this comment.
This path is also matching the current HF Gemma 4 reference. HF scales embed_tokens_per_layer by sqrt(hidden_size_per_layer_input) and combines it as (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.
On the MLC side, the loader-side pre-scale is compensating for the quantized embedding fusion path dropping the post-lookup multiply. Dividing by sqrt(hidden_size_per_layer_input) again here would double-correct and diverge from the reference behavior.
[Model] Add Gemma 4 E2B text-only support
Target repo:
mlc-ai/mlc-llmTarget branch:
mainSummary
This PR adds text-only support for
google/gemma-4-E2B-ittomlc-llm.This PR depends on the companion TVM / Relax prerequisite PR:
Please review this as Gemma 4 E2B text-only support, not as full Gemma 4 family support.
Scope
Included here:
mlc-llmside (max_num_threads: 128)Explicitly out of scope:
web-llmbuilt-in integration /prebuiltAppConfigWhat this adds
gemma4model family andGemma4ForCausalLMimplementation for the E2B text pathgemma4_loader.pygemma4_e2b_itpresetgemma4inMODELStests/python/model/test_gemma4.pyGemma 4-specific behavior covered here:
30.0Embedding / packaging adjustments:
Gemma4SplitScaledEmbeddingshards the per-layer embedding table along the embedding dimension so shards stay within the WebGPU per-buffer budgetdequantize + takefusion does not drop the post-lookup multiply_rewrite_quantize_nameskeeps thelanguage_model.*namespace aligned between the loader andexport_tvm()Why this PR is text-only
This PR intentionally targets only the language-model path for
google/gemma-4-E2B-it.I am deliberately not introducing:
Gemma4ForConditionalGenerationThe text-only path is the smallest reviewable unit that is already working, validated, and useful on the current MLC/WebLLM stack.
Validation
Result:
4 passedin ~12sAlso validated from a clean-room setup:
mlc-ai/relaxandmlc-ai/mlc-llmCanonical checks:
Hi-> valid short completionWhat is the capital of France? Answer in one word.->Parison first attemptfinish_reason = "stop"No retry wrapper, unload/reload workaround, or debug-only path was required in the clean-room validation.
This PR claims clean-room functional equivalence, not byte-identical parity across host environments.
Reviewer notes
Why
_rewrite_quantize_namesexistsGemma 4 stores the decoder under
language_model.*so multimodal encoders can be added later without renaming existing keys.export_tvm()andmake_quantization_functions()currently flatten / visit that namespace differently, so this PR normalizes the names on themlc-llmside. A cleaner long-term fix would be to teachmake_quantization_functions()to normalize the visitor prefix directly, but that is left as a follow-up.Why
max_num_threads: 128is includedWith the default WebGPU preset, Dlight-generated decode kernels can exceed Chrome/Dawn's workgroup-storage limits. Together with the companion TVM-side shared-memory registration, this keeps the generated decode path within current WebGPU limits for Gemma 4 E2B.
Non-goals / follow-ups
This PR does not claim:
web-llmbuilt-in integrationFollow-up work after this text-only landing can include: