Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ def experimental(self) -> bool:
"""Whether this backend is experimental and should emit a warning."""
return True

@property
def max_tensor_numel(self) -> int | None:
"""Per-tile maximum tensor element count enforced during config search.

Triton has a hard internal ceiling (currently 2**20) past which its
codegen rejects the kernel, so the search must avoid generating
configs that exceed it. Pallas/Mosaic has no analogous compile-time
cap; tile size is bounded by VMEM bytes (already guarded at runtime
in :mod:`helion.runtime`). Backends that don't need the cap should
return ``None`` to disable the constraint.
"""
from ..autotuner.config_generation import TRITON_MAX_TENSOR_NUMEL

return TRITON_MAX_TENSOR_NUMEL

@property
def codegen_name(self) -> str:
"""Backend name used to look up registered codegen functions."""
Expand Down Expand Up @@ -953,6 +968,12 @@ class PallasBackend(Backend):
def name(self) -> str:
return "pallas"

@property
def max_tensor_numel(self) -> int | None:
# No compile-time element cap on Pallas; VMEM byte budget is the
# real constraint and is enforced separately at runtime.
return None

def max_reduction_threads(self) -> int | None:
return None

Expand Down
9 changes: 7 additions & 2 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,14 @@ def finalize_config_spec(self) -> None:

def _extract_tensor_numel_constraints(self) -> None:
"""Compile per-tensor numel constraints from kernel_tensor_sizes."""
from ..autotuner.config_generation import TRITON_MAX_TENSOR_NUMEL
from ..autotuner.config_spec import TensorNumelConstraint

max_numel = self.backend.max_tensor_numel
if max_numel is None:
# Backend (e.g. Pallas) has no compile-time per-tile element cap;
# VMEM byte budget is enforced separately at runtime.
return None

block_sym_to_id: dict[sympy.Symbol, int] = {}
for bs in self.block_sizes:
block_sym_to_id[bs.symbol()] = bs.block_id
Expand Down Expand Up @@ -316,7 +321,7 @@ def _extract_tensor_numel_constraints(self) -> None:
ordered = sorted(involved_syms, key=lambda s: sym_to_cs_idx[s])
indices = tuple(sym_to_cs_idx[s] for s in ordered)
# pyrefly: ignore[unsupported-operation]
constraint_expr = numel_expr <= TRITON_MAX_TENSOR_NUMEL
constraint_expr = numel_expr <= max_numel
# srepr is more canonical than str() for dedup; a false
# negative only causes a harmless duplicate, not a missed one.
dedup_key = sympy.srepr(constraint_expr)
Expand Down
17 changes: 14 additions & 3 deletions helion/autotuner/config_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,21 @@ def shrink_config(
if self.num_warps_index < 0 or not self.block_size_indices:
return
num_threads = warps_to_threads(cast("int", flat_config[self.num_warps_index]))
# Respect Triton's maximum tensor element limit
triton_limit = TRITON_MAX_TENSOR_NUMEL
# Respect the backend's per-tile element ceiling (Triton: 2**20;
# Pallas: None, since the real bound is VMEM bytes). Unit-test
# callers may invoke shrink_config without an active environment;
# default to the Triton limit in that case.
from .._compiler.compile_environment import CompileEnvironment

backend_limit: int | None = TRITON_MAX_TENSOR_NUMEL
if CompileEnvironment.has_current():
backend_limit = CompileEnvironment.current().backend.max_tensor_numel
theoretical_max_elements = max_elements_per_thread * num_threads
max_elements = min(theoretical_max_elements, triton_limit)
max_elements = (
theoretical_max_elements
if backend_limit is None
else min(theoretical_max_elements, backend_limit)
)
while self.block_numel(flat_config) > max_elements:
changes = 0
for i in self.block_size_indices:
Expand Down
25 changes: 25 additions & 0 deletions test/test_tensor_numel_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def test_mixed_symbol_shape_skipped(self) -> None:

# Minimal mock of CompileEnvironment for _extract_tensor_numel_constraints
env = object.__new__(CompileEnvironment)
env._backend = SimpleNamespace(max_tensor_numel=TRITON_MAX_TENSOR_NUMEL)
env.block_sizes = [SimpleNamespace(symbol=lambda: b0, block_id=0)]
env.config_spec = SimpleNamespace(
block_sizes=SimpleNamespace(
Expand All @@ -317,6 +318,30 @@ def test_mixed_symbol_shape_skipped(self) -> None:
)


class TestBackendMaxTensorNumel(unittest.TestCase):
"""Backends without a per-tile element cap (Pallas) skip the constraint."""

def test_backend_with_no_cap_extracts_no_constraints(self) -> None:
"""When backend.max_tensor_numel is None, _extract_tensor_numel_constraints
emits no constraints regardless of how large the tile would be."""
from types import SimpleNamespace

from helion._compiler.compile_environment import CompileEnvironment

b0 = sympy.Symbol("u0", integer=True)
env = object.__new__(CompileEnvironment)
env._backend = SimpleNamespace(max_tensor_numel=None)
env.block_sizes = [SimpleNamespace(symbol=lambda: b0, block_id=0)]
env.config_spec = SimpleNamespace(
block_sizes=SimpleNamespace(block_id_to_index=lambda bid: bid),
tensor_numel_constraints=[],
)
# A tile that would otherwise trigger the cap (b0 * 16384).
env.kernel_tensor_sizes = [[b0, sympy.Integer(16384)]]
env._extract_tensor_numel_constraints()
self.assertEqual(env.config_spec.tensor_numel_constraints, [])


class TestFixedPointOverlapping(unittest.TestCase):
"""Verify the fixed-point loop handles overlapping constraints."""

Expand Down
Loading