Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,11 @@ def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int | N
len(config.reduction_loops) <= self.reduction_loop
or config.reduction_loops[self.reduction_loop] is None
):
return max(1, next_power_of_2(block_size_info.size_hint()))
size = max(1, block_size_info.size_hint())
# Backends override static_rdim_size to control whether the
# persistent-reduction extent is rounded up to a power of two
# (Triton/CuTe) or kept exact (Pallas).
return CompileEnvironment.current().backend.static_rdim_size(size)
return config.reduction_loops[self.reduction_loop]


Expand Down
14 changes: 12 additions & 2 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,18 @@ def compute_shape(
slice_size = compute_slice_size(k, size)

if slice_size != 1:
rdim = env.allocate_reduction_dimension(slice_size)
output_size.append(rdim.var)
if (
isinstance(slice_size, int)
and not env.backend.pad_factory_tensors_to_power_of_2
):
# On backends that don't pad factory ops to
# power-of-2, keep concrete dims concrete so shape
# inference can prove equality with concretely-sized
# buffers (matches _device_indexing_size).
output_size.append(slice_size)
else:
rdim = env.allocate_reduction_dimension(slice_size)
output_size.append(rdim.var)
else:
output_size.append(1)
elif isinstance(k, torch.Tensor):
Expand Down
13 changes: 13 additions & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,19 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
):
output_sizes.append(output_size)
continue
# On backends that don't pad factory ops to power-of-2,
# concrete int dims must stay concrete so subsequent shape
# inference can prove equality with other concretely-sized
# buffers (e.g. host-allocated accumulators via new_zeros).
# Allocating a reduction-dim block here would introduce a
# fresh unbacked symbol that does not unify with the int
# even when the hint matches.
if (
isinstance(output_size, int)
and not env.backend.pad_factory_tensors_to_power_of_2
):
output_sizes.append(output_size)
continue
rdim = CompileEnvironment.current().allocate_reduction_dimension(
output_size
)
Expand Down
52 changes: 51 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,10 @@ def test_layernorm_reduction_not_divisible(self):
)

@xfailIfCute("CuTe LayerNorm backward example still returns incorrect results")
@xfailIfPallas("InductorLoweringError")
@xfailIfPallas(
"block_sizes=[32, 1] violates TPU 128-alignment for 1D fp32 mean/rstd; "
"shapes also have dim < 128 (escape hatch needs full-dim coverage)"
)
@skipIfA10G("accuracy check fails on A10G GPUs")
def test_layernorm_bwd(self):
"""Test combined backward pass for layer norm with bias, including regression coverage."""
Expand Down Expand Up @@ -1281,6 +1284,53 @@ def test_layernorm_bwd(self):
atol=atol,
)

def test_layernorm_bwd_non_pow2_dim(self):
"""layer_norm_bwd traces and runs when the feature dim is not a power of 2.

Regression test: full-dim slices like ``x[mb, :]`` used to allocate a
fresh unbacked symbol for the trailing extent, while host-side
accumulators allocated via ``new_zeros([n])`` stayed concrete. The two
sides could not be unified at trace time, raising a shape-broadcast
error at ``grad_w_acc += torch.sum(...)``.
"""
if _get_backend() != "pallas":
self.skipTest(
"Pallas-only regression: u1 mismatch on non-pow2 trailing dim"
)

batch_size, dim = 512, 384 # dim=384 is not a power of 2
eps = 1e-4
torch.manual_seed(0)
x = -2.3 + 0.5 * torch.randn([batch_size, dim], device=DEVICE, dtype=HALF_DTYPE)
weight = torch.randn([dim], device=DEVICE, dtype=HALF_DTYPE)
bias = torch.randn([dim], device=DEVICE, dtype=HALF_DTYPE)
grad_out = torch.randn([batch_size, dim], device=DEVICE, dtype=HALF_DTYPE)

x_fp32 = x.to(torch.float32)
mean = x_fp32.mean(dim=-1)
rstd = torch.rsqrt(x_fp32.var(dim=-1, unbiased=False) + eps)

x_ref = x.clone().detach().requires_grad_(True)
w_ref = weight.clone().detach().requires_grad_(True)
b_ref = bias.clone().detach().requires_grad_(True)
torch.nn.functional.layer_norm(x_ref, [dim], w_ref, b_ref, eps).backward(
grad_out
)
expected = (x_ref.grad.detach(), w_ref.grad.detach(), b_ref.grad.detach())

# Tolerances are loose because the M-axis reduction in bf16 is
# inherently noisy at this batch size; this test guards trace-time
# shape inference for non-pow2 trailing dims, not numerical fidelity.
check_example(
"layer_norm",
(grad_out, x, mean, rstd, weight, True),
expected,
fn_name="layer_norm_bwd",
block_sizes=[128, 128],
rtol=1.0,
atol=1.0,
)

def test_softmax_bwd(self):
m, n = 2048, 2048
x = torch.randn([m, n], device=DEVICE, dtype=torch.bfloat16, requires_grad=True)
Expand Down
Loading