[Pallas] Skip factory tensor padding for Pallas backend#2279
Draft
norx1991 wants to merge 2 commits into
Draft
Conversation
norx1991
added a commit
that referenced
this pull request
May 5, 2026
After #2279 + this PR fix the trace-time bugs, the residual Pallas failure on the existing shapes is a TPU alignment violation from the hardcoded block_sizes=[32, 1] (offset arithmetic for 1D fp32 mean/rstd which requires 128-alignment), not the original InductorLoweringError.
Pallas does not require power-of-2 dims, so padding tensor factory ops (zeros/ones/empty/full/...) caused broadcast mismatches against unpadded full-tensor loads. This was masked when n was already a power of 2 but surfaced in matmul_layernorm-style kernels where the trailing dim is a specialized non-power-of-2 int. Add Backend.pad_factory_tensors_to_power_of_2 (default True), override to False on PallasBackend, and gate the two callers of patch_tensor_factories on it.
Verifies new_zeros([n]) inside hl.tile keeps n exact on Pallas instead of rounding to next_power_of_2. Previously triggered a JAX broadcast error like (block, n), (1, pow2) at trace time.
5c00665 to
75cb51b
Compare
norx1991
added a commit
that referenced
this pull request
May 6, 2026
…ference When a kernel does `x[tile, :]` over a host-origin tensor and follows it with operations that combine the result with a host-allocated buffer of the same trailing extent (e.g. `acc = x.new_zeros([n]); acc + sum_result`), shape inference used to allocate a fresh unbacked symbol for the slice's trailing dim while the host accumulator stayed concrete (post #2279). The two sides could not be unified at trace time, raising broadcast errors at the binop. This was previously masked by the dispatch-mode padding that forced both sides through `next_power_of_2` rounding, but is exposed once Pallas disables that padding. Fix: on backends that don't pad factory ops (Pallas), keep concrete int dims concrete in `_device_indexing_size` and `SubscriptIndexing.compute_shape`. Also route `ReductionLoopBlockSizeSource.from_config` through `backend.static_rdim_size` so the persistent-reduction extent matches what the rest of the codegen computes. Triton and CuTe paths are unchanged: their `pad_factory_tensors_to_power_of_2` property remains `True`, so they keep allocating reduction-dim symbols and `next_power_of_2`-rounding the persistent-reduction extent.
norx1991
added a commit
that referenced
this pull request
May 6, 2026
After #2279 + this PR fix the trace-time bugs, the residual Pallas failure on the existing shapes is a TPU alignment violation from the hardcoded block_sizes=[32, 1] (offset arithmetic for 1D fp32 mean/rstd which requires 128-alignment), not the original InductorLoweringError.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
_PadTensorFactoryModerewrote on-device tensor factory ops (zeros,ones,empty,full,new_*) to round their integer dim sizes up tonext_power_of_2. That's a Triton convenience but Pallas requires exact dims — when a kernel allocatesx.new_zeros([n])insidehl.tilewith non-power-of-2n, the padded buffer wasnext_power_of_2(n)while in-loop full-dim loads stayed atn, producing a JAX broadcast error likeadd got incompatible shapes for broadcasting: (block, n), (1, pow2)at trace time.This adds a
Backend.pad_factory_tensors_to_power_of_2flag (defaultTrue), overrides it toFalseonPallasBackend, and gates both call sites ofpatch_tensor_factorieson it.Test plan
Added
test/test_pallas.py::TestPallas::test_new_zeros_non_pow2_full_dim— runs in CI. Verified locally that the test fails onmain(add got incompatible shapes for broadcasting: (32, 10240), (1, 16384)) and passes with the fix.