Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ def experimental(self) -> bool:
"""Whether this backend is experimental and should emit a warning."""
return True

@property
def pad_factory_tensors_to_power_of_2(self) -> bool:
"""Whether on-device tensor factory ops (zeros/ones/empty/full/...) should
have their integer dim sizes rounded up to the next power of 2.

Triton requires power-of-2 block sizes, so the default is True. Pallas
does not require this and the padding causes broadcast mismatches
against unpadded full-tensor loads.
"""
return True

@property
def codegen_name(self) -> str:
"""Backend name used to look up registered codegen functions."""
Expand Down Expand Up @@ -952,6 +963,10 @@ class PallasBackend(Backend):
def name(self) -> str:
return "pallas"

@property
def pad_factory_tensors_to_power_of_2(self) -> bool:
return False

def max_reduction_threads(self) -> int | None:
return None

Expand Down
7 changes: 6 additions & 1 deletion helion/_compiler/host_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,14 @@ def __init__(
propagate_types(self)
with measure("HostFunction.finalize_config_spec"):
env.finalize_config_spec()
_factory_padding = (
patch_tensor_factories()
if env.backend.pad_factory_tensors_to_power_of_2
else contextlib.nullcontext()
)
with (
measure("HostFunction.lower_to_device_ir"),
patch_tensor_factories(),
_factory_padding,
):
self.device_ir = lower_to_device_ir(self)

Expand Down
5 changes: 4 additions & 1 deletion helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2428,7 +2428,10 @@ def visit_For(self, node: ast.For) -> TypeInfo:
self.device_loop_depth += device_loop
_maybe_patch_tensor_factories = (
patch_tensor_factories
if self.device_loop_depth > 0
if (
self.device_loop_depth > 0
and CompileEnvironment.current().backend.pad_factory_tensors_to_power_of_2
)
else contextlib.nullcontext
)
with _maybe_patch_tensor_factories():
Expand Down
28 changes: 28 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ def pallas_sum_reduce_dim0(x: torch.Tensor) -> torch.Tensor:
return out


@helion.kernel(backend="pallas", static_shapes=True)
def pallas_new_zeros_full_dim(x: torch.Tensor) -> torch.Tensor:
m, n = x.size()
out = torch.empty_like(x)
for tile_m in hl.tile(m):
zeros = x.new_zeros([n], dtype=x.dtype)
out[tile_m, :] = x[tile_m, :] + zeros[None, :]
return out


@helion.kernel(backend="pallas", static_shapes=True)
def pallas_sum_reduce_middle(x: torch.Tensor) -> torch.Tensor:
b, _n, m = x.size()
Expand Down Expand Up @@ -538,6 +548,24 @@ def test_inplace_add(self) -> None:
# x should be mutated in place
torch.testing.assert_close(x, expected)

def test_new_zeros_non_pow2_full_dim(self) -> None:
"""Tensor factories inside a tile loop must size buffers to the exact
dim, not next_power_of_2, when the dim is non-power-of-2.

Regression test: a 1D ``new_zeros([n])`` allocated inside ``hl.tile``
used to be silently rounded up to ``next_power_of_2(n)`` for all
backends. On Pallas this produced scratch shapes that did not match
the actual tensor extent, raising a JAX broadcast error like
``add got incompatible shapes for broadcasting: (block, n), (1, pow2)``
at trace time.
"""
n = 10240 # not a power of 2; next_power_of_2(10240) == 16384
x = torch.randn(4096, n, device=DEVICE, dtype=torch.bfloat16)
code, result = code_and_output(pallas_new_zeros_full_dim, (x,))
torch.testing.assert_close(result, x)
# Generated code should reference the exact dim, not its pow-2 padding.
self.assertNotIn("16384", code)

def test_pointwise_mul(self) -> None:
args = (
torch.randn(1024, device=DEVICE, dtype=torch.float32),
Expand Down
Loading