Skip to content

[Pallas] Add fused_linear_jsd and grpo_loss to TPU benchmark sweep#2421

Merged
norx1991 merged 1 commit into
mainfrom
yifeixu/tpu-bench-jsd-grpo
May 15, 2026
Merged

[Pallas] Add fused_linear_jsd and grpo_loss to TPU benchmark sweep#2421
norx1991 merged 1 commit into
mainfrom
yifeixu/tpu-bench-jsd-grpo

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented May 14, 2026

Summary

  • Adds fused_linear_jsd and grpo_loss to benchmarks/run_tpu.py::KERNEL_MAPPINGS so they can be exercised by the full-coverage TPU sweep. The slim 9-kernel dashboard nightly is unchanged.

Scope (what works on Pallas today)

  • fused_linear_jsd: benches the JSD-only fused_linear_jsd_kernel (the test-covered path) against a new PyTorch baseline defined in run_tpu.py that mirrors the JSD-only math. The example file's fused_linear_jsd_pytorch does the full linear+JSD computation and takes inputs+weights, so it doesn't match the JSD-only kernel's logits-only signature.
  • grpo_loss: benches helion_grpo_loss forward only, against the example's existing torch_grpo_loss (with a small wrapper to upcast logits to fp32, matching the kernel's internal precision). The backward kernel is @xfailIfPallas("InductorLoweringError") in test_examples.py, so this row tracks only what works today.

Autotune results (TPUv7, full effort)

Kernel Shape Helion (ms) Torch (ms) Helion speedup torch.compile speedup
fused_linear_jsd [1024, 4096] 0.6384 0.8579 1.34x 2.28x
fused_linear_jsd [4096, 4096] 0.6742 1.2014 1.78x 2.39x
grpo_loss [4, 512, 2048] 1.0945 1.4744 1.35x 10.34x
grpo_loss [4, 1024, 32000] 1.1172 1.3331 1.19x 10.38x

Both kernels already pass test_examples.py on Pallas (no @xfailIfPallas)
but were missing from benchmarks/run_tpu.py's KERNEL_MAPPINGS, so they
weren't exercised by the full-coverage TPU sweep.

This wires both into KERNEL_MAPPINGS with a baseline + shape generator:
- fused_linear_jsd: bench fused_linear_jsd_kernel (JSD-only on logits)
  against a local PyTorch baseline. The example's autograd-wrapped
  fused_linear_jsd_fwd path uses jsd_kernel internally and fails Pallas
  accuracy check, so this PR covers only the test-passing JSD-only path.
- grpo_loss: bench helion_grpo_loss against torch_grpo_loss from the
  example file. Logits cast to fp32 in the baseline to match the
  kernel's internal upcast. The backward kernel is xfailIfPallas, so
  this row covers forward only.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 14, 2026
@norx1991 norx1991 marked this pull request as ready for review May 15, 2026 22:41
@norx1991 norx1991 merged commit 70458c3 into main May 15, 2026
35 of 38 checks passed
norx1991 added a commit that referenced this pull request May 16, 2026
…ernel list

Picks up the two kernels landed via #2421. The benchmark_dispatch
workflow will now exercise all 27 KERNEL_MAPPINGS entries when this
wrapper is triggered.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants