Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
self.epilogue_subtile_autotune_choices: tuple[int | None, ...] | None = None
self.epilogue_subtile_k_hint: int = 0
self.has_pallas_inner_loops: bool = False
self.has_pallas_symbolic_bounds: bool = False
self.has_symbolic_or_data_dependent_bounds: bool = False
self.cute_tcgen05_search_enabled: bool = False
# Allowed values of tcgen05_cluster_m the autotuner is allowed to
# *search* over. None means "use the default set defined by
Expand Down Expand Up @@ -813,7 +813,7 @@ def normalize(
f"{key} is only supported for tcgen05-enabled CuTe matmul kernels"
)
if self.has_pallas_inner_loops:
if self.has_pallas_symbolic_bounds:
if self.has_symbolic_or_data_dependent_bounds:
# "unroll" uses Python range() which can't handle traced bounds.
# Between the remaining options, prefer "fori_loop": it handles
# both DMA-aligned and unaligned inner blocks, while
Expand Down Expand Up @@ -1162,7 +1162,7 @@ def _flat_fields(
fields.update(self.backend_tunable_fragments)
if self.has_pallas_inner_loops:
choices = VALID_PALLAS_LOOP_TYPES
if self.has_pallas_symbolic_bounds:
if self.has_symbolic_or_data_dependent_bounds:
# Exclude "unroll" (uses Python range(), can't handle traced
# bounds) and put "fori_loop" first: it handles both DMA-aligned
# and unaligned inner blocks, while "emit_pipeline" fails on
Expand Down
4 changes: 2 additions & 2 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ def _add_config_choices(
):
_add_config_range_choice([block_id], allow_static_range=allow_static_range)

if has_symbolic_bounds and config_spec.backend_name == "pallas":
config_spec.has_pallas_symbolic_bounds = True
if has_symbolic_bounds or has_data_dependent_bounds:
config_spec.has_symbolic_or_data_dependent_bounds = True


def _add_config_range_choice(
Expand Down
28 changes: 28 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,6 +1987,34 @@ def double_use(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ref = (acc_ref + running[:, :, None] * running[:, :, None]).to(a.dtype)
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)

def test_data_dependent_loop_bounds(self) -> None:
"""Data-dependent loop: hl.tile(0, n) where n comes from a tensor."""

@helion.kernel(backend="pallas", static_shapes=True)
def data_dependent_sum(
data: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor:
B = lengths.size(0)
out = torch.zeros([B], dtype=data.dtype, device=data.device)
for seg in hl.grid(B):
n = lengths[seg]
acc = hl.zeros([1], dtype=data.dtype)
for tile in hl.tile(0, n):
acc = acc + data[tile].sum(dim=0).unsqueeze(0)
out[seg] = acc.squeeze(0)
return out

N = 256
B = 4
data = torch.randn(N, device=DEVICE, dtype=torch.float32)
lengths = torch.tensor([128, 256, 128, 256], device=DEVICE, dtype=torch.int32)
code, result = code_and_output(
data_dependent_sum,
(data, lengths),
)
ref = torch.stack([data[: lengths[i]].sum() for i in range(B)])
torch.testing.assert_close(result, ref, rtol=1e-4, atol=1e-4)


@skipUnlessPallas("JAX/Pallas TPU not available")
class TestPallasIndirectGather(TestCase):
Expand Down
Loading