From 2aaaebae1fb779ebc340db10f7a76fe408b536e5 Mon Sep 17 00:00:00 2001 From: Dunfan Lu Date: Sat, 2 May 2026 06:05:46 +0000 Subject: [PATCH] [Pallas] When there are data-dependent loop bounds, also use fori_loop instead of unroll stack-info: PR: https://github.com/pytorch/helion/pull/2212, branch: AmesingFlank/stack/34 --- helion/autotuner/config_spec.py | 6 +++--- helion/language/loops.py | 4 ++-- test/test_pallas.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 730c57e43d..c2b99b27ed 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -281,7 +281,7 @@ def __init__( self.epilogue_subtile_autotune_choices: tuple[int | None, ...] | None = None self.epilogue_subtile_k_hint: int = 0 self.has_pallas_inner_loops: bool = False - self.has_pallas_symbolic_bounds: bool = False + self.has_symbolic_or_data_dependent_bounds: bool = False self.cute_tcgen05_search_enabled: bool = False # Allowed values of tcgen05_cluster_m the autotuner is allowed to # *search* over. None means "use the default set defined by @@ -813,7 +813,7 @@ def normalize( f"{key} is only supported for tcgen05-enabled CuTe matmul kernels" ) if self.has_pallas_inner_loops: - if self.has_pallas_symbolic_bounds: + if self.has_symbolic_or_data_dependent_bounds: # "unroll" uses Python range() which can't handle traced bounds. # Between the remaining options, prefer "fori_loop": it handles # both DMA-aligned and unaligned inner blocks, while @@ -1162,7 +1162,7 @@ def _flat_fields( fields.update(self.backend_tunable_fragments) if self.has_pallas_inner_loops: choices = VALID_PALLAS_LOOP_TYPES - if self.has_pallas_symbolic_bounds: + if self.has_symbolic_or_data_dependent_bounds: # Exclude "unroll" (uses Python range(), can't handle traced # bounds) and put "fori_loop" first: it handles both DMA-aligned # and unaligned inner blocks, while "emit_pipeline" fails on diff --git a/helion/language/loops.py b/helion/language/loops.py index c2bceceed9..83fe4ace45 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -435,8 +435,8 @@ def _add_config_choices( ): _add_config_range_choice([block_id], allow_static_range=allow_static_range) - if has_symbolic_bounds and config_spec.backend_name == "pallas": - config_spec.has_pallas_symbolic_bounds = True + if has_symbolic_bounds or has_data_dependent_bounds: + config_spec.has_symbolic_or_data_dependent_bounds = True def _add_config_range_choice( diff --git a/test/test_pallas.py b/test/test_pallas.py index c371c233b3..dd399e904f 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -1987,6 +1987,34 @@ def double_use(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ref = (acc_ref + running[:, :, None] * running[:, :, None]).to(a.dtype) torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2) + def test_data_dependent_loop_bounds(self) -> None: + """Data-dependent loop: hl.tile(0, n) where n comes from a tensor.""" + + @helion.kernel(backend="pallas", static_shapes=True) + def data_dependent_sum( + data: torch.Tensor, lengths: torch.Tensor + ) -> torch.Tensor: + B = lengths.size(0) + out = torch.zeros([B], dtype=data.dtype, device=data.device) + for seg in hl.grid(B): + n = lengths[seg] + acc = hl.zeros([1], dtype=data.dtype) + for tile in hl.tile(0, n): + acc = acc + data[tile].sum(dim=0).unsqueeze(0) + out[seg] = acc.squeeze(0) + return out + + N = 256 + B = 4 + data = torch.randn(N, device=DEVICE, dtype=torch.float32) + lengths = torch.tensor([128, 256, 128, 256], device=DEVICE, dtype=torch.int32) + code, result = code_and_output( + data_dependent_sum, + (data, lengths), + ) + ref = torch.stack([data[: lengths[i]].sum() for i in range(B)]) + torch.testing.assert_close(result, ref, rtol=1e-4, atol=1e-4) + @skipUnlessPallas("JAX/Pallas TPU not available") class TestPallasIndirectGather(TestCase):