Skip to content

[Pallas] Slice store values to match clamped Pallas BlockSpec ref shape#2398

Open
thcmbs wants to merge 1 commit into
pytorch:mainfrom
thcmbs:pallas-store-slice-clamped-ref
Open

[Pallas] Slice store values to match clamped Pallas BlockSpec ref shape#2398
thcmbs wants to merge 1 commit into
pytorch:mainfrom
thcmbs:pallas-store-slice-clamped-ref

Conversation

@thcmbs
Copy link
Copy Markdown
Collaborator

@thcmbs thcmbs commented May 12, 2026

When block_size > tensor.shape[dim], the launcher clamps the BlockSpec to min(block_size, dim_size) but the kernel computes block_size-shaped values. Slice the value before storing to avoid shape mismatch.

Only applies to grid-tiled dimensions that produce ":" in index_str; pl.ds() dimensions are padded instead of clamped and need no slicing.

Unblocks kl_div example on TPU.

Ideally, we might need to work on the autotuner to not fall into that case too often, but we still need to fix it in case a user wants to force it for some reason.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 12, 2026
@thcmbs thcmbs changed the title Slice store values to match clamped Pallas BlockSpec ref shape [Pallas] Slice store values to match clamped Pallas BlockSpec ref shape May 12, 2026
@thcmbs thcmbs force-pushed the pallas-store-slice-clamped-ref branch from 4f4670e to a7279ac Compare May 12, 2026 08:40
@thcmbs thcmbs marked this pull request as ready for review May 12, 2026 09:26
@thcmbs thcmbs requested review from AmesingFlank and oulgen May 12, 2026 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant