diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 322fc77e99..fa8977dae2 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -1195,7 +1195,11 @@ def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int | N len(config.reduction_loops) <= self.reduction_loop or config.reduction_loops[self.reduction_loop] is None ): - return max(1, next_power_of_2(block_size_info.size_hint())) + size = max(1, block_size_info.size_hint()) + # Backends override static_rdim_size to control whether the + # persistent-reduction extent is rounded up to a power of two + # (Triton/CuTe) or kept exact (Pallas). + return CompileEnvironment.current().backend.static_rdim_size(size) return config.reduction_loops[self.reduction_loop] diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index b6212eabc9..f6bc18e454 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -951,8 +951,18 @@ def compute_shape( slice_size = compute_slice_size(k, size) if slice_size != 1: - rdim = env.allocate_reduction_dimension(slice_size) - output_size.append(rdim.var) + if ( + isinstance(slice_size, int) + and not env.backend.pad_factory_tensors_to_power_of_2 + ): + # On backends that don't pad factory ops to + # power-of-2, keep concrete dims concrete so shape + # inference can prove equality with concretely-sized + # buffers (matches _device_indexing_size). + output_size.append(slice_size) + else: + rdim = env.allocate_reduction_dimension(slice_size) + output_size.append(rdim.var) else: output_size.append(1) elif isinstance(k, torch.Tensor): diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 05688e4765..c2e41a4969 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -501,6 +501,19 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: ): output_sizes.append(output_size) continue + # On backends that don't pad factory ops to power-of-2, + # concrete int dims must stay concrete so subsequent shape + # inference can prove equality with other concretely-sized + # buffers (e.g. host-allocated accumulators via new_zeros). + # Allocating a reduction-dim block here would introduce a + # fresh unbacked symbol that does not unify with the int + # even when the hint matches. + if ( + isinstance(output_size, int) + and not env.backend.pad_factory_tensors_to_power_of_2 + ): + output_sizes.append(output_size) + continue rdim = CompileEnvironment.current().allocate_reduction_dimension( output_size ) diff --git a/test/test_examples.py b/test/test_examples.py index 5305deb1db..91965691d2 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1211,7 +1211,10 @@ def test_layernorm_reduction_not_divisible(self): ) @xfailIfCute("CuTe LayerNorm backward example still returns incorrect results") - @xfailIfPallas("InductorLoweringError") + @xfailIfPallas( + "block_sizes=[32, 1] violates TPU 128-alignment for 1D fp32 mean/rstd; " + "shapes also have dim < 128 (escape hatch needs full-dim coverage)" + ) @skipIfA10G("accuracy check fails on A10G GPUs") def test_layernorm_bwd(self): """Test combined backward pass for layer norm with bias, including regression coverage.""" @@ -1281,6 +1284,53 @@ def test_layernorm_bwd(self): atol=atol, ) + def test_layernorm_bwd_non_pow2_dim(self): + """layer_norm_bwd traces and runs when the feature dim is not a power of 2. + + Regression test: full-dim slices like ``x[mb, :]`` used to allocate a + fresh unbacked symbol for the trailing extent, while host-side + accumulators allocated via ``new_zeros([n])`` stayed concrete. The two + sides could not be unified at trace time, raising a shape-broadcast + error at ``grad_w_acc += torch.sum(...)``. + """ + if _get_backend() != "pallas": + self.skipTest( + "Pallas-only regression: u1 mismatch on non-pow2 trailing dim" + ) + + batch_size, dim = 512, 384 # dim=384 is not a power of 2 + eps = 1e-4 + torch.manual_seed(0) + x = -2.3 + 0.5 * torch.randn([batch_size, dim], device=DEVICE, dtype=HALF_DTYPE) + weight = torch.randn([dim], device=DEVICE, dtype=HALF_DTYPE) + bias = torch.randn([dim], device=DEVICE, dtype=HALF_DTYPE) + grad_out = torch.randn([batch_size, dim], device=DEVICE, dtype=HALF_DTYPE) + + x_fp32 = x.to(torch.float32) + mean = x_fp32.mean(dim=-1) + rstd = torch.rsqrt(x_fp32.var(dim=-1, unbiased=False) + eps) + + x_ref = x.clone().detach().requires_grad_(True) + w_ref = weight.clone().detach().requires_grad_(True) + b_ref = bias.clone().detach().requires_grad_(True) + torch.nn.functional.layer_norm(x_ref, [dim], w_ref, b_ref, eps).backward( + grad_out + ) + expected = (x_ref.grad.detach(), w_ref.grad.detach(), b_ref.grad.detach()) + + # Tolerances are loose because the M-axis reduction in bf16 is + # inherently noisy at this batch size; this test guards trace-time + # shape inference for non-pow2 trailing dims, not numerical fidelity. + check_example( + "layer_norm", + (grad_out, x, mean, rstd, weight, True), + expected, + fn_name="layer_norm_bwd", + block_sizes=[128, 128], + rtol=1.0, + atol=1.0, + ) + def test_softmax_bwd(self): m, n = 2048, 2048 x = torch.randn([m, n], device=DEVICE, dtype=torch.bfloat16, requires_grad=True)