Skip to content

dsv3: fix pure-JAX MoE einsum (wi_0 / wi_1 axis labels)#5

Open
ultrons wants to merge 3 commits into
autoperf/dsv3_train_fullfrom
fix/moe-pure-jax-einsum
Open

dsv3: fix pure-JAX MoE einsum (wi_0 / wi_1 axis labels)#5
ultrons wants to merge 3 commits into
autoperf/dsv3_train_fullfrom
fix/moe-pure-jax-einsum

Conversation

@ultrons

@ultrons ultrons commented May 26, 2026

Copy link
Copy Markdown
Owner

Summary

Fixes a latent correctness bug in DSv3's pure-JAX MoE backends
(--moe_backend=jax and --moe_backend=ragged_dot). Both paths raise
TypeError immediately due to misaligned einsum/ragged_dot axis labels
against the actual weight storage layout.

TypeError: dot_general requires contracting dimensions to have the
same shape, got (7168,) and (2048,).

Root cause

init_params (model.py around line 2929) stores wi_0 / wi_1 with
shape (cfg.E, cfg.D_moe, cfg.D) — D_moe in the middle axis, D as
the last. Production paths (gmm_v2 / megablox / kernel-agent)
consume this layout correctly.

Two pure-JAX consumers labeled the axes wrong:

  • expert_mlp_jax line 889-890 (EP=1 branch): einsum string
    "td,edm->tem" assumes (E, D, D_moe) — places d (contracting)
    on the middle axis, but the middle axis is actually D_moe.
  • _rd_block (expert_mlp_ragged_dot no-mesh path): ragged_dot
    expects rhs (G, K, N), but wi_0/wi_1 arrive as (E, D_moe, D)
    = (G, N, K) — same contracting mismatch.

The down-projection wo is unaffected: its contracting dim is D_moe,
which already aligns with the middle axis of storage.

Fix

# expert_mlp_jax (EP=1)
- gate_all = jax.nn.silu(jnp.einsum("td,edm->tem", flat_x, wi_0))
- up_all   = jnp.einsum("td,edm->tem", flat_x, wi_1)
+ gate_all = jax.nn.silu(jnp.einsum("td,emd->tem", flat_x, wi_0))
+ up_all   = jnp.einsum("td,emd->tem", flat_x, wi_1)

# _rd_block (ragged_dot path)
- gate = jax.nn.silu(jax.lax.ragged_dot(sorted_x, wi_0, group_sizes))
- up   = jax.lax.ragged_dot(sorted_x, wi_1, group_sizes)
+ gate = jax.nn.silu(jax.lax.ragged_dot(sorted_x, wi_0.swapaxes(-1, -2), group_sizes))
+ up   = jax.lax.ragged_dot(sorted_x, wi_1.swapaxes(-1, -2), group_sizes)

Plus a comment block explaining the axis convention so future drifts
get caught at code-review time.

Why this stayed latent

Production training (v304, v3xx) uses --moe_backend gmm_v2 or
--moe_backend megablox — never --moe_backend jax. No integration
test exercises the broken paths, so a weight-layout refactor (storing
D_moe in the middle, presumably for FSDP / tile-layout reasons) broke
the pure-JAX consumers silently.

Verification

JAX 0.10 / CPU.

Lowering (was broken; now works):

  • DSv3 mini (L=2, full D=7168) forward lowers cleanly under
    moe_backend=jax: 3198 HLO lines, 26 dot ops, no errors.
  • Same under moe_backend=ragged_dot: 3721 HLO lines, 25 dots.

Numerical (cross-check between the two pure-JAX paths on a nano
config, D=64, D_moe=32, E=4):

  • fp32: max rel diff 2.0e-5, mean rel diff 3.0e-7 — bit-level
    agreement.
  • bf16: max 8% / mean 0.2% — expected summation-order noise.

Why this matters now

Surfaced while building perfsim's HLO cross-compile validation harness
(perfsim PR ultrons/perfsim#88, task #110). perfsim needs the pure-JAX
path to lower on CPU since Pallas backends are TPU-only. Without this
fix the no-hardware HLO validation track is blocked.

Base branch

I targeted autoperf/dsv3_train_full since main doesn't carry the
dsv3 module. If a different base is preferred (e.g. a stable release
branch), happy to re-target.

Test plan

  • DSv3 mini forward lowers under moe_backend=jax
  • DSv3 mini forward lowers under moe_backend=ragged_dot
  • Numerical cross-check (fp32 bit-level, bf16 noise-level)
  • Integration test (suggest adding one — see commit message for
    a minimal pytest skeleton)

ultrons added 3 commits May 22, 2026 22:47
…_ffn)

Surgical drop-in of kernel-agent's v_outside expert FFN (D.7 F-tiled, fits
production E=256 D=7168 F=2048 K=8 on v7x 64 MB VMEM, cluster-verified
upstream at d7-fix-5) inside _expert_mlp_gmm_ag_body. Swaps the 3
ragged_dot / gmm_v2 calls only; surrounding AG-dispatch + sort + scatter +
psum_scatter machinery unchanged.

cfg.moe_use_kernel_agent_ffn defaults False — production path
(v304 + iter-16 attn_proj_out SAVE) unchanged. CLI surface
--moe_use_kernel_agent_ffn flips it on. Mutually exclusive with
--moe_use_gmm_v2 and --moe_fp8_weights (kernel is bf16-only).

Also reverts the un-validated iter-19 SAVE-list expansion (kv_a) at
model.py:3112 → ("attn_proj_out",), restoring the iter-16 BASELINE
SAVE policy so this trial is apples-to-apples with iter-16 (1916 TPS/chip).

Files:
- jax_gpt/models/dsv3/kernels/kernel_agent/  (vendored from kernel-agent 2cda804)
  - expert_ffn.py             (auto-routes to D.7 when per-tile W1 > 4 MB)
  - expert_ffn_f_tiled.py     (D.7 F-tiling kernel)
  - expert_ffn_d_tiled.py     (D.6, retained for F=128 case)
- jax_gpt/models/dsv3/model.py        cfg flag + gated branch + cvjp plumbing
- jax_gpt/models/dsv3/train.py        --moe_use_kernel_agent_ffn CLI flag
- manifests/jobset.yaml.j2            register flag as bare --moe_use_kernel_agent_ffn
- research/dsv3/kernel-agent-snapshot-2cda804/   pinned upstream snapshot
- research/dsv3/kernel_agent_integration_notes.md  integration writeup
- research/dsv3/kernel_agent_945964d_feedback.md   initial usefulness assessment
- research/dsv3/aot_kernel_agent_integration.py    AOT probe (small + prod)
- research/dsv3/kernel_agent_aot_check.py          standalone kernel AOT probe
- tests/dsv3/kernels_test/exec_kernel_agent_ffn.py parity smoke vs gmm_v2

Verified at b4b63d1: parity smoke max_abs ~2.4e-4 vs ragged_dot baseline on
v4 TPU (better than gmm_v2's 4.9e-4 at the same shape); AOT compile PASS at
small shape (E=32 D=2048 K=4 EP=2 FSDP=4) on tpu7x:2x2x1; PASS for the
baseline path at production shape but FAIL at production with the kernel
flag on due to F=2048 VMEM. D.7 (commit 2cda804) closes that gap.

Refreshed local smoke + AOT re-runs are queued behind the TPU.
iter-16's cluster (bodaborg-super-rbq) was deleted between 2026-05-11
and 2026-05-22; the only currently-viable cluster with our hardware
(tpu7x, 4x8x8-compatible) is bodaborg-super-xpk-x8p. Per
~/infra/INSTRUCTIONS.md §4, that cluster routes jax-gpt-class workloads
via Kueue queue `multislice-queue` (not `lq`).

The 4x8x8 topology / 64-pod parallelism / podset-slice-required-topology
labels in the template already match x8p's dynamic-slice-composition
convention (4x8x8 composed from 4×4x4x4 sub-blocks at admission).
Verified by diff against ~/infra/manifests/jobset.yaml.j2 (the
known-working reference template).

Required for the kernel-agent-d7 trial. No docker rebuild needed —
the queue label is rendered at submit time, not baked into the image.
The pure-JAX MoE paths (moe_backend in {"jax", "ragged_dot"}) raise
TypeError immediately:

  dot_general requires contracting dimensions to have the same shape,
  got (7168,) and (2048,).

Root cause: init_params at model.py:2929 stores wi_0 / wi_1 with shape
(cfg.E, cfg.D_moe, cfg.D) — D_moe in the middle, D last.  Two consumers
were labelling the axes wrong:

  expert_mlp_jax line 897-898 (EP=1 branch):
    "td,edm->tem" — assumes (E, D, D_moe); contracts D_moe by accident.
  _rd_block line 2301-2303 (ragged_dot path):
    jax.lax.ragged_dot(sorted_x, wi_0, ...) — ragged_dot expects rhs
    of shape (G, K, N) but wi_0 is (E, N, K) → same contracting mismatch.

Production paths (gmm_v2 / megablox / kernel-agent) consume the
(E, D_moe, D) layout correctly and aren't affected.  The pure-JAX path
appears to have drifted out of sync during a layout change to the
weights init, and stayed broken because no integration test exercises
--moe_backend=jax.

Fix (one line each in the EP=1 branch + one swapaxes each in
_rd_block; wo is unchanged because its contracting dim is D_moe which
already aligns with the middle axis of storage):

  einsum:      "td,edm->tem"          →  "td,emd->tem"
  ragged_dot:  ragged_dot(x, wi_0, …) →  ragged_dot(x, wi_0.swapaxes(-1,-2), …)

Verified locally on JAX 0.10 / CPU:

  - DSv3 "mini" (L=2, full D=7168) forward lowers cleanly under
    moe_backend in {"jax", "ragged_dot"}.  Compiled HLO: 3198 / 3721
    lines respectively, ~25 dot ops, no errors.
  - Numerical cross-check on a nano config (D=64, D_moe=32, E=4):
    expert_mlp_jax vs expert_mlp_ragged_dot agree to max rel diff
    2.0e-5 in fp32 (mean 3.0e-7).  In bf16 the same comparison hits
    8% max-rel from summation-order noise (mean 0.2%) — expected;
    the two paths reduce in different orders.

Surfaced while building perfsim's HLO cross-compile validation
(perfsim PR #88, task #110) — needed the pure-JAX path to lower on
CPU since Pallas backends only run on TPU.  No production v304 /
v3xx training flow uses --moe_backend=jax so this commit is
strictly a correctness improvement on a previously broken path.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant