diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index fed5c90c4..08bce7806 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -55,6 +55,17 @@ def experimental(self) -> bool: """Whether this backend is experimental and should emit a warning.""" return True + @property + def pad_factory_tensors_to_power_of_2(self) -> bool: + """Whether on-device tensor factory ops (zeros/ones/empty/full/...) should + have their integer dim sizes rounded up to the next power of 2. + + Triton requires power-of-2 block sizes, so the default is True. Pallas + does not require this and the padding causes broadcast mismatches + against unpadded full-tensor loads. + """ + return True + @property def codegen_name(self) -> str: """Backend name used to look up registered codegen functions.""" @@ -952,6 +963,10 @@ class PallasBackend(Backend): def name(self) -> str: return "pallas" + @property + def pad_factory_tensors_to_power_of_2(self) -> bool: + return False + def max_reduction_threads(self) -> int | None: return None diff --git a/helion/_compiler/host_function.py b/helion/_compiler/host_function.py index cdf31f67b..ea8dc3e3a 100644 --- a/helion/_compiler/host_function.py +++ b/helion/_compiler/host_function.py @@ -151,9 +151,14 @@ def __init__( propagate_types(self) with measure("HostFunction.finalize_config_spec"): env.finalize_config_spec() + _factory_padding = ( + patch_tensor_factories() + if env.backend.pad_factory_tensors_to_power_of_2 + else contextlib.nullcontext() + ) with ( measure("HostFunction.lower_to_device_ir"), - patch_tensor_factories(), + _factory_padding, ): self.device_ir = lower_to_device_ir(self) diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 8d5c03a62..05688e476 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -2428,7 +2428,10 @@ def visit_For(self, node: ast.For) -> TypeInfo: self.device_loop_depth += device_loop _maybe_patch_tensor_factories = ( patch_tensor_factories - if self.device_loop_depth > 0 + if ( + self.device_loop_depth > 0 + and CompileEnvironment.current().backend.pad_factory_tensors_to_power_of_2 + ) else contextlib.nullcontext ) with _maybe_patch_tensor_factories(): diff --git a/test/test_pallas.py b/test/test_pallas.py index 43ff071b2..25a391275 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -149,6 +149,16 @@ def pallas_sum_reduce_dim0(x: torch.Tensor) -> torch.Tensor: return out +@helion.kernel(backend="pallas", static_shapes=True) +def pallas_new_zeros_full_dim(x: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.empty_like(x) + for tile_m in hl.tile(m): + zeros = x.new_zeros([n], dtype=x.dtype) + out[tile_m, :] = x[tile_m, :] + zeros[None, :] + return out + + @helion.kernel(backend="pallas", static_shapes=True) def pallas_sum_reduce_middle(x: torch.Tensor) -> torch.Tensor: b, _n, m = x.size() @@ -538,6 +548,24 @@ def test_inplace_add(self) -> None: # x should be mutated in place torch.testing.assert_close(x, expected) + def test_new_zeros_non_pow2_full_dim(self) -> None: + """Tensor factories inside a tile loop must size buffers to the exact + dim, not next_power_of_2, when the dim is non-power-of-2. + + Regression test: a 1D ``new_zeros([n])`` allocated inside ``hl.tile`` + used to be silently rounded up to ``next_power_of_2(n)`` for all + backends. On Pallas this produced scratch shapes that did not match + the actual tensor extent, raising a JAX broadcast error like + ``add got incompatible shapes for broadcasting: (block, n), (1, pow2)`` + at trace time. + """ + n = 10240 # not a power of 2; next_power_of_2(10240) == 16384 + x = torch.randn(4096, n, device=DEVICE, dtype=torch.bfloat16) + code, result = code_and_output(pallas_new_zeros_full_dim, (x,)) + torch.testing.assert_close(result, x) + # Generated code should reference the exact dim, not its pow-2 padding. + self.assertNotIn("16384", code) + def test_pointwise_mul(self) -> None: args = ( torch.randn(1024, device=DEVICE, dtype=torch.float32),