Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,18 @@ def autotune(
if bound_kernel.settings.autotune_effort == "none" and (
force or not bound_kernel.kernel.configs
):
config = bound_kernel.config_spec.default_config()
from ..autotuner.matmul_heuristics import (
matmul_heuristic_default_config_for_kernel,
)

config = (
matmul_heuristic_default_config_for_kernel(
bound_kernel,
args,
config_spec=bound_kernel.config_spec,
)
or bound_kernel.config_spec.default_config()
)
elif not force and bound_kernel.kernel.configs:
if len(bound_kernel.kernel.configs) == 1:
(config,) = bound_kernel.kernel.configs
Expand Down
64 changes: 52 additions & 12 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from .benchmark_provider import _unset_fn
from .benchmarking import interleaved_bench
from .logger import AutotuningLogger
from .matmul_heuristics import matmul_heuristic_seed_configs_for_kernel
from .matmul_heuristics import matmul_heuristics_supported_on_args
from .metrics import AutotuneMetrics
from .metrics import _run_post_autotune_hooks
from .precompile_future import PrecompileFuture as PrecompileFuture
Expand Down Expand Up @@ -687,6 +689,26 @@ def get_kwargs_from_profile(
**super().get_kwargs_from_profile(profile, settings),
}

def _heuristic_seed_configs(self, max_configs: int = 1) -> list[Config]:
if not matmul_heuristics_supported_on_args(self.args):
return []
return matmul_heuristic_seed_configs_for_kernel(
self.kernel,
self.args,
config_spec=self.config_gen.config_spec,
max_configs=max_configs,
)

def _autotune_seed_configs_with_heuristics(self) -> list[Config]:
return [*self._heuristic_seed_configs(), *self._autotune_seed_configs()]

def _random_population_flat_with_heuristics(self, n: int) -> list[FlatConfig]:
return self.config_gen.random_population_flat(
n,
user_seed_configs=self._autotune_seed_configs_with_heuristics(),
log_func=self.log,
)

@property
def best(self) -> PopulationMember:
"""
Expand Down Expand Up @@ -778,24 +800,43 @@ def _generate_best_available_population_flat(self) -> list[FlatConfig]:
Generate initial population using default config, explicit seed configs,
and cached configs.

Always starts with the default configuration, then adds up to
Starts with a matching heuristic config when available, otherwise starts
with the default configuration, then adds up to
MAX_BEST_AVAILABLE_CONFIGS matching cached configs from previous runs.
Explicit seed configs provided by the caller are added ahead of cached
configs and are not suppressed by cache-skip settings. No random configs
are added. Duplicate configs are discarded.

Returns:
A list of unique FlatConfig values for the initial population.
Minimum size is 1 (just default), plus any valid unique explicit
seed configs and up to autotune_best_available_max_configs cached
configs.
A list of unique FlatConfig values for the initial population. Minimum
size is 1 (heuristic or default), plus any valid unique explicit seed
configs and up to autotune_best_available_max_configs cached configs.
"""
# Always start with the default config
default_flat = self.config_gen.default_flat()
default_config = self.config_gen.unflatten(default_flat)
seen: set[Config] = {default_config}
result: list[FlatConfig] = [default_flat]
self.log("Starting with default config")
max_configs = self.settings.autotune_best_available_max_configs
matmul_heuristic_configs = self._heuristic_seed_configs(max_configs=max_configs)

seen: set[Config] = set()
result: list[FlatConfig] = []
for i, config in enumerate(matmul_heuristic_configs):
try:
flat = self.config_gen.flatten(config)
transferred_config = self.config_gen.unflatten(flat)
seen.add(transferred_config)
result.append(flat)
self.log(
f"Starting with matmul heuristic config {i + 1}: "
f"{transferred_config}"
)
break
except (ValueError, TypeError, KeyError, AssertionError) as e:
self.log(f"Failed to transfer matmul initial config: {e}")

if not result:
default_flat = self.config_gen.default_flat()
default_config = self.config_gen.unflatten(default_flat)
seen.add(default_config)
result.append(default_flat)
self.log("Starting with default config")

# User seed configs are explicit requests, so try them before compiler-owned
# seeds and cached configs while still deduplicating normalized configs.
Expand Down Expand Up @@ -825,7 +866,6 @@ def _generate_best_available_population_flat(self) -> list[FlatConfig]:
except (ValueError, TypeError, KeyError, AssertionError) as e:
self.log(f"Failed to transfer explicit seed config: {e}")

max_configs = self.settings.autotune_best_available_max_configs
cached_entries = self._find_similar_cached_configs(max_configs)

if cached_entries:
Expand Down
6 changes: 1 addition & 5 deletions helion/autotuner/differential_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,7 @@ def _generate_initial_population_flat(self) -> list[FlatConfig]:
return pop[:target]
return pop

return self.config_gen.random_population_flat(
self.population_size * 2,
user_seed_configs=self._autotune_seed_configs(),
log_func=self.log,
)
return self._random_population_flat_with_heuristics(self.population_size * 2)

def initial_two_generations(self) -> None:
# The initial population is 2x larger so we can throw out the slowest half and give the tuning process a head start
Expand Down
Loading
Loading