Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions examples/aot_autotuning_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/usr/bin/env python3
Copy link
Contributor

@yf225 yf225 Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about Horace's example, and curious: Would it make sense to support a "benchmark only" mode in the collect phase (or a separate phase) that skips autotuning and just measures existing configs against additional shapes (similar to secondary_inputs in Horace's RFC)? This would let users:

  1. Run collect on a small set of representative shapes (to do full autotune on)
  2. Run benchmark-only on a larger set of shapes (just measure)
  3. Build heuristics using the full timing matrix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The script already lets you specify different benchmarks for all three phases of measurement, so you can collect on a different benchmark than the one you measure.

Is that what you are asking about or something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah nice! curious should we add an example showing this workflow? cc. @Chillee would this cover the original need of primary_inputs / secondary_inputs ?

"""
AOT Autotuning Example
======================

This example demonstrates how to use the AOT (Ahead-of-Time) autotuning
workflow for Helion kernels.

The AOT workflow consists of three phases:
1. Collect: Run benchmarks, autotuning each shape individually
2. Measure: Re-run benchmarks, measuring all configs across all shapes
Copy link
Contributor

@yf225 yf225 Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(as we discussed, maybe "all shapes" is not exactly accurate, as user can customize what shapes to run in each phase)

3. Evaluate: Generate heuristics and validate performance

Usage:
# Run the full workflow using the AOT runner
python -m helion.autotuner.aot_runner --benchmark "python examples/aot_autotuning_example.py"

# Or run individual phases:
HELION_AOT_MODE=collect HELION_AOT_DATA_DIR=./aot_data python examples/aot_autotuning_example.py
HELION_AOT_MODE=measure HELION_AOT_DATA_DIR=./aot_data python examples/aot_autotuning_example.py
python -c "from helion.autotuner.heuristic_generator import generate_heuristic; from pathlib import Path; generate_heuristic(Path('./aot_data/measurements_*.csv'), Path('./aot_data'))"
HELION_AOT_MODE=evaluate HELION_AOT_DATA_DIR=./aot_data python examples/aot_autotuning_example.py
"""

from __future__ import annotations

import os

import torch

import helion
from helion._testing import DEVICE
import helion.language as hl


# Define a simple kernel for demonstration
@helion.kernel
def vector_scale(x: torch.Tensor, scale: float) -> torch.Tensor:
"""Scale a vector by a constant."""
n = x.size(0)
out = torch.empty_like(x)
for tile_n in hl.tile(n):
out[tile_n] = x[tile_n] * scale
return out


@helion.kernel
def rms_norm_simple(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
"""Simplified RMS normalization."""
m, n = x.size()
out = torch.empty_like(x)
for tile_m in hl.tile(m):
x_tile = x[tile_m, :].to(torch.float32)
rms = torch.sqrt(torch.mean(x_tile * x_tile, dim=-1) + eps)
out[tile_m, :] = (x_tile / rms[:, None]).to(out.dtype)
return out


def benchmark_kernels() -> None:
"""Run benchmarks on various shapes."""
print(f"AOT Mode: {os.environ.get('HELION_AOT_MODE', 'disabled')}")
print(f"AOT Data Dir: {os.environ.get('HELION_AOT_DATA_DIR', 'N/A')}")
print()

# Test vector_scale with various sizes
print("=== vector_scale kernel ===")
for n in [1024, 4096, 16384]:
x = torch.randn(n, device=DEVICE, dtype=torch.float16)
result = vector_scale(x, 2.0)
print(f" Shape ({n},): output sum = {result.sum().item():.2f}")

# Test rms_norm_simple with various shapes
print()
print("=== rms_norm_simple kernel ===")
for m, n in [(128, 512), (256, 1024), (512, 2048)]:
x = torch.randn(m, n, device=DEVICE, dtype=torch.float16)
result = rms_norm_simple(x)
print(f" Shape ({m}, {n}): output sum = {result.sum().item():.2f}")


def main() -> None:
"""Main entry point."""
# Check if we're in AOT mode
aot_mode = os.environ.get("HELION_AOT_MODE", "disabled")

if aot_mode == "disabled":
print("Running in normal mode (no AOT)")
print("Set HELION_AOT_MODE=collect|measure|evaluate to enable AOT workflow")
print()

# Enable AOT cache if in AOT mode
if aot_mode != "disabled":
os.environ["HELION_AUTOTUNE_CACHE"] = "AOTAutotuneCache"

benchmark_kernels()

if aot_mode != "disabled":
print()
print(f"AOT {aot_mode} phase completed!")
data_dir = os.environ.get("HELION_AOT_DATA_DIR", ".helion_aot")
print(f"Data saved to: {data_dir}")


if __name__ == "__main__":
main()
171 changes: 171 additions & 0 deletions examples/aot_multiconfig_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#!/usr/bin/env python3
"""
AOT Multi-Config Autotuning Example
===================================

This example demonstrates AOT autotuning with a kernel that requires
different configurations for different input shapes. The kernel has
2D block sizes that have different optimal values for:
- Tall-and-skinny matrices (M >> N): Small block_m, large block_n
- Short-and-wide matrices (M << N): Large block_m, small block_n
- Square matrices: Balanced block sizes

This generates heuristics with actual decision tree logic rather than
a single config.

Usage:
python -m helion.autotuner.aot_runner --benchmark "python examples/aot_multiconfig_example.py"
"""

from __future__ import annotations

import os

import torch

import helion
from helion._testing import DEVICE
import helion.language as hl


@helion.kernel()
def row_softmax(x: torch.Tensor) -> torch.Tensor:
"""
Row-wise softmax with explicit 2D tiling.

The optimal block sizes depend on the matrix shape:
- Tall matrices benefit from larger row tiles
- Wide matrices benefit from larger column tiles
"""
m, n = x.size()
out = torch.empty_like(x)

# Explicit block size registration allows tuning
block_m = hl.register_block_size(m)
block_n = hl.register_block_size(n)

for tile_m in hl.tile(m, block_size=block_m):
# First pass: compute max and sum for numerical stability
row_max = hl.full([tile_m], float("-inf"), dtype=torch.float32)
row_sum = hl.zeros([tile_m], dtype=torch.float32)

for tile_n in hl.tile(n, block_size=block_n):
values = x[tile_m, tile_n].to(torch.float32)
local_max = torch.amax(values, dim=1)
new_max = torch.maximum(row_max, local_max)
# Rescale previous sum and add new contributions
row_sum = row_sum * torch.exp(row_max - new_max) + torch.sum(
torch.exp(values - new_max[:, None]), dim=1
)
row_max = new_max

# Second pass: compute softmax output
for tile_n in hl.tile(n, block_size=block_n):
values = x[tile_m, tile_n].to(torch.float32)
out[tile_m, tile_n] = (
torch.exp(values - row_max[:, None]) / row_sum[:, None]
).to(out.dtype)

return out


@helion.kernel()
def col_reduce_sum(x: torch.Tensor) -> torch.Tensor:
"""
Column-wise sum reduction with 2D tiling.

For tall matrices, we want to process many rows in parallel.
For wide matrices, we want larger column blocks.
"""
m, n = x.size()
out = torch.zeros(n, dtype=x.dtype, device=x.device)

block_m = hl.register_block_size(m)
block_n = hl.register_block_size(n)

for tile_n in hl.tile(n, block_size=block_n):
col_acc = hl.zeros([tile_n], dtype=torch.float32)

for tile_m in hl.tile(m, block_size=block_m):
col_acc += torch.sum(x[tile_m, tile_n].to(torch.float32), dim=0)

out[tile_n] = col_acc.to(out.dtype)

return out


def benchmark_kernels() -> None:
"""Run benchmarks on various shapes and dtypes."""
print(f"AOT Mode: {os.environ.get('HELION_AOT_MODE', 'disabled')}")
print(f"AOT Data Dir: {os.environ.get('HELION_AOT_DATA_DIR', 'N/A')}")
print()

# Define shapes covering different aspect ratios
shapes = [
# Tall and skinny (M >> N)
(8192, 64),
(4096, 128),
(2048, 256),
# Square-ish
(1024, 1024),
(2048, 512),
(512, 2048),
# Short and wide (M << N)
(256, 2048),
(128, 4096),
(64, 8192),
]

# Test multiple dtypes - different dtypes often need different tile sizes
dtypes = [torch.float16, torch.float32] # , torch.bfloat16]

print("=== row_softmax kernel ===")
print("Testing across shapes and dtypes:")
for dtype in dtypes:
print(f"\n dtype={dtype}:")
for m, n in shapes:
x = torch.randn(m, n, device=DEVICE, dtype=dtype)
result = row_softmax(x)
# Verify softmax property: each row sums to 1
row_sums = result.sum(dim=1)
avg_sum = row_sums.mean().item()
print(f" Shape ({m:5d}, {n:5d}): row_sum mean = {avg_sum:.4f}")

# print()
# print("=== col_reduce_sum kernel ===")
# print("Testing across shapes and dtypes:")
# for dtype in dtypes:
# print(f"\n dtype={dtype}:")
# for m, n in shapes:
# x = torch.randn(m, n, device=DEVICE, dtype=dtype)
# result = col_reduce_sum(x)
# # Compare with torch reference
# expected = x.sum(dim=0)
# max_diff = (result - expected).abs().max().item()
# print(f" Shape ({m:5d}, {n:5d}): max_diff = {max_diff:.6f}")


def main() -> None:
"""Main entry point."""
aot_mode = os.environ.get("HELION_AOT_MODE", "disabled")

if aot_mode == "disabled":
print("Running in normal mode (no AOT)")
print("Set HELION_AOT_MODE=collect|measure|evaluate to enable AOT workflow")
print()

# Enable AOT cache if in AOT mode
if aot_mode != "disabled":
os.environ["HELION_AUTOTUNE_CACHE"] = "AOTAutotuneCache"

benchmark_kernels()

if aot_mode != "disabled":
print()
print(f"AOT {aot_mode} phase completed!")
data_dir = os.environ.get("HELION_AOT_DATA_DIR", ".helion_aot")
print(f"Data saved to: {data_dir}")


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions helion/autotuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .aot_cache import AOTAutotuneCache as AOTAutotuneCache
from .config_fragment import BooleanFragment as BooleanFragment
from .config_fragment import EnumFragment as EnumFragment
from .config_fragment import IntegerFragment as IntegerFragment
Expand Down Expand Up @@ -34,4 +35,5 @@
cache_classes = {
"LocalAutotuneCache": LocalAutotuneCache,
"StrictLocalAutotuneCache": StrictLocalAutotuneCache,
"AOTAutotuneCache": AOTAutotuneCache,
}
Loading
Loading