Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions helion/autotuner/config_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,26 @@ def differential_mutation(self, a: object, b: object, c: object) -> int:
return a


class DefaultBiasedIntegerFragment(IntegerFragment):
"""Integer fragment that samples its default value most of the time."""

def __init__(
self,
low: int,
high: int,
default_val: int | None = None,
*,
default_probability: float,
) -> None:
super().__init__(low, high, default_val)
self.default_probability = default_probability

def random(self) -> int:
if random.random() < self.default_probability:
return self.default()
return super().random()


@dataclasses.dataclass
class EnumFragment(ConfigSpecFragment):
choices: tuple[object, ...]
Expand Down Expand Up @@ -285,6 +305,17 @@ def encode(self, value: object) -> list[float]:
return [1.0 if i == choice_idx else 0.0 for i in range(len(self.choices))]


@dataclasses.dataclass
class DefaultBiasedEnumFragment(EnumFragment):
default_probability: float

def random(self) -> object:
if random.random() < self.default_probability:
return self.default()
choices = [choice for choice in self.choices if choice != self.default()]
return random.choice(choices)


class BooleanFragment(ConfigSpecFragment):
def default(self) -> bool:
return False
Expand Down
24 changes: 21 additions & 3 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@
from .config_fragment import BlockSizeFragment
from .config_fragment import BooleanFragment
from .config_fragment import ConfigSpecFragment
from .config_fragment import DefaultBiasedEnumFragment
from .config_fragment import DefaultBiasedIntegerFragment
from .config_fragment import EnumFragment
from .config_fragment import IntegerFragment
from .config_fragment import ListOf
Expand Down Expand Up @@ -347,6 +349,8 @@ def _get_backend_tunable_keys() -> frozenset[str]:
EPILOGUE_SUBTILE_DEFAULT_CHOICES = (None, 2)
EPILOGUE_SUBTILE_MIN_K_HINT = 1024
EPILOGUE_SUBTILE_MIN_K_HINT_EXTENDED = 16384
RANGE_INT_RANDOM_DEFAULT_PROBABILITY = 0.95
RANGE_WARP_SPECIALIZE_RANDOM_DEFAULT_PROBABILITY = 0.90
# maxnreg values: None means no limit, otherwise limit to this many registers per thread
# Lower values allow higher occupancy but may hurt performance for register-heavy kernels
VALID_MAXNREG = (None, 32, 64, 128, 256)
Expand Down Expand Up @@ -3056,16 +3060,30 @@ def _fill_missing(self) -> None:

class RangeUnrollFactorSpec(_OptionalIntSpec):
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
return IntegerFragment(0, 4, 0)
return DefaultBiasedIntegerFragment(
0,
4,
0,
default_probability=RANGE_INT_RANDOM_DEFAULT_PROBABILITY,
)


class RangeWarpSpecializeSpec(_OptionalBoolSpec):
pass
def _fragment(self, base: ConfigSpec) -> EnumFragment:
return DefaultBiasedEnumFragment(
(None, False, True),
default_probability=RANGE_WARP_SPECIALIZE_RANDOM_DEFAULT_PROBABILITY,
)


class RangeNumStagesSpec(_OptionalIntSpec):
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
return IntegerFragment(0, 4, 0)
return DefaultBiasedIntegerFragment(
0,
4,
0,
default_probability=RANGE_INT_RANDOM_DEFAULT_PROBABILITY,
)


class RangeMultiBufferSpec(_OptionalBoolSpec):
Expand Down
Loading
Loading