From e7f46c9907c56164eb93c3de648bdce04cac126c Mon Sep 17 00:00:00 2001 From: eche Date: Sat, 9 May 2026 18:40:23 -0700 Subject: [PATCH 1/9] Add compiler seed heuristics --- docs/aot_autotuning.md | 2 +- helion/_compiler/seed_heuristics/__init__.py | 41 ++ helion/_compiler/seed_heuristics/common.py | 70 ++ helion/_compiler/seed_heuristics/cute.py | 97 +++ helion/_compiler/seed_heuristics/registry.py | 27 + helion/_compiler/seed_heuristics/triton.py | 76 ++ helion/_hardware.py | 146 ++++ helion/autotuner/aot_cache.py | 147 +--- helion/autotuner/base_search.py | 2 +- helion/autotuner/config_generation.py | 8 +- helion/autotuner/config_spec.py | 73 +- helion/autotuner/heuristic_generator.py | 2 +- helion/experimental/aot_runner.py | 2 +- helion/language/matmul_ops.py | 15 + helion/runtime/kernel.py | 4 + test/test_aot_autotuning.py | 9 +- test/test_dot_requirements.py | 124 ---- test/test_seed_heuristics.py | 701 +++++++++++++++++++ 18 files changed, 1212 insertions(+), 334 deletions(-) create mode 100644 helion/_compiler/seed_heuristics/__init__.py create mode 100644 helion/_compiler/seed_heuristics/common.py create mode 100644 helion/_compiler/seed_heuristics/cute.py create mode 100644 helion/_compiler/seed_heuristics/registry.py create mode 100644 helion/_compiler/seed_heuristics/triton.py create mode 100644 helion/_hardware.py create mode 100644 test/test_seed_heuristics.py diff --git a/docs/aot_autotuning.md b/docs/aot_autotuning.md index 5c7ce13085..97f87bb861 100644 --- a/docs/aot_autotuning.md +++ b/docs/aot_autotuning.md @@ -204,7 +204,7 @@ A heuristic file is specific to one device kind + compute capability — to add support for a new GPU, generate a fresh heuristic on that hardware and commit it alongside any existing files. The recipe works for any GPU; the device-kind / compute-capability suffix is whatever -{py:class}`~helion.autotuner.aot_cache.HardwareInfo` reports for the +{py:class}`~helion._hardware.HardwareInfo` reports for the target. ### Step-by-step diff --git a/helion/_compiler/seed_heuristics/__init__.py b/helion/_compiler/seed_heuristics/__init__.py new file mode 100644 index 0000000000..d525f76503 --- /dev/null +++ b/helion/_compiler/seed_heuristics/__init__.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .common import dedupe_configs +from .cute import CuteTcgen05ClusterM2Heuristic +from .triton import TritonSkinnyGemmHeuristic + +if TYPE_CHECKING: + from ...runtime.config import Config + from ..compile_environment import CompileEnvironment + from ..device_ir import DeviceIR + from .registry import SeedHeuristicType + +# All active heuristics by backend +HEURISTICS_BY_BACKEND: dict[str, tuple[SeedHeuristicType, ...]] = { + "cute": (CuteTcgen05ClusterM2Heuristic,), + "triton": (TritonSkinnyGemmHeuristic,), +} + + +def get_heuristics(backend: str) -> tuple[SeedHeuristicType, ...]: + return HEURISTICS_BY_BACKEND.get(backend, ()) + + +def compiler_seed_configs( + env: CompileEnvironment, + device_ir: DeviceIR, +) -> list[Config]: + configs: list[Config] = [] + env.config_spec.compiler_seed_heuristics = [] + for heuristic in get_heuristics(env.backend_name): + if not heuristic.is_eligible(env, device_ir): + continue + + # If the heuristic is eligible, we must get a valid config + # We add the heuristic name to list of applied heuristics + config = heuristic.get_config(env, device_ir) + configs.append(config) + env.config_spec.compiler_seed_heuristics.append(heuristic.name) + return dedupe_configs(configs) diff --git a/helion/_compiler/seed_heuristics/common.py b/helion/_compiler/seed_heuristics/common.py new file mode 100644 index 0000000000..473945f16f --- /dev/null +++ b/helion/_compiler/seed_heuristics/common.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Iterable +from typing import Sequence +from typing import cast + +if TYPE_CHECKING: + from ...autotuner.config_spec import BlockSizeSpec + from ...runtime.config import Config + from ..compile_environment import CompileEnvironment + +HardwareTarget = tuple[str, str | None] + + +def dedupe_configs(configs: Iterable[Config]) -> list[Config]: + result: list[Config] = [] + seen: set[Config] = set() + for config in configs: + if config in seen: + continue + seen.add(config) + result.append(config) + return result + + +def matches_hardware( + env: CompileEnvironment, + targets: tuple[HardwareTarget, ...], +) -> bool: + from ..._hardware import get_hardware_info + + hardware = get_hardware_info(env.device) + return ( + (hardware.device_kind, hardware.compute_capability) in targets + or (hardware.device_kind, None) in targets + ) + + +def clamp_block_size_targets( + env: CompileEnvironment, + block_dims: Sequence[tuple[int, int, int]], +) -> list[int] | None: + """Clamp block-size targets against the live ConfigSpec constraints. + + Each entry in *block_dims* is ``(block_id, static_dim, target)``. + Returns the clamped block sizes, or ``None`` if any axis cannot + satisfy its floor/ceiling constraints. + """ + block_sizes: list[int] = [] + for block_id, static_dim, target in block_dims: + try: + spec = cast( + "BlockSizeSpec", + env.config_spec.block_sizes.block_id_lookup(block_id), + ) + except KeyError: + return None + candidate = min(target, static_dim) + if candidate < 1: + return None + candidate = 1 << (candidate.bit_length() - 1) + floor = max(spec.min_size, spec.autotuner_min) + if candidate < floor: + return None + candidate = min(candidate, spec.max_size) + if candidate < floor: + return None + block_sizes.append(candidate) + return block_sizes diff --git a/helion/_compiler/seed_heuristics/cute.py b/helion/_compiler/seed_heuristics/cute.py new file mode 100644 index 0000000000..43fdf729e0 --- /dev/null +++ b/helion/_compiler/seed_heuristics/cute.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import cast + +from ...runtime.config import Config +from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_M +from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N +from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_L2_GROUPING +from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_PID_TYPE +from .registry import SeedHeuristic + +if TYPE_CHECKING: + from ...autotuner.config_fragment import BlockSizeFragment + from ..compile_environment import CompileEnvironment + from ..device_ir import DeviceIR + + +class CuteTcgen05ClusterM2Heuristic(SeedHeuristic): + name = "cute_tcgen05_cluster_m2" + backend = "cute" + + @classmethod + def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: + spec = env.config_spec + constraints = spec._tcgen05_cluster_m2_search_constraints + if ( + constraints is None + or TCGEN05_TWO_CTA_SEED_PID_TYPE not in spec.allowed_pid_types + ): + return False + if len(spec.block_sizes) != 3: + return False + + bm_fragment = cast("BlockSizeFragment", spec.block_sizes[0]._fragment(spec)) + bn_fragment = cast("BlockSizeFragment", spec.block_sizes[1]._fragment(spec)) + return ( + bm_fragment.low <= TCGEN05_TWO_CTA_BLOCK_M <= bm_fragment.high + and bn_fragment.low <= TCGEN05_TWO_CTA_BLOCK_N <= bn_fragment.high + and cls._select_bk(env) is not None + ) + + @classmethod + def get_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: + spec = env.config_spec + bk = cls._select_bk(env) + if bk is None: + raise AssertionError(f"{cls.name} get_config called while ineligible") + + block_sizes = [ + TCGEN05_TWO_CTA_BLOCK_M, + TCGEN05_TWO_CTA_BLOCK_N, + bk, + ] + if spec.indexing.length == 3: + # Pure matmul has exactly the A/B/C indexing slots. Fused epilogues + # add more memory ops, so leave those seeds to the spec default + # rather than constructing a partial list. + return Config( + block_sizes=block_sizes, + l2_groupings=[TCGEN05_TWO_CTA_SEED_L2_GROUPING], + pid_type=TCGEN05_TWO_CTA_SEED_PID_TYPE, + tcgen05_cluster_m=2, + tcgen05_num_epi_warps=4, + indexing=[ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + ], + ) + + return Config( + block_sizes=[ + TCGEN05_TWO_CTA_BLOCK_M, + TCGEN05_TWO_CTA_BLOCK_N, + bk, + ], + l2_groupings=[TCGEN05_TWO_CTA_SEED_L2_GROUPING], + pid_type=TCGEN05_TWO_CTA_SEED_PID_TYPE, + tcgen05_cluster_m=2, + # Matches the validated tcgen05 search restriction. + tcgen05_num_epi_warps=4, + ) + + @staticmethod + def _select_bk(env: CompileEnvironment) -> int | None: + spec = env.config_spec + constraints = spec._tcgen05_cluster_m2_search_constraints + if constraints is None or len(spec.block_sizes) != 3: + return None + bk_fragment = cast("BlockSizeFragment", spec.block_sizes[2]._fragment(spec)) + bk = bk_fragment.high + while bk >= bk_fragment.low: + if spec._tcgen05_cluster_m2_bk_is_valid(bk, constraints): + return bk + bk //= 2 + return None diff --git a/helion/_compiler/seed_heuristics/registry.py b/helion/_compiler/seed_heuristics/registry.py new file mode 100644 index 0000000000..70b42dc0ce --- /dev/null +++ b/helion/_compiler/seed_heuristics/registry.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import ClassVar + +if TYPE_CHECKING: + from ...runtime.config import Config + from ..compile_environment import CompileEnvironment + from ..device_ir import DeviceIR + + +class SeedHeuristic: + """Base class for compiler-owned autotune seed heuristics.""" + + name: ClassVar[str] + backend: ClassVar[str] + + @classmethod + def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: + raise NotImplementedError + + @classmethod + def get_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: + raise NotImplementedError + + +SeedHeuristicType = type[SeedHeuristic] diff --git a/helion/_compiler/seed_heuristics/triton.py b/helion/_compiler/seed_heuristics/triton.py new file mode 100644 index 0000000000..e7e23c0f88 --- /dev/null +++ b/helion/_compiler/seed_heuristics/triton.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ...runtime.config import Config +from .common import clamp_block_size_targets +from .common import matches_hardware +from .registry import SeedHeuristic + +if TYPE_CHECKING: + from ..compile_environment import CompileEnvironment + from ..device_ir import DeviceIR + + +class TritonSkinnyGemmHeuristic(SeedHeuristic): + name = "triton_skinny_gemm" + backend = "triton" + MIN_ASPECT_RATIO = 8 + BLOCK_TARGETS = (64, 64, 256) + HARDWARE_TARGETS = (("cuda", "sm90"), ("rocm", "gfx950")) + + @classmethod + def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: + if not matches_hardware(env, cls.HARDWARE_TARGETS): + return False + facts = env.config_spec.matmul_facts + if len(facts) != 1: + return False + fact = facts[0] + if fact.lhs_ndim != 2 or fact.rhs_ndim != 2: + return False + if ( + fact.static_m is None + or fact.static_n is None + or fact.static_k is None + or fact.m_block_id is None + or fact.n_block_id is None + or fact.k_block_id is None + ): + return False + if max(fact.static_m, fact.static_n) < cls.MIN_ASPECT_RATIO * min( + fact.static_m, fact.static_n + ): + return False + return ( + clamp_block_size_targets( + env, + [ + (fact.m_block_id, fact.static_m, cls.BLOCK_TARGETS[0]), + (fact.n_block_id, fact.static_n, cls.BLOCK_TARGETS[1]), + (fact.k_block_id, fact.static_k, cls.BLOCK_TARGETS[2]), + ], + ) + is not None + ) + + @classmethod + def get_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: + assert len(env.config_spec.matmul_facts) == 1 + fact = env.config_spec.matmul_facts[0] + assert fact.static_m is not None + assert fact.static_n is not None + assert fact.static_k is not None + assert fact.m_block_id is not None + assert fact.n_block_id is not None + assert fact.k_block_id is not None + block_sizes = clamp_block_size_targets( + env, + [ + (fact.m_block_id, fact.static_m, cls.BLOCK_TARGETS[0]), + (fact.n_block_id, fact.static_n, cls.BLOCK_TARGETS[1]), + (fact.k_block_id, fact.static_k, cls.BLOCK_TARGETS[2]), + ], + ) + assert block_sizes is not None + return Config(block_sizes=block_sizes) diff --git a/helion/_hardware.py b/helion/_hardware.py new file mode 100644 index 0000000000..667ae9621f --- /dev/null +++ b/helion/_hardware.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import dataclasses +import functools + +import torch + +# Compute capability lists for fallback (newest to oldest) +_CUDA_COMPUTE_CAPS: list[str] = [ + "sm100", + "sm90", + "sm89", + "sm87", + "sm86", + "sm80", + "sm75", + "sm72", + "sm70", +] + +_ROCM_ARCHS: list[str] = [ + "gfx950", + "gfx942", + "gfx941", + "gfx940", + "gfx90a", + "gfx908", + "gfx906", + "gfx900", +] + + +@dataclasses.dataclass(frozen=True) +class HardwareInfo: + """ + Hardware information for cache keys and heuristic selection. + + Attributes: + device_kind: Device type ('cuda', 'rocm', 'xpu') + hardware_name: Device name (e.g., 'NVIDIA H100', 'gfx90a') + runtime_version: Runtime version (e.g., '12.4', 'gfx90a') + compute_capability: Compute capability for heuristics (e.g., 'sm90', 'gfx90a') + """ + + device_kind: str + hardware_name: str + runtime_version: str + compute_capability: str + + @property + def hardware_id(self) -> str: + """Get a unique identifier string for this hardware.""" + safe_name = self.hardware_name.replace(" ", "_") + return f"{self.device_kind}_{safe_name}_{self.runtime_version}" + + def get_compatible_compute_ids(self) -> list[str]: + """ + Get a list of compatible compute IDs for fallback, ordered from current to oldest. + + For CUDA/ROCm, returns the current compute capability followed by all older + compatible architectures. This allows using heuristics tuned on older hardware + when newer hardware-specific heuristics aren't available. + """ + if self.device_kind == "cuda": + arch_list = _CUDA_COMPUTE_CAPS + elif self.device_kind == "rocm": + arch_list = _ROCM_ARCHS + else: + return [self.compute_capability] + + try: + current_idx = arch_list.index(self.compute_capability) + return arch_list[current_idx:] + except ValueError: + return [self.compute_capability, *arch_list] + + +@functools.cache +def get_hardware_info(device: torch.device | None = None) -> HardwareInfo: + """ + Get hardware information for the current or specified device. + + Args: + device: Optional device to get info for. If None, uses first available GPU or CPU. + + Returns: + HardwareInfo with device details for caching and heuristic lookup. + """ + # XPU (Intel) path + if ( + device is not None + and device.type == "xpu" + and getattr(torch, "xpu", None) is not None + and torch.xpu.is_available() + ): + props = torch.xpu.get_device_properties(device) + return HardwareInfo( + device_kind="xpu", + hardware_name=props.name, + runtime_version=props.driver_version, + compute_capability=props.name, # XPU doesn't have compute capability + ) + + # CUDA/ROCm path + if torch.cuda.is_available(): + dev = ( + device + if device is not None and device.type == "cuda" + else torch.device("cuda:0") + ) + props = torch.cuda.get_device_properties(dev) + + if torch.version.cuda is not None: + return HardwareInfo( + device_kind="cuda", + hardware_name=props.name, + runtime_version=str(torch.version.cuda), + compute_capability=f"sm{props.major}{props.minor}", + ) + if torch.version.hip is not None: + return HardwareInfo( + device_kind="rocm", + hardware_name=props.gcnArchName, + runtime_version=torch.version.hip, + compute_capability=props.gcnArchName, + ) + + # TPU / Pallas path + try: + import jax + + tpu_devices = [d for d in jax.devices() if d.platform == "tpu"] + if tpu_devices: + first_tpu = tpu_devices[0] + return HardwareInfo( + device_kind="tpu", + hardware_name=first_tpu.device_kind, + runtime_version=jax.__version__, + compute_capability=first_tpu.device_kind, + ) + except ImportError: + pass + + raise RuntimeError( + "No supported GPU or TPU device found. Helion requires CUDA, ROCm, XPU, or TPU." + ) diff --git a/helion/autotuner/aot_cache.py b/helion/autotuner/aot_cache.py index a466f9c64c..11ee1a6de8 100644 --- a/helion/autotuner/aot_cache.py +++ b/helion/autotuner/aot_cache.py @@ -18,9 +18,7 @@ from __future__ import annotations import csv -import dataclasses from dataclasses import dataclass -import functools import hashlib import importlib import importlib.util @@ -39,6 +37,7 @@ import torch +from .._hardware import get_hardware_info from ..experimental.aot_kernel import _flatten_key_value from ..experimental.aot_kernel import extract_key_features from ..experimental.aot_kernel import extract_shape_features @@ -54,150 +53,6 @@ log: logging.Logger = logging.getLogger(__name__) -# Compute capability lists for fallback (newest to oldest) -_CUDA_COMPUTE_CAPS: list[str] = [ - "sm100", - "sm90", - "sm89", - "sm87", - "sm86", - "sm80", - "sm75", - "sm72", - "sm70", -] - -_ROCM_ARCHS: list[str] = [ - "gfx950", - "gfx942", - "gfx941", - "gfx940", - "gfx90a", - "gfx908", - "gfx906", - "gfx900", -] - - -@dataclasses.dataclass(frozen=True) -class HardwareInfo: - """ - Hardware information for cache keys and heuristic file discovery. - - Attributes: - device_kind: Device type ('cuda', 'rocm', 'xpu') - hardware_name: Device name (e.g., 'NVIDIA H100', 'gfx90a') - runtime_version: Runtime version (e.g., '12.4', 'gfx90a') - compute_capability: Compute capability for heuristics (e.g., 'sm90', 'gfx90a') - """ - - device_kind: str - hardware_name: str - runtime_version: str - compute_capability: str - - @property - def hardware_id(self) -> str: - """Get a unique identifier string for this hardware.""" - safe_name = self.hardware_name.replace(" ", "_") - return f"{self.device_kind}_{safe_name}_{self.runtime_version}" - - def get_compatible_compute_ids(self) -> list[str]: - """ - Get a list of compatible compute IDs for fallback, ordered from current to oldest. - - For CUDA/ROCm, returns the current compute capability followed by all older - compatible architectures. This allows using heuristics tuned on older hardware - when newer hardware-specific heuristics aren't available. - """ - if self.device_kind == "cuda": - arch_list = _CUDA_COMPUTE_CAPS - elif self.device_kind == "rocm": - arch_list = _ROCM_ARCHS - else: - return [self.compute_capability] - - try: - current_idx = arch_list.index(self.compute_capability) - return arch_list[current_idx:] - except ValueError: - return [self.compute_capability, *arch_list] - - -@functools.cache -def get_hardware_info(device: torch.device | None = None) -> HardwareInfo: - """ - Get hardware information for the current or specified device. - - This is the single source of truth for hardware detection, used by both - local cache and AOT cache. - - Args: - device: Optional device to get info for. If None, uses first available GPU or CPU. - - Returns: - HardwareInfo with device details for caching and heuristic lookup. - """ - # XPU (Intel) path - if ( - device is not None - and device.type == "xpu" - and getattr(torch, "xpu", None) is not None - and torch.xpu.is_available() - ): - props = torch.xpu.get_device_properties(device) - return HardwareInfo( - device_kind="xpu", - hardware_name=props.name, - runtime_version=props.driver_version, - compute_capability=props.name, # XPU doesn't have compute capability - ) - - # CUDA/ROCm path - if torch.cuda.is_available(): - dev = ( - device - if device is not None and device.type == "cuda" - else torch.device("cuda:0") - ) - props = torch.cuda.get_device_properties(dev) - - if torch.version.cuda is not None: - return HardwareInfo( - device_kind="cuda", - hardware_name=props.name, - runtime_version=str(torch.version.cuda), - compute_capability=f"sm{props.major}{props.minor}", - ) - if torch.version.hip is not None: - return HardwareInfo( - device_kind="rocm", - hardware_name=props.gcnArchName, - runtime_version=torch.version.hip, - compute_capability=props.gcnArchName, - ) - - # TPU / Pallas path - try: - import jax - - tpu_devices = [d for d in jax.devices() if d.platform == "tpu"] - if tpu_devices: - first_tpu = tpu_devices[0] - return HardwareInfo( - device_kind="tpu", - hardware_name=first_tpu.device_kind, - runtime_version=jax.__version__, - compute_capability=first_tpu.device_kind, - ) - except ImportError: - pass - - raise RuntimeError( - "No supported GPU or TPU device found. Helion requires CUDA, ROCm, XPU, or TPU." - ) - - # Environment variable to control AOT mode AOT_MODE_ENV = "HELION_AOT_MODE" AOT_DATA_DIR_ENV = "HELION_AOT_DATA_DIR" diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 55f64367e7..e02bfd3ad8 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -795,7 +795,7 @@ def _generate_best_available_population_flat(self) -> list[FlatConfig]: seen.add(transferred_config) result.append(flat) - # Compiler-owned seeds come from ConfigSpec.autotune_seed_configs(); + # Compiler-owned seeds come from ConfigSpec.compiler_seed_configs; # they encode backend/compiler heuristics and complement user seed configs. for flat, transferred_config in self.config_gen.seed_flat_config_pairs(): if transferred_config not in seen: diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index e6a4ab3fc1..5ca13a431a 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -384,14 +384,14 @@ def default_flat(self) -> FlatConfig: def seed_flat_config_pairs(self) -> list[tuple[FlatConfig, Config]]: """Return ConfigSpec-provided seeds as flat and normalized configs. - ``ConfigSpec.autotune_seed_configs()`` is compiler-owned and must - return configs that match the live spec structurally. ``InvalidConfig`` + ``ConfigSpec.compiler_seed_configs`` is compiler-owned and must + contain configs that match the live spec structurally. ``InvalidConfig`` means overrides make a seed inapplicable; other flatten/unflatten exceptions are programming errors and intentionally surface. """ result: list[tuple[FlatConfig, Config]] = [] seen: set[Config] = set() - for config in self.config_spec.autotune_seed_configs(): + for config in self.config_spec.compiler_seed_configs: try: flat = self.flatten(config) normalized = self.unflatten(flat) @@ -473,7 +473,7 @@ def random_population_flat( # Initial population order is default -> user seed configs -> compiler seeds # -> random. This preserves user seed priority without dropping built-in - # backend/compiler seeds from ConfigSpec.autotune_seed_configs(). + # backend/compiler seeds from ConfigSpec.compiler_seed_configs. for flat, _config in self.user_seed_flat_config_pairs( user_seed_configs, log_func ): diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 679c3c9e9b..bd7804d2df 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -78,7 +78,6 @@ from .._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_M from .._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N from .._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_MAX_K_TILES -from .._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_L2_GROUPING from .._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_PID_TYPE from ..exc import InvalidConfig from .block_id_sequence import BlockIdSequence @@ -125,6 +124,21 @@ class Tcgen05ClusterM2SearchConstraints(NamedTuple): max_k_tiles: int +class MatmulFact(NamedTuple): + """Shape facts recorded when matmul requirements are applied.""" + + lhs_ndim: int + rhs_ndim: int + m_block_id: int | None + n_block_id: int | None + k_block_id: int | None + static_m: int | None + static_n: int | None + static_k: int | None + lhs_dtype: torch.dtype + rhs_dtype: torch.dtype + + def shrink_block_sizes_for_numel_constraints( constraints: list[TensorNumelConstraint], block_sizes: list[int], @@ -399,6 +413,9 @@ def __init__( # no loud crash to alert a user who bypasses autotune via an # explicit config, so normalize() must reject the unsafe values. self._tcgen05_num_epi_warps_validation_choices: tuple[int, ...] | None = None + self.compiler_seed_configs: list[helion.Config] = [] + self.compiler_seed_heuristics: list[str] = [] + self.matmul_facts: list[MatmulFact] = [] self.store_indices: list[int] = [] self.backend_tunable_fragments = self.backend.tunable_fragments() unknown_tunables = set(self.backend_tunable_fragments) - BACKEND_TUNABLE_KEYS @@ -545,53 +562,6 @@ def allow_tcgen05_cluster_m2_search( # reopening the search choices. self.restrict_tcgen05_cluster_m_search((1, 2)) - def _tcgen05_cluster_m2_seed_config(self) -> helion.Config | None: - constraints = self._tcgen05_cluster_m2_search_constraints - if ( - constraints is None - or TCGEN05_TWO_CTA_SEED_PID_TYPE not in self.allowed_pid_types - ): - return None - if len(self.block_sizes) != 3: - return None - - bm_fragment = cast("BlockSizeFragment", self.block_sizes[0]._fragment(self)) - bn_fragment = cast("BlockSizeFragment", self.block_sizes[1]._fragment(self)) - bk_fragment = cast("BlockSizeFragment", self.block_sizes[2]._fragment(self)) - if not ( - bm_fragment.low <= TCGEN05_TWO_CTA_BLOCK_M <= bm_fragment.high - and bn_fragment.low <= TCGEN05_TWO_CTA_BLOCK_N <= bn_fragment.high - ): - return None - - bk = bk_fragment.high - while bk >= bk_fragment.low: - if self._tcgen05_cluster_m2_bk_is_valid(bk, constraints): - seed_config: dict[str, Any] = { - "block_sizes": [ - TCGEN05_TWO_CTA_BLOCK_M, - TCGEN05_TWO_CTA_BLOCK_N, - bk, - ], - "l2_groupings": [TCGEN05_TWO_CTA_SEED_L2_GROUPING], - "pid_type": TCGEN05_TWO_CTA_SEED_PID_TYPE, - "tcgen05_cluster_m": 2, - # Matches the validated tcgen05 search restriction. - "tcgen05_num_epi_warps": 4, - } - # Pure matmul has exactly the A/B/C indexing slots. Fused - # epilogues add more memory ops, so leave those seeds to the - # spec default rather than constructing a partial list. - if self.indexing.length == 3: - seed_config["indexing"] = [ - "tensor_descriptor", - "tensor_descriptor", - "tensor_descriptor", - ] - return helion.Config(**seed_config) - bk //= 2 - return None - @staticmethod def _tcgen05_cluster_m2_bk_is_valid( bk: int, constraints: Tcgen05ClusterM2SearchConstraints @@ -600,13 +570,6 @@ def _tcgen05_cluster_m2_bk_is_valid( constraints.static_k // bk <= constraints.max_k_tiles ) - def autotune_seed_configs(self) -> list[helion.Config]: - """Return validated extra configs that should be benchmarked early.""" - cluster_m2_seed = self._tcgen05_cluster_m2_seed_config() - if cluster_m2_seed is None: - return [] - return [cluster_m2_seed] - def _fix_tcgen05_cluster_m2_search_config(self, config: dict[str, object]) -> None: """Canonicalize unvalidated search-only ``cluster_m=2`` products.""" if not ( diff --git a/helion/autotuner/heuristic_generator.py b/helion/autotuner/heuristic_generator.py index 17a8efc777..2479e72dd0 100644 --- a/helion/autotuner/heuristic_generator.py +++ b/helion/autotuner/heuristic_generator.py @@ -881,7 +881,7 @@ def generate_heuristic( Returns: Dictionary mapping kernel names to HeuristicResult """ - from .aot_cache import get_hardware_info + from .._hardware import get_hardware_info if target is None: target = PerformanceTarget() diff --git a/helion/experimental/aot_runner.py b/helion/experimental/aot_runner.py index f99b2d314b..4417210164 100644 --- a/helion/experimental/aot_runner.py +++ b/helion/experimental/aot_runner.py @@ -33,7 +33,7 @@ from typing import Any import uuid -from ..autotuner.aot_cache import get_hardware_info +from .._hardware import get_hardware_info from ..autotuner.heuristic_generator import PerformanceTarget from ..autotuner.heuristic_generator import evaluate_heuristic from ..autotuner.heuristic_generator import generate_heuristic diff --git a/helion/language/matmul_ops.py b/helion/language/matmul_ops.py index c50d4037e7..82ee290783 100644 --- a/helion/language/matmul_ops.py +++ b/helion/language/matmul_ops.py @@ -31,6 +31,7 @@ from .._compiler.matmul_utils import _emit_tl_dot_scaled from .._compiler.matmul_utils import _needs_f32_accumulator from .._compiler.matmul_utils import emit_tl_dot_with_padding +from ..autotuner.config_spec import MatmulFact from . import _decorators if TYPE_CHECKING: @@ -268,6 +269,20 @@ def static_problem_extent(size: int | torch.SymInt) -> int | None: static_m = static_problem_extent(m) static_n = static_problem_extent(n) static_k = static_problem_extent(k) + env.config_spec.matmul_facts.append( + MatmulFact( + lhs_ndim=lhs.ndim, + rhs_ndim=rhs.ndim, + m_block_id=env.get_block_id(m), + n_block_id=env.get_block_id(n), + k_block_id=env.get_block_id(k), + static_m=static_m, + static_n=static_n, + static_k=static_k, + lhs_dtype=lhs.dtype, + rhs_dtype=rhs.dtype, + ) + ) if ( env.backend_name == "cute" and lhs.ndim == 2 diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 989186e957..22d698a50a 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -49,6 +49,7 @@ from .._compiler.inductor_lowering_extra import patch_inductor_lowerings from .._compiler.kernel_compiler import KernelCompiler from .._compiler.output_header import assert_no_conflicts +from .._compiler.seed_heuristics import compiler_seed_configs from .._compiler.variable_origin import ArgumentOrigin from .._dist_utils import _find_process_group_name from .._dist_utils import check_config_consistancy as dist_check_config_consistancy @@ -490,6 +491,9 @@ def __init__( raise self.env.config_spec.configure_epilogue_subtile_autotune(args) + self.env.config_spec.compiler_seed_configs = compiler_seed_configs( + self.env, self.host_function.device_ir + ) def _apply_mark_static(self, args: tuple[object, ...]) -> None: """ diff --git a/test/test_aot_autotuning.py b/test/test_aot_autotuning.py index 65f0f2f30c..cc653e0d7f 100644 --- a/test/test_aot_autotuning.py +++ b/test/test_aot_autotuning.py @@ -14,6 +14,7 @@ import pytest import torch +from helion._hardware import HardwareInfo from helion._testing import onlyBackends from helion.autotuner.aot_cache import ShapeKey from helion.autotuner.aot_cache import _deserialize_tuple @@ -33,10 +34,16 @@ class TestShapeKey: """Tests for ShapeKey class.""" def test_to_dict_and_back(self) -> None: + hardware = HardwareInfo( + device_kind="cuda", + hardware_name="RTX4090", + runtime_version="12.4", + compute_capability="sm89", + ) key = ShapeKey( kernel_name="test_kernel", specialization_key=(1024, 2048, "float32"), - hardware_id="cuda_RTX4090_12.4", + hardware_id=hardware.hardware_id, ) d = key.to_dict() restored = ShapeKey.from_dict(d) diff --git a/test/test_dot_requirements.py b/test/test_dot_requirements.py index def39730a4..44a03c44f8 100644 --- a/test/test_dot_requirements.py +++ b/test/test_dot_requirements.py @@ -10,7 +10,6 @@ from helion._compiler.cute.tcgen05_constants import TCGEN05_ONE_CTA_MAX_BLOCK_M from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_M from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N -from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_L2_GROUPING from helion._testing import DEVICE from helion._testing import HALF_DTYPE from helion._testing import RefEagerTestDisabled @@ -19,8 +18,6 @@ from helion._testing import onlyBackends from helion._testing import patch_cute_mma_support from helion._testing import skipIfMTIA -from helion.autotuner.pattern_search import InitialPopulationStrategy -from helion.autotuner.pattern_search import PatternSearch from helion.exc import InvalidConfig import helion.language as hl @@ -357,127 +354,6 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.assertEqual(two_cta_config["pid_type"], "persistent_interleaved") self.assertEqual(two_cta_config["block_sizes"][:3], [256, 256, 16]) - @onlyBackends(["cute"]) - def test_cute_tcgen05_two_cta_seeded_in_initial_populations(self) -> None: - @helion.kernel(backend="cute") - def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - m, k = x.size() - _, n = y.size() - out = torch.empty([m, n], dtype=x.dtype, device=x.device) - for tile_m, tile_n in hl.tile([m, n]): - acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - for tile_k in hl.tile(k): - acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) - out[tile_m, tile_n] = acc.to(x.dtype) - return out - - args = ( - torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), - torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), - ) - with patch_cute_mma_support(): - bound = cute_matmul_mma.bind(args) - - def assert_seeded(configs: list[helion.Config]) -> None: - seeded = [ - config.config - for config in configs - if config.config["tcgen05_cluster_m"] == 2 - ] - self.assertEqual(len(seeded), 1) - seed = seeded[0] - self.assertEqual( - seed["block_sizes"][:3], - [TCGEN05_TWO_CTA_BLOCK_M, TCGEN05_TWO_CTA_BLOCK_N, 128], - ) - self.assertEqual( - seed["indexing"], - ["tensor_descriptor", "tensor_descriptor", "tensor_descriptor"], - ) - self.assertEqual(seed["l2_groupings"], [TCGEN05_TWO_CTA_SEED_L2_GROUPING]) - self.assertEqual(seed["pid_type"], "persistent_interleaved") - self.assertEqual(seed["tcgen05_num_epi_warps"], 4) - - config_gen = bound.config_spec.create_config_generation() - zero_flat = config_gen.random_population_flat(0) - self.assertEqual(len(zero_flat), 1) - zero_config = config_gen.unflatten(zero_flat[0]) - self.assertEqual(zero_config.config["tcgen05_cluster_m"], 1) - one_flat = config_gen.random_population_flat(1) - self.assertEqual(len(one_flat), 1) - one_config = config_gen.unflatten(one_flat[0]) - self.assertEqual(one_config.config["tcgen05_cluster_m"], 1) - one_config_population = config_gen.random_population(1) - self.assertEqual(len(one_config_population), 1) - self.assertEqual(one_config_population[0].config["tcgen05_cluster_m"], 1) - assert_seeded(config_gen.random_population(2)) - - acf_config_gen = bound.config_spec.create_config_generation( - advanced_controls_files=["/tmp/helion-test.acf"] - ) - acf_configs = acf_config_gen.random_population(2) - self.assertEqual(len(acf_configs), 2) - self.assertEqual( - {config.config["advanced_controls_file"] for config in acf_configs}, - {"/tmp/helion-test.acf"}, - ) - assert_seeded(acf_configs) - - with patch.object( - PatternSearch, "_find_similar_cached_configs", return_value=[] - ): - search = PatternSearch( - bound, - args, - initial_population=30, - initial_population_strategy=InitialPopulationStrategy.FROM_BEST_AVAILABLE, - best_available_pad_random=False, - ) - configs = [ - search.config_gen.unflatten(flat) - for flat in search._generate_initial_population_flat() - ] - self.assertEqual(len(configs), 2) - self.assertEqual(configs[0].config["tcgen05_cluster_m"], 1) - assert_seeded(configs) - - @onlyBackends(["cute"]) - def test_cute_tcgen05_two_cta_seed_indexing_matches_live_spec(self) -> None: - @helion.kernel(backend="cute") - def cute_matmul_mma_epilogue( - x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor - ) -> torch.Tensor: - m, k = x.size() - _, n = y.size() - out = torch.empty([m, n], dtype=x.dtype, device=x.device) - for tile_m, tile_n in hl.tile([m, n]): - acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - for tile_k in hl.tile(k): - acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) - out[tile_m, tile_n] = (acc + bias[tile_n]).to(x.dtype) - return out - - args = ( - torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), - torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), - torch.empty([4096], device=DEVICE, dtype=HALF_DTYPE), - ) - with patch_cute_mma_support(): - bound = cute_matmul_mma_epilogue.bind(args) - self.assertGreater(bound.config_spec.indexing.length, 3) - - configs = bound.config_spec.create_config_generation().random_population(2) - seeded = [ - config.config - for config in configs - if config.config["tcgen05_cluster_m"] == 2 - ] - self.assertEqual(len(seeded), 1) - self.assertEqual( - len(seeded[0]["indexing"]), - bound.config_spec.indexing.length, - ) - @onlyBackends(["cute"]) def test_cute_tcgen05_two_cta_projection_falls_back_before_mutation( self, diff --git a/test/test_seed_heuristics.py b/test/test_seed_heuristics.py new file mode 100644 index 0000000000..c2da37ad7f --- /dev/null +++ b/test/test_seed_heuristics.py @@ -0,0 +1,701 @@ +from __future__ import annotations + +from unittest.mock import MagicMock +from unittest.mock import patch + +import torch + +import helion +from helion._compiler.backend import TritonBackend +from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_M +from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N +from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_L2_GROUPING +from helion._compiler.seed_heuristics import compiler_seed_configs +from helion._compiler.seed_heuristics.cute import CuteTcgen05ClusterM2Heuristic +from helion._compiler.seed_heuristics.triton import TritonSkinnyGemmHeuristic +from helion._hardware import HardwareInfo +from helion._testing import DEVICE +from helion._testing import HALF_DTYPE +from helion._testing import TestCase +from helion._testing import default_cute_mma_support +from helion._testing import onlyBackends +from helion._testing import patch_cute_mma_support +from helion.autotuner.config_spec import BlockSizeSpec +from helion.autotuner.config_spec import ConfigSpec +from helion.autotuner.config_spec import MatmulFact +from helion.autotuner.pattern_search import InitialPopulationStrategy +from helion.autotuner.pattern_search import PatternSearch +import helion.language as hl + +HOPPER_HARDWARE = HardwareInfo( + device_kind="cuda", + hardware_name="NVIDIA H100", + runtime_version="12.8", + compute_capability="sm90", +) +MI350_HARDWARE = HardwareInfo( + device_kind="rocm", + hardware_name="AMD MI350", + runtime_version="7.0", + compute_capability="gfx950", +) + + +class TestMatmulFacts(TestCase): + @onlyBackends(["triton"]) + def test_matmul_facts_record_kernel_structure(self) -> None: + @helion.kernel(backend="triton") + def triton_matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc.to(x.dtype) + return out + + @helion.kernel(backend="triton") + def triton_matmul_epilogue( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = (acc + bias[tile_n]).to(x.dtype) + return out + + @helion.kernel(backend="triton") + def triton_two_matmuls( + x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc0 = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc1 = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc0 = torch.addmm(acc0, x[tile_m, tile_k], y[tile_k, tile_n]) + acc1 = torch.addmm(acc1, x[tile_m, tile_k], z[tile_k, tile_n]) + out[tile_m, tile_n] = (acc0 + acc1).to(x.dtype) + return out + + @helion.kernel(backend="triton") + def triton_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m = x.size(0) + out = torch.empty_like(x) + for tile_m in hl.tile(m): + out[tile_m] = x[tile_m] + y[tile_m] + return out + + x = torch.empty([1024, 4096], device=DEVICE, dtype=HALF_DTYPE) + y = torch.empty([4096, 8192], device=DEVICE, dtype=HALF_DTYPE) + z = torch.empty([4096, 8192], device=DEVICE, dtype=HALF_DTYPE) + bias = torch.empty([8192], device=DEVICE, dtype=HALF_DTYPE) + add_x = torch.empty([1024], device=DEVICE, dtype=HALF_DTYPE) + add_y = torch.empty([1024], device=DEVICE, dtype=HALF_DTYPE) + + cases = ( + ("gemm", triton_matmul, (x, y), 1), + ("gemm_epilogue", triton_matmul_epilogue, (x, y, bias), 1), + ("gemm_gemm", triton_two_matmuls, (x, y, z), 2), + ("add", triton_add, (add_x, add_y), 0), + ) + + for name, kernel, args, expected_facts in cases: + with ( + self.subTest(name=name), + patch( + "helion._hardware.get_hardware_info", + return_value=HOPPER_HARDWARE, + ), + ): + bound = kernel.bind(args) + + self.assertEqual(len(bound.config_spec.matmul_facts), expected_facts) + if expected_facts == 0: + self.assertEqual(bound.config_spec.compiler_seed_configs, []) + self.assertEqual(bound.config_spec.compiler_seed_heuristics, []) + for fact in bound.config_spec.matmul_facts: + self.assertEqual(fact.lhs_ndim, 2) + self.assertEqual(fact.rhs_ndim, 2) + self.assertEqual( + (fact.static_m, fact.static_n, fact.static_k), + (1024, 8192, 4096), + ) + self.assertIsNotNone(fact.m_block_id) + self.assertIsNotNone(fact.n_block_id) + self.assertIsNotNone(fact.k_block_id) + self.assertEqual(fact.lhs_dtype, HALF_DTYPE) + self.assertEqual(fact.rhs_dtype, HALF_DTYPE) + + +class TestTritonSkinnyGemmHeuristic(TestCase): + def _make_triton_env_with_block_sizes( + self, + m_max: int = 8192, + n_max: int = 8192, + k_max: int = 8192, + ) -> MagicMock: + spec = ConfigSpec(backend=TritonBackend()) + spec.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=m_max)) + spec.block_sizes.append(BlockSizeSpec(block_id=1, size_hint=n_max)) + spec.block_sizes.append(BlockSizeSpec(block_id=2, size_hint=k_max)) + env = MagicMock() + env.backend_name = "triton" + env.config_spec = spec + env.device = DEVICE + return env + + def _matmul_fact( + self, + static_m: int = 1024, + static_n: int = 8192, + static_k: int = 4096, + *, + lhs_ndim: int = 2, + rhs_ndim: int = 2, + m_block_id: int | None = 0, + n_block_id: int | None = 1, + k_block_id: int | None = 2, + ) -> MatmulFact: + return MatmulFact( + lhs_ndim=lhs_ndim, + rhs_ndim=rhs_ndim, + m_block_id=m_block_id, + n_block_id=n_block_id, + k_block_id=k_block_id, + static_m=static_m, + static_n=static_n, + static_k=static_k, + lhs_dtype=HALF_DTYPE, + rhs_dtype=HALF_DTYPE, + ) + + def test_triton_skinny_gemm_seed_fact_default_empty(self) -> None: + env = self._make_triton_env_with_block_sizes() + heuristic = TritonSkinnyGemmHeuristic + + self.assertEqual(env.config_spec.matmul_facts, []) + self.assertFalse(heuristic.is_eligible(env, MagicMock())) + + def test_triton_skinny_gemm_seed_surfaces_through_compiler_seed_configs( + self, + ) -> None: + env = self._make_triton_env_with_block_sizes() + env.config_spec.matmul_facts.append(self._matmul_fact()) + + class DuplicateTritonSkinnyGemmHeuristic(TritonSkinnyGemmHeuristic): + name = "triton_skinny_gemm_duplicate" + + duplicate_heuristics = ( + TritonSkinnyGemmHeuristic, + DuplicateTritonSkinnyGemmHeuristic, + ) + for hardware in (HOPPER_HARDWARE, MI350_HARDWARE): + with ( + self.subTest(hardware=hardware.compute_capability), + patch( + "helion._hardware.get_hardware_info", + return_value=hardware, + ), + patch( + "helion._compiler.seed_heuristics.HEURISTICS_BY_BACKEND", + {"triton": duplicate_heuristics}, + ), + ): + configs = compiler_seed_configs(env, MagicMock()) + + self.assertEqual( + [config.config["block_sizes"] for config in configs], + [[64, 64, 256]], + ) + self.assertEqual( + env.config_spec.compiler_seed_heuristics, + [ + TritonSkinnyGemmHeuristic.name, + DuplicateTritonSkinnyGemmHeuristic.name, + ], + ) + + def test_triton_skinny_gemm_seed_skinny_n_returns_target_blocks(self) -> None: + env = self._make_triton_env_with_block_sizes( + m_max=1024, + n_max=8192, + k_max=4096, + ) + env.config_spec.matmul_facts.append( + self._matmul_fact(static_m=1024, static_n=8192, static_k=4096) + ) + + config = TritonSkinnyGemmHeuristic.get_config(env, MagicMock()) + + self.assertEqual(config.config["block_sizes"], [64, 64, 256]) + + def test_triton_skinny_gemm_seed_skinny_m_returns_target_blocks(self) -> None: + env = self._make_triton_env_with_block_sizes( + m_max=8192, + n_max=1024, + k_max=4096, + ) + env.config_spec.matmul_facts.append( + self._matmul_fact(static_m=8192, static_n=1024, static_k=4096) + ) + + config = TritonSkinnyGemmHeuristic.get_config(env, MagicMock()) + + self.assertEqual(config.config["block_sizes"], [64, 64, 256]) + + def test_triton_skinny_gemm_seed_caps_at_static_dim(self) -> None: + env = self._make_triton_env_with_block_sizes( + m_max=16, + n_max=8192, + k_max=128, + ) + env.config_spec.matmul_facts.append( + self._matmul_fact(static_m=16, static_n=8192, static_k=128) + ) + + config = TritonSkinnyGemmHeuristic.get_config(env, MagicMock()) + + self.assertEqual(config.config["block_sizes"], [16, 64, 128]) + + def test_triton_skinny_gemm_seed_returns_none_when_floor_violated(self) -> None: + env = self._make_triton_env_with_block_sizes( + m_max=1024, + n_max=8192, + k_max=4096, + ) + env.config_spec.block_sizes.block_id_lookup(0).autotuner_min = 256 + env.config_spec.matmul_facts.append( + self._matmul_fact(static_m=1024, static_n=8192, static_k=4096) + ) + + with patch( + "helion._hardware.get_hardware_info", + return_value=HOPPER_HARDWARE, + ): + self.assertFalse( + TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock()) + ) + + def test_triton_skinny_gemm_seed_returns_none_when_block_id_missing(self) -> None: + env = self._make_triton_env_with_block_sizes() + env.config_spec.matmul_facts.append(self._matmul_fact(m_block_id=None)) + + with patch( + "helion._hardware.get_hardware_info", + return_value=HOPPER_HARDWARE, + ): + self.assertFalse( + TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock()) + ) + + def test_triton_skinny_gemm_seed_requires_single_matmul_fact(self) -> None: + env = self._make_triton_env_with_block_sizes() + env.config_spec.matmul_facts.append(self._matmul_fact()) + env.config_spec.matmul_facts.append(self._matmul_fact()) + + with patch( + "helion._hardware.get_hardware_info", + return_value=HOPPER_HARDWARE, + ): + self.assertFalse( + TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock()) + ) + self.assertEqual(compiler_seed_configs(env, MagicMock()), []) + self.assertEqual(env.config_spec.compiler_seed_heuristics, []) + + def test_triton_skinny_gemm_seed_rejects_batched_matmul_fact(self) -> None: + env = self._make_triton_env_with_block_sizes() + env.config_spec.matmul_facts.append( + self._matmul_fact(lhs_ndim=3, rhs_ndim=3) + ) + + with patch( + "helion._hardware.get_hardware_info", + return_value=HOPPER_HARDWARE, + ): + self.assertFalse( + TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock()) + ) + self.assertEqual(compiler_seed_configs(env, MagicMock()), []) + self.assertEqual(env.config_spec.compiler_seed_heuristics, []) + + @onlyBackends(["triton"]) + def test_triton_skinny_gemm_seed_in_initial_population(self) -> None: + @helion.kernel(backend="triton") + def triton_matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc.to(x.dtype) + return out + + @helion.kernel(backend="triton") + def triton_matmul_epilogue( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = (acc + bias[tile_n]).to(x.dtype) + return out + + @helion.kernel(backend="triton") + def triton_two_matmuls( + x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc0 = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc1 = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc0 = torch.addmm(acc0, x[tile_m, tile_k], y[tile_k, tile_n]) + acc1 = torch.addmm(acc1, x[tile_m, tile_k], z[tile_k, tile_n]) + out[tile_m, tile_n] = (acc0 + acc1).to(x.dtype) + return out + + x = torch.empty([1024, 4096], device=DEVICE, dtype=HALF_DTYPE) + y = torch.empty([4096, 8192], device=DEVICE, dtype=HALF_DTYPE) + z = torch.empty([4096, 8192], device=DEVICE, dtype=HALF_DTYPE) + bias = torch.empty([8192], device=DEVICE, dtype=HALF_DTYPE) + cases = ( + ("gemm", triton_matmul, (x, y), True), + ("gemm_epilogue", triton_matmul_epilogue, (x, y, bias), True), + ("gemm_gemm", triton_two_matmuls, (x, y, z), False), + ) + seed_block_sizes = [64, 64, 256] + + def assert_skinny_gemm_seeded(configs: list[helion.Config]) -> None: + self.assertIn( + seed_block_sizes, + [config.config["block_sizes"] for config in configs], + ) + + for name, kernel, args, expect_seed in cases: + with ( + self.subTest(name=name), + patch( + "helion._hardware.get_hardware_info", + return_value=HOPPER_HARDWARE, + ), + ): + bound = kernel.bind(args) + heuristic = TritonSkinnyGemmHeuristic + + config_gen = bound.config_spec.create_config_generation() + compiler_seed_block_sizes = [ + config.config["block_sizes"] + for config in bound.config_spec.compiler_seed_configs + ] + seed_pair_block_sizes = [ + config.config["block_sizes"] + for _flat, config in config_gen.seed_flat_config_pairs() + ] + + if expect_seed: + self.assertIn( + TritonSkinnyGemmHeuristic.name, + bound.config_spec.compiler_seed_heuristics, + ) + self.assertTrue( + heuristic.is_eligible(bound.env, bound.host_function.device_ir) + ) + self.assertEqual( + heuristic.get_config( + bound.env, bound.host_function.device_ir + ).config["block_sizes"], + seed_block_sizes, + ) + self.assertIn(seed_block_sizes, compiler_seed_block_sizes) + self.assertIn(seed_block_sizes, seed_pair_block_sizes) + + zero_flat = config_gen.random_population_flat(0) + self.assertEqual(len(zero_flat), 1) + zero_config = config_gen.unflatten(zero_flat[0]) + self.assertNotEqual( + zero_config.config["block_sizes"], + seed_block_sizes, + ) + one_flat = config_gen.random_population_flat(1) + self.assertEqual(len(one_flat), 1) + one_config = config_gen.unflatten(one_flat[0]) + self.assertNotEqual( + one_config.config["block_sizes"], + seed_block_sizes, + ) + one_config_population = config_gen.random_population(1) + self.assertEqual(len(one_config_population), 1) + self.assertNotEqual( + one_config_population[0].config["block_sizes"], + seed_block_sizes, + ) + assert_skinny_gemm_seeded(config_gen.random_population(2)) + + acf_config_gen = bound.config_spec.create_config_generation( + advanced_controls_files=["/tmp/helion-test.acf"] + ) + acf_configs = acf_config_gen.random_population(2) + # Future heuristics may add more compiler seeds; this test + # only requires the skinny GEMM seed to be present. + self.assertGreaterEqual(len(acf_configs), 2) + self.assertEqual( + {config.config["advanced_controls_file"] for config in acf_configs}, + {"/tmp/helion-test.acf"}, + ) + assert_skinny_gemm_seeded(acf_configs) + + with patch.object( + PatternSearch, "_find_similar_cached_configs", return_value=[] + ): + search = PatternSearch( + bound, + args, + initial_population=30, + initial_population_strategy=InitialPopulationStrategy.FROM_BEST_AVAILABLE, + best_available_pad_random=False, + ) + configs = [ + search.config_gen.unflatten(flat) + for flat in search._generate_initial_population_flat() + ] + # Future heuristics may add more compiler seeds; this test + # only requires the skinny GEMM seed to be present. + self.assertGreaterEqual(len(configs), 2) + self.assertNotEqual( + configs[0].config["block_sizes"], + seed_block_sizes, + ) + assert_skinny_gemm_seeded(configs) + else: + self.assertFalse( + heuristic.is_eligible(bound.env, bound.host_function.device_ir) + ) + self.assertNotIn( + TritonSkinnyGemmHeuristic.name, + bound.config_spec.compiler_seed_heuristics, + ) + with patch.object( + PatternSearch, "_find_similar_cached_configs", return_value=[] + ): + search = PatternSearch( + bound, + args, + initial_population=30, + initial_population_strategy=InitialPopulationStrategy.FROM_BEST_AVAILABLE, + best_available_pad_random=False, + ) + configs = [ + search.config_gen.unflatten(flat) + for flat in search._generate_initial_population_flat() + ] + self.assertGreaterEqual(len(configs), 1) + self.assertNotIn( + TritonSkinnyGemmHeuristic.name, + bound.config_spec.compiler_seed_heuristics, + ) + + +class TestCuteTcgen05ClusterM2Heuristic(TestCase): + def _assert_cute_tcgen05_cluster_m2_seeded( + self, + configs: list[helion.Config], + ) -> None: + seeded = [ + config.config + for config in configs + if config.config["tcgen05_cluster_m"] == 2 + ] + self.assertEqual(len(seeded), 1) + seed = seeded[0] + self.assertEqual( + seed["block_sizes"][:3], + [TCGEN05_TWO_CTA_BLOCK_M, TCGEN05_TWO_CTA_BLOCK_N, 128], + ) + self.assertEqual( + seed["indexing"], + ["tensor_descriptor", "tensor_descriptor", "tensor_descriptor"], + ) + self.assertEqual(seed["l2_groupings"], [TCGEN05_TWO_CTA_SEED_L2_GROUPING]) + self.assertEqual(seed["pid_type"], "persistent_interleaved") + self.assertEqual(seed["tcgen05_num_epi_warps"], 4) + + @onlyBackends(["cute"]) + def test_cute_tcgen05_cluster_m2_seed_heuristic(self) -> None: + @helion.kernel(backend="cute") + def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc.to(x.dtype) + return out + + args = ( + torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), + torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), + ) + with patch_cute_mma_support(): + bound = cute_matmul_mma.bind(args) + + heuristic = CuteTcgen05ClusterM2Heuristic + self.assertIn( + CuteTcgen05ClusterM2Heuristic.name, + bound.config_spec.compiler_seed_heuristics, + ) + self.assertTrue( + heuristic.is_eligible(bound.env, bound.host_function.device_ir) + ) + self._assert_cute_tcgen05_cluster_m2_seeded( + [heuristic.get_config(bound.env, bound.host_function.device_ir)], + ) + + with patch_cute_mma_support( + default_cute_mma_support(tcgen05_f16bf16=False) + ): + unsupported_args = ( + torch.empty([2048, 2048], device=DEVICE, dtype=HALF_DTYPE), + torch.empty([2048, 2048], device=DEVICE, dtype=HALF_DTYPE), + ) + unsupported_bound = cute_matmul_mma.bind(unsupported_args) + self.assertFalse( + heuristic.is_eligible( + unsupported_bound.env, + unsupported_bound.host_function.device_ir, + ) + ) + self.assertNotIn( + CuteTcgen05ClusterM2Heuristic.name, + unsupported_bound.config_spec.compiler_seed_heuristics, + ) + + @onlyBackends(["cute"]) + def test_cute_tcgen05_two_cta_seeded_in_initial_populations(self) -> None: + @helion.kernel(backend="cute") + def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc.to(x.dtype) + return out + + args = ( + torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), + torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), + ) + with patch_cute_mma_support(): + bound = cute_matmul_mma.bind(args) + self.assertIn( + CuteTcgen05ClusterM2Heuristic.name, + bound.config_spec.compiler_seed_heuristics, + ) + + config_gen = bound.config_spec.create_config_generation() + zero_flat = config_gen.random_population_flat(0) + self.assertEqual(len(zero_flat), 1) + zero_config = config_gen.unflatten(zero_flat[0]) + self.assertEqual(zero_config.config["tcgen05_cluster_m"], 1) + one_flat = config_gen.random_population_flat(1) + self.assertEqual(len(one_flat), 1) + one_config = config_gen.unflatten(one_flat[0]) + self.assertEqual(one_config.config["tcgen05_cluster_m"], 1) + one_config_population = config_gen.random_population(1) + self.assertEqual(len(one_config_population), 1) + self.assertEqual(one_config_population[0].config["tcgen05_cluster_m"], 1) + self._assert_cute_tcgen05_cluster_m2_seeded( + config_gen.random_population(2) + ) + + acf_config_gen = bound.config_spec.create_config_generation( + advanced_controls_files=["/tmp/helion-test.acf"] + ) + acf_configs = acf_config_gen.random_population(2) + # Future heuristics may add more compiler seeds; this test only + # requires the CuTe cluster-m2 seed to be present. + self.assertGreaterEqual(len(acf_configs), 2) + self.assertEqual( + {config.config["advanced_controls_file"] for config in acf_configs}, + {"/tmp/helion-test.acf"}, + ) + self._assert_cute_tcgen05_cluster_m2_seeded(acf_configs) + + with patch.object( + PatternSearch, "_find_similar_cached_configs", return_value=[] + ): + search = PatternSearch( + bound, + args, + initial_population=30, + initial_population_strategy=InitialPopulationStrategy.FROM_BEST_AVAILABLE, + best_available_pad_random=False, + ) + configs = [ + search.config_gen.unflatten(flat) + for flat in search._generate_initial_population_flat() + ] + # Future heuristics may add more compiler seeds; this test only + # requires the CuTe cluster-m2 seed to be present. + self.assertGreaterEqual(len(configs), 2) + self.assertEqual(configs[0].config["tcgen05_cluster_m"], 1) + self._assert_cute_tcgen05_cluster_m2_seeded(configs) + + @onlyBackends(["cute"]) + def test_cute_tcgen05_two_cta_seed_indexing_matches_live_spec(self) -> None: + @helion.kernel(backend="cute") + def cute_matmul_mma_epilogue( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = (acc + bias[tile_n]).to(x.dtype) + return out + + args = ( + torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), + torch.empty([4096, 4096], device=DEVICE, dtype=HALF_DTYPE), + torch.empty([4096], device=DEVICE, dtype=HALF_DTYPE), + ) + with patch_cute_mma_support(): + bound = cute_matmul_mma_epilogue.bind(args) + self.assertGreater(bound.config_spec.indexing.length, 3) + + configs = bound.config_spec.create_config_generation().random_population(2) + seeded = [ + config.config + for config in configs + if config.config["tcgen05_cluster_m"] == 2 + ] + self.assertEqual(len(seeded), 1) + self.assertEqual( + len(seeded[0]["indexing"]), + bound.config_spec.indexing.length, + ) From e9612a7957017e12829fb7e0f1ab3249e7733808 Mon Sep 17 00:00:00 2001 From: eche Date: Sat, 9 May 2026 19:38:11 -0700 Subject: [PATCH 2/9] Make compiler seed heuristics best effort --- helion/_compiler/seed_heuristics/__init__.py | 20 +++- helion/_compiler/seed_heuristics/common.py | 8 +- helion/autotuner/base_search.py | 4 +- helion/autotuner/config_generation.py | 24 +++-- test/test_seed_heuristics.py | 107 ++++++++++++++----- 5 files changed, 119 insertions(+), 44 deletions(-) diff --git a/helion/_compiler/seed_heuristics/__init__.py b/helion/_compiler/seed_heuristics/__init__.py index d525f76503..ed5dc77428 100644 --- a/helion/_compiler/seed_heuristics/__init__.py +++ b/helion/_compiler/seed_heuristics/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING from .common import dedupe_configs @@ -18,6 +19,8 @@ "triton": (TritonSkinnyGemmHeuristic,), } +log: logging.Logger = logging.getLogger(__name__) + def get_heuristics(backend: str) -> tuple[SeedHeuristicType, ...]: return HEURISTICS_BY_BACKEND.get(backend, ()) @@ -30,12 +33,19 @@ def compiler_seed_configs( configs: list[Config] = [] env.config_spec.compiler_seed_heuristics = [] for heuristic in get_heuristics(env.backend_name): - if not heuristic.is_eligible(env, device_ir): + try: + if not heuristic.is_eligible(env, device_ir): + continue + + config = heuristic.get_config(env, device_ir) + except Exception as e: + log.debug( + "Compiler seed heuristic %s failed: %s", + heuristic.name, + e, + exc_info=True, + ) continue - - # If the heuristic is eligible, we must get a valid config - # We add the heuristic name to list of applied heuristics - config = heuristic.get_config(env, device_ir) configs.append(config) env.config_spec.compiler_seed_heuristics.append(heuristic.name) return dedupe_configs(configs) diff --git a/helion/_compiler/seed_heuristics/common.py b/helion/_compiler/seed_heuristics/common.py index 473945f16f..4567f23772 100644 --- a/helion/_compiler/seed_heuristics/common.py +++ b/helion/_compiler/seed_heuristics/common.py @@ -31,10 +31,10 @@ def matches_hardware( from ..._hardware import get_hardware_info hardware = get_hardware_info(env.device) - return ( - (hardware.device_kind, hardware.compute_capability) in targets - or (hardware.device_kind, None) in targets - ) + return (hardware.device_kind, hardware.compute_capability) in targets or ( + hardware.device_kind, + None, + ) in targets def clamp_block_size_targets( diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index e02bfd3ad8..215f2afaae 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -797,7 +797,9 @@ def _generate_best_available_population_flat(self) -> list[FlatConfig]: # Compiler-owned seeds come from ConfigSpec.compiler_seed_configs; # they encode backend/compiler heuristics and complement user seed configs. - for flat, transferred_config in self.config_gen.seed_flat_config_pairs(): + for flat, transferred_config in self.config_gen.seed_flat_config_pairs( + self.log + ): if transferred_config not in seen: seen.add(transferred_config) result.append(flat) diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index 5ca13a431a..a842ef103d 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -381,21 +381,31 @@ def default_flat(self) -> FlatConfig: self._repair_cute_num_threads(config) return config - def seed_flat_config_pairs(self) -> list[tuple[FlatConfig, Config]]: + def seed_flat_config_pairs( + self, + log_func: Callable[[str], None] | None = None, + ) -> list[tuple[FlatConfig, Config]]: """Return ConfigSpec-provided seeds as flat and normalized configs. ``ConfigSpec.compiler_seed_configs`` is compiler-owned and must - contain configs that match the live spec structurally. ``InvalidConfig`` - means overrides make a seed inapplicable; other flatten/unflatten - exceptions are programming errors and intentionally surface. + contain configs that match the live spec structurally. Invalid seeds + are skipped with the same transfer policy as user-provided seed configs. """ result: list[tuple[FlatConfig, Config]] = [] seen: set[Config] = set() - for config in self.config_spec.compiler_seed_configs: + for i, config in enumerate(self.config_spec.compiler_seed_configs): try: flat = self.flatten(config) normalized = self.unflatten(flat) - except InvalidConfig: + except ( + InvalidConfig, + ValueError, + TypeError, + KeyError, + AssertionError, + ) as e: + if log_func is not None: + log_func(f"Failed to transfer compiler seed config {i + 1}: {e}") continue if normalized in seen: continue @@ -483,7 +493,7 @@ def random_population_flat( if len(result) >= n: return result[:n] - for flat, _config in self.seed_flat_config_pairs(): + for flat, _config in self.seed_flat_config_pairs(log_func): if any(flat == existing for existing in result): continue result.append(flat) diff --git a/test/test_seed_heuristics.py b/test/test_seed_heuristics.py index c2da37ad7f..a2f709631b 100644 --- a/test/test_seed_heuristics.py +++ b/test/test_seed_heuristics.py @@ -12,6 +12,7 @@ from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_L2_GROUPING from helion._compiler.seed_heuristics import compiler_seed_configs from helion._compiler.seed_heuristics.cute import CuteTcgen05ClusterM2Heuristic +from helion._compiler.seed_heuristics.registry import SeedHeuristic from helion._compiler.seed_heuristics.triton import TritonSkinnyGemmHeuristic from helion._hardware import HardwareInfo from helion._testing import DEVICE @@ -41,6 +42,74 @@ ) +class TestSeedHeuristic(TestCase): + def test_compiler_seed_configs_skips_heuristic_get_config_error(self) -> None: + class FailingSeedHeuristic(SeedHeuristic): + name = "failing_seed_heuristic" + backend = "triton" + + @classmethod + def is_eligible(cls, env: object, device_ir: object) -> bool: + return True + + @classmethod + def get_config(cls, env: object, device_ir: object) -> helion.Config: + raise RuntimeError("synthetic compiler seed failure") + + class ValidSeedHeuristic(SeedHeuristic): + name = "valid_seed_heuristic" + backend = "triton" + + @classmethod + def is_eligible(cls, env: object, device_ir: object) -> bool: + return True + + @classmethod + def get_config(cls, env: object, device_ir: object) -> helion.Config: + return helion.Config(block_sizes=[64]) + + env = MagicMock() + env.backend_name = "triton" + env.config_spec = MagicMock() + heuristics = (FailingSeedHeuristic, ValidSeedHeuristic) + + with ( + self.assertLogs("helion._compiler.seed_heuristics", level="DEBUG") as logs, + patch( + "helion._compiler.seed_heuristics.HEURISTICS_BY_BACKEND", + {"triton": heuristics}, + ), + ): + configs = compiler_seed_configs(env, MagicMock()) + + self.assertEqual([config.config for config in configs], [{"block_sizes": [64]}]) + self.assertEqual( + env.config_spec.compiler_seed_heuristics, + [ValidSeedHeuristic.name], + ) + self.assertIn(FailingSeedHeuristic.name, "\n".join(logs.output)) + self.assertIn("synthetic compiler seed failure", "\n".join(logs.output)) + + def test_seed_flat_config_pairs_skips_invalid_compiler_seed(self) -> None: + spec = ConfigSpec(backend=TritonBackend()) + spec.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=1024)) + spec.compiler_seed_configs = [ + helion.Config(block_sizes=["invalid"]), + helion.Config(block_sizes=[64]), + ] + config_gen = spec.create_config_generation() + messages: list[str] = [] + + pairs = config_gen.seed_flat_config_pairs(messages.append) + + self.assertEqual( + [config.config["block_sizes"] for _flat, config in pairs], + [[64]], + ) + self.assertEqual(len(messages), 1) + self.assertIn("Failed to transfer compiler seed config 1", messages[0]) + + class TestMatmulFacts(TestCase): @onlyBackends(["triton"]) def test_matmul_facts_record_kernel_structure(self) -> None: @@ -281,9 +350,7 @@ def test_triton_skinny_gemm_seed_returns_none_when_floor_violated(self) -> None: "helion._hardware.get_hardware_info", return_value=HOPPER_HARDWARE, ): - self.assertFalse( - TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock()) - ) + self.assertFalse(TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock())) def test_triton_skinny_gemm_seed_returns_none_when_block_id_missing(self) -> None: env = self._make_triton_env_with_block_sizes() @@ -293,9 +360,7 @@ def test_triton_skinny_gemm_seed_returns_none_when_block_id_missing(self) -> Non "helion._hardware.get_hardware_info", return_value=HOPPER_HARDWARE, ): - self.assertFalse( - TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock()) - ) + self.assertFalse(TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock())) def test_triton_skinny_gemm_seed_requires_single_matmul_fact(self) -> None: env = self._make_triton_env_with_block_sizes() @@ -306,25 +371,19 @@ def test_triton_skinny_gemm_seed_requires_single_matmul_fact(self) -> None: "helion._hardware.get_hardware_info", return_value=HOPPER_HARDWARE, ): - self.assertFalse( - TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock()) - ) + self.assertFalse(TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock())) self.assertEqual(compiler_seed_configs(env, MagicMock()), []) self.assertEqual(env.config_spec.compiler_seed_heuristics, []) def test_triton_skinny_gemm_seed_rejects_batched_matmul_fact(self) -> None: env = self._make_triton_env_with_block_sizes() - env.config_spec.matmul_facts.append( - self._matmul_fact(lhs_ndim=3, rhs_ndim=3) - ) + env.config_spec.matmul_facts.append(self._matmul_fact(lhs_ndim=3, rhs_ndim=3)) with patch( "helion._hardware.get_hardware_info", return_value=HOPPER_HARDWARE, ): - self.assertFalse( - TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock()) - ) + self.assertFalse(TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock())) self.assertEqual(compiler_seed_configs(env, MagicMock()), []) self.assertEqual(env.config_spec.compiler_seed_heuristics, []) @@ -457,7 +516,10 @@ def assert_skinny_gemm_seeded(configs: list[helion.Config]) -> None: # only requires the skinny GEMM seed to be present. self.assertGreaterEqual(len(acf_configs), 2) self.assertEqual( - {config.config["advanced_controls_file"] for config in acf_configs}, + { + config.config["advanced_controls_file"] + for config in acf_configs + }, {"/tmp/helion-test.acf"}, ) assert_skinny_gemm_seeded(acf_configs) @@ -479,10 +541,6 @@ def assert_skinny_gemm_seeded(configs: list[helion.Config]) -> None: # Future heuristics may add more compiler seeds; this test # only requires the skinny GEMM seed to be present. self.assertGreaterEqual(len(configs), 2) - self.assertNotEqual( - configs[0].config["block_sizes"], - seed_block_sizes, - ) assert_skinny_gemm_seeded(configs) else: self.assertFalse( @@ -570,9 +628,7 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: [heuristic.get_config(bound.env, bound.host_function.device_ir)], ) - with patch_cute_mma_support( - default_cute_mma_support(tcgen05_f16bf16=False) - ): + with patch_cute_mma_support(default_cute_mma_support(tcgen05_f16bf16=False)): unsupported_args = ( torch.empty([2048, 2048], device=DEVICE, dtype=HALF_DTYPE), torch.empty([2048, 2048], device=DEVICE, dtype=HALF_DTYPE), @@ -626,9 +682,7 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: one_config_population = config_gen.random_population(1) self.assertEqual(len(one_config_population), 1) self.assertEqual(one_config_population[0].config["tcgen05_cluster_m"], 1) - self._assert_cute_tcgen05_cluster_m2_seeded( - config_gen.random_population(2) - ) + self._assert_cute_tcgen05_cluster_m2_seeded(config_gen.random_population(2)) acf_config_gen = bound.config_spec.create_config_generation( advanced_controls_files=["/tmp/helion-test.acf"] @@ -660,7 +714,6 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Future heuristics may add more compiler seeds; this test only # requires the CuTe cluster-m2 seed to be present. self.assertGreaterEqual(len(configs), 2) - self.assertEqual(configs[0].config["tcgen05_cluster_m"], 1) self._assert_cute_tcgen05_cluster_m2_seeded(configs) @onlyBackends(["cute"]) From a285ccff82e3fcdb17978cd08506a55d93940fa2 Mon Sep 17 00:00:00 2001 From: eche Date: Sat, 9 May 2026 20:02:16 -0700 Subject: [PATCH 3/9] Harden seed heuristic tests without GPU --- test/test_seed_heuristics.py | 44 +++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/test/test_seed_heuristics.py b/test/test_seed_heuristics.py index a2f709631b..189ca9cbdb 100644 --- a/test/test_seed_heuristics.py +++ b/test/test_seed_heuristics.py @@ -40,6 +40,12 @@ runtime_version="7.0", compute_capability="gfx950", ) +BLACKWELL_HARDWARE = HardwareInfo( + device_kind="cuda", + hardware_name="NVIDIA B200", + runtime_version="12.8", + compute_capability="sm100", +) class TestSeedHeuristic(TestCase): @@ -252,14 +258,20 @@ def test_triton_skinny_gemm_seed_fact_default_empty(self) -> None: heuristic = TritonSkinnyGemmHeuristic self.assertEqual(env.config_spec.matmul_facts, []) - self.assertFalse(heuristic.is_eligible(env, MagicMock())) + with patch( + "helion._hardware.get_hardware_info", + return_value=HOPPER_HARDWARE, + ): + self.assertFalse(heuristic.is_eligible(env, MagicMock())) - def test_triton_skinny_gemm_seed_surfaces_through_compiler_seed_configs( + def test_triton_skinny_gemm_seed_dedupes_configs_and_records_heuristics( self, ) -> None: env = self._make_triton_env_with_block_sizes() env.config_spec.matmul_facts.append(self._matmul_fact()) + # Same config, different heuristic name: the config should dedupe, but + # both successful heuristic applications should remain visible. class DuplicateTritonSkinnyGemmHeuristic(TritonSkinnyGemmHeuristic): name = "triton_skinny_gemm_duplicate" @@ -267,7 +279,26 @@ class DuplicateTritonSkinnyGemmHeuristic(TritonSkinnyGemmHeuristic): TritonSkinnyGemmHeuristic, DuplicateTritonSkinnyGemmHeuristic, ) - for hardware in (HOPPER_HARDWARE, MI350_HARDWARE): + cases = ( + ( + HOPPER_HARDWARE, + [[64, 64, 256]], + [ + TritonSkinnyGemmHeuristic.name, + DuplicateTritonSkinnyGemmHeuristic.name, + ], + ), + ( + MI350_HARDWARE, + [[64, 64, 256]], + [ + TritonSkinnyGemmHeuristic.name, + DuplicateTritonSkinnyGemmHeuristic.name, + ], + ), + (BLACKWELL_HARDWARE, [], []), + ) + for hardware, expected_block_sizes, expected_heuristics in cases: with ( self.subTest(hardware=hardware.compute_capability), patch( @@ -283,14 +314,11 @@ class DuplicateTritonSkinnyGemmHeuristic(TritonSkinnyGemmHeuristic): self.assertEqual( [config.config["block_sizes"] for config in configs], - [[64, 64, 256]], + expected_block_sizes, ) self.assertEqual( env.config_spec.compiler_seed_heuristics, - [ - TritonSkinnyGemmHeuristic.name, - DuplicateTritonSkinnyGemmHeuristic.name, - ], + expected_heuristics, ) def test_triton_skinny_gemm_seed_skinny_n_returns_target_blocks(self) -> None: From 751327929a5632a2d2eb0b97e8b965b3ea0b77ab Mon Sep 17 00:00:00 2001 From: eche Date: Sat, 9 May 2026 20:05:48 -0700 Subject: [PATCH 4/9] Document Triton skinny GEMM heuristic source --- helion/_compiler/seed_heuristics/triton.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/helion/_compiler/seed_heuristics/triton.py b/helion/_compiler/seed_heuristics/triton.py index e7e23c0f88..e3c26dd1f5 100644 --- a/helion/_compiler/seed_heuristics/triton.py +++ b/helion/_compiler/seed_heuristics/triton.py @@ -12,6 +12,8 @@ from ..device_ir import DeviceIR +# Heuristic was originally contributed by @umechand-amd +# in https://github.com/pytorch/helion/pull/2357. class TritonSkinnyGemmHeuristic(SeedHeuristic): name = "triton_skinny_gemm" backend = "triton" From 8497da31d09f723251e3c2362dbe76cc3ac7e536 Mon Sep 17 00:00:00 2001 From: eche Date: Sat, 9 May 2026 21:10:33 -0700 Subject: [PATCH 5/9] Skip seed heuristic compiler tests in ref eager --- test/test_seed_heuristics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_seed_heuristics.py b/test/test_seed_heuristics.py index 189ca9cbdb..976c91d043 100644 --- a/test/test_seed_heuristics.py +++ b/test/test_seed_heuristics.py @@ -21,6 +21,7 @@ from helion._testing import default_cute_mma_support from helion._testing import onlyBackends from helion._testing import patch_cute_mma_support +from helion._testing import skipIfRefEager from helion.autotuner.config_spec import BlockSizeSpec from helion.autotuner.config_spec import ConfigSpec from helion.autotuner.config_spec import MatmulFact @@ -118,6 +119,7 @@ def test_seed_flat_config_pairs_skips_invalid_compiler_seed(self) -> None: class TestMatmulFacts(TestCase): @onlyBackends(["triton"]) + @skipIfRefEager("Compiler matmul facts are not collected in ref eager mode") def test_matmul_facts_record_kernel_structure(self) -> None: @helion.kernel(backend="triton") def triton_matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -416,6 +418,7 @@ def test_triton_skinny_gemm_seed_rejects_batched_matmul_fact(self) -> None: self.assertEqual(env.config_spec.compiler_seed_heuristics, []) @onlyBackends(["triton"]) + @skipIfRefEager("Compiler seed configs are not generated in ref eager mode") def test_triton_skinny_gemm_seed_in_initial_population(self) -> None: @helion.kernel(backend="triton") def triton_matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: From 9082997d2ed8a754fab4101c19ff161b2e650914 Mon Sep 17 00:00:00 2001 From: eche Date: Sun, 10 May 2026 22:04:10 -0700 Subject: [PATCH 6/9] Generalize seed heuristics as autotuner heuristics --- .../__init__.py | 16 +- .../common.py | 0 .../cute.py | 8 +- .../registry.py | 12 +- .../triton.py | 6 +- helion/autotuner/config_spec.py | 2 +- helion/runtime/kernel.py | 2 +- ...istics.py => test_autotuner_heuristics.py} | 320 +++++------------- 8 files changed, 115 insertions(+), 251 deletions(-) rename helion/_compiler/{seed_heuristics => autotuner_heuristics}/__init__.py (68%) rename helion/_compiler/{seed_heuristics => autotuner_heuristics}/common.py (100%) rename helion/_compiler/{seed_heuristics => autotuner_heuristics}/cute.py (92%) rename helion/_compiler/{seed_heuristics => autotuner_heuristics}/registry.py (63%) rename helion/_compiler/{seed_heuristics => autotuner_heuristics}/triton.py (93%) rename test/{test_seed_heuristics.py => test_autotuner_heuristics.py} (66%) diff --git a/helion/_compiler/seed_heuristics/__init__.py b/helion/_compiler/autotuner_heuristics/__init__.py similarity index 68% rename from helion/_compiler/seed_heuristics/__init__.py rename to helion/_compiler/autotuner_heuristics/__init__.py index ed5dc77428..9089519ed6 100644 --- a/helion/_compiler/seed_heuristics/__init__.py +++ b/helion/_compiler/autotuner_heuristics/__init__.py @@ -11,10 +11,10 @@ from ...runtime.config import Config from ..compile_environment import CompileEnvironment from ..device_ir import DeviceIR - from .registry import SeedHeuristicType + from .registry import AutotunerHeuristicType # All active heuristics by backend -HEURISTICS_BY_BACKEND: dict[str, tuple[SeedHeuristicType, ...]] = { +HEURISTICS_BY_BACKEND: dict[str, tuple[AutotunerHeuristicType, ...]] = { "cute": (CuteTcgen05ClusterM2Heuristic,), "triton": (TritonSkinnyGemmHeuristic,), } @@ -22,7 +22,7 @@ log: logging.Logger = logging.getLogger(__name__) -def get_heuristics(backend: str) -> tuple[SeedHeuristicType, ...]: +def get_heuristics(backend: str) -> tuple[AutotunerHeuristicType, ...]: return HEURISTICS_BY_BACKEND.get(backend, ()) @@ -31,21 +31,23 @@ def compiler_seed_configs( device_ir: DeviceIR, ) -> list[Config]: configs: list[Config] = [] - env.config_spec.compiler_seed_heuristics = [] + env.config_spec.autotuner_heuristics = [] for heuristic in get_heuristics(env.backend_name): try: if not heuristic.is_eligible(env, device_ir): continue - config = heuristic.get_config(env, device_ir) + config = heuristic.get_seed_config(env, device_ir) except Exception as e: log.debug( - "Compiler seed heuristic %s failed: %s", + "Autotuner heuristic %s failed while generating compiler seed config: %s", heuristic.name, e, exc_info=True, ) continue + if config is None: + continue configs.append(config) - env.config_spec.compiler_seed_heuristics.append(heuristic.name) + env.config_spec.autotuner_heuristics.append(heuristic.name) return dedupe_configs(configs) diff --git a/helion/_compiler/seed_heuristics/common.py b/helion/_compiler/autotuner_heuristics/common.py similarity index 100% rename from helion/_compiler/seed_heuristics/common.py rename to helion/_compiler/autotuner_heuristics/common.py diff --git a/helion/_compiler/seed_heuristics/cute.py b/helion/_compiler/autotuner_heuristics/cute.py similarity index 92% rename from helion/_compiler/seed_heuristics/cute.py rename to helion/_compiler/autotuner_heuristics/cute.py index 43fdf729e0..765aa93f9f 100644 --- a/helion/_compiler/seed_heuristics/cute.py +++ b/helion/_compiler/autotuner_heuristics/cute.py @@ -8,7 +8,7 @@ from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_L2_GROUPING from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_PID_TYPE -from .registry import SeedHeuristic +from .registry import AutotunerHeuristic if TYPE_CHECKING: from ...autotuner.config_fragment import BlockSizeFragment @@ -16,7 +16,7 @@ from ..device_ir import DeviceIR -class CuteTcgen05ClusterM2Heuristic(SeedHeuristic): +class CuteTcgen05ClusterM2Heuristic(AutotunerHeuristic): name = "cute_tcgen05_cluster_m2" backend = "cute" @@ -41,11 +41,11 @@ def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: ) @classmethod - def get_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: + def get_seed_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: spec = env.config_spec bk = cls._select_bk(env) if bk is None: - raise AssertionError(f"{cls.name} get_config called while ineligible") + raise AssertionError(f"{cls.name} get_seed_config called while ineligible") block_sizes = [ TCGEN05_TWO_CTA_BLOCK_M, diff --git a/helion/_compiler/seed_heuristics/registry.py b/helion/_compiler/autotuner_heuristics/registry.py similarity index 63% rename from helion/_compiler/seed_heuristics/registry.py rename to helion/_compiler/autotuner_heuristics/registry.py index 70b42dc0ce..bd1ee3e1f1 100644 --- a/helion/_compiler/seed_heuristics/registry.py +++ b/helion/_compiler/autotuner_heuristics/registry.py @@ -9,8 +9,8 @@ from ..device_ir import DeviceIR -class SeedHeuristic: - """Base class for compiler-owned autotune seed heuristics.""" +class AutotunerHeuristic: + """Base class for compiler-owned autotuner heuristics.""" name: ClassVar[str] backend: ClassVar[str] @@ -20,8 +20,10 @@ def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: raise NotImplementedError @classmethod - def get_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: - raise NotImplementedError + def get_seed_config( + cls, env: CompileEnvironment, device_ir: DeviceIR + ) -> Config | None: + return None -SeedHeuristicType = type[SeedHeuristic] +AutotunerHeuristicType = type[AutotunerHeuristic] diff --git a/helion/_compiler/seed_heuristics/triton.py b/helion/_compiler/autotuner_heuristics/triton.py similarity index 93% rename from helion/_compiler/seed_heuristics/triton.py rename to helion/_compiler/autotuner_heuristics/triton.py index e3c26dd1f5..871c6321eb 100644 --- a/helion/_compiler/seed_heuristics/triton.py +++ b/helion/_compiler/autotuner_heuristics/triton.py @@ -5,7 +5,7 @@ from ...runtime.config import Config from .common import clamp_block_size_targets from .common import matches_hardware -from .registry import SeedHeuristic +from .registry import AutotunerHeuristic if TYPE_CHECKING: from ..compile_environment import CompileEnvironment @@ -14,7 +14,7 @@ # Heuristic was originally contributed by @umechand-amd # in https://github.com/pytorch/helion/pull/2357. -class TritonSkinnyGemmHeuristic(SeedHeuristic): +class TritonSkinnyGemmHeuristic(AutotunerHeuristic): name = "triton_skinny_gemm" backend = "triton" MIN_ASPECT_RATIO = 8 @@ -57,7 +57,7 @@ def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: ) @classmethod - def get_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: + def get_seed_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: assert len(env.config_spec.matmul_facts) == 1 fact = env.config_spec.matmul_facts[0] assert fact.static_m is not None diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index bd7804d2df..d67db8f093 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -414,7 +414,7 @@ def __init__( # explicit config, so normalize() must reject the unsafe values. self._tcgen05_num_epi_warps_validation_choices: tuple[int, ...] | None = None self.compiler_seed_configs: list[helion.Config] = [] - self.compiler_seed_heuristics: list[str] = [] + self.autotuner_heuristics: list[str] = [] self.matmul_facts: list[MatmulFact] = [] self.store_indices: list[int] = [] self.backend_tunable_fragments = self.backend.tunable_fragments() diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 4d3db3d014..0002d27a40 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -40,6 +40,7 @@ from .. import exc from .._compile_time import measure from .._compiler.ast_extension import unparse +from .._compiler.autotuner_heuristics import compiler_seed_configs from .._compiler.backend import TritonBackend from .._compiler.compile_environment import CompileEnvironment from .._compiler.compile_environment import TensorDescriptorLayoutGuard @@ -50,7 +51,6 @@ from .._compiler.inductor_lowering_extra import patch_inductor_lowerings from .._compiler.kernel_compiler import KernelCompiler from .._compiler.output_header import assert_no_conflicts -from .._compiler.seed_heuristics import compiler_seed_configs from .._compiler.variable_origin import ArgumentOrigin from .._dist_utils import _find_process_group_name from .._dist_utils import check_config_consistancy as dist_check_config_consistancy diff --git a/test/test_seed_heuristics.py b/test/test_autotuner_heuristics.py similarity index 66% rename from test/test_seed_heuristics.py rename to test/test_autotuner_heuristics.py index 976c91d043..9bfeeef88a 100644 --- a/test/test_seed_heuristics.py +++ b/test/test_autotuner_heuristics.py @@ -6,14 +6,14 @@ import torch import helion +from helion._compiler.autotuner_heuristics import compiler_seed_configs +from helion._compiler.autotuner_heuristics.cute import CuteTcgen05ClusterM2Heuristic +from helion._compiler.autotuner_heuristics.registry import AutotunerHeuristic +from helion._compiler.autotuner_heuristics.triton import TritonSkinnyGemmHeuristic from helion._compiler.backend import TritonBackend from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_M from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N from helion._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_L2_GROUPING -from helion._compiler.seed_heuristics import compiler_seed_configs -from helion._compiler.seed_heuristics.cute import CuteTcgen05ClusterM2Heuristic -from helion._compiler.seed_heuristics.registry import SeedHeuristic -from helion._compiler.seed_heuristics.triton import TritonSkinnyGemmHeuristic from helion._hardware import HardwareInfo from helion._testing import DEVICE from helion._testing import HALF_DTYPE @@ -49,10 +49,12 @@ ) -class TestSeedHeuristic(TestCase): - def test_compiler_seed_configs_skips_heuristic_get_config_error(self) -> None: - class FailingSeedHeuristic(SeedHeuristic): - name = "failing_seed_heuristic" +class TestAutotunerHeuristic(TestCase): + def test_compiler_seed_configs_handles_failed_optional_and_duplicate_seeds( + self, + ) -> None: + class FailingAutotunerHeuristic(AutotunerHeuristic): + name = "failing_autotuner_heuristic" backend = "triton" @classmethod @@ -60,11 +62,19 @@ def is_eligible(cls, env: object, device_ir: object) -> bool: return True @classmethod - def get_config(cls, env: object, device_ir: object) -> helion.Config: + def get_seed_config(cls, env: object, device_ir: object) -> helion.Config: raise RuntimeError("synthetic compiler seed failure") - class ValidSeedHeuristic(SeedHeuristic): - name = "valid_seed_heuristic" + class NoSeedAutotunerHeuristic(AutotunerHeuristic): + name = "no_seed_autotuner_heuristic" + backend = "triton" + + @classmethod + def is_eligible(cls, env: object, device_ir: object) -> bool: + return True + + class ValidAutotunerHeuristic(AutotunerHeuristic): + name = "valid_autotuner_heuristic" backend = "triton" @classmethod @@ -72,18 +82,28 @@ def is_eligible(cls, env: object, device_ir: object) -> bool: return True @classmethod - def get_config(cls, env: object, device_ir: object) -> helion.Config: + def get_seed_config(cls, env: object, device_ir: object) -> helion.Config: return helion.Config(block_sizes=[64]) + class DuplicateAutotunerHeuristic(ValidAutotunerHeuristic): + name = "duplicate_autotuner_heuristic" + env = MagicMock() env.backend_name = "triton" env.config_spec = MagicMock() - heuristics = (FailingSeedHeuristic, ValidSeedHeuristic) + heuristics = ( + FailingAutotunerHeuristic, + NoSeedAutotunerHeuristic, + ValidAutotunerHeuristic, + DuplicateAutotunerHeuristic, + ) with ( - self.assertLogs("helion._compiler.seed_heuristics", level="DEBUG") as logs, + self.assertLogs( + "helion._compiler.autotuner_heuristics", level="DEBUG" + ) as logs, patch( - "helion._compiler.seed_heuristics.HEURISTICS_BY_BACKEND", + "helion._compiler.autotuner_heuristics.HEURISTICS_BY_BACKEND", {"triton": heuristics}, ), ): @@ -91,10 +111,10 @@ def get_config(cls, env: object, device_ir: object) -> helion.Config: self.assertEqual([config.config for config in configs], [{"block_sizes": [64]}]) self.assertEqual( - env.config_spec.compiler_seed_heuristics, - [ValidSeedHeuristic.name], + env.config_spec.autotuner_heuristics, + [ValidAutotunerHeuristic.name, DuplicateAutotunerHeuristic.name], ) - self.assertIn(FailingSeedHeuristic.name, "\n".join(logs.output)) + self.assertIn(FailingAutotunerHeuristic.name, "\n".join(logs.output)) self.assertIn("synthetic compiler seed failure", "\n".join(logs.output)) def test_seed_flat_config_pairs_skips_invalid_compiler_seed(self) -> None: @@ -198,7 +218,7 @@ def triton_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.assertEqual(len(bound.config_spec.matmul_facts), expected_facts) if expected_facts == 0: self.assertEqual(bound.config_spec.compiler_seed_configs, []) - self.assertEqual(bound.config_spec.compiler_seed_heuristics, []) + self.assertEqual(bound.config_spec.autotuner_heuristics, []) for fact in bound.config_spec.matmul_facts: self.assertEqual(fact.lhs_ndim, 2) self.assertEqual(fact.rhs_ndim, 2) @@ -255,62 +275,55 @@ def _matmul_fact( rhs_dtype=HALF_DTYPE, ) - def test_triton_skinny_gemm_seed_fact_default_empty(self) -> None: - env = self._make_triton_env_with_block_sizes() - heuristic = TritonSkinnyGemmHeuristic - - self.assertEqual(env.config_spec.matmul_facts, []) - with patch( - "helion._hardware.get_hardware_info", - return_value=HOPPER_HARDWARE, - ): - self.assertFalse(heuristic.is_eligible(env, MagicMock())) - - def test_triton_skinny_gemm_seed_dedupes_configs_and_records_heuristics( + def test_triton_skinny_gemm_seed_eligibility_and_config( self, ) -> None: - env = self._make_triton_env_with_block_sizes() - env.config_spec.matmul_facts.append(self._matmul_fact()) - - # Same config, different heuristic name: the config should dedupe, but - # both successful heuristic applications should remain visible. - class DuplicateTritonSkinnyGemmHeuristic(TritonSkinnyGemmHeuristic): - name = "triton_skinny_gemm_duplicate" - - duplicate_heuristics = ( - TritonSkinnyGemmHeuristic, - DuplicateTritonSkinnyGemmHeuristic, - ) cases = ( ( + "hopper", HOPPER_HARDWARE, + [self._matmul_fact()], [[64, 64, 256]], - [ - TritonSkinnyGemmHeuristic.name, - DuplicateTritonSkinnyGemmHeuristic.name, - ], + [TritonSkinnyGemmHeuristic.name], ), ( + "mi350", MI350_HARDWARE, + [self._matmul_fact()], [[64, 64, 256]], - [ - TritonSkinnyGemmHeuristic.name, - DuplicateTritonSkinnyGemmHeuristic.name, - ], + [TritonSkinnyGemmHeuristic.name], + ), + ( + "blackwell", + BLACKWELL_HARDWARE, + [self._matmul_fact()], + [], + [], + ), + ( + "balanced_shape", + HOPPER_HARDWARE, + [self._matmul_fact(static_m=4096, static_n=4096)], + [], + [], + ), + ( + "multiple_matmuls", + HOPPER_HARDWARE, + [self._matmul_fact(), self._matmul_fact()], + [], + [], ), - (BLACKWELL_HARDWARE, [], []), ) - for hardware, expected_block_sizes, expected_heuristics in cases: + for name, hardware, facts, expected_block_sizes, expected_heuristics in cases: + env = self._make_triton_env_with_block_sizes() + env.config_spec.matmul_facts.extend(facts) with ( - self.subTest(hardware=hardware.compute_capability), + self.subTest(name=name), patch( "helion._hardware.get_hardware_info", return_value=hardware, ), - patch( - "helion._compiler.seed_heuristics.HEURISTICS_BY_BACKEND", - {"triton": duplicate_heuristics}, - ), ): configs = compiler_seed_configs(env, MagicMock()) @@ -319,39 +332,11 @@ class DuplicateTritonSkinnyGemmHeuristic(TritonSkinnyGemmHeuristic): expected_block_sizes, ) self.assertEqual( - env.config_spec.compiler_seed_heuristics, + env.config_spec.autotuner_heuristics, expected_heuristics, ) - def test_triton_skinny_gemm_seed_skinny_n_returns_target_blocks(self) -> None: - env = self._make_triton_env_with_block_sizes( - m_max=1024, - n_max=8192, - k_max=4096, - ) - env.config_spec.matmul_facts.append( - self._matmul_fact(static_m=1024, static_n=8192, static_k=4096) - ) - - config = TritonSkinnyGemmHeuristic.get_config(env, MagicMock()) - - self.assertEqual(config.config["block_sizes"], [64, 64, 256]) - - def test_triton_skinny_gemm_seed_skinny_m_returns_target_blocks(self) -> None: - env = self._make_triton_env_with_block_sizes( - m_max=8192, - n_max=1024, - k_max=4096, - ) - env.config_spec.matmul_facts.append( - self._matmul_fact(static_m=8192, static_n=1024, static_k=4096) - ) - - config = TritonSkinnyGemmHeuristic.get_config(env, MagicMock()) - - self.assertEqual(config.config["block_sizes"], [64, 64, 256]) - - def test_triton_skinny_gemm_seed_caps_at_static_dim(self) -> None: + def test_triton_skinny_gemm_seed_clamps_to_static_dims(self) -> None: env = self._make_triton_env_with_block_sizes( m_max=16, n_max=8192, @@ -361,62 +346,11 @@ def test_triton_skinny_gemm_seed_caps_at_static_dim(self) -> None: self._matmul_fact(static_m=16, static_n=8192, static_k=128) ) - config = TritonSkinnyGemmHeuristic.get_config(env, MagicMock()) + config = TritonSkinnyGemmHeuristic.get_seed_config(env, MagicMock()) + assert config is not None self.assertEqual(config.config["block_sizes"], [16, 64, 128]) - def test_triton_skinny_gemm_seed_returns_none_when_floor_violated(self) -> None: - env = self._make_triton_env_with_block_sizes( - m_max=1024, - n_max=8192, - k_max=4096, - ) - env.config_spec.block_sizes.block_id_lookup(0).autotuner_min = 256 - env.config_spec.matmul_facts.append( - self._matmul_fact(static_m=1024, static_n=8192, static_k=4096) - ) - - with patch( - "helion._hardware.get_hardware_info", - return_value=HOPPER_HARDWARE, - ): - self.assertFalse(TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock())) - - def test_triton_skinny_gemm_seed_returns_none_when_block_id_missing(self) -> None: - env = self._make_triton_env_with_block_sizes() - env.config_spec.matmul_facts.append(self._matmul_fact(m_block_id=None)) - - with patch( - "helion._hardware.get_hardware_info", - return_value=HOPPER_HARDWARE, - ): - self.assertFalse(TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock())) - - def test_triton_skinny_gemm_seed_requires_single_matmul_fact(self) -> None: - env = self._make_triton_env_with_block_sizes() - env.config_spec.matmul_facts.append(self._matmul_fact()) - env.config_spec.matmul_facts.append(self._matmul_fact()) - - with patch( - "helion._hardware.get_hardware_info", - return_value=HOPPER_HARDWARE, - ): - self.assertFalse(TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock())) - self.assertEqual(compiler_seed_configs(env, MagicMock()), []) - self.assertEqual(env.config_spec.compiler_seed_heuristics, []) - - def test_triton_skinny_gemm_seed_rejects_batched_matmul_fact(self) -> None: - env = self._make_triton_env_with_block_sizes() - env.config_spec.matmul_facts.append(self._matmul_fact(lhs_ndim=3, rhs_ndim=3)) - - with patch( - "helion._hardware.get_hardware_info", - return_value=HOPPER_HARDWARE, - ): - self.assertFalse(TritonSkinnyGemmHeuristic.is_eligible(env, MagicMock())) - self.assertEqual(compiler_seed_configs(env, MagicMock()), []) - self.assertEqual(env.config_spec.compiler_seed_heuristics, []) - @onlyBackends(["triton"]) @skipIfRefEager("Compiler seed configs are not generated in ref eager mode") def test_triton_skinny_gemm_seed_in_initial_population(self) -> None: @@ -495,110 +429,32 @@ def assert_skinny_gemm_seeded(configs: list[helion.Config]) -> None: config.config["block_sizes"] for config in bound.config_spec.compiler_seed_configs ] - seed_pair_block_sizes = [ - config.config["block_sizes"] - for _flat, config in config_gen.seed_flat_config_pairs() - ] if expect_seed: self.assertIn( TritonSkinnyGemmHeuristic.name, - bound.config_spec.compiler_seed_heuristics, + bound.config_spec.autotuner_heuristics, ) self.assertTrue( heuristic.is_eligible(bound.env, bound.host_function.device_ir) ) + seed_config = heuristic.get_seed_config( + bound.env, bound.host_function.device_ir + ) + assert seed_config is not None self.assertEqual( - heuristic.get_config( - bound.env, bound.host_function.device_ir - ).config["block_sizes"], + seed_config.config["block_sizes"], seed_block_sizes, ) self.assertIn(seed_block_sizes, compiler_seed_block_sizes) - self.assertIn(seed_block_sizes, seed_pair_block_sizes) - - zero_flat = config_gen.random_population_flat(0) - self.assertEqual(len(zero_flat), 1) - zero_config = config_gen.unflatten(zero_flat[0]) - self.assertNotEqual( - zero_config.config["block_sizes"], - seed_block_sizes, - ) - one_flat = config_gen.random_population_flat(1) - self.assertEqual(len(one_flat), 1) - one_config = config_gen.unflatten(one_flat[0]) - self.assertNotEqual( - one_config.config["block_sizes"], - seed_block_sizes, - ) - one_config_population = config_gen.random_population(1) - self.assertEqual(len(one_config_population), 1) - self.assertNotEqual( - one_config_population[0].config["block_sizes"], - seed_block_sizes, - ) assert_skinny_gemm_seeded(config_gen.random_population(2)) - - acf_config_gen = bound.config_spec.create_config_generation( - advanced_controls_files=["/tmp/helion-test.acf"] - ) - acf_configs = acf_config_gen.random_population(2) - # Future heuristics may add more compiler seeds; this test - # only requires the skinny GEMM seed to be present. - self.assertGreaterEqual(len(acf_configs), 2) - self.assertEqual( - { - config.config["advanced_controls_file"] - for config in acf_configs - }, - {"/tmp/helion-test.acf"}, - ) - assert_skinny_gemm_seeded(acf_configs) - - with patch.object( - PatternSearch, "_find_similar_cached_configs", return_value=[] - ): - search = PatternSearch( - bound, - args, - initial_population=30, - initial_population_strategy=InitialPopulationStrategy.FROM_BEST_AVAILABLE, - best_available_pad_random=False, - ) - configs = [ - search.config_gen.unflatten(flat) - for flat in search._generate_initial_population_flat() - ] - # Future heuristics may add more compiler seeds; this test - # only requires the skinny GEMM seed to be present. - self.assertGreaterEqual(len(configs), 2) - assert_skinny_gemm_seeded(configs) else: self.assertFalse( heuristic.is_eligible(bound.env, bound.host_function.device_ir) ) self.assertNotIn( TritonSkinnyGemmHeuristic.name, - bound.config_spec.compiler_seed_heuristics, - ) - with patch.object( - PatternSearch, "_find_similar_cached_configs", return_value=[] - ): - search = PatternSearch( - bound, - args, - initial_population=30, - initial_population_strategy=InitialPopulationStrategy.FROM_BEST_AVAILABLE, - best_available_pad_random=False, - ) - configs = [ - search.config_gen.unflatten(flat) - for flat in search._generate_initial_population_flat() - ] - self.assertGreaterEqual(len(configs), 1) - self.assertNotIn( - TritonSkinnyGemmHeuristic.name, - bound.config_spec.compiler_seed_heuristics, + bound.config_spec.autotuner_heuristics, ) @@ -650,13 +506,17 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: heuristic = CuteTcgen05ClusterM2Heuristic self.assertIn( CuteTcgen05ClusterM2Heuristic.name, - bound.config_spec.compiler_seed_heuristics, + bound.config_spec.autotuner_heuristics, ) self.assertTrue( heuristic.is_eligible(bound.env, bound.host_function.device_ir) ) + seed_config = heuristic.get_seed_config( + bound.env, bound.host_function.device_ir + ) + assert seed_config is not None self._assert_cute_tcgen05_cluster_m2_seeded( - [heuristic.get_config(bound.env, bound.host_function.device_ir)], + [seed_config], ) with patch_cute_mma_support(default_cute_mma_support(tcgen05_f16bf16=False)): @@ -673,7 +533,7 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ) self.assertNotIn( CuteTcgen05ClusterM2Heuristic.name, - unsupported_bound.config_spec.compiler_seed_heuristics, + unsupported_bound.config_spec.autotuner_heuristics, ) @onlyBackends(["cute"]) @@ -698,7 +558,7 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: bound = cute_matmul_mma.bind(args) self.assertIn( CuteTcgen05ClusterM2Heuristic.name, - bound.config_spec.compiler_seed_heuristics, + bound.config_spec.autotuner_heuristics, ) config_gen = bound.config_spec.create_config_generation() From 82de6fa16ba8775e8f36adfeb91606164bb7a628 Mon Sep 17 00:00:00 2001 From: eche Date: Mon, 11 May 2026 08:27:16 -0700 Subject: [PATCH 7/9] Update tcgen05 seed gate test for compiler seeds --- test/test_dot_requirements.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_dot_requirements.py b/test/test_dot_requirements.py index 01e11f64b9..0f3d186371 100644 --- a/test/test_dot_requirements.py +++ b/test/test_dot_requirements.py @@ -427,7 +427,9 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # shape where it has no productive lever. self.assertEqual(spec._tcgen05_cluster_m_search_choices, (1,)) self.assertIsNone(spec._tcgen05_cluster_m2_search_constraints) - self.assertEqual(spec.autotune_seed_configs(), []) + self.assertEqual(spec.compiler_seed_configs, []) + self.assertEqual(spec.autotuner_heuristics, []) + self.assertEqual(ConfigGeneration(spec).seed_flat_config_pairs(), []) # Persistent pid types are still allowed (the static- # full-tile gate above this is unaffected) — only the # cluster_m search arm narrows. From 2125d311e6f58d057c221ffb8c9407481ca251b6 Mon Sep 17 00:00:00 2001 From: eche Date: Mon, 11 May 2026 08:32:40 -0700 Subject: [PATCH 8/9] Narrow tcgen05 seed gate assertion --- test/test_dot_requirements.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_dot_requirements.py b/test/test_dot_requirements.py index 0f3d186371..73f46b4447 100644 --- a/test/test_dot_requirements.py +++ b/test/test_dot_requirements.py @@ -8,6 +8,7 @@ import helion from helion import _compat +from helion._compiler.autotuner_heuristics.cute import CuteTcgen05ClusterM2Heuristic from helion._compiler.cute.strategies import ROLE_LOCAL_MONOLITHIC_DEFAULT_WARP_SPEC from helion._compiler.cute.strategies import Tcgen05LayoutOverrides from helion._compiler.cute.strategies import Tcgen05LayoutStrategy @@ -427,9 +428,12 @@ def cute_matmul_mma(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # shape where it has no productive lever. self.assertEqual(spec._tcgen05_cluster_m_search_choices, (1,)) self.assertIsNone(spec._tcgen05_cluster_m2_search_constraints) - self.assertEqual(spec.compiler_seed_configs, []) - self.assertEqual(spec.autotuner_heuristics, []) - self.assertEqual(ConfigGeneration(spec).seed_flat_config_pairs(), []) + # Keep this assertion scoped to the cluster_m=2 seed heuristic: + # future unrelated heuristics may still apply to these shapes. + self.assertNotIn( + CuteTcgen05ClusterM2Heuristic.name, + spec.autotuner_heuristics, + ) # Persistent pid types are still allowed (the static- # full-tile gate above this is unaffected) — only the # cluster_m search arm narrows. From 3bf3445b9588d3031fcba7ce296329a983e58577 Mon Sep 17 00:00:00 2001 From: eche Date: Mon, 11 May 2026 09:44:23 -0700 Subject: [PATCH 9/9] Add autotuner heuristic disable setting --- .../autotuner_heuristics/__init__.py | 3 ++ helion/runtime/settings.py | 10 +++++ test/test_autotuner_heuristics.py | 39 +++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/helion/_compiler/autotuner_heuristics/__init__.py b/helion/_compiler/autotuner_heuristics/__init__.py index 9089519ed6..814f47ddc4 100644 --- a/helion/_compiler/autotuner_heuristics/__init__.py +++ b/helion/_compiler/autotuner_heuristics/__init__.py @@ -32,6 +32,9 @@ def compiler_seed_configs( ) -> list[Config]: configs: list[Config] = [] env.config_spec.autotuner_heuristics = [] + if env.settings.disable_autotuner_heuristics: + return configs + for heuristic in get_heuristics(env.backend_name): try: if not heuristic.is_eligible(env, device_ir): diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 51af559a28..5d39f11320 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -481,6 +481,11 @@ class _Settings: default_factory=_get_autotune_config_overrides ) autotune_seed_configs: ConfigLike | Sequence[ConfigLike] | None = None + disable_autotuner_heuristics: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_DISABLE_AUTOTUNER_HEURISTICS", False + ) + ) autotune_effort: AutotuneEffort = dataclasses.field( default_factory=functools.partial( _env_get_literal, @@ -628,6 +633,11 @@ class Settings(_Settings): "A Config or sequence of Configs to seed the autotuner initial population " "without constraining the search space." ), + "disable_autotuner_heuristics": ( + "If True, disable compiler/autotuner heuristics such as compiler seed " + "configs. User-provided autotune_seed_configs are unaffected. " + "Set HELION_DISABLE_AUTOTUNER_HEURISTICS=1 to disable globally." + ), "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", "debug_dtype_asserts": "If True, emit tl.static_assert checks for dtype after each device node.", "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", diff --git a/test/test_autotuner_heuristics.py b/test/test_autotuner_heuristics.py index 9bfeeef88a..b85ec2a8ab 100644 --- a/test/test_autotuner_heuristics.py +++ b/test/test_autotuner_heuristics.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from unittest.mock import MagicMock from unittest.mock import patch @@ -28,6 +29,7 @@ from helion.autotuner.pattern_search import InitialPopulationStrategy from helion.autotuner.pattern_search import PatternSearch import helion.language as hl +from helion.runtime.settings import Settings HOPPER_HARDWARE = HardwareInfo( device_kind="cuda", @@ -50,6 +52,17 @@ class TestAutotunerHeuristic(TestCase): + def test_disable_autotuner_heuristics_setting_env(self) -> None: + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("HELION_DISABLE_AUTOTUNER_HEURISTICS", None) + self.assertFalse(Settings().disable_autotuner_heuristics) + + with patch.dict( + os.environ, + {"HELION_DISABLE_AUTOTUNER_HEURISTICS": "1"}, + ): + self.assertTrue(Settings().disable_autotuner_heuristics) + def test_compiler_seed_configs_handles_failed_optional_and_duplicate_seeds( self, ) -> None: @@ -91,6 +104,7 @@ class DuplicateAutotunerHeuristic(ValidAutotunerHeuristic): env = MagicMock() env.backend_name = "triton" env.config_spec = MagicMock() + env.settings = Settings() heuristics = ( FailingAutotunerHeuristic, NoSeedAutotunerHeuristic, @@ -117,6 +131,30 @@ class DuplicateAutotunerHeuristic(ValidAutotunerHeuristic): self.assertIn(FailingAutotunerHeuristic.name, "\n".join(logs.output)) self.assertIn("synthetic compiler seed failure", "\n".join(logs.output)) + def test_compiler_seed_configs_respects_disable_setting(self) -> None: + class EnabledAutotunerHeuristic(AutotunerHeuristic): + name = "enabled_autotuner_heuristic" + backend = "triton" + + @classmethod + def is_eligible(cls, env: object, device_ir: object) -> bool: + raise AssertionError("disabled heuristics should not be queried") + + env = MagicMock() + env.backend_name = "triton" + env.config_spec = MagicMock() + env.config_spec.autotuner_heuristics = ["stale"] + env.settings = Settings(disable_autotuner_heuristics=True) + + with patch( + "helion._compiler.autotuner_heuristics.HEURISTICS_BY_BACKEND", + {"triton": (EnabledAutotunerHeuristic,)}, + ): + configs = compiler_seed_configs(env, MagicMock()) + + self.assertEqual(configs, []) + self.assertEqual(env.config_spec.autotuner_heuristics, []) + def test_seed_flat_config_pairs_skips_invalid_compiler_seed(self) -> None: spec = ConfigSpec(backend=TritonBackend()) spec.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=1024)) @@ -248,6 +286,7 @@ def _make_triton_env_with_block_sizes( env.backend_name = "triton" env.config_spec = spec env.device = DEVICE + env.settings = Settings() return env def _matmul_fact(