Skip to content

[Pallas] Add test for fused_linear_jsd_fwd autograd path#2456

Draft
norx1991 wants to merge 2 commits into
mainfrom
yifeixu/pallas-fused-linear-jsd-fwd
Draft

[Pallas] Add test for fused_linear_jsd_fwd autograd path#2456
norx1991 wants to merge 2 commits into
mainfrom
yifeixu/pallas-fused-linear-jsd-fwd

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented May 15, 2026

Summary

Adds test_fused_linear_jsd_fwd to test/test_examples.py covering the autograd-wrapped path (fused_linear_jsd_fwdFusedLinearJSDFunctionjsd_kernel per chunk).

The existing test_fused_linear_jsd covers only the simpler fused_linear_jsd_kernel (JSD on pre-computed logits, no chunking, no gradient). The autograd path that the example advertises as the user-facing API ("For memory-efficient fused linear + JSD, use fused_linear_jsd_fwd instead") had no test, so a regression in FusedLinearJSDFunction's chunking math or in jsd_kernel's autotune would not surface in CI.

Shape (m=128, n=512, k=1024) is chosen so the chunk_size heuristic picks > 1 (chunked path actually runs) and so the matmul produces softmax-stable logits at temperature = sqrt(hidden_dim). jsd_kernel's config is pinned (block_sizes=[16]) to skip autotune and keep the test cheap in CI.

The existing `test_fused_linear_jsd` covers only the simpler
`fused_linear_jsd_kernel` (JSD on pre-computed logits, no chunking, no
gradient). The autograd path that the example advertises as the
user-facing API ("For memory-efficient fused linear + JSD, use
fused_linear_jsd_fwd instead") had no test, so a regression in
FusedLinearJSDFunction's chunking math or in jsd_kernel's autotune
would not surface in CI.

Shape (m=512, n=4096, k=4096) is chosen so the chunk_size heuristic
picks > 1 (so the chunked path actually runs) and so the matmul
produces softmax-stable logits at temperature = sqrt(hidden_dim).
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 15, 2026
Without a pinned config, fused_linear_jsd_fwd triggers full autotune on
jsd_kernel on first call. On H100 with no autotune cache, this pushed
the test suite past its job-level time budget (exit 137). Pin
block_sizes=[16] (small, safe across backends) so the test is fast and
deterministic.

Also shrink the shape from (512, 4096, 4096) to (128, 512, 1024) — still
exercises chunked path (chunk_size=8) but with much smaller matmul.

Drop @skipIfXPU: the existing test_fused_linear_jsd has it for legacy
reasons, but with a pinned config and smaller shape there's no reason to
preemptively skip. CI will surface any XPU-specific issue.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant