Skip to content

[Model] Add Gemma 4 E2B text-only support#3485

Draft
MakotoUwu wants to merge 1 commit intomlc-ai:mainfrom
MakotoUwu:prep/gemma4-e2b-text-only
Draft

[Model] Add Gemma 4 E2B text-only support#3485
MakotoUwu wants to merge 1 commit intomlc-ai:mainfrom
MakotoUwu:prep/gemma4-e2b-text-only

Conversation

@MakotoUwu
Copy link
Copy Markdown

@MakotoUwu MakotoUwu commented Apr 19, 2026

[Model] Add Gemma 4 E2B text-only support

Target repo: mlc-ai/mlc-llm
Target branch: main

Summary

This PR adds text-only support for google/gemma-4-E2B-it to mlc-llm.

This PR depends on the companion TVM / Relax prerequisite PR:

Please review this as Gemma 4 E2B text-only support, not as full Gemma 4 family support.

Scope

Included here:

  • Gemma 4 model registration
  • Gemma 4 loader/model implementation for the E2B text path
  • E2B preset
  • Gemma 4 unit tests
  • the minimal WebGPU-target adjustment needed on the mlc-llm side (max_num_threads: 128)

Explicitly out of scope:

  • multimodal Gemma 4
  • additional Gemma 4 variants such as E4B
  • web-llm built-in integration / prebuiltAppConfig
  • unrelated local fork cleanup

What this adds

  • adds a new gemma4 model family and Gemma4ForCausalLM implementation for the E2B text path
  • adds HF -> MLC parameter mapping in gemma4_loader.py
  • adds the gemma4_e2b_it preset
  • registers gemma4 in MODELS
  • adds tests/python/model/test_gemma4.py

Gemma 4-specific behavior covered here:

  • 35-layer hybrid sliding/full-attention layout
  • partial-rotary full-attention layers
  • per-layer input-gate mechanism
  • final-logit softcapping at 30.0

Embedding / packaging adjustments:

  • Gemma4SplitScaledEmbedding shards the per-layer embedding table along the embedding dimension so shards stay within the WebGPU per-buffer budget
  • the loader folds the embedding scale into the tied embedding weight so the downstream dequantize + take fusion does not drop the post-lookup multiply
  • the 1-element per-layer scalar is padded to satisfy WebGPU storage-buffer alignment
  • _rewrite_quantize_names keeps the language_model.* namespace aligned between the loader and export_tvm()

Why this PR is text-only

This PR intentionally targets only the language-model path for google/gemma-4-E2B-it.

I am deliberately not introducing:

  • vision/audio encoders
  • multimodal wrapper logic
  • Gemma4ForConditionalGeneration

The text-only path is the smallest reviewable unit that is already working, validated, and useful on the current MLC/WebLLM stack.

Validation

python -m pytest tests/python/model/test_gemma4.py -v

Result: 4 passed in ~12s

Also validated from a clean-room setup:

  • fresh upstream clones of mlc-ai/relax and mlc-ai/mlc-llm
  • only the two staged commits applied
  • fresh rebuild
  • canonical browser checks run on the rebuilt WebGPU artifact

Canonical checks:

  • Hi -> valid short completion
  • What is the capital of France? Answer in one word. -> Paris on first attempt
  • short haiku generation -> coherent output with finish_reason = "stop"

No retry wrapper, unload/reload workaround, or debug-only path was required in the clean-room validation.

This PR claims clean-room functional equivalence, not byte-identical parity across host environments.

Reviewer notes

Why _rewrite_quantize_names exists

Gemma 4 stores the decoder under language_model.* so multimodal encoders can be added later without renaming existing keys.

export_tvm() and make_quantization_functions() currently flatten / visit that namespace differently, so this PR normalizes the names on the mlc-llm side. A cleaner long-term fix would be to teach make_quantization_functions() to normalize the visitor prefix directly, but that is left as a follow-up.

Why max_num_threads: 128 is included

With the default WebGPU preset, Dlight-generated decode kernels can exceed Chrome/Dawn's workgroup-storage limits. Together with the companion TVM-side shared-memory registration, this keeps the generated decode path within current WebGPU limits for Gemma 4 E2B.

Non-goals / follow-ups

This PR does not claim:

  • full Gemma 4 support
  • multimodal support
  • E4B support
  • web-llm built-in integration

Follow-up work after this text-only landing can include:

  • additional Gemma 4 variants
  • multimodal wrapper support
  • later cleanup around quantization-name handling once the broader multimodal-shaped path is addressed

Adds the text decoder for Gemma 4 E2B-it. Multimodal encoders are not
wired up in this PR; the class is deliberately named
`Gemma4ForCausalLM` (not `Gemma4ForConditionalGeneration`) so a
future multimodal wrapper can be added without renaming.

Architecture follows Gemma 3 with a hybrid sliding/full attention
pattern (4 sliding : 1 full, cycling every 5 layers) and adds a
per-layer input-gate mechanism that concatenates an auxiliary
embedding stream with the main hidden stream. The full-attention
layers use partial-rotary RoPE (`partial_rotary_factor=0.25`) with
`head_dim` as the frequency base, which is routed through TVM's
`rope_freq_gptj` via the companion PR's new `freq_dim_base`
parameter. Sliding-attention layers use the standard GPT-J-style
frequency. Final logits are softcapped at 30.0.

New files
---------
* `python/mlc_llm/model/gemma4/__init__.py`
* `python/mlc_llm/model/gemma4/gemma4_loader.py` — HuggingFace
  weight-name → MLC param-name mapping. Folds the embedding-scale
  (`sqrt(hidden_size)`) into the tied embedding weight so the
  downstream `dequantize+take` fusion cannot drop the post-lookup
  multiply. Pads the 1-element per-layer scalar to 2 elements to
  meet WebGPU's 4-byte storage-buffer alignment.
* `python/mlc_llm/model/gemma4/gemma4_model.py` — Relax-level spec
  for the 35-layer text decoder. Includes a
  `Gemma4SplitScaledEmbedding` that shards the per-layer embedding
  table along the embedding dimension so each shard fits under
  WebGPU's per-buffer budget.
* `tests/python/model/test_gemma4.py` — four unit tests (registry
  presence, config round-trip, double-wide-MLP guard, `export_tvm`
  smoke). Mirrors `test_gemma3.py`. Runs in ~12 s with no GPU.

Registration
------------
* `python/mlc_llm/model/model.py`: register `gemma4` and wrap
  `make_quantization_functions` in `_rewrite_quantize_names` to
  strip the `language_model.` prefix from the quantization map.
  Required because `export_tvm()` walks the top-level
  `Gemma4ForCausalLM` module and flattens those names, but
  `make_quantization_functions()` visits the decoder as a
  sub-module and keeps the prefix. A cleaner long-term fix is to
  teach `make_quantization_functions()` to understand the visitor
  prefix; that touches every multimodal-shaped model and is left
  as a follow-up.
* `python/mlc_llm/model/model_preset.py`: adds `gemma4_e2b_it`
  preset with the E2B-it text-config layout used by the new tests.

WebGPU compile-time infrastructure
----------------------------------
* `python/mlc_llm/support/auto_target.py`: add
  `max_num_threads: 128` to the `webgpu:generic` preset. With the
  default 256-thread workgroup, Dlight-generated decode kernels
  exceed Chrome/Dawn's 32 KB workgroup-storage budget.

Runtime prerequisites (separate PR)
-----------------------------------
Depends on the TVM changes on branch `prep/gemma4-e2b-support`:
paged-KV hybrid sliding/full dispatch, `rope_freq_gptj` with a new
`freq_dim_base` parameter, `max_shared_memory_per_block = 32768` on
the `webgpu` target kind, ArrayDecodeStorage bf16-from-f32
tolerance, chunked weight loading, and a handful of
`str(target.kind) -> target.kind.name` fixes made necessary by the
`max_num_threads: 128` preset here. Those go as a separate PR
against mlc-ai/relax.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Gemma 4 model architecture, including the model implementation, weight loader, and configuration presets. The review feedback identifies several critical issues that would prevent the model from functioning correctly: a typo in the weight name transformation logic that would cause loading failures, the omission of the standard attention scaling factor (1/sqrt(head_dim)) which is necessary for numerical stability, and a missing scale normalization for per-layer inputs that were pre-scaled during weight conversion.

if name.startswith("language_model."):
name = name[len("language_model.") :]
if name.startswith("model."):
return f"model.language_{name}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The name_transform function appears to have a typo in the prefix construction. It currently produces names like model.language_model.layers... (since name already starts with model.), whereas standard HuggingFace multimodal checkpoints typically use the language_model.model.layers... prefix. This will likely cause weight loading failures.

Suggested change
return f"model.language_{name}"
return f"language_model.{name}"

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked this against Gemma 4-family HF checkpoint keys. The text decoder weights are stored under model.language_model.*, not language_model.model.*.

The current transform intentionally strips the outer language_model. wrapper from the MLC-side multimodal namespace and maps nested model.* names back to HF model.language_model.*. Changing this to language_model.{name} would point at non-existent keys and break loading.

self.shared_cache_source = self.text_config.layer_cache_source(layer_idx)
self.rope_theta = self.text_config.layer_rope_theta(layer_idx)
self.rope_scaling = self.text_config.layer_rope_scaling(layer_idx)
self.scaling = 1.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attention scaling factor sm_scale is hardcoded to 1.0. For models with large head dimensions (256 or 512 in Gemma 4), omitting the 1/sqrt(head_dim) scaling factor will lead to extremely large attention scores, causing softmax saturation and incorrect model behavior. It should be set to self.local_head_dim**-0.5.

Suggested change
self.scaling = 1.0
self.scaling = self.local_head_dim**-0.5

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is matching the current HF Gemma 4 reference implementation. In transformers' Gemma4TextAttention, self.scaling is also set explicitly to 1.0 and passed into the attention interface.

So this is not missing the standard 1/sqrt(head_dim) scaling; it is preserving the model-specific Gemma 4 behavior rather than falling back to the generic default.

projection = self.per_layer_projection_norm(projection)
if per_layer_inputs is None:
return projection
return (projection + per_layer_inputs) * self.per_layer_input_scale
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In gemma4_loader.py, the embed_tokens_per_layer weights are pre-scaled by sqrt(hidden_size_per_layer_input) to support scale folding in quantized embedding fusion. However, this scale is not divided out here before adding per_layer_inputs to the unscaled projection. This will cause the per-layer inputs to be incorrectly weighted by a factor of sqrt(hidden_size_per_layer_input) (e.g., 16x for E2B).

Suggested change
return (projection + per_layer_inputs) * self.per_layer_input_scale
return (projection + per_layer_inputs / self.text_config.hidden_size_per_layer_input**0.5) * self.per_layer_input_scale

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path is also matching the current HF Gemma 4 reference. HF scales embed_tokens_per_layer by sqrt(hidden_size_per_layer_input) and combines it as (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.

On the MLC side, the loader-side pre-scale is compensating for the quantized embedding fusion path dropping the post-lookup multiply. Dividing by sqrt(hidden_size_per_layer_input) again here would double-correct and diverge from the reference behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant