Skip to content

[Runtime] Add Gemma 4 E2B prerequisites#346

Draft
MakotoUwu wants to merge 1 commit intomlc-ai:mlcfrom
MakotoUwu:prep/gemma4-e2b-support
Draft

[Runtime] Add Gemma 4 E2B prerequisites#346
MakotoUwu wants to merge 1 commit intomlc-ai:mlcfrom
MakotoUwu:prep/gemma4-e2b-support

Conversation

@MakotoUwu
Copy link
Copy Markdown

[Runtime] Add Gemma 4 E2B prerequisites

Target repo: mlc-ai/relax
Target branch: mlc

Summary

This PR adds the runtime/compiler changes needed for google/gemma-4-E2B-it text-only support on the MLC WebGPU path.

This is a prerequisite PR. The companion mlc-llm PR depends on it.

What this changes

  • paged_kv_cache.cc
    • hoist ReserveAppendLengthInSeq so page metadata reflects the current prefill when a request spans multiple blocks in one call
    • route AttnKind::kMHASliding through the MHA path in SelfAttention() / CrossAttention()
  • target_kind.cc
    • register max_shared_memory_per_block = 32768 on the webgpu target kind so Dlight does not schedule decode kernels against an unrealistically large shared-memory budget
  • position_embedding.py
    • extend rope_freq_gptj with freq_dim_base for Gemma 4 full-attention layers (partial_rotary_factor = 0.25)
    • mark nested rope primfuncs private=True to avoid duplicate global-symbol failures during NormalizeGlobalVar
    • align the apply_rope primfunc argument type with current in-tree callers (int64)
  • kv_cache.py, tree_attn.py, tirx/build.py
    • switch str(target.kind) comparisons to target.kind.name once the WebGPU target carries additional attributes
  • web/src/runtime.ts
    • add chunked loading for large weight records
    • add kTVMFFIShape unpacking in tensorCreateView for the chunked path
  • web/emcc/wasm_runtime.cc
    • keep ArrayDecodeStorage tolerant of f32-to-bf16-tagged payloads whose bytes are raw float32

Why this is needed for Gemma 4

Gemma 4 E2B text-only exercises:

  • hybrid sliding/full attention
  • partial-rotary full-attention layers
  • tighter WebGPU shared-memory constraints on generated decode kernels
  • large weight records on the web runtime

Without these changes, the intended WebGPU path does not compile and run correctly.

Validation

Validated together with the companion mlc-llm PR from a clean-room setup:

  • fresh upstream clones
  • only the staged relax + mlc-llm commits applied
  • fresh rebuild
  • no local lab-tree dependency

Canonical browser checks pass on the rebuilt WASM:

  • Hi -> valid short completion
  • What is the capital of France? Answer in one word. -> Paris on first attempt
  • short haiku generation -> coherent output with finish_reason = "stop"

This PR claims clean-room functional equivalence, not byte-identical parity across host environments.

Scope / non-goals

This PR does not add:

  • multimodal Gemma 4 support
  • web-llm built-in model-list integration
  • the broader Gemma 4 family

This PR only covers the lower-layer prerequisites required for Gemma 4 E2B text-only support.

Follow-up work intentionally left out of scope:

  • broader kMHASliding symmetry in the remaining sibling sites of paged_kv_cache.cc
  • later cleanup of generic web runtime behavior beyond what Gemma 4 E2B needs now

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several enhancements to the TVM Relax frontend and Web runtime, including support for partial rotary embeddings in RoPE, improved target kind handling to prevent module shadowing, and chunked tensor loading for large weights. It also addresses a bug in the KV cache regarding intra-prefill shared-KV cross-attention, refines WebGPU target defaults, and improves the robustness of the f32-to-bf16 weight loader. Feedback suggests using loadI64 for shape sizes in the Wasm runtime to avoid truncation, ensuring attribute compatibility when grouping device functions by target kind, and adding explicit size validation in the f32-to-bf16 decoding logic.

Comment thread web/src/runtime.ts
if (callbackArg) {
const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader;
const shapeDataPtr = this.memory.loadPointer(shapeCellPtr);
const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

ShapeTupleNode::size is defined as a uint64_t (8 bytes) in the TVM C++ runtime. Using loadUSize (which typically loads 4 bytes on Wasm32) might truncate the length if it ever exceeds $2^{32}-1$. Although shape lengths are unlikely to be that large, it is safer to use loadI64 to match the underlying type and avoid potential issues with extremely large tensors.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swapping this call to loadI64() would not improve correctness in the current JS runtime. On wasm32, loadI64() already truncates to the low 32 bits, exactly as noted in web/src/memory.ts.

A real fix here would require end-to-end 64-bit / BigInt handling in the JS FFI rather than a one-line swap. Given the current wasm32 runtime model and practical ShapeTuple sizes, I left this as loadUSize().

Comment thread python/tvm/tirx/build.py
kind = target.kind.name if target is not None else ""
# Keep the first target encountered for each kind as the canonical one
if kind not in kind2target:
kind2target[kind] = target
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Grouping device functions by target.kind.name and using the first encountered Target as the canonical one might lead to issues if different functions of the same kind have incompatible attributes (e.g., different max_num_threads or arch). While this avoids module shadowing as described in the comments, a more robust approach would be to verify that the attributes of subsequent targets are compatible with the canonical one or to merge them.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grouping by target.kind.name is deliberate here to avoid module splitting from benign target-attribute drift introduced by different TIR passes.

Adding a stricter per-target compatibility check at this point would risk recreating the original shadowing problem for cases like WebGPU targets that differ only by attached attrs such as max_num_threads. Richer target merging / validation would make sense as a follow-up, but for this bug fix the goal was to keep one device module per actual backend kind.

Comment thread web/emcc/wasm_runtime.cc
// to the generic byte copy. This makes the loader tolerant of weight
// shards produced by older / alternate quantisation pipelines that retain
// the "f32-to-bf16" tag without performing the bf16 truncation.
if (size == byte_size / 2) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling the f32-to-bf16 format assumes that if the byte size doesn't match the bf16 expectation (2 bytes per element), it must be raw float32. It would be safer to explicitly check if byte_size == size * 4 before falling through to CopyFromBytes, or at least add a diagnostic log if the size matches neither expectation, as an unexpected byte size could lead to incorrect data being copied into the tensor.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback here is not silently accepting arbitrary payload sizes. If the data is not in packed bf16 form, it falls through to cpu_arr.CopyFromBytes(...), and TensorCopyFromBytes already checks that the incoming byte count matches the tensor's exact storage size.

So malformed sizes still fail hard with a size-mismatch error rather than being accepted incorrectly. I can add an explicit size * 4 check or a clearer diagnostic if reviewers prefer, but the current path is already safe.

Unblocks Gemma 4 E2B text-only support in mlc-llm. Ships the minimal
set of runtime / frontend changes the model exercises as soon as
prefill runs on WebGPU.

PagedKVCache hybrid dispatch (src/runtime/vm/paged_kv_cache.cc)
---------------------------------------------------------------
* Hoist `ReserveAppendLengthInSeq` above the aux-data loop so page
  metadata reflects the current prefill when a request spans multiple
  blocks in a single call. Previously the aux-data loop read block
  page counts before the blocks were reserved, producing empty
  `page_indptr` entries for the first call of a newly-created
  sequence whose length exceeded `page_size`.
* Route `AttnKind::kMHASliding` through the MHA dispatch arm in
  `SelfAttention()` and `CrossAttention()`. Without this, sliding
  layers in Gemma 4 fell through to the MLA path and returned
  zero-initialised output for their attention sub-graph.

WebGPU target kind (src/target/target_kind.cc)
----------------------------------------------
* Register `max_shared_memory_per_block = 32768` on the `webgpu`
  target kind. Without this attribute, Dlight's shared-memory
  analysis falls back to the generic 48 KB default and generates
  decode kernels that exceed Chrome/Dawn's 32 KB workgroup-storage
  budget. Chrome currently exposes 32768 and the WebGPU spec
  mandates at least 16384, so 32768 is a safe default.

Relax nn.llm (python/tvm/relax/frontend/nn/llm/)
------------------------------------------------
* position_embedding.py:
  - Add a `freq_dim_base` parameter to `rope_freq_gptj` so callers
    can decouple the frequency-base dimension (HF's `head_dim`)
    from the rotated range (HF's `rotary_dim`). Required by Gemma 4's
    full-attention layers, which use `partial_rotary_factor=0.25`.
  - Mark `fused_rope` and `fused_rope_longrope_scaling` as
    `@T.prim_func(private=True)`. Gemma 4 builds two instances of
    `llama_rope_with_position_map` per model (one for sliding
    layers, one for partial-rotary full-attention layers), both of
    which would otherwise register a prim_func named `fused_rope`
    at module scope and trip `NormalizeGlobalVar`'s duplicate-
    symbol check. These prim_funcs are nested inside the factory
    and have no external callers, so scoping their global-symbol
    registration is safe.
  - Promote the `apply_rope` prim_func parameter from `T.int32` to
    `T.int64`. Both in-tree callers that pass `apply_rope`
    (mlc_llm/model/gemma4 and the existing mlc_llm/model/llama4)
    construct the immediate with `tirx.IntImm("int64", 1)`;
    aligning the parameter type avoids an implicit cast in the
    generated kernel and matches the caller convention.
* kv_cache.py, tree_attn.py: replace `str(target.kind)` with
  `target.kind.name` in dispatch-kind comparisons. Required because
  the mlc-llm side now adds `max_num_threads: 128` to the `webgpu`
  target preset, which makes `str(target.kind)` no longer equal the
  bare "webgpu" string.
* tirx/build.py: group device functions in `split_host_device_mods`
  by `kind.name` instead of the stringified target so multi-target
  attribute-pass output is not silently shadowed.

WebAssembly runtime (web/)
--------------------------
* emcc/wasm_runtime.cc: reorder FFI includes so `tvm_ffi::*` static
  initialisers run before `runtime::*` initialisers. Also extend
  `ArrayDecodeStorage` with a fall-through for payloads tagged
  `f32-to-bf16` whose byte length matches native float32; this
  keeps native-f32 shards (e.g. Gemma 4's tied-embedding weight,
  which `mlc_llm convert_weight` still labels `f32-to-bf16`)
  decodable.
* src/runtime.ts: chunked weight loading for records whose `nbytes`
  exceed a 128 MB per-call budget. Gemma 4's per-layer embedding
  shards are ~390 MB after the model-side
  `Gemma4SplitScaledEmbedding` split, which is still larger than
  the WebGPU single-write limit. Also add `kTVMFFIShape` unpacking
  in the FFI result path, required by the new chunking code that
  calls `tensorCreateView` with explicit shape tuples.
@MakotoUwu MakotoUwu force-pushed the prep/gemma4-e2b-support branch from 6f780f8 to ac9cf7a Compare April 21, 2026 06:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant