-
Notifications
You must be signed in to change notification settings - Fork 89
Draft: first draft of a aot autotuning runner and cache and heuristics gener… #1278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
v0i0
wants to merge
6
commits into
main
Choose a base branch
from
v0i0/autotune-heuristic
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
fdf0655
first draft of a aot autotuning runner and cache and heuristics gener…
v0i0 40eb06c
add data type support
v0i0 a5e9030
fixing code gen
v0i0 ff0c8cb
heuristics bbackends
v0i0 e0ab1f7
better decision tree, better output in runner, more analysis
v0i0 15f568d
make it easy to use
v0i0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| AOT Autotuning Example | ||
| ====================== | ||
|
|
||
| This example demonstrates how to use the AOT (Ahead-of-Time) autotuning | ||
| workflow for Helion kernels. | ||
|
|
||
| The AOT workflow consists of three phases: | ||
| 1. Collect: Run benchmarks, autotuning each shape individually | ||
| 2. Measure: Re-run benchmarks, measuring all configs across all shapes | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (as we discussed, maybe "all shapes" is not exactly accurate, as user can customize what shapes to run in each phase) |
||
| 3. Evaluate: Generate heuristics and validate performance | ||
|
|
||
| Usage: | ||
| # Run the full workflow using the AOT runner | ||
| python -m helion.autotuner.aot_runner --benchmark "python examples/aot_autotuning_example.py" | ||
|
|
||
| # Or run individual phases: | ||
| HELION_AOT_MODE=collect HELION_AOT_DATA_DIR=./aot_data python examples/aot_autotuning_example.py | ||
| HELION_AOT_MODE=measure HELION_AOT_DATA_DIR=./aot_data python examples/aot_autotuning_example.py | ||
| python -c "from helion.autotuner.heuristic_generator import generate_heuristic; from pathlib import Path; generate_heuristic(Path('./aot_data/measurements_*.csv'), Path('./aot_data'))" | ||
| HELION_AOT_MODE=evaluate HELION_AOT_DATA_DIR=./aot_data python examples/aot_autotuning_example.py | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import os | ||
|
|
||
| import torch | ||
|
|
||
| import helion | ||
| from helion._testing import DEVICE | ||
| import helion.language as hl | ||
|
|
||
|
|
||
| # Define a simple kernel for demonstration | ||
| @helion.kernel | ||
| def vector_scale(x: torch.Tensor, scale: float) -> torch.Tensor: | ||
| """Scale a vector by a constant.""" | ||
| n = x.size(0) | ||
| out = torch.empty_like(x) | ||
| for tile_n in hl.tile(n): | ||
| out[tile_n] = x[tile_n] * scale | ||
| return out | ||
|
|
||
|
|
||
| @helion.kernel | ||
| def rms_norm_simple(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: | ||
| """Simplified RMS normalization.""" | ||
| m, n = x.size() | ||
| out = torch.empty_like(x) | ||
| for tile_m in hl.tile(m): | ||
| x_tile = x[tile_m, :].to(torch.float32) | ||
| rms = torch.sqrt(torch.mean(x_tile * x_tile, dim=-1) + eps) | ||
| out[tile_m, :] = (x_tile / rms[:, None]).to(out.dtype) | ||
| return out | ||
|
|
||
|
|
||
| def benchmark_kernels() -> None: | ||
| """Run benchmarks on various shapes.""" | ||
| print(f"AOT Mode: {os.environ.get('HELION_AOT_MODE', 'disabled')}") | ||
| print(f"AOT Data Dir: {os.environ.get('HELION_AOT_DATA_DIR', 'N/A')}") | ||
| print() | ||
|
|
||
| # Test vector_scale with various sizes | ||
| print("=== vector_scale kernel ===") | ||
| for n in [1024, 4096, 16384]: | ||
| x = torch.randn(n, device=DEVICE, dtype=torch.float16) | ||
| result = vector_scale(x, 2.0) | ||
| print(f" Shape ({n},): output sum = {result.sum().item():.2f}") | ||
|
|
||
| # Test rms_norm_simple with various shapes | ||
| print() | ||
| print("=== rms_norm_simple kernel ===") | ||
| for m, n in [(128, 512), (256, 1024), (512, 2048)]: | ||
| x = torch.randn(m, n, device=DEVICE, dtype=torch.float16) | ||
| result = rms_norm_simple(x) | ||
| print(f" Shape ({m}, {n}): output sum = {result.sum().item():.2f}") | ||
|
|
||
|
|
||
| def main() -> None: | ||
| """Main entry point.""" | ||
| # Check if we're in AOT mode | ||
| aot_mode = os.environ.get("HELION_AOT_MODE", "disabled") | ||
|
|
||
| if aot_mode == "disabled": | ||
| print("Running in normal mode (no AOT)") | ||
| print("Set HELION_AOT_MODE=collect|measure|evaluate to enable AOT workflow") | ||
| print() | ||
|
|
||
| # Enable AOT cache if in AOT mode | ||
| if aot_mode != "disabled": | ||
| os.environ["HELION_AUTOTUNE_CACHE"] = "AOTAutotuneCache" | ||
|
|
||
| benchmark_kernels() | ||
|
|
||
| if aot_mode != "disabled": | ||
| print() | ||
| print(f"AOT {aot_mode} phase completed!") | ||
| data_dir = os.environ.get("HELION_AOT_DATA_DIR", ".helion_aot") | ||
| print(f"Data saved to: {data_dir}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| AOT Multi-Config Autotuning Example | ||
| =================================== | ||
|
|
||
| This example demonstrates AOT autotuning with a kernel that requires | ||
| different configurations for different input shapes. The kernel has | ||
| 2D block sizes that have different optimal values for: | ||
| - Tall-and-skinny matrices (M >> N): Small block_m, large block_n | ||
| - Short-and-wide matrices (M << N): Large block_m, small block_n | ||
| - Square matrices: Balanced block sizes | ||
|
|
||
| This generates heuristics with actual decision tree logic rather than | ||
| a single config. | ||
|
|
||
| Usage: | ||
| python -m helion.autotuner.aot_runner --benchmark "python examples/aot_multiconfig_example.py" | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import os | ||
|
|
||
| import torch | ||
|
|
||
| import helion | ||
| from helion._testing import DEVICE | ||
| import helion.language as hl | ||
|
|
||
|
|
||
| @helion.kernel() | ||
| def row_softmax(x: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Row-wise softmax with explicit 2D tiling. | ||
|
|
||
| The optimal block sizes depend on the matrix shape: | ||
| - Tall matrices benefit from larger row tiles | ||
| - Wide matrices benefit from larger column tiles | ||
| """ | ||
| m, n = x.size() | ||
| out = torch.empty_like(x) | ||
|
|
||
| # Explicit block size registration allows tuning | ||
| block_m = hl.register_block_size(m) | ||
| block_n = hl.register_block_size(n) | ||
|
|
||
| for tile_m in hl.tile(m, block_size=block_m): | ||
| # First pass: compute max and sum for numerical stability | ||
| row_max = hl.full([tile_m], float("-inf"), dtype=torch.float32) | ||
| row_sum = hl.zeros([tile_m], dtype=torch.float32) | ||
|
|
||
| for tile_n in hl.tile(n, block_size=block_n): | ||
| values = x[tile_m, tile_n].to(torch.float32) | ||
| local_max = torch.amax(values, dim=1) | ||
| new_max = torch.maximum(row_max, local_max) | ||
| # Rescale previous sum and add new contributions | ||
| row_sum = row_sum * torch.exp(row_max - new_max) + torch.sum( | ||
| torch.exp(values - new_max[:, None]), dim=1 | ||
| ) | ||
| row_max = new_max | ||
|
|
||
| # Second pass: compute softmax output | ||
| for tile_n in hl.tile(n, block_size=block_n): | ||
| values = x[tile_m, tile_n].to(torch.float32) | ||
| out[tile_m, tile_n] = ( | ||
| torch.exp(values - row_max[:, None]) / row_sum[:, None] | ||
| ).to(out.dtype) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| @helion.kernel() | ||
| def col_reduce_sum(x: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Column-wise sum reduction with 2D tiling. | ||
|
|
||
| For tall matrices, we want to process many rows in parallel. | ||
| For wide matrices, we want larger column blocks. | ||
| """ | ||
| m, n = x.size() | ||
| out = torch.zeros(n, dtype=x.dtype, device=x.device) | ||
|
|
||
| block_m = hl.register_block_size(m) | ||
| block_n = hl.register_block_size(n) | ||
|
|
||
| for tile_n in hl.tile(n, block_size=block_n): | ||
| col_acc = hl.zeros([tile_n], dtype=torch.float32) | ||
|
|
||
| for tile_m in hl.tile(m, block_size=block_m): | ||
| col_acc += torch.sum(x[tile_m, tile_n].to(torch.float32), dim=0) | ||
|
|
||
| out[tile_n] = col_acc.to(out.dtype) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def benchmark_kernels() -> None: | ||
| """Run benchmarks on various shapes and dtypes.""" | ||
| print(f"AOT Mode: {os.environ.get('HELION_AOT_MODE', 'disabled')}") | ||
| print(f"AOT Data Dir: {os.environ.get('HELION_AOT_DATA_DIR', 'N/A')}") | ||
| print() | ||
|
|
||
| # Define shapes covering different aspect ratios | ||
| shapes = [ | ||
| # Tall and skinny (M >> N) | ||
| (8192, 64), | ||
| (4096, 128), | ||
| (2048, 256), | ||
| # Square-ish | ||
| (1024, 1024), | ||
| (2048, 512), | ||
| (512, 2048), | ||
| # Short and wide (M << N) | ||
| (256, 2048), | ||
| (128, 4096), | ||
| (64, 8192), | ||
| ] | ||
|
|
||
| # Test multiple dtypes - different dtypes often need different tile sizes | ||
| dtypes = [torch.float16, torch.float32] # , torch.bfloat16] | ||
|
|
||
| print("=== row_softmax kernel ===") | ||
| print("Testing across shapes and dtypes:") | ||
| for dtype in dtypes: | ||
| print(f"\n dtype={dtype}:") | ||
| for m, n in shapes: | ||
| x = torch.randn(m, n, device=DEVICE, dtype=dtype) | ||
| result = row_softmax(x) | ||
| # Verify softmax property: each row sums to 1 | ||
| row_sums = result.sum(dim=1) | ||
| avg_sum = row_sums.mean().item() | ||
| print(f" Shape ({m:5d}, {n:5d}): row_sum mean = {avg_sum:.4f}") | ||
|
|
||
| # print() | ||
| # print("=== col_reduce_sum kernel ===") | ||
| # print("Testing across shapes and dtypes:") | ||
| # for dtype in dtypes: | ||
| # print(f"\n dtype={dtype}:") | ||
| # for m, n in shapes: | ||
| # x = torch.randn(m, n, device=DEVICE, dtype=dtype) | ||
| # result = col_reduce_sum(x) | ||
| # # Compare with torch reference | ||
| # expected = x.sum(dim=0) | ||
| # max_diff = (result - expected).abs().max().item() | ||
| # print(f" Shape ({m:5d}, {n:5d}): max_diff = {max_diff:.6f}") | ||
|
|
||
|
|
||
| def main() -> None: | ||
| """Main entry point.""" | ||
| aot_mode = os.environ.get("HELION_AOT_MODE", "disabled") | ||
|
|
||
| if aot_mode == "disabled": | ||
| print("Running in normal mode (no AOT)") | ||
| print("Set HELION_AOT_MODE=collect|measure|evaluate to enable AOT workflow") | ||
| print() | ||
|
|
||
| # Enable AOT cache if in AOT mode | ||
| if aot_mode != "disabled": | ||
| os.environ["HELION_AUTOTUNE_CACHE"] = "AOTAutotuneCache" | ||
|
|
||
| benchmark_kernels() | ||
|
|
||
| if aot_mode != "disabled": | ||
| print() | ||
| print(f"AOT {aot_mode} phase completed!") | ||
| data_dir = os.environ.get("HELION_AOT_DATA_DIR", ".helion_aot") | ||
| print(f"Data saved to: {data_dir}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking about Horace's example, and curious: Would it make sense to support a "benchmark only" mode in the
collectphase (or a separate phase) that skips autotuning and just measures existing configs against additional shapes (similar tosecondary_inputsin Horace's RFC)? This would let users:collecton a small set of representative shapes (to do full autotune on)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script already lets you specify different benchmarks for all three phases of measurement, so you can collect on a different benchmark than the one you measure.
Is that what you are asking about or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah nice! curious should we add an example showing this workflow? cc. @Chillee would this cover the original need of
primary_inputs/secondary_inputs?