[Runtime] Add Gemma 4 E2B prerequisites#346
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several enhancements to the TVM Relax frontend and Web runtime, including support for partial rotary embeddings in RoPE, improved target kind handling to prevent module shadowing, and chunked tensor loading for large weights. It also addresses a bug in the KV cache regarding intra-prefill shared-KV cross-attention, refines WebGPU target defaults, and improves the robustness of the f32-to-bf16 weight loader. Feedback suggests using loadI64 for shape sizes in the Wasm runtime to avoid truncation, ensuring attribute compatibility when grouping device functions by target kind, and adding explicit size validation in the f32-to-bf16 decoding logic.
| if (callbackArg) { | ||
| const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader; | ||
| const shapeDataPtr = this.memory.loadPointer(shapeCellPtr); | ||
| const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr()); |
There was a problem hiding this comment.
ShapeTupleNode::size is defined as a uint64_t (8 bytes) in the TVM C++ runtime. Using loadUSize (which typically loads 4 bytes on Wasm32) might truncate the length if it ever exceeds loadI64 to match the underlying type and avoid potential issues with extremely large tensors.
There was a problem hiding this comment.
Swapping this call to loadI64() would not improve correctness in the current JS runtime. On wasm32, loadI64() already truncates to the low 32 bits, exactly as noted in web/src/memory.ts.
A real fix here would require end-to-end 64-bit / BigInt handling in the JS FFI rather than a one-line swap. Given the current wasm32 runtime model and practical ShapeTuple sizes, I left this as loadUSize().
| kind = target.kind.name if target is not None else "" | ||
| # Keep the first target encountered for each kind as the canonical one | ||
| if kind not in kind2target: | ||
| kind2target[kind] = target |
There was a problem hiding this comment.
Grouping device functions by target.kind.name and using the first encountered Target as the canonical one might lead to issues if different functions of the same kind have incompatible attributes (e.g., different max_num_threads or arch). While this avoids module shadowing as described in the comments, a more robust approach would be to verify that the attributes of subsequent targets are compatible with the canonical one or to merge them.
There was a problem hiding this comment.
The grouping by target.kind.name is deliberate here to avoid module splitting from benign target-attribute drift introduced by different TIR passes.
Adding a stricter per-target compatibility check at this point would risk recreating the original shadowing problem for cases like WebGPU targets that differ only by attached attrs such as max_num_threads. Richer target merging / validation would make sense as a follow-up, but for this bug fix the goal was to keep one device module per actual backend kind.
| // to the generic byte copy. This makes the loader tolerant of weight | ||
| // shards produced by older / alternate quantisation pipelines that retain | ||
| // the "f32-to-bf16" tag without performing the bf16 truncation. | ||
| if (size == byte_size / 2) { |
There was a problem hiding this comment.
The logic for handling the f32-to-bf16 format assumes that if the byte size doesn't match the bf16 expectation (2 bytes per element), it must be raw float32. It would be safer to explicitly check if byte_size == size * 4 before falling through to CopyFromBytes, or at least add a diagnostic log if the size matches neither expectation, as an unexpected byte size could lead to incorrect data being copied into the tensor.
There was a problem hiding this comment.
The fallback here is not silently accepting arbitrary payload sizes. If the data is not in packed bf16 form, it falls through to cpu_arr.CopyFromBytes(...), and TensorCopyFromBytes already checks that the incoming byte count matches the tensor's exact storage size.
So malformed sizes still fail hard with a size-mismatch error rather than being accepted incorrectly. I can add an explicit size * 4 check or a clearer diagnostic if reviewers prefer, but the current path is already safe.
Unblocks Gemma 4 E2B text-only support in mlc-llm. Ships the minimal
set of runtime / frontend changes the model exercises as soon as
prefill runs on WebGPU.
PagedKVCache hybrid dispatch (src/runtime/vm/paged_kv_cache.cc)
---------------------------------------------------------------
* Hoist `ReserveAppendLengthInSeq` above the aux-data loop so page
metadata reflects the current prefill when a request spans multiple
blocks in a single call. Previously the aux-data loop read block
page counts before the blocks were reserved, producing empty
`page_indptr` entries for the first call of a newly-created
sequence whose length exceeded `page_size`.
* Route `AttnKind::kMHASliding` through the MHA dispatch arm in
`SelfAttention()` and `CrossAttention()`. Without this, sliding
layers in Gemma 4 fell through to the MLA path and returned
zero-initialised output for their attention sub-graph.
WebGPU target kind (src/target/target_kind.cc)
----------------------------------------------
* Register `max_shared_memory_per_block = 32768` on the `webgpu`
target kind. Without this attribute, Dlight's shared-memory
analysis falls back to the generic 48 KB default and generates
decode kernels that exceed Chrome/Dawn's 32 KB workgroup-storage
budget. Chrome currently exposes 32768 and the WebGPU spec
mandates at least 16384, so 32768 is a safe default.
Relax nn.llm (python/tvm/relax/frontend/nn/llm/)
------------------------------------------------
* position_embedding.py:
- Add a `freq_dim_base` parameter to `rope_freq_gptj` so callers
can decouple the frequency-base dimension (HF's `head_dim`)
from the rotated range (HF's `rotary_dim`). Required by Gemma 4's
full-attention layers, which use `partial_rotary_factor=0.25`.
- Mark `fused_rope` and `fused_rope_longrope_scaling` as
`@T.prim_func(private=True)`. Gemma 4 builds two instances of
`llama_rope_with_position_map` per model (one for sliding
layers, one for partial-rotary full-attention layers), both of
which would otherwise register a prim_func named `fused_rope`
at module scope and trip `NormalizeGlobalVar`'s duplicate-
symbol check. These prim_funcs are nested inside the factory
and have no external callers, so scoping their global-symbol
registration is safe.
- Promote the `apply_rope` prim_func parameter from `T.int32` to
`T.int64`. Both in-tree callers that pass `apply_rope`
(mlc_llm/model/gemma4 and the existing mlc_llm/model/llama4)
construct the immediate with `tirx.IntImm("int64", 1)`;
aligning the parameter type avoids an implicit cast in the
generated kernel and matches the caller convention.
* kv_cache.py, tree_attn.py: replace `str(target.kind)` with
`target.kind.name` in dispatch-kind comparisons. Required because
the mlc-llm side now adds `max_num_threads: 128` to the `webgpu`
target preset, which makes `str(target.kind)` no longer equal the
bare "webgpu" string.
* tirx/build.py: group device functions in `split_host_device_mods`
by `kind.name` instead of the stringified target so multi-target
attribute-pass output is not silently shadowed.
WebAssembly runtime (web/)
--------------------------
* emcc/wasm_runtime.cc: reorder FFI includes so `tvm_ffi::*` static
initialisers run before `runtime::*` initialisers. Also extend
`ArrayDecodeStorage` with a fall-through for payloads tagged
`f32-to-bf16` whose byte length matches native float32; this
keeps native-f32 shards (e.g. Gemma 4's tied-embedding weight,
which `mlc_llm convert_weight` still labels `f32-to-bf16`)
decodable.
* src/runtime.ts: chunked weight loading for records whose `nbytes`
exceed a 128 MB per-call budget. Gemma 4's per-layer embedding
shards are ~390 MB after the model-side
`Gemma4SplitScaledEmbedding` split, which is still larger than
the WebGPU single-write limit. Also add `kTVMFFIShape` unpacking
in the FFI result path, required by the new chunking code that
calls `tensorCreateView` with explicit shape tuples.
6f780f8 to
ac9cf7a
Compare
[Runtime] Add Gemma 4 E2B prerequisites
Target repo:
mlc-ai/relaxTarget branch:
mlcSummary
This PR adds the runtime/compiler changes needed for
google/gemma-4-E2B-ittext-only support on the MLC WebGPU path.This is a prerequisite PR. The companion
mlc-llmPR depends on it.What this changes
paged_kv_cache.ccReserveAppendLengthInSeqso page metadata reflects the current prefill when a request spans multiple blocks in one callAttnKind::kMHASlidingthrough the MHA path inSelfAttention()/CrossAttention()target_kind.ccmax_shared_memory_per_block = 32768on thewebgputarget kind so Dlight does not schedule decode kernels against an unrealistically large shared-memory budgetposition_embedding.pyrope_freq_gptjwithfreq_dim_basefor Gemma 4 full-attention layers (partial_rotary_factor = 0.25)private=Trueto avoid duplicate global-symbol failures duringNormalizeGlobalVarapply_ropeprimfunc argument type with current in-tree callers (int64)kv_cache.py,tree_attn.py,tirx/build.pystr(target.kind)comparisons totarget.kind.nameonce the WebGPU target carries additional attributesweb/src/runtime.tskTVMFFIShapeunpacking intensorCreateViewfor the chunked pathweb/emcc/wasm_runtime.ccArrayDecodeStoragetolerant off32-to-bf16-tagged payloads whose bytes are rawfloat32Why this is needed for Gemma 4
Gemma 4 E2B text-only exercises:
Without these changes, the intended WebGPU path does not compile and run correctly.
Validation
Validated together with the companion
mlc-llmPR from a clean-room setup:relax+mlc-llmcommits appliedCanonical browser checks pass on the rebuilt WASM:
Hi-> valid short completionWhat is the capital of France? Answer in one word.->Parison first attemptfinish_reason = "stop"This PR claims clean-room functional equivalence, not byte-identical parity across host environments.
Scope / non-goals
This PR does not add:
web-llmbuilt-in model-list integrationThis PR only covers the lower-layer prerequisites required for Gemma 4 E2B text-only support.
Follow-up work intentionally left out of scope:
kMHASlidingsymmetry in the remaining sibling sites ofpaged_kv_cache.cc