-
Notifications
You must be signed in to change notification settings - Fork 147
[compiler][autotuner] Autotuner heuristics #2392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
e7f46c9
Add compiler seed heuristics
e9612a7
Make compiler seed heuristics best effort
813817c
Merge remote-tracking branch 'origin/main' into compiler-seed-heuristics
a285ccf
Harden seed heuristic tests without GPU
7513279
Document Triton skinny GEMM heuristic source
8497da3
Skip seed heuristic compiler tests in ref eager
9082997
Generalize seed heuristics as autotuner heuristics
0cc6043
Merge remote-tracking branch 'origin/main' into compiler-seed-heuristics
82de6fa
Update tcgen05 seed gate test for compiler seeds
2125d31
Narrow tcgen05 seed gate assertion
3bf3445
Add autotuner heuristic disable setting
4ee76e9
Merge branch 'main' into compiler-seed-heuristics
ethche File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from .common import dedupe_configs | ||
| from .cute import CuteTcgen05ClusterM2Heuristic | ||
| from .triton import TritonSkinnyGemmHeuristic | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ...runtime.config import Config | ||
| from ..compile_environment import CompileEnvironment | ||
| from ..device_ir import DeviceIR | ||
| from .registry import AutotunerHeuristicType | ||
|
|
||
| # All active heuristics by backend | ||
| HEURISTICS_BY_BACKEND: dict[str, tuple[AutotunerHeuristicType, ...]] = { | ||
| "cute": (CuteTcgen05ClusterM2Heuristic,), | ||
| "triton": (TritonSkinnyGemmHeuristic,), | ||
| } | ||
|
|
||
| log: logging.Logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def get_heuristics(backend: str) -> tuple[AutotunerHeuristicType, ...]: | ||
| return HEURISTICS_BY_BACKEND.get(backend, ()) | ||
|
|
||
|
|
||
| def compiler_seed_configs( | ||
| env: CompileEnvironment, | ||
| device_ir: DeviceIR, | ||
| ) -> list[Config]: | ||
| configs: list[Config] = [] | ||
| env.config_spec.autotuner_heuristics = [] | ||
| if env.settings.disable_autotuner_heuristics: | ||
| return configs | ||
|
|
||
| for heuristic in get_heuristics(env.backend_name): | ||
| try: | ||
| if not heuristic.is_eligible(env, device_ir): | ||
| continue | ||
|
|
||
| config = heuristic.get_seed_config(env, device_ir) | ||
| except Exception as e: | ||
| log.debug( | ||
| "Autotuner heuristic %s failed while generating compiler seed config: %s", | ||
| heuristic.name, | ||
| e, | ||
| exc_info=True, | ||
| ) | ||
| continue | ||
| if config is None: | ||
| continue | ||
| configs.append(config) | ||
| env.config_spec.autotuner_heuristics.append(heuristic.name) | ||
| return dedupe_configs(configs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
| from typing import Iterable | ||
| from typing import Sequence | ||
| from typing import cast | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ...autotuner.config_spec import BlockSizeSpec | ||
| from ...runtime.config import Config | ||
| from ..compile_environment import CompileEnvironment | ||
|
|
||
| HardwareTarget = tuple[str, str | None] | ||
|
|
||
|
|
||
| def dedupe_configs(configs: Iterable[Config]) -> list[Config]: | ||
| result: list[Config] = [] | ||
| seen: set[Config] = set() | ||
| for config in configs: | ||
| if config in seen: | ||
| continue | ||
| seen.add(config) | ||
| result.append(config) | ||
| return result | ||
|
|
||
|
|
||
| def matches_hardware( | ||
| env: CompileEnvironment, | ||
| targets: tuple[HardwareTarget, ...], | ||
| ) -> bool: | ||
| from ..._hardware import get_hardware_info | ||
|
|
||
| hardware = get_hardware_info(env.device) | ||
| return (hardware.device_kind, hardware.compute_capability) in targets or ( | ||
| hardware.device_kind, | ||
| None, | ||
| ) in targets | ||
|
|
||
|
|
||
| def clamp_block_size_targets( | ||
| env: CompileEnvironment, | ||
| block_dims: Sequence[tuple[int, int, int]], | ||
| ) -> list[int] | None: | ||
| """Clamp block-size targets against the live ConfigSpec constraints. | ||
|
|
||
| Each entry in *block_dims* is ``(block_id, static_dim, target)``. | ||
| Returns the clamped block sizes, or ``None`` if any axis cannot | ||
| satisfy its floor/ceiling constraints. | ||
| """ | ||
| block_sizes: list[int] = [] | ||
| for block_id, static_dim, target in block_dims: | ||
| try: | ||
| spec = cast( | ||
| "BlockSizeSpec", | ||
| env.config_spec.block_sizes.block_id_lookup(block_id), | ||
| ) | ||
| except KeyError: | ||
| return None | ||
| candidate = min(target, static_dim) | ||
| if candidate < 1: | ||
| return None | ||
| candidate = 1 << (candidate.bit_length() - 1) | ||
| floor = max(spec.min_size, spec.autotuner_min) | ||
| if candidate < floor: | ||
| return None | ||
| candidate = min(candidate, spec.max_size) | ||
| if candidate < floor: | ||
| return None | ||
| block_sizes.append(candidate) | ||
| return block_sizes |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
| from typing import cast | ||
|
|
||
| from ...runtime.config import Config | ||
| from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_M | ||
| from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N | ||
| from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_L2_GROUPING | ||
| from ..cute.tcgen05_constants import TCGEN05_TWO_CTA_SEED_PID_TYPE | ||
| from .registry import AutotunerHeuristic | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ...autotuner.config_fragment import BlockSizeFragment | ||
| from ..compile_environment import CompileEnvironment | ||
| from ..device_ir import DeviceIR | ||
|
|
||
|
|
||
| class CuteTcgen05ClusterM2Heuristic(AutotunerHeuristic): | ||
| name = "cute_tcgen05_cluster_m2" | ||
| backend = "cute" | ||
|
|
||
| @classmethod | ||
| def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: | ||
| spec = env.config_spec | ||
| constraints = spec._tcgen05_cluster_m2_search_constraints | ||
| if ( | ||
| constraints is None | ||
| or TCGEN05_TWO_CTA_SEED_PID_TYPE not in spec.allowed_pid_types | ||
| ): | ||
| return False | ||
| if len(spec.block_sizes) != 3: | ||
| return False | ||
|
|
||
| bm_fragment = cast("BlockSizeFragment", spec.block_sizes[0]._fragment(spec)) | ||
| bn_fragment = cast("BlockSizeFragment", spec.block_sizes[1]._fragment(spec)) | ||
| return ( | ||
| bm_fragment.low <= TCGEN05_TWO_CTA_BLOCK_M <= bm_fragment.high | ||
| and bn_fragment.low <= TCGEN05_TWO_CTA_BLOCK_N <= bn_fragment.high | ||
| and cls._select_bk(env) is not None | ||
| ) | ||
|
|
||
| @classmethod | ||
| def get_seed_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: | ||
| spec = env.config_spec | ||
| bk = cls._select_bk(env) | ||
| if bk is None: | ||
| raise AssertionError(f"{cls.name} get_seed_config called while ineligible") | ||
|
|
||
| block_sizes = [ | ||
| TCGEN05_TWO_CTA_BLOCK_M, | ||
| TCGEN05_TWO_CTA_BLOCK_N, | ||
| bk, | ||
| ] | ||
| if spec.indexing.length == 3: | ||
| # Pure matmul has exactly the A/B/C indexing slots. Fused epilogues | ||
| # add more memory ops, so leave those seeds to the spec default | ||
| # rather than constructing a partial list. | ||
| return Config( | ||
| block_sizes=block_sizes, | ||
| l2_groupings=[TCGEN05_TWO_CTA_SEED_L2_GROUPING], | ||
| pid_type=TCGEN05_TWO_CTA_SEED_PID_TYPE, | ||
| tcgen05_cluster_m=2, | ||
| tcgen05_num_epi_warps=4, | ||
| indexing=[ | ||
| "tensor_descriptor", | ||
| "tensor_descriptor", | ||
| "tensor_descriptor", | ||
| ], | ||
| ) | ||
|
|
||
| return Config( | ||
| block_sizes=[ | ||
| TCGEN05_TWO_CTA_BLOCK_M, | ||
| TCGEN05_TWO_CTA_BLOCK_N, | ||
| bk, | ||
| ], | ||
| l2_groupings=[TCGEN05_TWO_CTA_SEED_L2_GROUPING], | ||
| pid_type=TCGEN05_TWO_CTA_SEED_PID_TYPE, | ||
| tcgen05_cluster_m=2, | ||
| # Matches the validated tcgen05 search restriction. | ||
| tcgen05_num_epi_warps=4, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _select_bk(env: CompileEnvironment) -> int | None: | ||
| spec = env.config_spec | ||
| constraints = spec._tcgen05_cluster_m2_search_constraints | ||
| if constraints is None or len(spec.block_sizes) != 3: | ||
| return None | ||
| bk_fragment = cast("BlockSizeFragment", spec.block_sizes[2]._fragment(spec)) | ||
| bk = bk_fragment.high | ||
| while bk >= bk_fragment.low: | ||
| if spec._tcgen05_cluster_m2_bk_is_valid(bk, constraints): | ||
| return bk | ||
| bk //= 2 | ||
| return None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
| from typing import ClassVar | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ...runtime.config import Config | ||
| from ..compile_environment import CompileEnvironment | ||
| from ..device_ir import DeviceIR | ||
|
|
||
|
|
||
| class AutotunerHeuristic: | ||
| """Base class for compiler-owned autotuner heuristics.""" | ||
|
|
||
| name: ClassVar[str] | ||
| backend: ClassVar[str] | ||
|
|
||
| @classmethod | ||
| def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: | ||
| raise NotImplementedError | ||
|
|
||
| @classmethod | ||
| def get_seed_config( | ||
| cls, env: CompileEnvironment, device_ir: DeviceIR | ||
| ) -> Config | None: | ||
| return None | ||
|
|
||
|
|
||
| AutotunerHeuristicType = type[AutotunerHeuristic] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| from ...runtime.config import Config | ||
| from .common import clamp_block_size_targets | ||
| from .common import matches_hardware | ||
| from .registry import AutotunerHeuristic | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ..compile_environment import CompileEnvironment | ||
| from ..device_ir import DeviceIR | ||
|
|
||
|
|
||
| # Heuristic was originally contributed by @umechand-amd | ||
| # in https://github.com/pytorch/helion/pull/2357. | ||
| class TritonSkinnyGemmHeuristic(AutotunerHeuristic): | ||
| name = "triton_skinny_gemm" | ||
| backend = "triton" | ||
| MIN_ASPECT_RATIO = 8 | ||
| BLOCK_TARGETS = (64, 64, 256) | ||
| HARDWARE_TARGETS = (("cuda", "sm90"), ("rocm", "gfx950")) | ||
|
|
||
| @classmethod | ||
| def is_eligible(cls, env: CompileEnvironment, device_ir: DeviceIR) -> bool: | ||
| if not matches_hardware(env, cls.HARDWARE_TARGETS): | ||
| return False | ||
| facts = env.config_spec.matmul_facts | ||
| if len(facts) != 1: | ||
| return False | ||
| fact = facts[0] | ||
| if fact.lhs_ndim != 2 or fact.rhs_ndim != 2: | ||
| return False | ||
| if ( | ||
| fact.static_m is None | ||
| or fact.static_n is None | ||
| or fact.static_k is None | ||
| or fact.m_block_id is None | ||
| or fact.n_block_id is None | ||
| or fact.k_block_id is None | ||
| ): | ||
| return False | ||
| if max(fact.static_m, fact.static_n) < cls.MIN_ASPECT_RATIO * min( | ||
| fact.static_m, fact.static_n | ||
| ): | ||
| return False | ||
| return ( | ||
| clamp_block_size_targets( | ||
| env, | ||
| [ | ||
| (fact.m_block_id, fact.static_m, cls.BLOCK_TARGETS[0]), | ||
| (fact.n_block_id, fact.static_n, cls.BLOCK_TARGETS[1]), | ||
| (fact.k_block_id, fact.static_k, cls.BLOCK_TARGETS[2]), | ||
| ], | ||
| ) | ||
| is not None | ||
| ) | ||
|
|
||
| @classmethod | ||
| def get_seed_config(cls, env: CompileEnvironment, device_ir: DeviceIR) -> Config: | ||
| assert len(env.config_spec.matmul_facts) == 1 | ||
| fact = env.config_spec.matmul_facts[0] | ||
| assert fact.static_m is not None | ||
| assert fact.static_n is not None | ||
| assert fact.static_k is not None | ||
| assert fact.m_block_id is not None | ||
| assert fact.n_block_id is not None | ||
| assert fact.k_block_id is not None | ||
| block_sizes = clamp_block_size_targets( | ||
| env, | ||
| [ | ||
| (fact.m_block_id, fact.static_m, cls.BLOCK_TARGETS[0]), | ||
| (fact.n_block_id, fact.static_n, cls.BLOCK_TARGETS[1]), | ||
| (fact.k_block_id, fact.static_k, cls.BLOCK_TARGETS[2]), | ||
| ], | ||
| ) | ||
| assert block_sizes is not None | ||
| return Config(block_sizes=block_sizes) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking longer term we might want mutation heuristics as well (a custom function to get neighbors). Should we give this a more general name (without heuristics in the name)?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. To clarify, would this heuristic override the get_neighbors function, or seed the list of neighbors? Would be helpful to see an example.
In this case would the design look something like?
To replace the name
SeedHeuristic, how do you feel aboutTemplate. Other options I'm thinking of:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add the neighbors to the existing set.
Maybe AutotunerHeuristic?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. I re-named to AutotunerHeuristic with the following template: