Skip to content

[Pallas] Keep concrete dim sizes concrete in slice/reduction shape inference#2280

Closed
norx1991 wants to merge 2 commits into
yifeixu/pallas-skip-factory-paddingfrom
yifeixu/pallas-keep-concrete-dims
Closed

[Pallas] Keep concrete dim sizes concrete in slice/reduction shape inference#2280
norx1991 wants to merge 2 commits into
yifeixu/pallas-skip-factory-paddingfrom
yifeixu/pallas-keep-concrete-dims

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented May 5, 2026

Summary

Stacked on #2279.

When a Pallas kernel does x[tile, :] over a host-origin tensor and then combines the result with a host-allocated buffer of the same trailing extent (e.g. acc = x.new_zeros([n]); acc + sum_result), shape inference used to allocate a fresh unbacked symbol for the slice's trailing dim while the host accumulator stayed concrete (post-#2279). The two sides could not be unified at trace time, raising broadcast errors at the binop:

RuntimeError: Attempting to broadcast a dimension of length 512 at -1!
Mismatching argument at index 1 had torch.Size([512]); but expected shape should be broadcastable to [384]
While processing: grad_w_acc += torch.sum(dy_mb * x_hat, dim=0)

This was previously masked by the dispatch-mode padding (_PadTensorFactoryMode) that forced both sides through next_power_of_2 rounding so they happened to share an unbacked symbol. Disabling that padding for Pallas (#2279) exposes the trace-time mismatch.

This PR:

  • _device_indexing_size (type_propagation.py) and SubscriptIndexing.compute_shape (indexing_strategy.py): on backends that don't pad factory ops to power-of-2, keep concrete int slice extents concrete instead of allocating a fresh reduction-dim symbol.
  • ReductionLoopBlockSizeSource.from_config (compile_environment.py): route the persistent-reduction extent through backend.static_rdim_size so the host-side _RDIM_SIZE_* constant agrees with what the rest of the codegen computes.

Triton and CuTe paths are unchanged: pad_factory_tensors_to_power_of_2 stays True for them, so they keep allocating reduction-dim symbols and next_power_of_2-rounding the persistent-reduction extent.

Together with #2279, this unblocks tracing of examples/layer_norm.py::layer_norm_bwd on Pallas with non-power-of-2 trailing dims.

Test plan

Added test/test_examples.py::TestExamples::test_layernorm_bwd_non_pow2_dim (Pallas-only, runs in CI). Verified locally that without this PR the test fails with the broadcast error above; with this PR + #2279 it passes. All existing Pallas tests in test/test_pallas.py still pass (88 pass, 4 xfail unchanged).

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 5, 2026
@norx1991 norx1991 force-pushed the yifeixu/pallas-skip-factory-padding branch from 5c00665 to 75cb51b Compare May 6, 2026 20:28
norx1991 added 2 commits May 6, 2026 13:28
…ference

When a kernel does `x[tile, :]` over a host-origin tensor and follows it
with operations that combine the result with a host-allocated buffer of
the same trailing extent (e.g. `acc = x.new_zeros([n]); acc + sum_result`),
shape inference used to allocate a fresh unbacked symbol for the slice's
trailing dim while the host accumulator stayed concrete (post #2279).
The two sides could not be unified at trace time, raising broadcast
errors at the binop.

This was previously masked by the dispatch-mode padding that forced both
sides through `next_power_of_2` rounding, but is exposed once Pallas
disables that padding.

Fix: on backends that don't pad factory ops (Pallas), keep concrete int
dims concrete in `_device_indexing_size` and `SubscriptIndexing.compute_shape`.
Also route `ReductionLoopBlockSizeSource.from_config` through
`backend.static_rdim_size` so the persistent-reduction extent matches
what the rest of the codegen computes.

Triton and CuTe paths are unchanged: their `pad_factory_tensors_to_power_of_2`
property remains `True`, so they keep allocating reduction-dim symbols
and `next_power_of_2`-rounding the persistent-reduction extent.
After #2279 + this PR fix the trace-time bugs, the residual Pallas
failure on the existing shapes is a TPU alignment violation from the
hardcoded block_sizes=[32, 1] (offset arithmetic for 1D fp32 mean/rstd
which requires 128-alignment), not the original InductorLoweringError.
@norx1991 norx1991 force-pushed the yifeixu/pallas-keep-concrete-dims branch from 22a3941 to 51c1183 Compare May 6, 2026 20:28
@norx1991
Copy link
Copy Markdown
Contributor Author

norx1991 commented May 6, 2026

Closing for now — the "keep concrete dim sizes concrete" change in this PR breaks test/test_pallas.py::TestPallas::test_nested_fori_loop_scratch_scoping (ShapeMismatch: [u2, 8, 256] vs [u3, 8, 256]). The mechanism is that bypassing allocate_reduction_dimension for concrete-int slice extents perturbs symbol allocation across nested-tile broadcasting in a way I haven't fully understood yet. The u1 shape-prop bug at non-pow2 trailing dims still needs a fix, but a different shape than this. #2282 has been re-stacked directly on #2279 so the per-tile element cap fix can land independently.

@norx1991 norx1991 closed this May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant