diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index 88e352caf8..caff12e4a9 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -56,6 +56,21 @@ def experimental(self) -> bool: """Whether this backend is experimental and should emit a warning.""" return True + @property + def max_tensor_numel(self) -> int | None: + """Per-tile maximum tensor element count enforced during config search. + + Triton has a hard internal ceiling (currently 2**20) past which its + codegen rejects the kernel, so the search must avoid generating + configs that exceed it. Pallas/Mosaic has no analogous compile-time + cap; tile size is bounded by VMEM bytes (already guarded at runtime + in :mod:`helion.runtime`). Backends that don't need the cap should + return ``None`` to disable the constraint. + """ + from ..autotuner.config_generation import TRITON_MAX_TENSOR_NUMEL + + return TRITON_MAX_TENSOR_NUMEL + @property def codegen_name(self) -> str: """Backend name used to look up registered codegen functions.""" @@ -953,6 +968,12 @@ class PallasBackend(Backend): def name(self) -> str: return "pallas" + @property + def max_tensor_numel(self) -> int | None: + # No compile-time element cap on Pallas; VMEM byte budget is the + # real constraint and is enforced separately at runtime. + return None + def max_reduction_threads(self) -> int | None: return None diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 322fc77e99..daaf565b20 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -272,9 +272,14 @@ def finalize_config_spec(self) -> None: def _extract_tensor_numel_constraints(self) -> None: """Compile per-tensor numel constraints from kernel_tensor_sizes.""" - from ..autotuner.config_generation import TRITON_MAX_TENSOR_NUMEL from ..autotuner.config_spec import TensorNumelConstraint + max_numel = self.backend.max_tensor_numel + if max_numel is None: + # Backend (e.g. Pallas) has no compile-time per-tile element cap; + # VMEM byte budget is enforced separately at runtime. + return None + block_sym_to_id: dict[sympy.Symbol, int] = {} for bs in self.block_sizes: block_sym_to_id[bs.symbol()] = bs.block_id @@ -316,7 +321,7 @@ def _extract_tensor_numel_constraints(self) -> None: ordered = sorted(involved_syms, key=lambda s: sym_to_cs_idx[s]) indices = tuple(sym_to_cs_idx[s] for s in ordered) # pyrefly: ignore[unsupported-operation] - constraint_expr = numel_expr <= TRITON_MAX_TENSOR_NUMEL + constraint_expr = numel_expr <= max_numel # srepr is more canonical than str() for dedup; a false # negative only causes a harmless duplicate, not a missed one. dedup_key = sympy.srepr(constraint_expr) diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index e6a4ab3fc1..2822a956ac 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -351,10 +351,21 @@ def shrink_config( if self.num_warps_index < 0 or not self.block_size_indices: return num_threads = warps_to_threads(cast("int", flat_config[self.num_warps_index])) - # Respect Triton's maximum tensor element limit - triton_limit = TRITON_MAX_TENSOR_NUMEL + # Respect the backend's per-tile element ceiling (Triton: 2**20; + # Pallas: None, since the real bound is VMEM bytes). Unit-test + # callers may invoke shrink_config without an active environment; + # default to the Triton limit in that case. + from .._compiler.compile_environment import CompileEnvironment + + backend_limit: int | None = TRITON_MAX_TENSOR_NUMEL + if CompileEnvironment.has_current(): + backend_limit = CompileEnvironment.current().backend.max_tensor_numel theoretical_max_elements = max_elements_per_thread * num_threads - max_elements = min(theoretical_max_elements, triton_limit) + max_elements = ( + theoretical_max_elements + if backend_limit is None + else min(theoretical_max_elements, backend_limit) + ) while self.block_numel(flat_config) > max_elements: changes = 0 for i in self.block_size_indices: diff --git a/test/test_tensor_numel_constraints.py b/test/test_tensor_numel_constraints.py index b6541f3482..11070242d9 100644 --- a/test/test_tensor_numel_constraints.py +++ b/test/test_tensor_numel_constraints.py @@ -295,6 +295,7 @@ def test_mixed_symbol_shape_skipped(self) -> None: # Minimal mock of CompileEnvironment for _extract_tensor_numel_constraints env = object.__new__(CompileEnvironment) + env._backend = SimpleNamespace(max_tensor_numel=TRITON_MAX_TENSOR_NUMEL) env.block_sizes = [SimpleNamespace(symbol=lambda: b0, block_id=0)] env.config_spec = SimpleNamespace( block_sizes=SimpleNamespace( @@ -317,6 +318,30 @@ def test_mixed_symbol_shape_skipped(self) -> None: ) +class TestBackendMaxTensorNumel(unittest.TestCase): + """Backends without a per-tile element cap (Pallas) skip the constraint.""" + + def test_backend_with_no_cap_extracts_no_constraints(self) -> None: + """When backend.max_tensor_numel is None, _extract_tensor_numel_constraints + emits no constraints regardless of how large the tile would be.""" + from types import SimpleNamespace + + from helion._compiler.compile_environment import CompileEnvironment + + b0 = sympy.Symbol("u0", integer=True) + env = object.__new__(CompileEnvironment) + env._backend = SimpleNamespace(max_tensor_numel=None) + env.block_sizes = [SimpleNamespace(symbol=lambda: b0, block_id=0)] + env.config_spec = SimpleNamespace( + block_sizes=SimpleNamespace(block_id_to_index=lambda bid: bid), + tensor_numel_constraints=[], + ) + # A tile that would otherwise trigger the cap (b0 * 16384). + env.kernel_tensor_sizes = [[b0, sympy.Integer(16384)]] + env._extract_tensor_numel_constraints() + self.assertEqual(env.config_spec.tensor_numel_constraints, []) + + class TestFixedPointOverlapping(unittest.TestCase): """Verify the fixed-point loop handles overlapping constraints."""