|
54 | 54 |
|
55 | 55 | def get_aot_mode() -> AOTMode: |
56 | 56 | """Get the current AOT mode from environment.""" |
57 | | - mode = os.environ.get(AOT_MODE_ENV, "disabled").lower() |
| 57 | + mode = os.environ.get(AOT_MODE_ENV, "evaluate").lower() |
58 | 58 | if mode in ("collect", "measure", "evaluate", "disabled"): |
59 | 59 | return mode # type: ignore[return-value] |
60 | 60 | raise ValueError( |
@@ -101,6 +101,65 @@ def get_device_compute_id() -> tuple[str, str]: |
101 | 101 | return ("cpu", platform.machine()) |
102 | 102 |
|
103 | 103 |
|
| 104 | +# Known compute capabilities in descending order (newest first) |
| 105 | +# This allows fallback to older architectures when heuristics aren't available |
| 106 | +_CUDA_COMPUTE_CAPS: list[str] = [ |
| 107 | + "sm100", |
| 108 | + "sm90", |
| 109 | + "sm89", |
| 110 | + "sm87", |
| 111 | + "sm86", |
| 112 | + "sm80", |
| 113 | + "sm75", |
| 114 | + "sm72", |
| 115 | + "sm70", |
| 116 | +] |
| 117 | + |
| 118 | +_ROCM_ARCHS: list[str] = [ |
| 119 | + "gfx950", |
| 120 | + "gfx942", |
| 121 | + "gfx941", |
| 122 | + "gfx940", |
| 123 | + "gfx90a", |
| 124 | + "gfx908", |
| 125 | + "gfx906", |
| 126 | + "gfx900", |
| 127 | +] |
| 128 | + |
| 129 | + |
| 130 | +def get_compatible_compute_ids(device_kind: str, compute_kind: str) -> list[str]: |
| 131 | + """ |
| 132 | + Get a list of compatible compute IDs for fallback, ordered from current to oldest. |
| 133 | +
|
| 134 | + For CUDA/ROCm, returns the current compute capability followed by all older |
| 135 | + compatible architectures. This allows using heuristics tuned on older hardware |
| 136 | + when newer hardware-specific heuristics aren't available. |
| 137 | +
|
| 138 | + Args: |
| 139 | + device_kind: 'cuda', 'rocm', or 'cpu' |
| 140 | + compute_kind: The current compute capability (e.g., 'sm90', 'gfx942') |
| 141 | +
|
| 142 | + Returns: |
| 143 | + List of compute IDs to try, starting with the exact match |
| 144 | + """ |
| 145 | + if device_kind == "cuda": |
| 146 | + arch_list = _CUDA_COMPUTE_CAPS |
| 147 | + elif device_kind == "rocm": |
| 148 | + arch_list = _ROCM_ARCHS |
| 149 | + else: |
| 150 | + # CPU or unknown - no fallback |
| 151 | + return [compute_kind] |
| 152 | + |
| 153 | + # Find current architecture in the list |
| 154 | + try: |
| 155 | + current_idx = arch_list.index(compute_kind) |
| 156 | + # Return current and all older architectures |
| 157 | + return arch_list[current_idx:] |
| 158 | + except ValueError: |
| 159 | + # Unknown architecture - try it alone, then try all known ones |
| 160 | + return [compute_kind] + arch_list |
| 161 | + |
| 162 | + |
104 | 163 | def get_heuristic_path_for_kernel(kernel_source_file: str | Path) -> Path: |
105 | 164 | """ |
106 | 165 | Get the path where heuristics should be stored for a kernel. |
@@ -481,6 +540,21 @@ class AOTAutotuneCache(AutotuneCacheBase): |
481 | 540 |
|
482 | 541 | _mode_announced: set[str] = set() # Class-level to avoid repeated messages |
483 | 542 |
|
| 543 | + # Class-level caches for heuristic lookup (shared across instances) |
| 544 | + # Maps heuristic file path -> loaded module |
| 545 | + _heuristic_modules: dict[Path, Any] = {} |
| 546 | + # Maps (kernel_source_file, kernel_name, shape_features_hash) -> Config |
| 547 | + # Using source file ensures kernels with same name in different modules don't collide |
| 548 | + _heuristic_results: dict[tuple[str, str, str], Config] = {} |
| 549 | + |
| 550 | + @classmethod |
| 551 | + def clear_caches(cls) -> None: |
| 552 | + """Clear all class-level caches (heuristic modules and results).""" |
| 553 | + cls._heuristic_modules.clear() |
| 554 | + cls._heuristic_results.clear() |
| 555 | + cls._mode_announced.clear() |
| 556 | + log.debug("Cleared AOTAutotuneCache caches") |
| 557 | + |
484 | 558 | def __init__(self, autotuner: BaseSearch) -> None: |
485 | 559 | super().__init__(autotuner) |
486 | 560 | self.mode = get_aot_mode() |
@@ -516,11 +590,16 @@ def _create_shape_key(self) -> ShapeKey: |
516 | 590 | hardware_id=self.hardware_id, |
517 | 591 | ) |
518 | 592 |
|
519 | | - def _extract_shape_features(self) -> dict[str, Any]: |
| 593 | + def _extract_shape_features( |
| 594 | + self, args: Sequence[object] | None = None |
| 595 | + ) -> dict[str, Any]: |
520 | 596 | """Extract numeric features from the shape for ML model.""" |
| 597 | + if args is None: |
| 598 | + args = self.args |
| 599 | + |
521 | 600 | features: dict[str, Any] = {} |
522 | 601 |
|
523 | | - for i, arg in enumerate(self.args): |
| 602 | + for i, arg in enumerate(args): |
524 | 603 | if isinstance(arg, torch.Tensor): |
525 | 604 | features[f"arg{i}_ndim"] = arg.ndim |
526 | 605 | for j, size in enumerate(arg.shape): |
@@ -709,85 +788,165 @@ def measure_all_configs(self) -> list[tuple[Config, float]]: |
709 | 788 | ) |
710 | 789 | return results |
711 | 790 |
|
712 | | - def _get_heuristic_config(self) -> Config | None: |
| 791 | + def _find_heuristic_file(self) -> Path | None: |
713 | 792 | """ |
714 | | - Use the heuristic to select a config. |
| 793 | + Find the heuristic file for this kernel. |
715 | 794 |
|
716 | | - Search order for heuristic files: |
| 795 | + Search order: |
717 | 796 | 1. HELION_HEURISTIC_DIR env var (if set) - for comparing different heuristics |
718 | 797 | 2. Next to kernel source file: _<filename>_<device>_<compute>.py |
719 | | - 3. AOT data directory: heuristic_<kernel_name>.py (fallback) |
| 798 | + 3. Fallback to older compute capabilities within the same device family |
| 799 | + 4. AOT data directory: heuristic_<kernel_name>.py (fallback) |
720 | 800 | """ |
721 | 801 | kernel_name = self.kernel.kernel.name |
722 | 802 |
|
723 | 803 | # Get the kernel source file path |
724 | 804 | kernel_source_file = self.kernel.kernel.__code__.co_filename |
| 805 | + source_path = Path(kernel_source_file) |
| 806 | + base_name = source_path.stem |
| 807 | + |
| 808 | + # Get device info and compatible compute capabilities |
| 809 | + device_kind, compute_kind = get_device_compute_id() |
| 810 | + compatible_computes = get_compatible_compute_ids(device_kind, compute_kind) |
725 | 811 |
|
726 | 812 | # Build list of candidate heuristic files in priority order |
727 | 813 | candidates: list[Path] = [] |
728 | 814 |
|
729 | 815 | # 1. Check HELION_HEURISTIC_DIR override (for comparing heuristics) |
730 | 816 | if (heuristic_dir := os.environ.get(HEURISTIC_DIR_ENV)) is not None: |
731 | 817 | heuristic_dir_path = Path(heuristic_dir) |
732 | | - # Look for the standard naming convention in override dir |
733 | | - device_kind, compute_kind = get_device_compute_id() |
734 | | - base_name = Path(kernel_source_file).stem |
735 | | - candidates.append( |
736 | | - heuristic_dir_path / f"_{base_name}_{device_kind}_{compute_kind}.py" |
737 | | - ) |
| 818 | + # Try each compatible compute capability in order |
| 819 | + for compat_compute in compatible_computes: |
| 820 | + candidates.append( |
| 821 | + heuristic_dir_path |
| 822 | + / f"_{base_name}_{device_kind}_{compat_compute}.py" |
| 823 | + ) |
738 | 824 | # Also check kernel-specific file in override dir |
739 | 825 | candidates.append(heuristic_dir_path / f"heuristic_{kernel_name}.py") |
740 | 826 |
|
741 | | - # 2. Check next to kernel source file: _<filename>_<device>_<compute>.py |
742 | | - candidates.append(get_heuristic_path_for_kernel(kernel_source_file)) |
| 827 | + # 2. Check next to kernel source file with compute capability fallback |
| 828 | + for compat_compute in compatible_computes: |
| 829 | + heuristic_name = f"_{base_name}_{device_kind}_{compat_compute}.py" |
| 830 | + candidates.append(source_path.parent / heuristic_name) |
743 | 831 |
|
744 | 832 | # 3. Check AOT data directory (fallback for backward compatibility) |
745 | 833 | candidates.append(self.data_store.data_dir / f"heuristic_{kernel_name}.py") |
746 | 834 | candidates.append(self.data_store.heuristic_file) |
747 | 835 |
|
748 | 836 | # Find first existing heuristic file |
749 | | - heuristic_file = None |
750 | 837 | for candidate in candidates: |
751 | 838 | if candidate.exists(): |
752 | | - heuristic_file = candidate |
753 | | - log.debug(f"Found heuristic file: {heuristic_file}") |
754 | | - break |
| 839 | + log.debug(f"Found heuristic file: {candidate}") |
| 840 | + return candidate |
| 841 | + |
| 842 | + log.debug( |
| 843 | + f"Heuristic file not found for {kernel_name}. Searched: " |
| 844 | + f"{[str(c) for c in candidates[:3]]}..." |
| 845 | + ) |
| 846 | + return None |
| 847 | + |
| 848 | + def _get_heuristic_config( |
| 849 | + self, args: Sequence[object] | None = None |
| 850 | + ) -> Config | None: |
| 851 | + """ |
| 852 | + Use the heuristic to select a config. |
| 853 | +
|
| 854 | + Args: |
| 855 | + args: Optional arguments to use for shape feature extraction. |
| 856 | + If None, uses self.args. |
755 | 857 |
|
| 858 | + For CUDA/ROCm, if heuristics for the current compute capability aren't found, |
| 859 | + we try older compatible architectures (e.g., sm80 heuristics on sm90 hardware). |
| 860 | + """ |
| 861 | + heuristic_file = self._find_heuristic_file() |
756 | 862 | if heuristic_file is None: |
757 | | - # Only warn in evaluate mode, not during normal operation |
758 | | - if self.mode == "evaluate": |
759 | | - log.warning( |
760 | | - f"Heuristic file not found for {kernel_name}. Searched: " |
761 | | - f"{[str(c) for c in candidates[:2]]}..." |
762 | | - ) |
763 | 863 | return None |
764 | 864 |
|
| 865 | + kernel_name = self.kernel.kernel.name |
| 866 | + kernel_source_file = self.kernel.kernel.__code__.co_filename |
| 867 | + |
| 868 | + # Extract shape features and compute hash for caching |
| 869 | + shape_features = self._extract_shape_features(args) |
| 870 | + shape_hash = hashlib.sha256( |
| 871 | + json.dumps(shape_features, sort_keys=True).encode() |
| 872 | + ).hexdigest()[:16] |
| 873 | + |
| 874 | + # Check if we already have a cached result for this kernel+shape |
| 875 | + # Include source file in key to avoid collisions between kernels with same name |
| 876 | + cache_key = (kernel_source_file, kernel_name, shape_hash) |
| 877 | + if cache_key in AOTAutotuneCache._heuristic_results: |
| 878 | + log.debug( |
| 879 | + f"Using cached heuristic result for {kernel_name} shape={shape_hash}" |
| 880 | + ) |
| 881 | + return AOTAutotuneCache._heuristic_results[cache_key] |
| 882 | + |
765 | 883 | try: |
766 | | - # Import the heuristic module |
767 | | - import importlib.util |
| 884 | + # Load heuristic module from cache or import fresh |
| 885 | + if heuristic_file in AOTAutotuneCache._heuristic_modules: |
| 886 | + module = AOTAutotuneCache._heuristic_modules[heuristic_file] |
| 887 | + else: |
| 888 | + import importlib.util |
768 | 889 |
|
769 | | - spec = importlib.util.spec_from_file_location("heuristic", heuristic_file) |
770 | | - if spec is None or spec.loader is None: |
771 | | - return None |
772 | | - module = importlib.util.module_from_spec(spec) |
773 | | - spec.loader.exec_module(module) |
| 890 | + spec = importlib.util.spec_from_file_location( |
| 891 | + "heuristic", heuristic_file |
| 892 | + ) |
| 893 | + if spec is None or spec.loader is None: |
| 894 | + return None |
| 895 | + module = importlib.util.module_from_spec(spec) |
| 896 | + spec.loader.exec_module(module) |
| 897 | + AOTAutotuneCache._heuristic_modules[heuristic_file] = module |
| 898 | + log.debug(f"Loaded heuristic module: {heuristic_file}") |
774 | 899 |
|
775 | 900 | # Call the heuristic function |
| 901 | + config: Config | None = None |
776 | 902 | if hasattr(module, f"select_config_{kernel_name}"): |
777 | 903 | select_fn = getattr(module, f"select_config_{kernel_name}") |
778 | | - shape_features = self._extract_shape_features() |
779 | 904 | config_dict = select_fn(shape_features) |
780 | | - return Config(**config_dict) |
781 | | - if hasattr(module, "select_config"): |
| 905 | + config = Config(**config_dict) |
| 906 | + elif hasattr(module, "select_config"): |
782 | 907 | select_fn = module.select_config |
783 | | - shape_features = self._extract_shape_features() |
784 | 908 | config_dict = select_fn(kernel_name, shape_features) |
785 | | - return Config(**config_dict) |
| 909 | + config = Config(**config_dict) |
| 910 | + |
| 911 | + # Cache the result |
| 912 | + if config is not None: |
| 913 | + AOTAutotuneCache._heuristic_results[cache_key] = config |
| 914 | + log.debug( |
| 915 | + f"Cached heuristic result for {kernel_name} shape={shape_hash}" |
| 916 | + ) |
| 917 | + |
| 918 | + return config |
786 | 919 | except Exception as e: |
787 | 920 | log.warning(f"Failed to load heuristic from {heuristic_file}: {e}") |
788 | 921 |
|
789 | 922 | return None |
790 | 923 |
|
| 924 | + def supports_per_shape_config(self) -> bool: |
| 925 | + """ |
| 926 | + Return True if heuristics are available for per-shape config selection. |
| 927 | +
|
| 928 | + When True, the kernel can use get_config_for_args() on each invocation |
| 929 | + to get shape-specific configs, even when static_shapes=False. |
| 930 | + """ |
| 931 | + # Only support per-shape config in evaluate mode with available heuristics |
| 932 | + if self.mode != "evaluate": |
| 933 | + return False |
| 934 | + return self._find_heuristic_file() is not None |
| 935 | + |
| 936 | + def get_config_for_args(self, args: Sequence[object]) -> Config | None: |
| 937 | + """ |
| 938 | + Get a config for the given arguments using heuristics. |
| 939 | +
|
| 940 | + This enables per-shape config selection independent of static_shapes setting. |
| 941 | +
|
| 942 | + Args: |
| 943 | + args: The kernel arguments for this invocation |
| 944 | +
|
| 945 | + Returns: |
| 946 | + Config if heuristics provide a config for this shape, None otherwise |
| 947 | + """ |
| 948 | + return self._get_heuristic_config(args) |
| 949 | + |
791 | 950 | def _get_cache_key(self) -> LooseAutotuneCacheKey: |
792 | 951 | """Return a cache key for compatibility.""" |
793 | 952 | return self.autotuner.kernel.kernel._create_bound_kernel_cache_key( |
|
0 commit comments