Skip to content

Commit 3f0ced4

Browse files
committed
make it easy to use
1 parent 92da26e commit 3f0ced4

File tree

4 files changed

+259
-40
lines changed

4 files changed

+259
-40
lines changed

helion/autotuner/aot_cache.py

Lines changed: 196 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454

5555
def get_aot_mode() -> AOTMode:
5656
"""Get the current AOT mode from environment."""
57-
mode = os.environ.get(AOT_MODE_ENV, "disabled").lower()
57+
mode = os.environ.get(AOT_MODE_ENV, "evaluate").lower()
5858
if mode in ("collect", "measure", "evaluate", "disabled"):
5959
return mode # type: ignore[return-value]
6060
raise ValueError(
@@ -101,6 +101,65 @@ def get_device_compute_id() -> tuple[str, str]:
101101
return ("cpu", platform.machine())
102102

103103

104+
# Known compute capabilities in descending order (newest first)
105+
# This allows fallback to older architectures when heuristics aren't available
106+
_CUDA_COMPUTE_CAPS: list[str] = [
107+
"sm100",
108+
"sm90",
109+
"sm89",
110+
"sm87",
111+
"sm86",
112+
"sm80",
113+
"sm75",
114+
"sm72",
115+
"sm70",
116+
]
117+
118+
_ROCM_ARCHS: list[str] = [
119+
"gfx950",
120+
"gfx942",
121+
"gfx941",
122+
"gfx940",
123+
"gfx90a",
124+
"gfx908",
125+
"gfx906",
126+
"gfx900",
127+
]
128+
129+
130+
def get_compatible_compute_ids(device_kind: str, compute_kind: str) -> list[str]:
131+
"""
132+
Get a list of compatible compute IDs for fallback, ordered from current to oldest.
133+
134+
For CUDA/ROCm, returns the current compute capability followed by all older
135+
compatible architectures. This allows using heuristics tuned on older hardware
136+
when newer hardware-specific heuristics aren't available.
137+
138+
Args:
139+
device_kind: 'cuda', 'rocm', or 'cpu'
140+
compute_kind: The current compute capability (e.g., 'sm90', 'gfx942')
141+
142+
Returns:
143+
List of compute IDs to try, starting with the exact match
144+
"""
145+
if device_kind == "cuda":
146+
arch_list = _CUDA_COMPUTE_CAPS
147+
elif device_kind == "rocm":
148+
arch_list = _ROCM_ARCHS
149+
else:
150+
# CPU or unknown - no fallback
151+
return [compute_kind]
152+
153+
# Find current architecture in the list
154+
try:
155+
current_idx = arch_list.index(compute_kind)
156+
# Return current and all older architectures
157+
return arch_list[current_idx:]
158+
except ValueError:
159+
# Unknown architecture - try it alone, then try all known ones
160+
return [compute_kind] + arch_list
161+
162+
104163
def get_heuristic_path_for_kernel(kernel_source_file: str | Path) -> Path:
105164
"""
106165
Get the path where heuristics should be stored for a kernel.
@@ -481,6 +540,21 @@ class AOTAutotuneCache(AutotuneCacheBase):
481540

482541
_mode_announced: set[str] = set() # Class-level to avoid repeated messages
483542

543+
# Class-level caches for heuristic lookup (shared across instances)
544+
# Maps heuristic file path -> loaded module
545+
_heuristic_modules: dict[Path, Any] = {}
546+
# Maps (kernel_source_file, kernel_name, shape_features_hash) -> Config
547+
# Using source file ensures kernels with same name in different modules don't collide
548+
_heuristic_results: dict[tuple[str, str, str], Config] = {}
549+
550+
@classmethod
551+
def clear_caches(cls) -> None:
552+
"""Clear all class-level caches (heuristic modules and results)."""
553+
cls._heuristic_modules.clear()
554+
cls._heuristic_results.clear()
555+
cls._mode_announced.clear()
556+
log.debug("Cleared AOTAutotuneCache caches")
557+
484558
def __init__(self, autotuner: BaseSearch) -> None:
485559
super().__init__(autotuner)
486560
self.mode = get_aot_mode()
@@ -516,11 +590,16 @@ def _create_shape_key(self) -> ShapeKey:
516590
hardware_id=self.hardware_id,
517591
)
518592

519-
def _extract_shape_features(self) -> dict[str, Any]:
593+
def _extract_shape_features(
594+
self, args: Sequence[object] | None = None
595+
) -> dict[str, Any]:
520596
"""Extract numeric features from the shape for ML model."""
597+
if args is None:
598+
args = self.args
599+
521600
features: dict[str, Any] = {}
522601

523-
for i, arg in enumerate(self.args):
602+
for i, arg in enumerate(args):
524603
if isinstance(arg, torch.Tensor):
525604
features[f"arg{i}_ndim"] = arg.ndim
526605
for j, size in enumerate(arg.shape):
@@ -709,85 +788,165 @@ def measure_all_configs(self) -> list[tuple[Config, float]]:
709788
)
710789
return results
711790

712-
def _get_heuristic_config(self) -> Config | None:
791+
def _find_heuristic_file(self) -> Path | None:
713792
"""
714-
Use the heuristic to select a config.
793+
Find the heuristic file for this kernel.
715794
716-
Search order for heuristic files:
795+
Search order:
717796
1. HELION_HEURISTIC_DIR env var (if set) - for comparing different heuristics
718797
2. Next to kernel source file: _<filename>_<device>_<compute>.py
719-
3. AOT data directory: heuristic_<kernel_name>.py (fallback)
798+
3. Fallback to older compute capabilities within the same device family
799+
4. AOT data directory: heuristic_<kernel_name>.py (fallback)
720800
"""
721801
kernel_name = self.kernel.kernel.name
722802

723803
# Get the kernel source file path
724804
kernel_source_file = self.kernel.kernel.__code__.co_filename
805+
source_path = Path(kernel_source_file)
806+
base_name = source_path.stem
807+
808+
# Get device info and compatible compute capabilities
809+
device_kind, compute_kind = get_device_compute_id()
810+
compatible_computes = get_compatible_compute_ids(device_kind, compute_kind)
725811

726812
# Build list of candidate heuristic files in priority order
727813
candidates: list[Path] = []
728814

729815
# 1. Check HELION_HEURISTIC_DIR override (for comparing heuristics)
730816
if (heuristic_dir := os.environ.get(HEURISTIC_DIR_ENV)) is not None:
731817
heuristic_dir_path = Path(heuristic_dir)
732-
# Look for the standard naming convention in override dir
733-
device_kind, compute_kind = get_device_compute_id()
734-
base_name = Path(kernel_source_file).stem
735-
candidates.append(
736-
heuristic_dir_path / f"_{base_name}_{device_kind}_{compute_kind}.py"
737-
)
818+
# Try each compatible compute capability in order
819+
for compat_compute in compatible_computes:
820+
candidates.append(
821+
heuristic_dir_path
822+
/ f"_{base_name}_{device_kind}_{compat_compute}.py"
823+
)
738824
# Also check kernel-specific file in override dir
739825
candidates.append(heuristic_dir_path / f"heuristic_{kernel_name}.py")
740826

741-
# 2. Check next to kernel source file: _<filename>_<device>_<compute>.py
742-
candidates.append(get_heuristic_path_for_kernel(kernel_source_file))
827+
# 2. Check next to kernel source file with compute capability fallback
828+
for compat_compute in compatible_computes:
829+
heuristic_name = f"_{base_name}_{device_kind}_{compat_compute}.py"
830+
candidates.append(source_path.parent / heuristic_name)
743831

744832
# 3. Check AOT data directory (fallback for backward compatibility)
745833
candidates.append(self.data_store.data_dir / f"heuristic_{kernel_name}.py")
746834
candidates.append(self.data_store.heuristic_file)
747835

748836
# Find first existing heuristic file
749-
heuristic_file = None
750837
for candidate in candidates:
751838
if candidate.exists():
752-
heuristic_file = candidate
753-
log.debug(f"Found heuristic file: {heuristic_file}")
754-
break
839+
log.debug(f"Found heuristic file: {candidate}")
840+
return candidate
841+
842+
log.debug(
843+
f"Heuristic file not found for {kernel_name}. Searched: "
844+
f"{[str(c) for c in candidates[:3]]}..."
845+
)
846+
return None
847+
848+
def _get_heuristic_config(
849+
self, args: Sequence[object] | None = None
850+
) -> Config | None:
851+
"""
852+
Use the heuristic to select a config.
853+
854+
Args:
855+
args: Optional arguments to use for shape feature extraction.
856+
If None, uses self.args.
755857
858+
For CUDA/ROCm, if heuristics for the current compute capability aren't found,
859+
we try older compatible architectures (e.g., sm80 heuristics on sm90 hardware).
860+
"""
861+
heuristic_file = self._find_heuristic_file()
756862
if heuristic_file is None:
757-
# Only warn in evaluate mode, not during normal operation
758-
if self.mode == "evaluate":
759-
log.warning(
760-
f"Heuristic file not found for {kernel_name}. Searched: "
761-
f"{[str(c) for c in candidates[:2]]}..."
762-
)
763863
return None
764864

865+
kernel_name = self.kernel.kernel.name
866+
kernel_source_file = self.kernel.kernel.__code__.co_filename
867+
868+
# Extract shape features and compute hash for caching
869+
shape_features = self._extract_shape_features(args)
870+
shape_hash = hashlib.sha256(
871+
json.dumps(shape_features, sort_keys=True).encode()
872+
).hexdigest()[:16]
873+
874+
# Check if we already have a cached result for this kernel+shape
875+
# Include source file in key to avoid collisions between kernels with same name
876+
cache_key = (kernel_source_file, kernel_name, shape_hash)
877+
if cache_key in AOTAutotuneCache._heuristic_results:
878+
log.debug(
879+
f"Using cached heuristic result for {kernel_name} shape={shape_hash}"
880+
)
881+
return AOTAutotuneCache._heuristic_results[cache_key]
882+
765883
try:
766-
# Import the heuristic module
767-
import importlib.util
884+
# Load heuristic module from cache or import fresh
885+
if heuristic_file in AOTAutotuneCache._heuristic_modules:
886+
module = AOTAutotuneCache._heuristic_modules[heuristic_file]
887+
else:
888+
import importlib.util
768889

769-
spec = importlib.util.spec_from_file_location("heuristic", heuristic_file)
770-
if spec is None or spec.loader is None:
771-
return None
772-
module = importlib.util.module_from_spec(spec)
773-
spec.loader.exec_module(module)
890+
spec = importlib.util.spec_from_file_location(
891+
"heuristic", heuristic_file
892+
)
893+
if spec is None or spec.loader is None:
894+
return None
895+
module = importlib.util.module_from_spec(spec)
896+
spec.loader.exec_module(module)
897+
AOTAutotuneCache._heuristic_modules[heuristic_file] = module
898+
log.debug(f"Loaded heuristic module: {heuristic_file}")
774899

775900
# Call the heuristic function
901+
config: Config | None = None
776902
if hasattr(module, f"select_config_{kernel_name}"):
777903
select_fn = getattr(module, f"select_config_{kernel_name}")
778-
shape_features = self._extract_shape_features()
779904
config_dict = select_fn(shape_features)
780-
return Config(**config_dict)
781-
if hasattr(module, "select_config"):
905+
config = Config(**config_dict)
906+
elif hasattr(module, "select_config"):
782907
select_fn = module.select_config
783-
shape_features = self._extract_shape_features()
784908
config_dict = select_fn(kernel_name, shape_features)
785-
return Config(**config_dict)
909+
config = Config(**config_dict)
910+
911+
# Cache the result
912+
if config is not None:
913+
AOTAutotuneCache._heuristic_results[cache_key] = config
914+
log.debug(
915+
f"Cached heuristic result for {kernel_name} shape={shape_hash}"
916+
)
917+
918+
return config
786919
except Exception as e:
787920
log.warning(f"Failed to load heuristic from {heuristic_file}: {e}")
788921

789922
return None
790923

924+
def supports_per_shape_config(self) -> bool:
925+
"""
926+
Return True if heuristics are available for per-shape config selection.
927+
928+
When True, the kernel can use get_config_for_args() on each invocation
929+
to get shape-specific configs, even when static_shapes=False.
930+
"""
931+
# Only support per-shape config in evaluate mode with available heuristics
932+
if self.mode != "evaluate":
933+
return False
934+
return self._find_heuristic_file() is not None
935+
936+
def get_config_for_args(self, args: Sequence[object]) -> Config | None:
937+
"""
938+
Get a config for the given arguments using heuristics.
939+
940+
This enables per-shape config selection independent of static_shapes setting.
941+
942+
Args:
943+
args: The kernel arguments for this invocation
944+
945+
Returns:
946+
Config if heuristics provide a config for this shape, None otherwise
947+
"""
948+
return self._get_heuristic_config(args)
949+
791950
def _get_cache_key(self) -> LooseAutotuneCacheKey:
792951
"""Return a cache key for compatibility."""
793952
return self.autotuner.kernel.kernel._create_bound_kernel_cache_key(

helion/autotuner/aot_runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,14 @@ def run_full_workflow(config: RunConfig) -> bool:
437437
log.info("=" * 60)
438438
log.info("AOT autotuning workflow completed successfully!")
439439
log.info("=" * 60)
440+
log.info("")
441+
log.info("TIP: To use the generated heuristics automatically, add")
442+
log.info(" autotune_cache='AOTAutotuneCache' to your kernel decorators:")
443+
log.info("")
444+
log.info(" @helion.kernel(autotune_cache='AOTAutotuneCache')")
445+
log.info(" def my_kernel(...):")
446+
log.info(" ...")
447+
log.info("")
440448
return True
441449

442450

helion/autotuner/base_cache.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,32 @@ def get(self) -> Config | None:
157157
def put(self, config: Config) -> None:
158158
raise NotImplementedError
159159

160+
def supports_per_shape_config(self) -> bool:
161+
"""
162+
Return True if this cache supports per-shape config selection.
163+
164+
When True, get_config_for_args() can be called on each kernel invocation
165+
to get shape-specific configs, bypassing the normal BoundKernel config caching.
166+
This is useful for heuristic-based config selection where different shapes
167+
should use different configs even when static_shapes=False.
168+
"""
169+
return False
170+
171+
def get_config_for_args(self, args: Sequence[object]) -> Config | None:
172+
"""
173+
Get a config for the given arguments.
174+
175+
This method is called on each kernel invocation when supports_per_shape_config()
176+
returns True. Override this to provide per-shape config selection.
177+
178+
Args:
179+
args: The kernel arguments for this invocation
180+
181+
Returns:
182+
Config if a shape-specific config is available, None otherwise
183+
"""
184+
return None
185+
160186
def _get_cache_info_message(self) -> str:
161187
"""Return a message describing where the cache is and how to clear it."""
162188
return ""

0 commit comments

Comments
 (0)