Skip to content

[Pallas] Skip factory tensor padding for Pallas backend#2279

Draft
norx1991 wants to merge 2 commits into
mainfrom
yifeixu/pallas-skip-factory-padding
Draft

[Pallas] Skip factory tensor padding for Pallas backend#2279
norx1991 wants to merge 2 commits into
mainfrom
yifeixu/pallas-skip-factory-padding

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented May 5, 2026

Summary

_PadTensorFactoryMode rewrote on-device tensor factory ops (zeros, ones, empty, full, new_*) to round their integer dim sizes up to next_power_of_2. That's a Triton convenience but Pallas requires exact dims — when a kernel allocates x.new_zeros([n]) inside hl.tile with non-power-of-2 n, the padded buffer was next_power_of_2(n) while in-loop full-dim loads stayed at n, producing a JAX broadcast error like add got incompatible shapes for broadcasting: (block, n), (1, pow2) at trace time.

This adds a Backend.pad_factory_tensors_to_power_of_2 flag (default True), overrides it to False on PallasBackend, and gates both call sites of patch_tensor_factories on it.

Test plan

Added test/test_pallas.py::TestPallas::test_new_zeros_non_pow2_full_dim — runs in CI. Verified locally that the test fails on main (add got incompatible shapes for broadcasting: (32, 10240), (1, 16384)) and passes with the fix.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 5, 2026
norx1991 added a commit that referenced this pull request May 5, 2026
After #2279 + this PR fix the trace-time bugs, the residual Pallas
failure on the existing shapes is a TPU alignment violation from the
hardcoded block_sizes=[32, 1] (offset arithmetic for 1D fp32 mean/rstd
which requires 128-alignment), not the original InductorLoweringError.
norx1991 added 2 commits May 6, 2026 13:27
Pallas does not require power-of-2 dims, so padding tensor factory ops
(zeros/ones/empty/full/...) caused broadcast mismatches against unpadded
full-tensor loads. This was masked when n was already a power of 2 but
surfaced in matmul_layernorm-style kernels where the trailing dim is a
specialized non-power-of-2 int.

Add Backend.pad_factory_tensors_to_power_of_2 (default True), override
to False on PallasBackend, and gate the two callers of
patch_tensor_factories on it.
Verifies new_zeros([n]) inside hl.tile keeps n exact on Pallas instead
of rounding to next_power_of_2. Previously triggered a JAX broadcast
error like (block, n), (1, pow2) at trace time.
@norx1991 norx1991 force-pushed the yifeixu/pallas-skip-factory-padding branch from 5c00665 to 75cb51b Compare May 6, 2026 20:28
norx1991 added a commit that referenced this pull request May 6, 2026
…ference

When a kernel does `x[tile, :]` over a host-origin tensor and follows it
with operations that combine the result with a host-allocated buffer of
the same trailing extent (e.g. `acc = x.new_zeros([n]); acc + sum_result`),
shape inference used to allocate a fresh unbacked symbol for the slice's
trailing dim while the host accumulator stayed concrete (post #2279).
The two sides could not be unified at trace time, raising broadcast
errors at the binop.

This was previously masked by the dispatch-mode padding that forced both
sides through `next_power_of_2` rounding, but is exposed once Pallas
disables that padding.

Fix: on backends that don't pad factory ops (Pallas), keep concrete int
dims concrete in `_device_indexing_size` and `SubscriptIndexing.compute_shape`.
Also route `ReductionLoopBlockSizeSource.from_config` through
`backend.static_rdim_size` so the persistent-reduction extent matches
what the rest of the codegen computes.

Triton and CuTe paths are unchanged: their `pad_factory_tensors_to_power_of_2`
property remains `True`, so they keep allocating reduction-dim symbols
and `next_power_of_2`-rounding the persistent-reduction extent.
norx1991 added a commit that referenced this pull request May 6, 2026
After #2279 + this PR fix the trace-time bugs, the residual Pallas
failure on the existing shapes is a TPU alignment violation from the
hardcoded block_sizes=[32, 1] (offset arithmetic for 1D fp32 mean/rstd
which requires 128-alignment), not the original InductorLoweringError.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant