Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .logger import AutotuningLogger
from .metrics import AutotuneMetrics
from .metrics import _run_post_autotune_hooks
from .observed_heuristics import observed_heuristic_seed_configs_for_kernel
from .precompile_future import PrecompileFuture as PrecompileFuture
from helion._dist_utils import all_gather_object
from helion._dist_utils import is_master_rank
Expand Down Expand Up @@ -741,27 +742,25 @@ def make_unbenchmarked(self, flat_values: FlatConfig) -> PopulationMember | None

def _generate_best_available_population_flat(self) -> list[FlatConfig]:
"""
Generate initial population using default config, explicit seed configs,
and cached configs.
Generate initial population using explicit seed configs, observed
heuristic seeds, default fallback, and cached configs.

Always starts with the default configuration, then adds up to
MAX_BEST_AVAILABLE_CONFIGS matching cached configs from previous runs.
Explicit seed configs provided by the caller are added ahead of cached
configs and are not suppressed by cache-skip settings. No random configs
are added. Duplicate configs are discarded.
Exact observed heuristic matches replace the default slot. If there are
no explicit or observed seeds, the default configuration is used as the
fallback. Then up to MAX_BEST_AVAILABLE_CONFIGS matching cached configs
from previous runs are added. Explicit seed configs provided by the
caller are added ahead of observed heuristic and cached configs and are
not suppressed by cache-skip settings.
No random configs are added. Duplicate configs are discarded.

Returns:
A list of unique FlatConfig values for the initial population.
Minimum size is 1 (just default), plus any valid unique explicit
seed configs and up to autotune_best_available_max_configs cached
configs.
Minimum size is 1, either from an explicit/observed seed or from the
default fallback, plus up to autotune_best_available_max_configs
cached configs.
"""
# Always start with the default config
default_flat = self.config_gen.default_flat()
default_config = self.config_gen.unflatten(default_flat)
seen: set[Config] = {default_config}
result: list[FlatConfig] = [default_flat]
self.log("Starting with default config")
seen: set[Config] = set()
result: list[FlatConfig] = []

for config in self._best_available_seed_configs:
try:
Expand All @@ -774,6 +773,30 @@ def _generate_best_available_population_flat(self) -> list[FlatConfig]:
self.log(f"Failed to transfer explicit seed config: {e}")

max_configs = self.settings.autotune_best_available_max_configs
observed_configs = observed_heuristic_seed_configs_for_kernel(
self.kernel,
self.args,
config_spec=self.config_spec,
max_configs=max_configs,
)

for i, config in enumerate(observed_configs):
flat = self.config_gen.flatten(config)
transferred_config = self.config_gen.unflatten(flat)
if transferred_config not in seen:
seen.add(transferred_config)
result.append(flat)
self.log.debug(
f"Observed heuristic seed config {i + 1}: {transferred_config}"
)

if not result:
default_flat = self.config_gen.default_flat()
default_config = self.config_gen.unflatten(default_flat)
seen.add(default_config)
result.append(default_flat)
self.log("Starting with default config")

cached_entries = self._find_similar_cached_configs(max_configs)

if cached_entries:
Expand Down
57 changes: 57 additions & 0 deletions helion/autotuner/config_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""Validation helpers for sparse autotune config dictionaries."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .config_spec import ConfigSpec


def is_positive_power_of_two_int(val: object) -> bool:
"""Return whether `val` is a strictly positive power-of-two integer."""
return type(val) is int and val > 0 and (val & (val - 1)) == 0


def _expected_json_array_length(field: object) -> int | None:
"""Return the required JSON array length for fixed-length sequence fields."""
from .block_id_sequence import BlockIdSequence
from .config_fragment import ListOf

if isinstance(field, BlockIdSequence):
return len(field)
if isinstance(field, ListOf):
return field.length
return None


def _validate_json_array_length(key: str, val: object, *, expected_len: int) -> None:
"""Validate that a JSON value is a list with the expected fixed length."""
if not isinstance(val, list):
raise ValueError(
f"{key} must be a JSON array of length {expected_len}, got {val!r}"
)
if len(val) != expected_len:
raise ValueError(f"{key} must have length {expected_len}, got {len(val)}")


def validate_sparse_config_shape(
raw: dict[str, object], *, config_spec: ConfigSpec
) -> None:
"""Reject sparse config shape/type mismatches instead of silently repairing them."""
flat_fields = config_spec._flat_fields()
for key, val in raw.items():
field = flat_fields.get(key)
if key == "num_warps" and not is_positive_power_of_two_int(val):
raise ValueError(
f"num_warps must be a positive power-of-two integer, got {val!r}"
)

if (expected_len := _expected_json_array_length(field)) is not None:
_validate_json_array_length(key, val, expected_len=expected_len)
continue

if field is not None and isinstance(val, list):
raise ValueError(f"{key} must be a scalar value, got {val!r}")
Loading