diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 2667835f3f..e6668a2988 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -38,6 +38,7 @@ from .logger import AutotuningLogger from .metrics import AutotuneMetrics from .metrics import _run_post_autotune_hooks +from .observed_heuristics import observed_heuristic_seed_configs_for_kernel from .precompile_future import PrecompileFuture as PrecompileFuture from helion._dist_utils import all_gather_object from helion._dist_utils import is_master_rank @@ -741,27 +742,25 @@ def make_unbenchmarked(self, flat_values: FlatConfig) -> PopulationMember | None def _generate_best_available_population_flat(self) -> list[FlatConfig]: """ - Generate initial population using default config, explicit seed configs, - and cached configs. + Generate initial population using explicit seed configs, observed + heuristic seeds, default fallback, and cached configs. - Always starts with the default configuration, then adds up to - MAX_BEST_AVAILABLE_CONFIGS matching cached configs from previous runs. - Explicit seed configs provided by the caller are added ahead of cached - configs and are not suppressed by cache-skip settings. No random configs - are added. Duplicate configs are discarded. + Exact observed heuristic matches replace the default slot. If there are + no explicit or observed seeds, the default configuration is used as the + fallback. Then up to MAX_BEST_AVAILABLE_CONFIGS matching cached configs + from previous runs are added. Explicit seed configs provided by the + caller are added ahead of observed heuristic and cached configs and are + not suppressed by cache-skip settings. + No random configs are added. Duplicate configs are discarded. Returns: A list of unique FlatConfig values for the initial population. - Minimum size is 1 (just default), plus any valid unique explicit - seed configs and up to autotune_best_available_max_configs cached - configs. + Minimum size is 1, either from an explicit/observed seed or from the + default fallback, plus up to autotune_best_available_max_configs + cached configs. """ - # Always start with the default config - default_flat = self.config_gen.default_flat() - default_config = self.config_gen.unflatten(default_flat) - seen: set[Config] = {default_config} - result: list[FlatConfig] = [default_flat] - self.log("Starting with default config") + seen: set[Config] = set() + result: list[FlatConfig] = [] for config in self._best_available_seed_configs: try: @@ -774,6 +773,30 @@ def _generate_best_available_population_flat(self) -> list[FlatConfig]: self.log(f"Failed to transfer explicit seed config: {e}") max_configs = self.settings.autotune_best_available_max_configs + observed_configs = observed_heuristic_seed_configs_for_kernel( + self.kernel, + self.args, + config_spec=self.config_spec, + max_configs=max_configs, + ) + + for i, config in enumerate(observed_configs): + flat = self.config_gen.flatten(config) + transferred_config = self.config_gen.unflatten(flat) + if transferred_config not in seen: + seen.add(transferred_config) + result.append(flat) + self.log.debug( + f"Observed heuristic seed config {i + 1}: {transferred_config}" + ) + + if not result: + default_flat = self.config_gen.default_flat() + default_config = self.config_gen.unflatten(default_flat) + seen.add(default_config) + result.append(default_flat) + self.log("Starting with default config") + cached_entries = self._find_similar_cached_configs(max_configs) if cached_entries: diff --git a/helion/autotuner/config_validation.py b/helion/autotuner/config_validation.py new file mode 100644 index 0000000000..1a82ee159a --- /dev/null +++ b/helion/autotuner/config_validation.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +"""Validation helpers for sparse autotune config dictionaries.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .config_spec import ConfigSpec + + +def is_positive_power_of_two_int(val: object) -> bool: + """Return whether `val` is a strictly positive power-of-two integer.""" + return type(val) is int and val > 0 and (val & (val - 1)) == 0 + + +def _expected_json_array_length(field: object) -> int | None: + """Return the required JSON array length for fixed-length sequence fields.""" + from .block_id_sequence import BlockIdSequence + from .config_fragment import ListOf + + if isinstance(field, BlockIdSequence): + return len(field) + if isinstance(field, ListOf): + return field.length + return None + + +def _validate_json_array_length(key: str, val: object, *, expected_len: int) -> None: + """Validate that a JSON value is a list with the expected fixed length.""" + if not isinstance(val, list): + raise ValueError( + f"{key} must be a JSON array of length {expected_len}, got {val!r}" + ) + if len(val) != expected_len: + raise ValueError(f"{key} must have length {expected_len}, got {len(val)}") + + +def validate_sparse_config_shape( + raw: dict[str, object], *, config_spec: ConfigSpec +) -> None: + """Reject sparse config shape/type mismatches instead of silently repairing them.""" + flat_fields = config_spec._flat_fields() + for key, val in raw.items(): + field = flat_fields.get(key) + if key == "num_warps" and not is_positive_power_of_two_int(val): + raise ValueError( + f"num_warps must be a positive power-of-two integer, got {val!r}" + ) + + if (expected_len := _expected_json_array_length(field)) is not None: + _validate_json_array_length(key, val, expected_len=expected_len) + continue + + if field is not None and isinstance(val, list): + raise ValueError(f"{key} must be a scalar value, got {val!r}") diff --git a/helion/autotuner/data/observed_heuristics_b200.json b/helion/autotuner/data/observed_heuristics_b200.json new file mode 100644 index 0000000000..3433596c3e --- /dev/null +++ b/helion/autotuner/data/observed_heuristics_b200.json @@ -0,0 +1,3118 @@ +{ + "schema_version": 2, + "description": "B200 observed seed heuristics for generic autotune seed configs. Rules are exact-bucket only and limited to buckets with validated seed-over-baseline improvement. Extended with quantized-GEMM rules (matmul_int4, matmul_int16, matmul_fp4) derived from 120 fully-tuned B200 archive shapes (run 20260509_q2_llmseeded) and live-validated against a 12-shape grid per kernel (Q5/Q6 experiments). Plus per-(kernel_class, shape-group) fallbacks used when exact bucket lookup misses (5 groups for matmul family: small_m/n/k + balanced + rect).", + "policy": { + "target": "seed-only round0_best_geo <= 0.80 versus non-heuristic baseline per promoted workload", + "device": "NVIDIA B200 sm100", + "matching": "exact shape bucket only", + "validation": "h56_clean_policy_opus47_seed_validation_20260509_101149; promoted geomean 0.732x over 3 repeats; attention+GEMV geomean 0.720x; softmax_2k_8k geomean 0.839x with one neutral repeat", + "quantized_validation": "Q5 (no-autotune baseline): family heldout 0.142 (~7x vs Helion default). Q6 (LLMGuidedSearch max_rounds=1, Opus 4.7 via Bedrock): family heldout 0.663 (~34% over LLM-alone round-0). Per-rule validation blocks use q5_* and q6_* keys.", + "fallbacks": "Each fallback is the median-perf winning config for its shape group in the archive. Used only when no exact-bucket rule matches; acts as a coarse safety net between the strict rule lookup and Helion default." + }, + "rules": [ + { + "kernel_class": "attention", + "shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "seq_bin": "<=2048" + }, + "match_exact_only": true, + "source": { + "source_rule_index": 8, + "source_shape_bucket": { + "batch_heads_bin": "<=64", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "seq_bin": "<=2048" + }, + "source_workloads": [ + "attention_2k_d64" + ], + "measurement": "h55_combined_strict_opus47_seed_validation_20260509_005646" + }, + "validation": { + "h55_seed_over_baseline_geo": 0.784, + "h55_repeat_range": [ + 0.767, + 0.793 + ] + }, + "templates": [ + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "l2_groupings": [ + 8 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "l2_groupings": [ + 16 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "attention", + "shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=128", + "seq_bin": "<=2048" + }, + "match_exact_only": true, + "source": { + "source_rule_index": 9, + "source_shape_bucket": { + "batch_heads_bin": "<=64", + "dtype": "fp16_bf16", + "head_dim_bin": "<=128", + "seq_bin": "<=2048" + }, + "source_workloads": [ + "attention_2k_d128" + ], + "measurement": "h55_combined_strict_opus47_seed_validation_20260509_005646" + }, + "validation": { + "h55_seed_over_baseline_geo": 0.731, + "h55_repeat_range": [ + 0.728, + 0.736 + ] + }, + "templates": [ + { + "template": { + "block_sizes": [ + 1, + 128, + 64 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "l2_groupings": [ + 1 + ], + "range_unroll_factors": [ + 0, + 1 + ], + "range_warp_specializes": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_multi_buffers": [ + null, + null + ], + "range_flattens": [ + null, + true + ], + "load_eviction_policies": [ + "first", + "", + "first" + ], + "num_warps": 4, + "num_stages": 3, + "indexing": [ + "pointer", + "pointer", + "tensor_descriptor", + "tensor_descriptor" + ], + "atomic_indexing": [], + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 64 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 64 + ], + "l2_groupings": [ + 8 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 64 + ], + "l2_groupings": [ + 16 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 64 + ], + "l2_groupings": [ + 16 + ], + "num_sm_multiplier": 2, + "num_stages": 4, + "num_warps": 4, + "pid_type": "persistent_interleaved" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 64 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "l2_groupings": [ + 1 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_multi_buffers": [ + null, + null + ], + "range_flattens": [ + null, + null + ], + "load_eviction_policies": [ + "", + "", + "" + ], + "num_sm_multiplier": 2, + "num_warps": 4, + "num_stages": 3, + "indexing": [ + "pointer", + "pointer", + "pointer", + "pointer" + ], + "atomic_indexing": [], + "pid_type": "persistent_interleaved" + } + } + ] + }, + { + "kernel_class": "attention", + "shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "seq_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_rule_index": 11, + "source_shape_bucket": { + "batch_heads_bin": "<=128", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "seq_bin": "<=4096" + }, + "source_workloads": [ + "attention_4k_d64" + ], + "measurement": "h55_combined_strict_opus47_seed_validation_20260509_005646" + }, + "validation": { + "h55_seed_over_baseline_geo": 0.764, + "h55_repeat_range": [ + 0.734, + 0.805 + ] + }, + "templates": [ + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "l2_groupings": [ + 32 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "l2_groupings": [ + 16 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "l2_groupings": [ + 32 + ], + "num_sm_multiplier": 2, + "num_stages": 3, + "num_warps": 4, + "pid_type": "persistent_blocked" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "l2_groupings": [ + 16 + ], + "num_sm_multiplier": 2, + "num_stages": 3, + "num_warps": 4, + "pid_type": "persistent_blocked" + } + } + ] + }, + { + "kernel_class": "attention", + "shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "seq_bin": "<=8192" + }, + "match_exact_only": true, + "source": { + "source_rule_index": 6, + "source_shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "seq_bin": "<=8192" + }, + "source_workloads": [ + "attention_8k_d64_bh8" + ], + "measurement": "h55_combined_strict_opus47_seed_validation_20260509_005646" + }, + "validation": { + "h55_seed_over_baseline_geo": 0.718, + "h55_repeat_range": [ + 0.706, + 0.738 + ] + }, + "templates": [ + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "l2_groupings": [ + 8 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_multi_buffers": [ + null, + null + ], + "range_flattens": [ + null, + null + ], + "load_eviction_policies": [ + "", + "", + "" + ], + "num_sm_multiplier": 2, + "num_warps": 4, + "num_stages": 3, + "indexing": [ + "pointer", + "pointer", + "pointer", + "pointer" + ], + "atomic_indexing": [], + "pid_type": "persistent_interleaved" + } + }, + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "l2_groupings": [ + 8 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_multi_buffers": [ + null, + null + ], + "range_flattens": [ + null, + null + ], + "load_eviction_policies": [ + "", + "", + "" + ], + "num_warps": 4, + "num_stages": 3, + "indexing": [ + "pointer", + "pointer", + "pointer", + "pointer" + ], + "atomic_indexing": [], + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "attention", + "shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "seq_bin": "<=16384" + }, + "match_exact_only": true, + "source": { + "source_rule_index": 7, + "source_shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "seq_bin": "<=16384" + }, + "source_workloads": [ + "attention_16k_d64_bh4" + ], + "measurement": "h55_combined_strict_opus47_seed_validation_20260509_005646" + }, + "validation": { + "h55_seed_over_baseline_geo": 0.797, + "h55_repeat_range": [ + 0.761, + 0.815 + ] + }, + "templates": [ + { + "template": { + "block_sizes": [ + 1, + 128, + 128 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "l2_groupings": [ + 8 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_multi_buffers": [ + null, + null + ], + "range_flattens": [ + null, + null + ], + "load_eviction_policies": [ + "", + "", + "" + ], + "num_warps": 4, + "num_stages": 3, + "indexing": [ + "pointer", + "pointer", + "pointer", + "pointer" + ], + "atomic_indexing": [], + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "attention", + "shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "kv_seq_bin": "<=16384", + "seq_bin": "<=16" + }, + "match_exact_only": true, + "source": { + "source_rule_index": 23, + "source_shape_bucket": { + "batch_heads_bin": "<=32", + "dtype": "fp16_bf16", + "head_dim_bin": "<=64", + "kv_seq_bin": "<=16384", + "seq_bin": "<=16" + }, + "source_workloads": [ + "attention_decode_q16_k16k_d64_bh4" + ], + "measurement": "h55_combined_strict_opus47_seed_validation_20260509_005646" + }, + "validation": { + "h55_seed_over_baseline_geo": 0.69, + "h55_repeat_range": [ + 0.682, + 0.694 + ] + }, + "templates": [ + { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 1, + 1, + 512 + ], + "indexing": [ + "pointer", + "tensor_descriptor", + "pointer", + "pointer" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "last", + "first" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 2, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 1 + ], + "range_warp_specializes": [ + null, + null + ] + } + } + ] + }, + { + "kernel_class": "matmul", + "shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=32768", + "m_bin": "<=8", + "n_bin": ">4096" + }, + "match_exact_only": true, + "source": { + "source_rule_index": 21, + "source_shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=32768", + "m_bin": "<=8", + "n_bin": ">4096" + }, + "source_workloads": [ + "matmul_gemv_m8_8k" + ], + "measurement": "h55_combined_strict_opus47_seed_validation_20260509_005646" + }, + "validation": { + "h55_seed_over_baseline_geo": 0.62, + "h55_repeat_range": [ + 0.618, + 0.622 + ] + }, + "templates": [ + { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 8, + 64, + 256 + ], + "indexing": [ + "pointer", + "pointer", + "pointer" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 5, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + } + } + ] + }, + { + "kernel_class": "matmul", + "shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=32768", + "m_bin": "<=4", + "n_bin": ">4096" + }, + "match_exact_only": true, + "source": { + "source_rule_index": 22, + "source_shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=32768", + "m_bin": "<=4", + "n_bin": ">4096" + }, + "source_workloads": [ + "matmul_gemv_m4_k16k_n8k" + ], + "measurement": "h55_combined_strict_opus47_seed_validation_20260509_005646" + }, + "validation": { + "h55_seed_over_baseline_geo": 0.701, + "h55_repeat_range": [ + 0.637, + 0.821 + ] + }, + "templates": [ + { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 8, + 64, + 256 + ], + "indexing": [ + "pointer", + "pointer", + "pointer" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 5, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + } + } + ] + }, + { + "kernel_class": "row_softmax", + "shape_bucket": { + "cols_bin": "<=8192", + "dtype": "fp16_bf16", + "rows_bin": "<=2048" + }, + "match_exact_only": true, + "source": { + "source_shape_bucket": { + "cols_bin": "<=8192", + "dtype": "fp16_bf16", + "rows_bin": "<=4096" + }, + "source_workloads": [ + "softmax_2k_8k" + ], + "measurement": "h56_clean_policy_opus47_seed_validation_20260509_101149", + "confidence": "useful_but_below_20_percent_target" + }, + "validation": { + "h56_seed_over_baseline_geo": 0.839, + "h56_repeat_range": [ + 0.765, + 1.004 + ] + }, + "templates": [ + { + "geomean_slowdown": 1.0046, + "p90_slowdown": 1.0221, + "shape_coverage": 64, + "template": { + "block_sizes": [ + 1 + ], + "num_stages": 7, + "num_warps": 8, + "pid_type": "flat", + "reduction_loops": [ + null + ] + }, + "win_count": 24 + }, + { + "geomean_slowdown": 1.0057, + "p90_slowdown": 1.0278, + "shape_coverage": 64, + "template": { + "block_sizes": [ + 1 + ], + "num_stages": 8, + "num_warps": 8, + "pid_type": "flat", + "reduction_loops": [ + null + ] + }, + "win_count": 13 + }, + { + "geomean_slowdown": 1.0058, + "p90_slowdown": 1.026, + "shape_coverage": 64, + "template": { + "block_sizes": [ + 1 + ], + "num_stages": 2, + "num_warps": 8, + "pid_type": "flat", + "reduction_loops": [ + null + ] + }, + "win_count": 3 + } + ] + }, + { + "kernel_class": "matmul_fp4", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=4096", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 5, + "live_validation_workloads": [ + "FP4_004", + "FP4_005", + "FP4_006", + "FP4_007", + "FP4_008" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.12, + "q5_repeat_range": [ + 0.102, + 0.166 + ], + "q5_n_pairs": 15, + "q6_exp2_seed_over_baseline_geo": 0.46, + "q6_repeat_range": [ + 0.339, + 0.688 + ], + "q6_n_pairs": 15 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 8, + 128, + 128 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 5, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 16, + 128, + 256 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 4, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 8, + 256, + 128 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 7, + "num_warps": 4, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_fp4", + "shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=1024", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "FP4_010" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.205, + "q5_repeat_range": [ + 0.205, + 0.205 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.706, + "q6_repeat_range": [ + 0.628, + 0.764 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 16, + 128, + 256 + ], + "l2_groupings": [ + 2 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_fp4", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=1024", + "m_bin": "<=1024", + "n_bin": "<=1024" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "FP4_003" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.238, + "q5_repeat_range": [ + 0.238, + 0.238 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.628, + "q6_repeat_range": [ + 0.614, + 0.658 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 64, + 128, + 32 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 1, + "num_warps": 8, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 128, + 128, + 32 + ], + "l2_groupings": [ + 4 + ], + "num_stages": 1, + "num_warps": 8, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_fp4", + "shape_bucket": { + "aspect": "skinny_n", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=4096", + "n_bin": "<=128" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "FP4_012" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.282, + "q5_repeat_range": [ + 0.28, + 0.286 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.921, + "q6_repeat_range": [ + 0.851, + 0.958 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 64, + 64, + 16 + ], + "l2_groupings": [ + 4 + ], + "num_stages": 3, + "num_warps": 2, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 128, + 128, + 32 + ], + "l2_groupings": [ + 1 + ], + "num_sm_multiplier": 1, + "num_stages": 1, + "num_warps": 8, + "pid_type": "persistent_blocked" + } + } + ] + }, + { + "kernel_class": "matmul_fp4", + "shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=128", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "FP4_011" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.403, + "q5_repeat_range": [ + 0.403, + 0.403 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.728, + "q6_repeat_range": [ + 0.686, + 0.787 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 64, + 64, + 16 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 2, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_fp4", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=512", + "m_bin": "<=512", + "n_bin": "<=512" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "FP4_002" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.52, + "q5_repeat_range": [ + 0.52, + 0.52 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.793, + "q6_repeat_range": [ + 0.793, + 0.793 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 64, + 64, + 16 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 1, + "num_warps": 8, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 128, + 128, + 16 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 1, + "num_warps": 8, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_fp4", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=256", + "m_bin": "<=256", + "n_bin": "<=256" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "FP4_001" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.648, + "q5_repeat_range": [ + 0.647, + 0.649 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.807, + "q6_repeat_range": [ + 0.732, + 0.849 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 128, + 32, + 16 + ], + "l2_groupings": [ + 1 + ], + "num_sm_multiplier": 1, + "num_stages": 3, + "num_warps": 4, + "pid_type": "persistent_blocked" + } + } + ] + }, + { + "kernel_class": "matmul_int16", + "shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=1024", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "I16_010" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.058, + "q5_repeat_range": [ + 0.058, + 0.058 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.813, + "q6_repeat_range": [ + 0.813, + 0.813 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 256, + 128, + 64 + ], + "l2_groupings": [ + 8 + ], + "num_sm_multiplier": 1, + "num_stages": 3, + "num_warps": 8, + "pid_type": "persistent_blocked" + } + } + ] + }, + { + "kernel_class": "matmul_int16", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=4096", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 5, + "live_validation_workloads": [ + "I16_004", + "I16_005", + "I16_006", + "I16_007", + "I16_008" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.069, + "q5_repeat_range": [ + 0.044, + 0.171 + ], + "q5_n_pairs": 15, + "q6_exp2_seed_over_baseline_geo": 0.847, + "q6_repeat_range": [ + 0.694, + 1.001 + ], + "q6_n_pairs": 15 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 256, + 128, + 64 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 128, + 128, + 64 + ], + "l2_groupings": [ + 4 + ], + "num_sm_multiplier": 1, + "num_stages": 4, + "num_warps": 8, + "pid_type": "persistent_blocked" + } + }, + { + "template": { + "block_sizes": [ + 256, + 256, + 64 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_int16", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=1024", + "m_bin": "<=1024", + "n_bin": "<=1024" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "I16_003" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.187, + "q5_repeat_range": [ + 0.187, + 0.187 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.962, + "q6_repeat_range": [ + 0.895, + 0.998 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 128, + 64, + 64 + ], + "l2_groupings": [ + 4 + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 128, + 32, + 128 + ], + "l2_groupings": [ + 1 + ], + "num_sm_multiplier": 1, + "num_stages": 4, + "num_warps": 8, + "pid_type": "persistent_interleaved" + } + } + ] + }, + { + "kernel_class": "matmul_int16", + "shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=128", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "I16_011" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.2, + "q5_repeat_range": [ + 0.2, + 0.2 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.669, + "q6_repeat_range": [ + 0.587, + 0.81 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 64, + 32, + 128 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 4, + "num_warps": 8, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_int16", + "shape_bucket": { + "aspect": "skinny_n", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=4096", + "n_bin": "<=128" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "I16_012" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.264, + "q5_repeat_range": [ + 0.264, + 0.264 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.912, + "q6_repeat_range": [ + 0.703, + 1.239 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 128, + 32, + 128 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 4, + "num_warps": 2, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 64, + 32, + 256 + ], + "l2_groupings": [ + 16 + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_int16", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=512", + "m_bin": "<=512", + "n_bin": "<=512" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "I16_002" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.36, + "q5_repeat_range": [ + 0.36, + 0.36 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.693, + "q6_repeat_range": [ + 0.692, + 0.693 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 32, + 32, + 128 + ], + "l2_groupings": [ + 1 + ], + "num_sm_multiplier": 1, + "num_stages": 3, + "num_warps": 8, + "pid_type": "persistent_blocked" + } + }, + { + "template": { + "block_sizes": [ + 32, + 64, + 128 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 4, + "num_warps": 4, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_int16", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=256", + "m_bin": "<=256", + "n_bin": "<=256" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "I16_001" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.474, + "q5_repeat_range": [ + 0.472, + 0.475 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.686, + "q6_repeat_range": [ + 0.64, + 0.782 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 16, + 32, + 256 + ], + "l2_groupings": [ + 2 + ], + "num_sm_multiplier": 1, + "num_stages": 1, + "num_warps": 8, + "pid_type": "persistent_blocked" + } + } + ] + }, + { + "kernel_class": "matmul_int4", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=4096", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 5, + "live_validation_workloads": [ + "I4_004", + "I4_005", + "I4_006", + "I4_007", + "I4_008" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.065, + "q5_repeat_range": [ + 0.052, + 0.098 + ], + "q5_n_pairs": 15, + "q6_exp2_seed_over_baseline_geo": 0.391, + "q6_repeat_range": [ + 0.284, + 0.577 + ], + "q6_n_pairs": 15 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 16, + 128, + 128 + ], + "l2_groupings": [ + 2 + ], + "num_stages": 6, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 16, + 128, + 128 + ], + "l2_groupings": [ + 2 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 16, + 128, + 256 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 8, + "num_warps": 8, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_int4", + "shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=1024", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "I4_010" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.116, + "q5_repeat_range": [ + 0.116, + 0.116 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.66, + "q6_repeat_range": [ + 0.573, + 0.758 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 16, + 256, + 128 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 2, + "num_warps": 8, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_int4", + "shape_bucket": { + "aspect": "skinny_n", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=4096", + "n_bin": "<=128" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "I4_012" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.232, + "q5_repeat_range": [ + 0.232, + 0.232 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.964, + "q6_repeat_range": [ + 0.947, + 0.972 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 64, + 64, + 16 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 2, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 64, + 64, + 32 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 2, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_int4", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=1024", + "m_bin": "<=1024", + "n_bin": "<=1024" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "I4_003" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.282, + "q5_repeat_range": [ + 0.281, + 0.284 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.958, + "q6_repeat_range": [ + 0.744, + 1.181 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 16, + 128, + 64 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat" + } + }, + { + "template": { + "block_sizes": [ + 64, + 64, + 64 + ], + "l2_groupings": [ + 16 + ], + "num_sm_multiplier": 1, + "num_stages": 3, + "num_warps": 8, + "pid_type": "persistent_interleaved" + } + } + ] + }, + { + "kernel_class": "matmul_int4", + "shape_bucket": { + "aspect": "skinny_m", + "dtype": "fp16_bf16", + "k_bin": "<=4096", + "m_bin": "<=128", + "n_bin": "<=4096" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "I4_011" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.367, + "q5_repeat_range": [ + 0.367, + 0.367 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.826, + "q6_repeat_range": [ + 0.645, + 0.935 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 128, + 64, + 32 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat" + } + } + ] + }, + { + "kernel_class": "matmul_int4", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=256", + "m_bin": "<=256", + "n_bin": "<=256" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 1, + "live_validation_workloads": [ + "I4_001" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.601, + "q5_repeat_range": [ + 0.601, + 0.601 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 0.819, + "q6_repeat_range": [ + 0.819, + 0.819 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 128, + 32, + 16 + ], + "l2_groupings": [ + 1 + ], + "num_sm_multiplier": 1, + "num_stages": 2, + "num_warps": 4, + "pid_type": "persistent_interleaved" + } + } + ] + }, + { + "kernel_class": "matmul_int4", + "shape_bucket": { + "aspect": "balanced", + "dtype": "fp16_bf16", + "k_bin": "<=512", + "m_bin": "<=512", + "n_bin": "<=512" + }, + "match_exact_only": true, + "source": { + "source_measurement": "20260509_q2_llmseeded", + "source_n_archive_shapes": 2, + "live_validation_workloads": [ + "I4_002" + ] + }, + "validation": { + "q5_exp1_seed_over_baseline_geo": 0.629, + "q5_repeat_range": [ + 0.628, + 0.63 + ], + "q5_n_pairs": 3, + "q6_exp2_seed_over_baseline_geo": 1.001, + "q6_repeat_range": [ + 1.0, + 1.002 + ], + "q6_n_pairs": 3 + }, + "templates": [ + { + "template": { + "block_sizes": [ + 64, + 32, + 32 + ], + "l2_groupings": [ + 1 + ], + "num_sm_multiplier": 1, + "num_stages": 3, + "num_warps": 4, + "pid_type": "persistent_interleaved" + } + }, + { + "template": { + "block_sizes": [ + 32, + 32, + 32 + ], + "l2_groupings": [ + 1 + ], + "num_stages": 3, + "num_warps": 2, + "pid_type": "flat" + } + } + ] + } + ], + "fallbacks": { + "matmul_int4": { + "small_m": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 128, + 16, + 32 + ], + "indexing": [ + "pointer", + "pointer", + "pointer" + ], + "l2_groupings": [ + 4 + ], + "load_eviction_policies": [ + "last", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 5, + "num_warps": 4, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 16, + "K": 4096, + "N": 4096 + }, + "archive_perf_ms": 0.027648000046610832, + "n_archive_shapes": 9 + } + }, + "balanced": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 64, + 128, + 64 + ], + "indexing": [ + "tensor_descriptor", + "pointer", + "tensor_descriptor" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + false + ], + "range_num_stages": [ + 0, + 2 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 1024, + "K": 1024, + "N": 1024 + }, + "archive_perf_ms": 0.02969600073993206, + "n_archive_shapes": 7 + } + }, + "small_n": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 64, + 64, + 16 + ], + "indexing": [ + "pointer", + "tensor_descriptor", + "pointer" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 2, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 3 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 4096, + "K": 4096, + "N": 32 + }, + "archive_perf_ms": 0.03379200026392937, + "n_archive_shapes": 8 + } + }, + "small_k": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 16, + 128, + 256 + ], + "indexing": [ + "pointer", + "tensor_descriptor", + "tensor_descriptor" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 1, + 0 + ] + ], + "num_stages": 4, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + true + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 2 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 2048, + "K": 256, + "N": 2048 + }, + "archive_perf_ms": 0.01945599913597107, + "n_archive_shapes": 6 + } + }, + "rect": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 16, + 128, + 128 + ], + "indexing": [ + "pointer", + "tensor_descriptor", + "tensor_descriptor" + ], + "l2_groupings": [ + 2 + ], + "load_eviction_policies": [ + "", + "first" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 3 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 1536, + "K": 3072, + "N": 1536 + }, + "archive_perf_ms": 0.07280000299215317, + "n_archive_shapes": 10 + } + } + }, + "matmul_int16": { + "small_m": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 64, + 32, + 256 + ], + "indexing": [ + "pointer", + "pointer", + "pointer" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + false + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 128, + "K": 2048, + "N": 2048 + }, + "archive_perf_ms": 0.01539199985563755, + "n_archive_shapes": 9 + } + }, + "balanced": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 128, + 64, + 64 + ], + "indexing": [ + "pointer", + "pointer", + "tensor_descriptor" + ], + "l2_groupings": [ + 4 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 1024, + "K": 1024, + "N": 1024 + }, + "archive_perf_ms": 0.015359999611973763, + "n_archive_shapes": 7 + } + }, + "small_n": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 128, + 32, + 128 + ], + "indexing": [ + "pointer", + "pointer", + "tensor_descriptor" + ], + "l2_groupings": [ + 2 + ], + "load_eviction_policies": [ + "last", + "first" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 2048, + "K": 2048, + "N": 256 + }, + "archive_perf_ms": 0.019392000511288643, + "n_archive_shapes": 8 + } + }, + "small_k": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 256, + 128, + 64 + ], + "indexing": [ + "pointer", + "pointer", + "tensor_descriptor" + ], + "l2_groupings": [ + 8 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "maxnreg": 256, + "num_sm_multiplier": 1, + "num_stages": 3, + "num_warps": 8, + "pid_type": "persistent_blocked", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 2048, + "K": 256, + "N": 2048 + }, + "archive_perf_ms": 0.013279999606311321, + "n_archive_shapes": 6 + } + }, + "rect": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 256, + 256, + 64 + ], + "indexing": [ + "pointer", + "pointer", + "tensor_descriptor" + ], + "l2_groupings": [ + 8 + ], + "load_eviction_policies": [ + "last", + "last" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + false + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 4096, + "K": 512, + "N": 4096 + }, + "archive_perf_ms": 0.03788800165057182, + "n_archive_shapes": 10 + } + } + }, + "matmul_fp4": { + "small_m": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 64, + 64, + 16 + ], + "indexing": [ + "pointer", + "tensor_descriptor", + "tensor_descriptor" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "first" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 2, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 128, + "K": 2048, + "N": 2048 + }, + "archive_perf_ms": 0.035840000957250595, + "n_archive_shapes": 9 + } + }, + "balanced": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 64, + 128, + 32 + ], + "indexing": [ + "pointer", + "pointer", + "tensor_descriptor" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "last", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 1, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + true + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + false + ] + }, + "source": { + "archive_shape": { + "M": 1024, + "K": 1024, + "N": 1024 + }, + "archive_perf_ms": 0.035840000957250595, + "n_archive_shapes": 7 + } + }, + "small_n": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 64, + 128, + 16 + ], + "indexing": [ + "pointer", + "pointer", + "pointer" + ], + "l2_groupings": [ + 4 + ], + "load_eviction_policies": [ + "last", + "last" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 8, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 1 + ], + "range_unroll_factors": [ + 0, + 0 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 2048, + "K": 2048, + "N": 256 + }, + "archive_perf_ms": 0.04608000069856644, + "n_archive_shapes": 8 + } + }, + "small_k": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 8, + 128, + 128 + ], + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor" + ], + "l2_groupings": [ + 1 + ], + "load_eviction_policies": [ + "", + "first" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 3, + "num_warps": 4, + "pid_type": "flat", + "range_flattens": [ + null, + null + ], + "range_multi_buffers": [ + null, + null + ], + "range_num_stages": [ + 0, + 0 + ], + "range_unroll_factors": [ + 0, + 1 + ], + "range_warp_specializes": [ + null, + null + ] + }, + "source": { + "archive_shape": { + "M": 2048, + "K": 256, + "N": 2048 + }, + "archive_perf_ms": 0.025599999353289604, + "n_archive_shapes": 6 + } + }, + "rect": { + "template": { + "atomic_indexing": [], + "block_sizes": [ + 8, + 256, + 128 + ], + "indexing": [ + "tensor_descriptor", + "pointer", + "tensor_descriptor" + ], + "l2_groupings": [ + 16 + ], + "load_eviction_policies": [ + "", + "" + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "num_stages": 6, + "num_warps": 4, + "pid_type": "flat", + "range_flattens": [ + null, + false + ], + "range_multi_buffers": [ + null, + true + ], + "range_num_stages": [ + 0, + 4 + ], + "range_unroll_factors": [ + 0, + 3 + ], + "range_warp_specializes": [ + null, + false + ] + }, + "source": { + "archive_shape": { + "M": 2048, + "K": 1024, + "N": 4096 + }, + "archive_perf_ms": 0.11363200098276138, + "n_archive_shapes": 10 + } + } + } + } +} diff --git a/helion/autotuner/finite_search.py b/helion/autotuner/finite_search.py index 32faa584c2..e4aa8c5899 100644 --- a/helion/autotuner/finite_search.py +++ b/helion/autotuner/finite_search.py @@ -4,6 +4,7 @@ from .. import exc from .base_search import BaseSearch +from .observed_heuristics import observed_heuristic_seed_configs_for_kernel if TYPE_CHECKING: from collections.abc import Callable @@ -71,8 +72,15 @@ def __init__( if max_configs is not None else self.settings.autotune_best_available_max_configs ) + observed = observed_heuristic_seed_configs_for_kernel( + kernel, + args, + config_spec=self.config_spec, + max_configs=cap, + ) cached: list[Config] = [] - for i, entry in enumerate(self._find_similar_cached_configs(cap)): + cached_cap = max(0, cap - len(observed)) + for i, entry in enumerate(self._find_similar_cached_configs(cached_cap)): try: cached.append(self.config_gen.unflatten(entry.to_mutable_flat_config())) except ( @@ -83,8 +91,11 @@ def __init__( exc.InvalidConfig, ) as e: self.log(f"from_cache: failed to transfer cached config {i + 1}: {e}") - self.log(f"from_cache: resolved {len(cached)} cached config(s) (cap={cap})") - self.configs: list[Config] = [*cached, *list(configs)] + self.log( + f"from_cache: resolved {len(observed)} observed config(s) " + f"and {len(cached)} cached config(s) (cap={cap})" + ) + self.configs: list[Config] = [*observed, *cached, *list(configs)] if len(self.configs) < 2: raise exc.NotEnoughConfigs(len(self.configs)) diff --git a/helion/autotuner/llm/configs.py b/helion/autotuner/llm/configs.py index 7dae110364..9c0843905c 100644 --- a/helion/autotuner/llm/configs.py +++ b/helion/autotuner/llm/configs.py @@ -6,6 +6,7 @@ from typing import Any from typing import cast +from ..config_validation import validate_sparse_config_shape from .parsing import parse_jsonish if TYPE_CHECKING: @@ -15,53 +16,6 @@ from ..logger import AutotuningLogger -def is_positive_power_of_two_int(val: object) -> bool: - """Return whether `val` is a strictly positive power-of-two integer.""" - return type(val) is int and val > 0 and (val & (val - 1)) == 0 - - -def _expected_json_array_length(field: object) -> int | None: - """Return the required JSON array length for fixed-length sequence fields.""" - from ..block_id_sequence import BlockIdSequence - from ..config_fragment import ListOf - - if isinstance(field, BlockIdSequence): - return len(field) - if isinstance(field, ListOf): - return field.length - return None - - -def _validate_json_array_length(key: str, val: object, *, expected_len: int) -> None: - """Validate that a JSON value is a list with the expected fixed length.""" - if not isinstance(val, list): - raise ValueError( - f"{key} must be a JSON array of length {expected_len}, got {val!r}" - ) - if len(val) != expected_len: - raise ValueError(f"{key} must have length {expected_len}, got {len(val)}") - - -def validate_sparse_config_shape( - raw: dict[str, object], *, config_spec: ConfigSpec -) -> None: - """Reject common LLM shape/type mismatches instead of silently repairing them.""" - flat_fields = config_spec._flat_fields() - for key, val in raw.items(): - field = flat_fields.get(key) - if key == "num_warps" and not is_positive_power_of_two_int(val): - raise ValueError( - f"num_warps must be a positive power-of-two integer, got {val!r}" - ) - - if (expected_len := _expected_json_array_length(field)) is not None: - _validate_json_array_length(key, val, expected_len=expected_len) - continue - - if field is not None and isinstance(val, list): - raise ValueError(f"{key} must be a scalar value, got {val!r}") - - def parse_response_configs( response: str, *, diff --git a/helion/autotuner/llm/prompting.py b/helion/autotuner/llm/prompting.py index eca7989a13..4f4f359def 100644 --- a/helion/autotuner/llm/prompting.py +++ b/helion/autotuner/llm/prompting.py @@ -5,12 +5,12 @@ import textwrap from typing import TYPE_CHECKING +from ..workload import detect_workload_traits from .configs import describe_config_space from .feedback import MAX_CHANGED_FIELDS_PER_CONFIG from .feedback import format_config_for_prompt from .workload import compute_workload_hints from .workload import describe_kernel -from .workload import detect_workload_traits if TYPE_CHECKING: from collections.abc import Mapping diff --git a/helion/autotuner/llm/workload.py b/helion/autotuner/llm/workload.py index 07ef5a8eb6..8638c7ef4d 100644 --- a/helion/autotuner/llm/workload.py +++ b/helion/autotuner/llm/workload.py @@ -1,56 +1,19 @@ -"""Infer workload traits and render kernel context for LLM prompts.""" +"""Render workload context for LLM prompts.""" from __future__ import annotations -import inspect -import textwrap from typing import TYPE_CHECKING import torch from ..._compat import get_device_name from ..._compat import num_compute_units +from ..workload import kernel_source_text if TYPE_CHECKING: - from collections.abc import Iterator from collections.abc import Sequence from ..base_search import _AutotunableKernel - from ..config_spec import ConfigSpec - - -MATMUL_TARGETS = frozenset( - { - torch.matmul, - torch.ops.aten.mm.default, - torch.ops.aten.addmm.default, - torch.ops.aten.bmm.default, - torch.ops.aten.baddbmm.default, - } -) -MATMUL_API_NAMES = frozenset({"dot", "dot_scaled"}) -BATCH_MATMUL_TARGET_NAMES = frozenset({"bmm", "baddbmm"}) -REDUCTION_TARGET_NAMES = frozenset({"amax", "sum", "softmax", "logsumexp"}) -EXP_TARGET_NAMES = frozenset({"exp", "exp2"}) - - -def _kernel_source_text(kernel: _AutotunableKernel) -> str: - """Extract the underlying kernel source when it is available.""" - try: - inner_kernel = getattr(kernel, "kernel", None) - if inner_kernel is None or not hasattr(inner_kernel, "fn"): - return "# Source unavailable" - raw_source = inspect.getsource(inner_kernel.fn) - except (OSError, TypeError): - return "# Source unavailable" - - source_lines = textwrap.dedent(raw_source).splitlines() - start_idx = 0 - while start_idx < len(source_lines) and not source_lines[ - start_idx - ].lstrip().startswith("def "): - start_idx += 1 - return "\n".join(source_lines[start_idx:]) def _tensor_args(args: Sequence[object]) -> list[torch.Tensor]: @@ -91,7 +54,7 @@ def _gpu_hardware_lines(device: torch.device) -> list[str]: def describe_kernel(kernel: _AutotunableKernel, args: Sequence[object]) -> str: """Build a description of the kernel, its inputs, and the target GPU.""" - parts = [f"## Kernel Source Code\n```python\n{_kernel_source_text(kernel)}\n```"] + parts = [f"## Kernel Source Code\n```python\n{kernel_source_text(kernel)}\n```"] if tensor_lines := _input_tensor_lines(args): parts.append("## Input Tensors\n" + "\n".join(tensor_lines)) @@ -102,69 +65,6 @@ def describe_kernel(kernel: _AutotunableKernel, args: Sequence[object]) -> str: return "\n\n".join(parts) -def _target_name_parts(target: object) -> frozenset[str]: - """Extract coarse name tokens for a traced call target.""" - parts: set[str] = set() - for raw in ( - getattr(target, "__name__", None), - getattr(target, "name", None), - str(target), - ): - if not isinstance(raw, str): - continue - parts.add(raw) - parts.update(piece for piece in raw.split(".") if piece) - return frozenset(parts) - - -def _iter_call_targets(kernel: _AutotunableKernel) -> Iterator[object]: - """Yield traced call targets from compiler-generated FX graphs.""" - host_function = getattr(kernel, "host_function", None) - device_ir = getattr(host_function, "device_ir", None) - for graph_info in getattr(device_ir, "graphs", ()): - graph = getattr(graph_info, "graph", None) - if not isinstance(graph, torch.fx.Graph): - continue - for node in graph.nodes: - if node.op == "call_function": - yield node.target - - -def detect_workload_traits( - kernel: _AutotunableKernel | None, - *, - config_spec: ConfigSpec | None = None, -) -> frozenset[str]: - """Infer coarse workload traits from compiler-traced graphs.""" - if kernel is None: - return frozenset() - - saw_matmul = False - saw_batched_matmul = False - saw_reduction = bool(config_spec is not None and config_spec.reduction_loops) - saw_exp = False - - for target in _iter_call_targets(kernel): - name_parts = _target_name_parts(target) - if target in MATMUL_TARGETS or name_parts & MATMUL_API_NAMES: - saw_matmul = True - if name_parts & BATCH_MATMUL_TARGET_NAMES: - saw_batched_matmul = True - if name_parts & REDUCTION_TARGET_NAMES: - saw_reduction = True - if name_parts & EXP_TARGET_NAMES: - saw_exp = True - - traits: set[str] = set() - if saw_matmul: - traits.add("matmul") - if saw_reduction: - traits.add("reduction") - if saw_matmul and saw_reduction and (saw_batched_matmul or saw_exp): - traits.add("attention_reduction") - return frozenset(traits) - - def _summary_hints( tensors: Sequence[torch.Tensor], *, diff --git a/helion/autotuner/llm_search.py b/helion/autotuner/llm_search.py index de65cb8735..5caa26b810 100644 --- a/helion/autotuner/llm_search.py +++ b/helion/autotuner/llm_search.py @@ -5,8 +5,8 @@ config so the first LLM call sees both the workload description and the available tuning knobs. 2. Round 0 launches the first LLM call immediately, then benchmarks the - default config plus a few random seed configs while that request is in - flight. + default config plus observed heuristic and random seed configs while that + request is in flight. 3. When the round-0 LLM response arrives, the search benchmarks its new unique configs and folds those results into the running set of top configs. 4. The top configs are then rebenchmarked before the next prompt is built, so each @@ -60,6 +60,7 @@ from .llm.transport import DEFAULT_REQUEST_TIMEOUT_S from .llm.transport import call_provider as _call_provider from .llm.transport import infer_provider as _infer_provider +from .observed_heuristics import observed_heuristic_seed_configs_for_kernel if TYPE_CHECKING: from collections.abc import Iterator @@ -364,10 +365,33 @@ def _initialize_prompt_state(self) -> None: ] def _build_seed_configs(self) -> list[Config]: - """Build the initial benchmark set: default plus a few random seeds.""" - # Start from default and add only distinct random configs that unflatten cleanly. - seed_configs: list[Config] = [self.config_spec.default_config()] - seen_config_keys = {self._config_key(seed_configs[0])} + """Build the initial benchmark set: observed/default seed, then random seeds.""" + # Exact observed heuristic matches replace the default slot. Unsupported + # kernels keep the old default-plus-random seed behavior. + target_count = 1 + self.initial_random_configs + seed_configs: list[Config] = [] + seen_config_keys: set[str] = set() + + heuristic_configs = observed_heuristic_seed_configs_for_kernel( + self.kernel, + self.args, + config_spec=self.config_spec, + max_configs=target_count, + ) + for cfg in heuristic_configs: + key = self._config_key(cfg) + if key in seen_config_keys: + continue + seen_config_keys.add(key) + seed_configs.append(cfg) + if len(seed_configs) >= target_count: + return seed_configs + + if not seed_configs: + default_config = self.config_spec.default_config() + seed_configs.append(default_config) + seen_config_keys.add(self._config_key(default_config)) + for flat in self.config_gen.random_population_flat( self.initial_random_configs + 1 )[1:]: @@ -380,13 +404,15 @@ def _build_seed_configs(self) -> list[Config]: continue seen_config_keys.add(key) seed_configs.append(cfg) + if len(seed_configs) >= target_count: + break return seed_configs def _dedupe_new_configs( self, configs: list[Config], seen_config_keys: set[str] ) -> list[Config]: """Filter out configs that have already been seen in earlier rounds.""" - # Drop configs that were already benchmarked or queued in prior rounds. + # Drop configs that were already benchmarked or queued in earlier rounds. new_configs: list[Config] = [] for cfg in configs: key = self._config_key(cfg) @@ -531,14 +557,12 @@ def _update_early_stop_state(self, state: _SearchLoopState) -> bool: def _run_initial_round(self, state: _SearchLoopState) -> None: """Run round 0 by overlapping the initial LLM request with seed benchmarking.""" # Launch the first request before benchmarking because round 0 does not need - # any prior search feedback to build its prompt. + # any search feedback to build its prompt. seed_configs = self._build_seed_configs() state.seen_config_keys.update(self._config_key(cfg) for cfg in seed_configs) - self.log( f"Round 0: starting initial LLM call while benchmarking " - f"{len(seed_configs)} seed configs (1 default + " - f"{max(0, len(seed_configs) - 1)} random)" + f"{len(seed_configs)} seed configs" ) llm_future: concurrent.futures.Future[str] | None = None @@ -546,7 +570,8 @@ def _run_initial_round(self, state: _SearchLoopState) -> None: llm_future = self._call_llm_async(self._build_llm_messages()) except Exception: self.log.warning( - "Round 0: could not start initial LLM call, continuing with seed configs" + "Round 0: could not start initial LLM call, continuing with " + "seed configs" ) if seed_configs: @@ -573,7 +598,7 @@ def _run_initial_round(self, state: _SearchLoopState) -> None: def _run_refinement_round(self, round_num: int, state: _SearchLoopState) -> bool: """Run one post-seed refinement round and report whether search should stop.""" - # Build the next prompt from the stabilized prior round, then benchmark new configs. + # Build the next prompt from the stabilized search round, then benchmark new configs. prompt = self._build_refinement_prompt(round_num) try: llm_response = self._call_llm(self._build_llm_messages(prompt)) diff --git a/helion/autotuner/observed_heuristics.py b/helion/autotuner/observed_heuristics.py new file mode 100644 index 0000000000..e570a07b12 --- /dev/null +++ b/helion/autotuner/observed_heuristics.py @@ -0,0 +1,751 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +"""Data-derived autotune heuristics. + +The seed templates in this module are not hand-authored configs. They are read +from ``helion/autotuner/data/observed_heuristics_b200.json``, which is generated +from measured AOT CSV data. Runtime code only classifies the active kernel into +a workload class and compact shape bucket, then looks up validated structural +templates for that exact bucket. +""" + +from __future__ import annotations + +import functools +import json +import os +from pathlib import Path +from typing import TYPE_CHECKING + +from .config_validation import validate_sparse_config_shape +from .workload import detect_workload_traits +from .workload import kernel_source_text + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..runtime.config import Config + from .base_search import _AutotunableKernel + from .config_spec import ConfigSpec + + +OBSERVED_HEURISTICS_ENV = "HELION_AUTOTUNE_OBSERVED_HEURISTICS" +OBSERVED_HEURISTIC_SEEDS_ENV = "HELION_AUTOTUNE_OBSERVED_HEURISTIC_SEEDS" +OBSERVED_HEURISTICS_PATH_ENV = "HELION_AUTOTUNE_OBSERVED_HEURISTICS_PATH" +OBSERVED_HEURISTIC_MAX_TEMPLATES_ENV = ( + "HELION_AUTOTUNE_OBSERVED_HEURISTIC_MAX_TEMPLATES" +) +OBSERVED_HEURISTIC_DISABLED_CLASSES_ENV = ( + "HELION_AUTOTUNE_OBSERVED_HEURISTIC_DISABLED_CLASSES" +) +_RUNTIME_HEURISTICS_PATH = ( + Path(__file__).resolve().parent / "data" / "observed_heuristics_b200.json" +) + + +def _env_value(name: str) -> str | None: + return os.environ.get(name) + + +def _env_flag_enabled(name: str, *, default: bool) -> bool: + value = _env_value(name) + if value is None: + return default + return value.lower() in {"1", "true", "yes", "on"} + + +def observed_heuristics_enabled() -> bool: + """Return whether data-derived autotune heuristics should be used.""" + return _env_flag_enabled( + OBSERVED_HEURISTICS_ENV, + default=True, + ) + + +def observed_heuristic_seeds_enabled() -> bool: + """Return whether data-derived first-round seed configs should be used.""" + return observed_heuristics_enabled() and _env_flag_enabled( + OBSERVED_HEURISTIC_SEEDS_ENV, + default=True, + ) + + +@functools.cache +def _runtime_heuristics() -> dict[str, object]: + """Load the observed-heuristic JSON once per process.""" + path = Path(_env_value(OBSERVED_HEURISTICS_PATH_ENV) or _RUNTIME_HEURISTICS_PATH) + if not path.exists(): + return {"rules": [], "class_summary": {}} + with path.open(encoding="utf-8") as handle: + data = json.load(handle) + if not isinstance(data, dict): + return {"rules": [], "class_summary": {}} + return data + + +def _disabled_kernel_classes() -> set[str]: + value = ( + _env_value(OBSERVED_HEURISTIC_DISABLED_CLASSES_ENV) or "" + ) + return {item.strip() for item in value.split(",") if item.strip()} + + +def _max_observed_templates() -> int | None: + value = _env_value(OBSERVED_HEURISTIC_MAX_TEMPLATES_ENV) + if value is None: + return None + try: + parsed = int(value) + except ValueError: + return None + if parsed <= 0: + return None + return parsed + + +def _stable_json(value: object) -> str: + return json.dumps(value, sort_keys=True, separators=(",", ":")) + + +def _float_field(raw: dict[str, object], name: str, default: float) -> float: + try: + return float(raw.get(name, default)) + except (TypeError, ValueError): + return default + + +def _int_field(raw: dict[str, object], name: str, default: int) -> int: + try: + return int(raw.get(name, default)) + except (TypeError, ValueError): + return default + + +def _template_rank_key( + raw_template: dict[str, object], +) -> tuple[int, int, float, float]: + return ( + -_int_field(raw_template, "win_count", 0), + -_int_field(raw_template, "shape_coverage", 0), + _float_field(raw_template, "geomean_slowdown", float("inf")), + _float_field(raw_template, "p90_slowdown", float("inf")), + ) + + +def _has_template_rank_metadata(raw_template: dict[str, object]) -> bool: + return any( + name in raw_template + for name in ( + "win_count", + "shape_coverage", + "geomean_slowdown", + "p90_slowdown", + ) + ) + + +def _selected_templates(rule: dict[str, object]) -> list[dict[str, object]]: + templates = rule.get("templates", []) + if not isinstance(templates, list): + return [] + valid_templates = [ + raw_template for raw_template in templates if isinstance(raw_template, dict) + ] + if any(_has_template_rank_metadata(raw_template) for raw_template in valid_templates): + ranked = [ + raw_template + for _index, raw_template in sorted( + enumerate(valid_templates), + key=lambda item: (_template_rank_key(item[1]), item[0]), + ) + ] + else: + ranked = valid_templates + limit = _max_observed_templates() + if limit is not None: + ranked = ranked[:limit] + return ranked + + +def _default_block_rank(config_spec: ConfigSpec) -> int: + default = dict(config_spec.default_config()) + block_sizes = default.get("block_sizes") + if isinstance(block_sizes, list): + return len(block_sizes) + return 0 + + +def _has_reduction(config_spec: ConfigSpec) -> bool: + return bool(config_spec.reduction_loops) + + +def _flat_fields(config_spec: ConfigSpec) -> dict[str, object]: + return dict(config_spec._flat_fields()) + + +def _tensor_shapes(args: Sequence[object]) -> list[tuple[int, ...]]: + shapes: list[tuple[int, ...]] = [] + for arg in args: + shape = getattr(arg, "shape", None) + if shape is None: + continue + try: + shapes.append(tuple(int(dim) for dim in shape)) + except (TypeError, ValueError): + continue + return shapes + + +def _tensor_dtypes(args: Sequence[object]) -> list[str]: + dtypes: list[str] = [] + for arg in args: + dtype = getattr(arg, "dtype", None) + if dtype is not None: + dtypes.append(str(dtype)) + return dtypes + + +def _dtype_family(args: Sequence[object]) -> str: + dtypes = _tensor_dtypes(args) + dtype = dtypes[0] if dtypes else "unknown" + if "float8" in dtype: + return "fp8" + if "float16" in dtype or "bfloat16" in dtype: + return "fp16_bf16" + if "float32" in dtype: + return "fp32" + if "int" in dtype: + return "int" + return "other" + + +def _bin_le(value: int | None, bins: Sequence[int]) -> str: + if value is None: + return "unknown" + for bound in bins: + if value <= bound: + return f"<={bound}" + return f">{bins[-1]}" + + +def _numel(shape: Sequence[int]) -> int: + result = 1 + for dim in shape: + result *= dim + return result + + +def _is_fp8_matmul(args: Sequence[object]) -> bool: + return any("float8" in dtype for dtype in _tensor_dtypes(args)[:2]) + + +_QUANTIZED_KERNEL_FINGERPRINTS: tuple[tuple[str, frozenset[str]], ...] = ( + ("matmul_fp4", frozenset({"e2m1", "fp4", "nvfp4"})), + ("matmul_int4", frozenset({"int4", "pack_int4", "unpack_int4"})), +) + + +def _infer_quantized_matmul_class( + kernel: _AutotunableKernel | None, + args: Sequence[object], +) -> str | None: + """Return a quantized-matmul sub-class if the kernel is one of int4/int16/fp4. + + int16 is unambiguous from arg1 dtype. int4 vs fp4 share the packed-int8 + signature, so we fall back to scanning the kernel source for the + distinguishing term. Order matters: check fp4 before int4 (the fp4 source + can mention "int4" only in comments that refer to nibbles). + """ + dtypes = _tensor_dtypes(args) + if len(dtypes) < 2: + return None + if "int16" in dtypes[1]: + return "matmul_int16" + if "int8" not in dtypes[1]: + return None + # Packed-int8 shape signature: arg1 dim0 is K//2 of arg0 dim1. + shapes = _tensor_shapes(args) + if len(shapes) < 2 or len(shapes[0]) < 2 or len(shapes[1]) < 2: + return None + if shapes[0][-1] != shapes[1][-2] * 2: + return None + if kernel is None: + return None + try: + source = kernel_source_text(kernel).lower() + except Exception: # noqa: BLE001 + return None + for class_name, markers in _QUANTIZED_KERNEL_FINGERPRINTS: + if any(marker in source for marker in markers): + return class_name + return None + + +def _matmul_shape(shapes: Sequence[tuple[int, ...]]) -> tuple[int, int, int] | None: + if len(shapes) < 2 or len(shapes[0]) < 2 or len(shapes[1]) < 2: + return None + return (shapes[0][-2], shapes[1][-1], shapes[0][-1]) + + +def _row_shape(shapes: Sequence[tuple[int, ...]]) -> tuple[int | None, int | None]: + if not shapes or len(shapes[0]) < 2: + return None, None + rows = 1 + for dim in shapes[0][:-1]: + rows *= dim + return rows, shapes[0][-1] + + +def _attention_shape( + shapes: Sequence[tuple[int, ...]], +) -> tuple[int | None, int | None, int | None, int | None] | None: + if not shapes or len(shapes[0]) < 4: + return None + batch_heads = shapes[0][0] * shapes[0][1] + q_seq = shapes[0][-2] + kv_seq = shapes[1][-2] if len(shapes) > 1 and len(shapes[1]) >= 4 else q_seq + return batch_heads, q_seq, kv_seq, shapes[0][-1] + + +def _aspect_bucket(m: int | None, n: int | None, k: int | None) -> str: + values = [value for value in (m, n, k) if value is not None and value > 0] + if len(values) != 3: + return "unknown" + min_dim = min(values) + max_dim = max(values) + if max_dim / min_dim < 4: + return "balanced" + if m == min_dim: + return "skinny_m" + if n == min_dim: + return "skinny_n" + return "skinny_k" + + +def _rank1_reduction_class( + shapes: Sequence[tuple[int, ...]], + dtypes: Sequence[str], + workload_traits: frozenset[str], +) -> str | None: + if "cross_entropy" in workload_traits: + return "row_cross_entropy" + if ( + len(shapes) >= 2 + and len(dtypes) >= 2 + and dtypes[1] + in { + "torch.int64", + "torch.int32", + } + ): + return "row_cross_entropy" + + cols = shapes[0][-1] if shapes and shapes[0] else None + + def rank1_matches_cols(shape: tuple[int, ...]) -> bool: + return cols is not None and len(shape) == 1 and shape[0] == cols + + if ( + len(shapes) >= 3 + and rank1_matches_cols(shapes[1]) + and rank1_matches_cols(shapes[2]) + ): + return "row_norm_layer" + if len(shapes) >= 2 and rank1_matches_cols(shapes[1]): + return "row_norm_rms" + if "softmax" in workload_traits or ( + "exp" in workload_traits and "sum_reduction" in workload_traits + ): + return "row_softmax" + return None + + +def _looks_dense_elementwise(shapes: Sequence[tuple[int, ...]]) -> bool: + if not shapes: + return False + output_shape = shapes[0] + if len(shapes) == 1: + return True + return all(shape == output_shape for shape in shapes[1:]) + + +def classify_runtime_kernel( + args: Sequence[object], + *, + workload_traits: frozenset[str], + config_spec: ConfigSpec, + kernel: _AutotunableKernel | None = None, +) -> str | None: + """Classify the runtime kernel into the CSV-derived workload taxonomy. + + ``kernel`` is optional and used only for fingerprinting specific quantized + matmul variants (int4 vs fp4, which are otherwise indistinguishable from + arg shapes/dtypes alone). Classification works without it for every class + except ``matmul_int4``/``matmul_fp4``, which will fall back to ``matmul``. + """ + block_rank = _default_block_rank(config_spec) + shapes = _tensor_shapes(args) + if "attention_reduction" in workload_traits and block_rank == 3: + return "attention" + if "matmul" in workload_traits and "split_k" in _flat_fields(config_spec): + return "split_k_matmul" + if "matmul" in workload_traits and block_rank == 4: + if len(shapes) >= 2 and len(shapes[0]) == 3 and len(shapes[1]) == 3: + return "batched_matmul" + return "grouped_matmul" + if "matmul" in workload_traits and block_rank == 3: + if _is_fp8_matmul(args): + return "matmul_fp8" + quantized = _infer_quantized_matmul_class(kernel, args) + if quantized is not None: + return quantized + return "matmul" + # Quantized matmuls (int4/fp4) don't call hl.dot directly — they do a + # manual outer-product + sum-reduction over unpacked weights — so they + # emit {"reduction", "sum_reduction"} instead of {"matmul"}. Classify by + # shape signature + source fingerprint. + if ( + block_rank == 3 + and "sum_reduction" in workload_traits + and "reduction" in workload_traits + ): + quantized = _infer_quantized_matmul_class(kernel, args) + if quantized is not None: + return quantized + if _has_reduction(config_spec) and block_rank == 1: + return _rank1_reduction_class(shapes, _tensor_dtypes(args), workload_traits) + if ( + not _has_reduction(config_spec) + and block_rank == 1 + and _looks_dense_elementwise(shapes) + ): + return "elementwise" + return None + + +def _shape_bucket_for_class( + kernel_class: str, args: Sequence[object] +) -> dict[str, object]: + shapes = _tensor_shapes(args) + dtype_family = _dtype_family(args) + if kernel_class == "attention": + attention_shape = _attention_shape(shapes) + batch_heads, q_seq, kv_seq, head_dim = ( + attention_shape + if attention_shape is not None + else (None, None, None, None) + ) + bucket = { + "batch_heads_bin": _bin_le(batch_heads, [32, 64, 128, 256]), + "dtype": dtype_family, + "head_dim_bin": _bin_le(head_dim, [64, 128, 256]), + "seq_bin": _bin_le(q_seq, [1, 16, 1024, 2048, 4096, 8192, 16384]), + } + if q_seq != kv_seq: + bucket["kv_seq_bin"] = _bin_le(kv_seq, [1024, 2048, 4096, 8192, 16384]) + return bucket + if kernel_class in { + "matmul", + "matmul_fp8", + "grouped_matmul", + "matmul_int4", + "matmul_int16", + "matmul_fp4", + }: + matmul_shape = _matmul_shape(shapes) + m, n, k = matmul_shape if matmul_shape is not None else (None, None, None) + return { + "aspect": _aspect_bucket(m, n, k), + "dtype": dtype_family, + "k_bin": _bin_le(k, [64, 128, 256, 512, 1024, 4096, 32768]), + "m_bin": _bin_le(m, [4, 8, 16, 64, 128, 256, 512, 1024, 4096]), + "n_bin": _bin_le(n, [64, 128, 256, 512, 1024, 4096]), + } + if kernel_class.startswith("row_"): + rows, cols = _row_shape(shapes) + return { + "cols_bin": _bin_le(cols, [512, 1024, 2048, 4096, 8192, 16384, 32768]), + "dtype": dtype_family, + "rows_bin": _bin_le(rows, [512, 2048, 4096, 16384, 65536, 262144]), + } + if kernel_class == "elementwise": + numel = _numel(shapes[0]) if shapes else None + return { + "dtype": dtype_family, + "numel_bin": _bin_le(numel, [4096, 65536, 1048576, 16777216, 134217728]), + } + return {"dtype": dtype_family} + + +def _fallback_group_for_class( + kernel_class: str, args: Sequence[object] +) -> str | None: + """Coarse shape-group label for fallback lookup. + + Used only when exact-bucket rule lookup misses. The grouping is + deliberately much coarser than ``_shape_bucket_for_class`` — + enough partitions to capture the dominant config-shape + correlation (skinny axes, balanced-vs-rect), not fine enough + to need many archive shapes per group. + + Returns None for kernel classes without a defined grouping. The + lookup will simply skip fallbacks in that case (safe default). + """ + shapes = _tensor_shapes(args) + if kernel_class in { + "matmul", + "matmul_fp8", + "grouped_matmul", + "matmul_int4", + "matmul_int16", + "matmul_fp4", + }: + matmul_shape = _matmul_shape(shapes) + if matmul_shape is None: + return None + m, n, k = matmul_shape + if m is None or n is None or k is None: + return None + if m <= 256: + return "small_m" + if n <= 256: + return "small_n" + if k <= 256: + return "small_k" + dims = [m, n, k] + if max(dims) / max(1, min(dims)) < 2: + return "balanced" + return "rect" + if kernel_class.startswith("row_"): + rows, cols = _row_shape(shapes) + if rows is None or cols is None: + return None + if rows <= 512: + return "short" + if cols <= 1024: + return "narrow" + if cols >= 8192: + return "wide" + return "square" + if kernel_class == "elementwise": + numel = _numel(shapes[0]) if shapes else None + if numel is None: + return None + if numel <= 65536: + return "tiny" + if numel <= 1048576: + return "mid" + return "huge" + if kernel_class == "attention": + attention_shape = _attention_shape(shapes) + if attention_shape is None: + return None + batch_heads, q_seq, _kv_seq, head_dim = attention_shape + if q_seq is None or head_dim is None: + return None + if q_seq <= 1024: + return "short_seq" + if q_seq >= 8192: + return "long_seq" + if head_dim is not None and head_dim <= 64: + return "small_head" + return "mid_seq" + return None + + +@functools.cache +def _rules_by_key() -> dict[str, dict[str, object]]: + data = _runtime_heuristics() + raw_rules = data.get("rules", []) + if not isinstance(raw_rules, list): + return {} + rules: dict[str, dict[str, object]] = {} + for raw_rule in raw_rules: + if not isinstance(raw_rule, dict): + continue + kernel_class = raw_rule.get("kernel_class") + shape_bucket = raw_rule.get("shape_bucket") + if not isinstance(kernel_class, str) or not isinstance(shape_bucket, dict): + continue + rules[f"{kernel_class}:{_stable_json(shape_bucket)}"] = raw_rule + return rules + + +def _find_rule( + kernel_class: str, + shape_bucket: dict[str, object], +) -> dict[str, object] | None: + if kernel_class in _disabled_kernel_classes(): + return None + return _rules_by_key().get(f"{kernel_class}:{_stable_json(shape_bucket)}") + + +@functools.cache +def _fallbacks_by_class() -> dict[str, dict[str, dict[str, object]]]: + """Return the ``fallbacks`` map from the loaded JSON, or empty.""" + data = _runtime_heuristics() + raw = data.get("fallbacks", {}) + if not isinstance(raw, dict): + return {} + out: dict[str, dict[str, dict[str, object]]] = {} + for kernel_class, group_map in raw.items(): + if not isinstance(kernel_class, str) or not isinstance(group_map, dict): + continue + clean_groups: dict[str, dict[str, object]] = {} + for group, entry in group_map.items(): + if isinstance(group, str) and isinstance(entry, dict): + clean_groups[group] = entry + if clean_groups: + out[kernel_class] = clean_groups + return out + + +def _find_fallback( + kernel_class: str, + group: str | None, +) -> dict[str, object] | None: + """Look up the fallback entry for a (kernel_class, group) pair. + + Returns the raw entry (same shape as a rule's ``templates[i]``) + or None if no fallback is defined. + """ + if group is None: + return None + if kernel_class in _disabled_kernel_classes(): + return None + return _fallbacks_by_class().get(kernel_class, {}).get(group) + + +def _matched_rule( + args: Sequence[object], + *, + workload_traits: frozenset[str], + config_spec: ConfigSpec, + kernel: _AutotunableKernel | None = None, +) -> tuple[str | None, dict[str, object], dict[str, object] | None]: + kernel_class = classify_runtime_kernel( + args, + workload_traits=workload_traits, + config_spec=config_spec, + kernel=kernel, + ) + if kernel_class is None: + return None, {}, None + shape_bucket = _shape_bucket_for_class(kernel_class, args) + return kernel_class, shape_bucket, _find_rule(kernel_class, shape_bucket) + + +def _supported_sparse_config( + raw: dict[str, object], + *, + config_spec: ConfigSpec, +) -> dict[str, object]: + flat_fields = _flat_fields(config_spec) + supported = {key: value for key, value in raw.items() if key in flat_fields} + if ( + "pid_type" in supported + and supported["pid_type"] not in config_spec.allowed_pid_types + ): + supported.pop("pid_type") + validate_sparse_config_shape(supported, config_spec=config_spec) + return supported + + +def _materialize_config( + raw: dict[str, object], + *, + config_spec: ConfigSpec, +) -> Config: + import helion + + supported = _supported_sparse_config(raw, config_spec=config_spec) + merged = dict(config_spec.default_config()) + merged.update(supported) + config_spec.normalize(merged, _fix_invalid=True) + return helion.Config(**merged) + + +def observed_heuristic_seed_configs( + args: Sequence[object], + *, + workload_traits: frozenset[str], + config_spec: ConfigSpec, + max_configs: int, + kernel: _AutotunableKernel | None = None, +) -> list[Config]: + """Return valid CSV-derived seed configs for this config space.""" + if max_configs <= 0 or not observed_heuristic_seeds_enabled(): + return [] + + kernel_class, _shape_bucket, rule = _matched_rule( + args, + workload_traits=workload_traits, + config_spec=config_spec, + kernel=kernel, + ) + templates: list[dict[str, object]] + if rule is not None: + templates = _selected_templates(rule) + elif kernel_class is not None: + # Exact-bucket lookup missed — try the per-kernel-class fallback. + group = _fallback_group_for_class(kernel_class, args) + fallback_entry = _find_fallback(kernel_class, group) + if fallback_entry is None: + return [] + templates = [fallback_entry] + else: + return [] + + seeds: list[Config] = [] + seen: set[str] = set() + for raw_template in templates: + if not isinstance(raw_template, dict): + continue + raw = raw_template.get("template") + if not isinstance(raw, dict): + continue + config = _materialize_config(raw, config_spec=config_spec) + key = repr(config) + if key in seen: + continue + seen.add(key) + seeds.append(config) + if len(seeds) >= max_configs: + break + return seeds + + +def observed_heuristic_seed_configs_for_kernel( + kernel: _AutotunableKernel | None, + args: Sequence[object], + *, + config_spec: ConfigSpec, + max_configs: int, +) -> list[Config]: + """Return observed seed configs after inferring workload traits from a kernel.""" + if max_configs <= 0 or not observed_heuristic_seeds_enabled(): + return [] + + return observed_heuristic_seed_configs( + args, + workload_traits=detect_workload_traits(kernel, config_spec=config_spec), + config_spec=config_spec, + max_configs=max_configs, + kernel=kernel, + ) + + +def observed_heuristic_default_config( + kernel: _AutotunableKernel | None, + args: Sequence[object], + *, + config_spec: ConfigSpec, +) -> Config | None: + """Return the first observed seed config for no-autotune/default execution.""" + configs = observed_heuristic_seed_configs_for_kernel( + kernel, + args, + config_spec=config_spec, + max_configs=1, + ) + return configs[0] if configs else None diff --git a/helion/autotuner/random_search.py b/helion/autotuner/random_search.py index 942debf51d..cc3c19eb56 100644 --- a/helion/autotuner/random_search.py +++ b/helion/autotuner/random_search.py @@ -4,11 +4,13 @@ from .effort_profile import RANDOM_SEARCH_DEFAULTS from .finite_search import FiniteSearch +from .observed_heuristics import observed_heuristic_seed_configs_for_kernel if TYPE_CHECKING: from collections.abc import Sequence from ..autotuner.effort_profile import AutotuneEffortProfile + from ..runtime.config import Config from .base_search import _AutotunableKernel from helion.runtime.settings import Settings @@ -35,14 +37,42 @@ def __init__( args: Sequence[object], count: int = RANDOM_SEARCH_DEFAULTS.count, ) -> None: + config_gen = kernel.config_spec.create_config_generation( + overrides=kernel.settings.autotune_config_overrides or None, + advanced_controls_files=kernel.settings.autotune_search_acf or None, + process_group_name=kernel.env.process_group_name, + ) + random_configs = config_gen.random_population(count) + seed_configs = observed_heuristic_seed_configs_for_kernel( + kernel, + args, + config_spec=kernel.config_spec, + max_configs=count, + ) + configs: list[Config] = [] + seen: set[Config] = set() + leading_configs = seed_configs or random_configs[:1] + # Keep the requested population size stable: observed seeds replace the + # default slot first, and random configs fill any remaining slots. + for config in [*leading_configs, *random_configs[1:]]: + if config in seen: + continue + seen.add(config) + configs.append(config) + if len(configs) >= count: + break + attempts = 0 + while len(configs) < count and attempts < 64: + attempts += 1 + config = config_gen.unflatten(config_gen.random_flat()) + if config in seen: + continue + seen.add(config) + configs.append(config) super().__init__( kernel, args, - configs=kernel.config_spec.create_config_generation( - overrides=kernel.settings.autotune_config_overrides or None, - advanced_controls_files=kernel.settings.autotune_search_acf or None, - process_group_name=kernel.env.process_group_name, - ).random_population(count), + configs=configs, ) @classmethod diff --git a/helion/autotuner/workload.py b/helion/autotuner/workload.py new file mode 100644 index 0000000000..a037a0e125 --- /dev/null +++ b/helion/autotuner/workload.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +"""Infer coarse workload traits for autotune policy selection.""" + +from __future__ import annotations + +import inspect +import textwrap +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from collections.abc import Iterator + + from .base_search import _AutotunableKernel + from .config_spec import ConfigSpec + + +MATMUL_TARGETS = frozenset( + { + torch.matmul, + torch.ops.aten.mm.default, + torch.ops.aten.addmm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.baddbmm.default, + } +) +MATMUL_API_NAMES = frozenset({"dot", "dot_scaled"}) +BATCH_MATMUL_TARGET_NAMES = frozenset({"bmm", "baddbmm"}) +REDUCTION_TARGET_NAMES = frozenset({"amax", "sum", "softmax", "logsumexp"}) +EXP_TARGET_NAMES = frozenset({"exp", "exp2"}) +SOFTMAX_TARGET_NAMES = frozenset({"softmax", "_softmax"}) +LOGSUMEXP_TARGET_NAMES = frozenset({"logsumexp"}) +SUM_TARGET_NAMES = frozenset({"sum"}) + + +def kernel_source_text(kernel: _AutotunableKernel) -> str: + """Extract the underlying kernel source when it is available.""" + try: + inner_kernel = getattr(kernel, "kernel", None) + if inner_kernel is None or not hasattr(inner_kernel, "fn"): + return "# Source unavailable" + raw_source = inspect.getsource(inner_kernel.fn) + except (OSError, TypeError): + return "# Source unavailable" + + source_lines = textwrap.dedent(raw_source).splitlines() + start_idx = 0 + while start_idx < len(source_lines) and not source_lines[ + start_idx + ].lstrip().startswith("def "): + start_idx += 1 + return "\n".join(source_lines[start_idx:]) + + +def _target_name_parts(target: object) -> frozenset[str]: + """Extract coarse name tokens for a traced call target.""" + parts: set[str] = set() + for raw in ( + getattr(target, "__name__", None), + getattr(target, "name", None), + str(target), + ): + if not isinstance(raw, str): + continue + parts.add(raw) + parts.update(piece for piece in raw.split(".") if piece) + return frozenset(parts) + + +def _iter_call_targets(kernel: _AutotunableKernel) -> Iterator[object]: + """Yield traced call targets from compiler-generated FX graphs.""" + host_function = getattr(kernel, "host_function", None) + device_ir = getattr(host_function, "device_ir", None) + for graph_info in getattr(device_ir, "graphs", ()): + graph = getattr(graph_info, "graph", None) + if not isinstance(graph, torch.fx.Graph): + continue + for node in graph.nodes: + if node.op == "call_function": + yield node.target + + +def detect_workload_traits( + kernel: _AutotunableKernel | None, + *, + config_spec: ConfigSpec | None = None, +) -> frozenset[str]: + """Infer coarse workload traits from compiler-traced graphs.""" + if kernel is None: + return frozenset() + + saw_matmul = False + saw_batched_matmul = False + saw_reduction = bool(config_spec is not None and config_spec.reduction_loops) + saw_exp = False + saw_softmax = False + saw_logsumexp = False + saw_sum = False + + for target in _iter_call_targets(kernel): + name_parts = _target_name_parts(target) + if target in MATMUL_TARGETS or name_parts & MATMUL_API_NAMES: + saw_matmul = True + if name_parts & BATCH_MATMUL_TARGET_NAMES: + saw_batched_matmul = True + if name_parts & REDUCTION_TARGET_NAMES: + saw_reduction = True + if name_parts & EXP_TARGET_NAMES: + saw_exp = True + if name_parts & SOFTMAX_TARGET_NAMES: + saw_softmax = True + if name_parts & LOGSUMEXP_TARGET_NAMES: + saw_logsumexp = True + if name_parts & SUM_TARGET_NAMES: + saw_sum = True + + traits: set[str] = set() + if saw_matmul: + traits.add("matmul") + if saw_reduction: + traits.add("reduction") + if saw_exp: + traits.add("exp") + if saw_softmax: + traits.add("softmax") + if saw_logsumexp: + traits.add("cross_entropy") + if saw_sum: + traits.add("sum_reduction") + if saw_matmul and saw_reduction and (saw_batched_matmul or saw_exp): + traits.add("attention_reduction") + return frozenset(traits) diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 630403df4d..ceaa66078e 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -51,6 +51,7 @@ from .._logging import LazyString from .._utils import counters from ..autotuner.base_search import _AutotunableKernel +from ..autotuner.observed_heuristics import observed_heuristic_default_config from ..language.constexpr import ConstExpr from .config import Config from .ref_mode import RefModeContext @@ -873,7 +874,22 @@ def stride_extractor( extractors.append(make_extractor(source)) return extractors - def _user_provided_config(self) -> Config | None: + def _observed_heuristic_config( + self, args: Sequence[object] | None = None + ) -> Config | None: + # Some codegen paths ask for an implicit config before real runtime args + # are available; fake_args carry the same specialized shape/dtype data. + runtime_args = args if args is not None else self.fake_args + + return observed_heuristic_default_config( + self, + runtime_args, + config_spec=self.config_spec, + ) + + def _user_provided_config( + self, args: Sequence[object] | None = None + ) -> Config | None: """Return a config if the user explicitly provided one, else None. Checks the kernel's config list and settings to determine if @@ -885,29 +901,35 @@ def _user_provided_config(self) -> Config | None: if len(configs) == 1: return configs[0] if len(configs) == 0 and self.kernel.settings.autotune_effort == "none": - config = self.config_spec.default_config() + config_source = "observed heuristic" + config = self._observed_heuristic_config(args) + if config is None: + config_source = "default" + config = self.config_spec.default_config() if not is_ref_mode_enabled(self.kernel.settings): kernel_decorator = self.format_kernel_decorator(config, self.settings) print( - f"Using default config: {kernel_decorator}", + f"Using {config_source} config: {kernel_decorator}", file=sys.stderr, ) return config return None - def _implicit_config(self) -> Config | None: + def _implicit_config(self, args: Sequence[object] | None = None) -> Config | None: """ Returns a single config that is implicitly used by this kernel, if any. """ if self._config is not None: return self._config - return self._user_provided_config() + return self._user_provided_config(args) - def _require_implicit_config(self) -> Config: + def _require_implicit_config( + self, args: Sequence[object] | None = None + ) -> Config: """ Returns the implicit config for this kernel, or raises an error if no implicit config is available. """ - if (config := self._implicit_config()) is None: + if (config := self._implicit_config(args)) is None: raise RuntimeError("no config provided and no implicit config available") return config @@ -920,7 +942,7 @@ def ensure_config_exists(self, args: Sequence[object]) -> None: """ if self._config is not None: return # Already have a config - if (config := self._implicit_config()) is not None: + if (config := self._implicit_config(args)) is not None: with measure("BoundKernel.set_config"): self.set_config(config) else: @@ -954,7 +976,7 @@ def __call__(self, *args: object) -> _R: """ if self._run is None: if is_ref_mode_enabled(self.kernel.settings): - if (config := self._implicit_config()) is not None: + if (config := self._implicit_config(args)) is not None: self._config = config return self.run_ref(*args) self.ensure_config_exists(args) diff --git a/pyproject.toml b/pyproject.toml index df52b6ec3e..d094287e50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,7 @@ packages = ["helion"] include = [ "helion/**/*.py", "helion/**/*.pyi", + "helion/**/*.json", "LICENSE", ] exclude = [ diff --git a/test/test_best_available.py b/test/test_best_available.py index 0261d3c537..4c5e741f6b 100644 --- a/test/test_best_available.py +++ b/test/test_best_available.py @@ -987,9 +987,13 @@ def _make_mock_search(self, config_gen, cached_configs): ) mock_search = MagicMock() mock_search.config_gen = config_gen + mock_search.config_spec = config_gen.config_spec + mock_search.kernel = MagicMock() + mock_search.args = () mock_search.settings = Settings() mock_search.log = MagicMock() mock_search.log.debug = MagicMock() + mock_search._best_available_seed_configs = [] mock_search._find_similar_cached_configs = MagicMock(return_value=entries) return mock_search @@ -1025,6 +1029,26 @@ def test_cached_configs_added(self): self.assertEqual(result[1][num_warps_idx], 8) self.assertEqual(result[2][num_warps_idx], 2) + def test_observed_heuristic_configs_replace_default_slot_before_cache(self): + """Observed heuristic seeds replace default and run before cache.""" + config_gen = self._make_config_gen() + observed = [Config(block_sizes=[32, 64], num_warps=8, num_stages=2)] + cached = [Config(block_sizes=[128, 256], num_warps=2, num_stages=4)] + mock_search = self._make_mock_search(config_gen, cached) + + with patch( + "helion.autotuner.base_search.observed_heuristic_seed_configs_for_kernel", + return_value=observed, + ): + result = PopulationBasedSearch._generate_best_available_population_flat( + mock_search + ) + + self.assertEqual(len(result), 2) + num_warps_idx = config_gen._key_to_flat_indices["num_warps"][0][0] + self.assertEqual(result[0][num_warps_idx], 8) + self.assertEqual(result[1][num_warps_idx], 2) + def test_duplicate_configs_deduplicated(self): """Duplicate cached configs are discarded.""" config_gen = self._make_config_gen() diff --git a/test/test_observed_heuristics.py b/test/test_observed_heuristics.py new file mode 100644 index 0000000000..c75d289030 --- /dev/null +++ b/test/test_observed_heuristics.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import os +from types import SimpleNamespace +from unittest.mock import patch + +import torch + +import helion +from helion.autotuner.config_fragment import EnumFragment +from helion.autotuner.config_fragment import IntegerFragment +from helion.autotuner.config_fragment import ListOf +from helion.autotuner.config_fragment import PowerOfTwoFragment +from helion.autotuner.observed_heuristics import classify_runtime_kernel +from helion.autotuner.observed_heuristics import observed_heuristic_seed_configs +from helion.runtime.settings import Settings + + +def _row_config_spec() -> SimpleNamespace: + return SimpleNamespace( + reduction_loops=[object()], + allowed_pid_types=("flat",), + default_config=lambda: helion.Config( + block_sizes=[1], + num_warps=4, + num_stages=1, + pid_type="flat", + ), + _flat_fields=lambda: { + "block_sizes": ListOf(IntegerFragment(1, 4096, 1), length=1), + "num_warps": PowerOfTwoFragment(1, 32, 4), + "num_stages": IntegerFragment(1, 8, 1), + "pid_type": EnumFragment(("flat",)), + }, + normalize=lambda raw, _fix_invalid=False: None, + ) + + +def test_observed_heuristic_seeds_are_default_on_and_disableable() -> None: + config_spec = _row_config_spec() + x = torch.empty((2048, 8192), dtype=torch.bfloat16) + + with patch.dict(os.environ, {}, clear=True): + assert observed_heuristic_seed_configs( + (x,), + workload_traits=frozenset({"reduction", "softmax"}), + config_spec=config_spec, + max_configs=3, + ) + + with patch.dict( + os.environ, {"HELION_AUTOTUNE_OBSERVED_HEURISTICS": "0"}, clear=True + ): + assert ( + observed_heuristic_seed_configs( + (x,), + workload_traits=frozenset({"reduction", "softmax"}), + config_spec=config_spec, + max_configs=3, + ) + == [] + ) + + +def test_observed_heuristics_generate_valid_matmul_seeds() -> None: + config_spec = SimpleNamespace( + reduction_loops=[], + allowed_pid_types=("flat",), + default_config=lambda: helion.Config( + block_sizes=[1, 1, 1], + l2_groupings=[1], + num_warps=4, + num_stages=1, + pid_type="flat", + ), + _flat_fields=lambda: { + "block_sizes": ListOf(IntegerFragment(1, 4096, 1), length=3), + "l2_groupings": ListOf(IntegerFragment(1, 64, 1), length=1), + "num_warps": PowerOfTwoFragment(1, 32, 4), + "num_stages": IntegerFragment(1, 8, 1), + "pid_type": EnumFragment(("flat",)), + }, + normalize=lambda raw, _fix_invalid=False: None, + ) + + x = torch.empty((4, 16384), dtype=torch.bfloat16) + y = torch.empty((16384, 8192), dtype=torch.bfloat16) + seeds = observed_heuristic_seed_configs( + (x, y), + workload_traits=frozenset({"matmul"}), + config_spec=config_spec, + max_configs=3, + ) + + assert len(seeds) == 1 + seed = dict(seeds[0]) + assert len(seed["block_sizes"]) == 3 + assert all(isinstance(block, int) and block > 0 for block in seed["block_sizes"]) + assert seed["num_warps"] in {1, 2, 4, 8, 16, 32} + assert 1 <= seed["num_stages"] <= 8 + + +def test_autotune_effort_none_uses_observed_heuristic_config() -> None: + from helion.runtime.kernel import BoundKernel + + observed_config = helion.Config(num_warps=8) + settings = Settings(autotune_effort="none") + config_spec = SimpleNamespace( + default_config=lambda: helion.Config(num_warps=4), + ) + bound = BoundKernel.__new__(BoundKernel) + bound.kernel = SimpleNamespace( + configs=[], + settings=settings, + ) + bound._env = SimpleNamespace(config_spec=config_spec) + bound.fake_args = [] + assert bound.env is bound._env + + with ( + patch( + "helion.runtime.kernel.observed_heuristic_default_config", + return_value=observed_config, + ), + patch("helion.runtime.kernel.is_ref_mode_enabled", return_value=True), + ): + config = BoundKernel._user_provided_config(bound, ()) + + assert config is observed_config + + +def test_row_reduction_classification_uses_runtime_structure() -> None: + config_spec = SimpleNamespace( + reduction_loops=[object()], + default_config=lambda: helion.Config(block_sizes=[1]), + _flat_fields=lambda: { + "block_sizes": ListOf(IntegerFragment(1, 4096, 1), length=1), + }, + ) + + x = torch.empty((2048, 8192), dtype=torch.bfloat16) + weight = torch.empty((8192,), dtype=torch.bfloat16) + bias = torch.empty((8192,), dtype=torch.bfloat16) + labels = torch.empty((2048,), dtype=torch.int64) + + assert ( + classify_runtime_kernel( + (x, weight, bias), + workload_traits=frozenset({"reduction"}), + config_spec=config_spec, + ) + == "row_norm_layer" + ) + assert ( + classify_runtime_kernel( + (x, weight), + workload_traits=frozenset({"reduction"}), + config_spec=config_spec, + ) + == "row_norm_rms" + ) + assert ( + classify_runtime_kernel( + (x, labels), + workload_traits=frozenset({"reduction", "exp", "sum_reduction"}), + config_spec=config_spec, + ) + == "row_cross_entropy" + ) + assert ( + classify_runtime_kernel( + (x,), + workload_traits=frozenset({"reduction", "exp", "sum_reduction"}), + config_spec=config_spec, + ) + == "row_softmax" + )