[Pallas] Add test for fused_linear_jsd_fwd autograd path#2456
Draft
norx1991 wants to merge 2 commits into
Draft
Conversation
The existing `test_fused_linear_jsd` covers only the simpler
`fused_linear_jsd_kernel` (JSD on pre-computed logits, no chunking, no
gradient). The autograd path that the example advertises as the
user-facing API ("For memory-efficient fused linear + JSD, use
fused_linear_jsd_fwd instead") had no test, so a regression in
FusedLinearJSDFunction's chunking math or in jsd_kernel's autotune
would not surface in CI.
Shape (m=512, n=4096, k=4096) is chosen so the chunk_size heuristic
picks > 1 (so the chunked path actually runs) and so the matmul
produces softmax-stable logits at temperature = sqrt(hidden_dim).
Without a pinned config, fused_linear_jsd_fwd triggers full autotune on jsd_kernel on first call. On H100 with no autotune cache, this pushed the test suite past its job-level time budget (exit 137). Pin block_sizes=[16] (small, safe across backends) so the test is fast and deterministic. Also shrink the shape from (512, 4096, 4096) to (128, 512, 1024) — still exercises chunked path (chunk_size=8) but with much smaller matmul. Drop @skipIfXPU: the existing test_fused_linear_jsd has it for legacy reasons, but with a pinned config and smaller shape there's no reason to preemptively skip. CI will surface any XPU-specific issue.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
test_fused_linear_jsd_fwdtotest/test_examples.pycovering the autograd-wrapped path (fused_linear_jsd_fwd→FusedLinearJSDFunction→jsd_kernelper chunk).The existing
test_fused_linear_jsdcovers only the simplerfused_linear_jsd_kernel(JSD on pre-computed logits, no chunking, no gradient). The autograd path that the example advertises as the user-facing API ("For memory-efficient fused linear + JSD, use fused_linear_jsd_fwd instead") had no test, so a regression inFusedLinearJSDFunction's chunking math or injsd_kernel's autotune would not surface in CI.Shape
(m=128, n=512, k=1024)is chosen so the chunk_size heuristic picks > 1 (chunked path actually runs) and so the matmul produces softmax-stable logits attemperature = sqrt(hidden_dim).jsd_kernel's config is pinned (block_sizes=[16]) to skip autotune and keep the test cheap in CI.