[Pallas] Keep concrete dim sizes concrete in slice/reduction shape inference#2280
Closed
norx1991 wants to merge 2 commits into
Closed
[Pallas] Keep concrete dim sizes concrete in slice/reduction shape inference#2280norx1991 wants to merge 2 commits into
norx1991 wants to merge 2 commits into
Conversation
5c00665 to
75cb51b
Compare
…ference When a kernel does `x[tile, :]` over a host-origin tensor and follows it with operations that combine the result with a host-allocated buffer of the same trailing extent (e.g. `acc = x.new_zeros([n]); acc + sum_result`), shape inference used to allocate a fresh unbacked symbol for the slice's trailing dim while the host accumulator stayed concrete (post #2279). The two sides could not be unified at trace time, raising broadcast errors at the binop. This was previously masked by the dispatch-mode padding that forced both sides through `next_power_of_2` rounding, but is exposed once Pallas disables that padding. Fix: on backends that don't pad factory ops (Pallas), keep concrete int dims concrete in `_device_indexing_size` and `SubscriptIndexing.compute_shape`. Also route `ReductionLoopBlockSizeSource.from_config` through `backend.static_rdim_size` so the persistent-reduction extent matches what the rest of the codegen computes. Triton and CuTe paths are unchanged: their `pad_factory_tensors_to_power_of_2` property remains `True`, so they keep allocating reduction-dim symbols and `next_power_of_2`-rounding the persistent-reduction extent.
After #2279 + this PR fix the trace-time bugs, the residual Pallas failure on the existing shapes is a TPU alignment violation from the hardcoded block_sizes=[32, 1] (offset arithmetic for 1D fp32 mean/rstd which requires 128-alignment), not the original InductorLoweringError.
22a3941 to
51c1183
Compare
Contributor
Author
|
Closing for now — the "keep concrete dim sizes concrete" change in this PR breaks |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacked on #2279.
When a Pallas kernel does
x[tile, :]over a host-origin tensor and then combines the result with a host-allocated buffer of the same trailing extent (e.g.acc = x.new_zeros([n]); acc + sum_result), shape inference used to allocate a fresh unbacked symbol for the slice's trailing dim while the host accumulator stayed concrete (post-#2279). The two sides could not be unified at trace time, raising broadcast errors at the binop:This was previously masked by the dispatch-mode padding (
_PadTensorFactoryMode) that forced both sides throughnext_power_of_2rounding so they happened to share an unbacked symbol. Disabling that padding for Pallas (#2279) exposes the trace-time mismatch.This PR:
_device_indexing_size(type_propagation.py) andSubscriptIndexing.compute_shape(indexing_strategy.py): on backends that don't pad factory ops to power-of-2, keep concreteintslice extents concrete instead of allocating a fresh reduction-dim symbol.ReductionLoopBlockSizeSource.from_config(compile_environment.py): route the persistent-reduction extent throughbackend.static_rdim_sizeso the host-side_RDIM_SIZE_*constant agrees with what the rest of the codegen computes.Triton and CuTe paths are unchanged:
pad_factory_tensors_to_power_of_2staysTruefor them, so they keep allocating reduction-dim symbols andnext_power_of_2-rounding the persistent-reduction extent.Together with #2279, this unblocks tracing of
examples/layer_norm.py::layer_norm_bwdon Pallas with non-power-of-2 trailing dims.Test plan
Added
test/test_examples.py::TestExamples::test_layernorm_bwd_non_pow2_dim(Pallas-only, runs in CI). Verified locally that without this PR the test fails with the broadcast error above; with this PR + #2279 it passes. All existing Pallas tests intest/test_pallas.pystill pass (88 pass, 4 xfail unchanged).