From 98e81e01ca41bdadf0a91c1063512e0fa457d7d8 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 20 Dec 2025 11:39:52 -0800 Subject: [PATCH] [mxfp8 moe training] update readme with kernel microbenchmarks for dsv3 stack-info: PR: https://github.com/pytorch/ao/pull/3521, branch: danielvegamyhre/stack/90 --- torchao/prototype/moe_training/README.md | 45 ++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/torchao/prototype/moe_training/README.md b/torchao/prototype/moe_training/README.md index f88a151044..664fd9f915 100644 --- a/torchao/prototype/moe_training/README.md +++ b/torchao/prototype/moe_training/README.md @@ -222,6 +222,51 @@ cd benchmarks/prototype/moe_training/mxfp8 python roofline_unified.py --K=7168 --N=2048 --G=8 --power_limit_percent=100 --breakdown_M=131072 --plot_file=dsv3_rooflines.png ``` +### MXFP8 Kernel Breakdown by Pass + +The following table provides a detailed breakdown of all MXFP8 kernels used in the forward and backward passes, with shapes representative of **DeepSeekV3 671B** (dim=7168, hidden_dim=2048, total_tokens=131072, groups=8, block_size=32). + +**Environment:** +- torch: `2.11.0.dev20251216+cu128` +- torchao: `0.15.0+gitd1305bc78` +- NVIDIA B200 + +| Pass | Kernel | Purpose | Input Shape | Time (µs) | Efficiency | +|------|--------|---------|-------------|-----------|------------| +| **Forward** | `triton_to_mxfp8_dim0` | Quantize A (activations) along dim0 | (131072, 7168) | 580.6 | 83.3% peak BW | +| **Forward** | `mxfp8_quantize_cuda_3d` | Quantize B (weights) along dim0 | (8, 2048, 7168) | 76.8 | 78.8% peak BW | +| **Forward** | `triton_mx_block_rearrange_2d_M_groups` | Convert A scales to blocked format | (131072, 224) | 198.7 | — | +| **Forward** | `triton_mx_block_rearrange_per_group_3d` | Convert B scales to blocked format | (8, 2048, 224) | 11.4 | — | +| **Forward** | `torch._scaled_grouped_mm` | 2D-3D scaled grouped GEMM | (131072, 7168) @ (8, 7168, 2048) | 1838.1 | 74.6% peak TFLOPS | +| **Backward (dA)** | `triton_to_mxfp8_dim0` | Quantize grad_out along dim0 | (131072, 2048) | 166.0 | 83.3% peak BW | +| **Backward (dA)** | `mxfp8_quantize_cuda_3d` | Quantize B along dim1 (N dimension) | (8, 2048, 7168) | 76.8 | 78.8% peak BW | +| **Backward (dA)** | `triton_mx_block_rearrange_2d_M_groups` | Convert grad_out scales to blocked format | (131072, 64) | 192.5 | — | +| **Backward (dA)** | `triton_mx_block_rearrange_per_group_3d` | Convert B scales to blocked format | (8, 7168, 64) | 11.0 | — | +| **Backward (dA)** | `torch._scaled_grouped_mm` | 2D-3D scaled grouped GEMM | (131072, 2048) @ (8, 2048, 7168) | 1838.1 | 74.6% peak TFLOPS | +| **Backward (dB)** | `mxfp8_quantize_cuda` | Quantize grad_out along dim1 (colwise) | (131072, 2048) | 191.7 | 72.1% peak BW | +| **Backward (dB)** | `mxfp8_quantize_cuda` | Quantize A along dim1 (colwise) | (131072, 7168) | 670.7 | 72.1% peak BW | +| **Backward (dB)** | `mx_block_rearrange_2d_K_groups_cuda` | Convert grad_out_t scales to blocked format | (2048, 4096) | 17.4 | — | +| **Backward (dB)** | `mx_block_rearrange_2d_K_groups_cuda` | Convert A_t scales to blocked format | (7168, 4096) | 31.6 | — | +| **Backward (dB)** | `torch._scaled_grouped_mm` | 2D-2D scaled grouped GEMM | (2048, 131072) @ (131072, 7168) | 2412.4 | 56.9% peak TFLOPS | + +**Notes:** +- **Efficiency** is reported as percentage of peak achievable bandwidth (for memory-bound quantization kernels) or percentage of peak TFLOPS (for compute-bound GEMM kernels) +- Scale rearrangement kernels are not conventional memory bandwidth bound or compute bound kernels, so we report absolute runtime only +- Scale tensor shapes are derived from input shapes divided by `block_size=32` along the scaling dimension +- Detailed kernel breakdown with timing for all kernels is available in the roofline plots above (generated by `roofline_unified.py`) +- All kernels can be benchmarked individually using the scripts in `benchmarks/prototype/moe_training/mxfp8/` + +**Benchmark Scripts:** +| Kernel Type | Benchmark Script | +|-------------|------------------| +| 2D Quantization (dim0/dim1) | `benchmarks/mx_formats/cast_bench.py` | +| 3D Quantization | `benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py` | +| 2D M-groups Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py` | +| 2D K-groups Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py` | +| 3D Per-group Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py` | +| Grouped GEMM (2D-3D, 2D-2D) | `benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py` | +| Unified Roofline Analysis | `benchmarks/prototype/moe_training/mxfp8/roofline_unified.py` | + ## Benchmark: single MoE layer forward + backward pass | Model | total_M | N | K | bf16 time (ms) | mxfp8 time (ms) | speedup |