dsv3: fix pure-JAX MoE einsum (wi_0 / wi_1 axis labels)#5
Open
ultrons wants to merge 3 commits into
Open
Conversation
…_ffn)
Surgical drop-in of kernel-agent's v_outside expert FFN (D.7 F-tiled, fits
production E=256 D=7168 F=2048 K=8 on v7x 64 MB VMEM, cluster-verified
upstream at d7-fix-5) inside _expert_mlp_gmm_ag_body. Swaps the 3
ragged_dot / gmm_v2 calls only; surrounding AG-dispatch + sort + scatter +
psum_scatter machinery unchanged.
cfg.moe_use_kernel_agent_ffn defaults False — production path
(v304 + iter-16 attn_proj_out SAVE) unchanged. CLI surface
--moe_use_kernel_agent_ffn flips it on. Mutually exclusive with
--moe_use_gmm_v2 and --moe_fp8_weights (kernel is bf16-only).
Also reverts the un-validated iter-19 SAVE-list expansion (kv_a) at
model.py:3112 → ("attn_proj_out",), restoring the iter-16 BASELINE
SAVE policy so this trial is apples-to-apples with iter-16 (1916 TPS/chip).
Files:
- jax_gpt/models/dsv3/kernels/kernel_agent/ (vendored from kernel-agent 2cda804)
- expert_ffn.py (auto-routes to D.7 when per-tile W1 > 4 MB)
- expert_ffn_f_tiled.py (D.7 F-tiling kernel)
- expert_ffn_d_tiled.py (D.6, retained for F=128 case)
- jax_gpt/models/dsv3/model.py cfg flag + gated branch + cvjp plumbing
- jax_gpt/models/dsv3/train.py --moe_use_kernel_agent_ffn CLI flag
- manifests/jobset.yaml.j2 register flag as bare --moe_use_kernel_agent_ffn
- research/dsv3/kernel-agent-snapshot-2cda804/ pinned upstream snapshot
- research/dsv3/kernel_agent_integration_notes.md integration writeup
- research/dsv3/kernel_agent_945964d_feedback.md initial usefulness assessment
- research/dsv3/aot_kernel_agent_integration.py AOT probe (small + prod)
- research/dsv3/kernel_agent_aot_check.py standalone kernel AOT probe
- tests/dsv3/kernels_test/exec_kernel_agent_ffn.py parity smoke vs gmm_v2
Verified at b4b63d1: parity smoke max_abs ~2.4e-4 vs ragged_dot baseline on
v4 TPU (better than gmm_v2's 4.9e-4 at the same shape); AOT compile PASS at
small shape (E=32 D=2048 K=4 EP=2 FSDP=4) on tpu7x:2x2x1; PASS for the
baseline path at production shape but FAIL at production with the kernel
flag on due to F=2048 VMEM. D.7 (commit 2cda804) closes that gap.
Refreshed local smoke + AOT re-runs are queued behind the TPU.
iter-16's cluster (bodaborg-super-rbq) was deleted between 2026-05-11 and 2026-05-22; the only currently-viable cluster with our hardware (tpu7x, 4x8x8-compatible) is bodaborg-super-xpk-x8p. Per ~/infra/INSTRUCTIONS.md §4, that cluster routes jax-gpt-class workloads via Kueue queue `multislice-queue` (not `lq`). The 4x8x8 topology / 64-pod parallelism / podset-slice-required-topology labels in the template already match x8p's dynamic-slice-composition convention (4x8x8 composed from 4×4x4x4 sub-blocks at admission). Verified by diff against ~/infra/manifests/jobset.yaml.j2 (the known-working reference template). Required for the kernel-agent-d7 trial. No docker rebuild needed — the queue label is rendered at submit time, not baked into the image.
The pure-JAX MoE paths (moe_backend in {"jax", "ragged_dot"}) raise
TypeError immediately:
dot_general requires contracting dimensions to have the same shape,
got (7168,) and (2048,).
Root cause: init_params at model.py:2929 stores wi_0 / wi_1 with shape
(cfg.E, cfg.D_moe, cfg.D) — D_moe in the middle, D last. Two consumers
were labelling the axes wrong:
expert_mlp_jax line 897-898 (EP=1 branch):
"td,edm->tem" — assumes (E, D, D_moe); contracts D_moe by accident.
_rd_block line 2301-2303 (ragged_dot path):
jax.lax.ragged_dot(sorted_x, wi_0, ...) — ragged_dot expects rhs
of shape (G, K, N) but wi_0 is (E, N, K) → same contracting mismatch.
Production paths (gmm_v2 / megablox / kernel-agent) consume the
(E, D_moe, D) layout correctly and aren't affected. The pure-JAX path
appears to have drifted out of sync during a layout change to the
weights init, and stayed broken because no integration test exercises
--moe_backend=jax.
Fix (one line each in the EP=1 branch + one swapaxes each in
_rd_block; wo is unchanged because its contracting dim is D_moe which
already aligns with the middle axis of storage):
einsum: "td,edm->tem" → "td,emd->tem"
ragged_dot: ragged_dot(x, wi_0, …) → ragged_dot(x, wi_0.swapaxes(-1,-2), …)
Verified locally on JAX 0.10 / CPU:
- DSv3 "mini" (L=2, full D=7168) forward lowers cleanly under
moe_backend in {"jax", "ragged_dot"}. Compiled HLO: 3198 / 3721
lines respectively, ~25 dot ops, no errors.
- Numerical cross-check on a nano config (D=64, D_moe=32, E=4):
expert_mlp_jax vs expert_mlp_ragged_dot agree to max rel diff
2.0e-5 in fp32 (mean 3.0e-7). In bf16 the same comparison hits
8% max-rel from summation-order noise (mean 0.2%) — expected;
the two paths reduce in different orders.
Surfaced while building perfsim's HLO cross-compile validation
(perfsim PR #88, task #110) — needed the pure-JAX path to lower on
CPU since Pallas backends only run on TPU. No production v304 /
v3xx training flow uses --moe_backend=jax so this commit is
strictly a correctness improvement on a previously broken path.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes a latent correctness bug in DSv3's pure-JAX MoE backends
(
--moe_backend=jaxand--moe_backend=ragged_dot). Both paths raiseTypeErrorimmediately due to misaligned einsum/ragged_dot axis labelsagainst the actual weight storage layout.
Root cause
init_params(model.py around line 2929) storeswi_0/wi_1withshape
(cfg.E, cfg.D_moe, cfg.D)— D_moe in the middle axis, D asthe last. Production paths (gmm_v2 / megablox / kernel-agent)
consume this layout correctly.
Two pure-JAX consumers labeled the axes wrong:
expert_mlp_jaxline 889-890 (EP=1 branch): einsum string"td,edm->tem"assumes(E, D, D_moe)— placesd(contracting)on the middle axis, but the middle axis is actually D_moe.
_rd_block(expert_mlp_ragged_dotno-mesh path):ragged_dotexpects rhs
(G, K, N), butwi_0/wi_1arrive as(E, D_moe, D)=
(G, N, K)— same contracting mismatch.The down-projection
wois unaffected: its contracting dim is D_moe,which already aligns with the middle axis of storage.
Fix
Plus a comment block explaining the axis convention so future drifts
get caught at code-review time.
Why this stayed latent
Production training (v304, v3xx) uses
--moe_backend gmm_v2or--moe_backend megablox— never--moe_backend jax. No integrationtest exercises the broken paths, so a weight-layout refactor (storing
D_moe in the middle, presumably for FSDP / tile-layout reasons) broke
the pure-JAX consumers silently.
Verification
JAX 0.10 / CPU.
Lowering (was broken; now works):
DSv3 mini(L=2, full D=7168) forward lowers cleanly undermoe_backend=jax: 3198 HLO lines, 26 dot ops, no errors.moe_backend=ragged_dot: 3721 HLO lines, 25 dots.Numerical (cross-check between the two pure-JAX paths on a nano
config, D=64, D_moe=32, E=4):
2.0e-5, mean rel diff3.0e-7— bit-levelagreement.
Why this matters now
Surfaced while building perfsim's HLO cross-compile validation harness
(perfsim PR ultrons/perfsim#88, task #110). perfsim needs the pure-JAX
path to lower on CPU since Pallas backends are TPU-only. Without this
fix the no-hardware HLO validation track is blocked.
Base branch
I targeted
autoperf/dsv3_train_fullsincemaindoesn't carry thedsv3 module. If a different base is preferred (e.g. a stable release
branch), happy to re-target.
Test plan
moe_backend=jaxmoe_backend=ragged_dota minimal pytest skeleton)