Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions torchao/prototype/moe_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,51 @@ cd benchmarks/prototype/moe_training/mxfp8
python roofline_unified.py --K=7168 --N=2048 --G=8 --power_limit_percent=100 --breakdown_M=131072 --plot_file=dsv3_rooflines.png
```

### MXFP8 Kernel Breakdown by Pass

The following table provides a detailed breakdown of all MXFP8 kernels used in the forward and backward passes, with shapes representative of **DeepSeekV3 671B** (dim=7168, hidden_dim=2048, total_tokens=131072, groups=8, block_size=32).

**Environment:**
- torch: `2.11.0.dev20251216+cu128`
- torchao: `0.15.0+gitd1305bc78`
- NVIDIA B200

| Pass | Kernel | Purpose | Input Shape | Time (µs) | Efficiency |
|------|--------|---------|-------------|-----------|------------|
| **Forward** | `triton_to_mxfp8_dim0` | Quantize A (activations) along dim0 | (131072, 7168) | 580.6 | 83.3% peak BW |
| **Forward** | `mxfp8_quantize_cuda_3d` | Quantize B (weights) along dim0 | (8, 2048, 7168) | 76.8 | 78.8% peak BW |
| **Forward** | `triton_mx_block_rearrange_2d_M_groups` | Convert A scales to blocked format | (131072, 224) | 198.7 | — |
| **Forward** | `triton_mx_block_rearrange_per_group_3d` | Convert B scales to blocked format | (8, 2048, 224) | 11.4 | — |
| **Forward** | `torch._scaled_grouped_mm` | 2D-3D scaled grouped GEMM | (131072, 7168) @ (8, 7168, 2048) | 1838.1 | 74.6% peak TFLOPS |
| **Backward (dA)** | `triton_to_mxfp8_dim0` | Quantize grad_out along dim0 | (131072, 2048) | 166.0 | 83.3% peak BW |
| **Backward (dA)** | `mxfp8_quantize_cuda_3d` | Quantize B along dim1 (N dimension) | (8, 2048, 7168) | 76.8 | 78.8% peak BW |
| **Backward (dA)** | `triton_mx_block_rearrange_2d_M_groups` | Convert grad_out scales to blocked format | (131072, 64) | 192.5 | — |
| **Backward (dA)** | `triton_mx_block_rearrange_per_group_3d` | Convert B scales to blocked format | (8, 7168, 64) | 11.0 | — |
| **Backward (dA)** | `torch._scaled_grouped_mm` | 2D-3D scaled grouped GEMM | (131072, 2048) @ (8, 2048, 7168) | 1838.1 | 74.6% peak TFLOPS |
| **Backward (dB)** | `mxfp8_quantize_cuda` | Quantize grad_out along dim1 (colwise) | (131072, 2048) | 191.7 | 72.1% peak BW |
| **Backward (dB)** | `mxfp8_quantize_cuda` | Quantize A along dim1 (colwise) | (131072, 7168) | 670.7 | 72.1% peak BW |
| **Backward (dB)** | `mx_block_rearrange_2d_K_groups_cuda` | Convert grad_out_t scales to blocked format | (2048, 4096) | 17.4 | — |
| **Backward (dB)** | `mx_block_rearrange_2d_K_groups_cuda` | Convert A_t scales to blocked format | (7168, 4096) | 31.6 | — |
| **Backward (dB)** | `torch._scaled_grouped_mm` | 2D-2D scaled grouped GEMM | (2048, 131072) @ (131072, 7168) | 2412.4 | 56.9% peak TFLOPS |

**Notes:**
- **Efficiency** is reported as percentage of peak achievable bandwidth (for memory-bound quantization kernels) or percentage of peak TFLOPS (for compute-bound GEMM kernels)
- Scale rearrangement kernels are not conventional memory bandwidth bound or compute bound kernels, so we report absolute runtime only
- Scale tensor shapes are derived from input shapes divided by `block_size=32` along the scaling dimension
- Detailed kernel breakdown with timing for all kernels is available in the roofline plots above (generated by `roofline_unified.py`)
- All kernels can be benchmarked individually using the scripts in `benchmarks/prototype/moe_training/mxfp8/`

**Benchmark Scripts:**
| Kernel Type | Benchmark Script |
|-------------|------------------|
| 2D Quantization (dim0/dim1) | `benchmarks/mx_formats/cast_bench.py` |
| 3D Quantization | `benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py` |
| 2D M-groups Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py` |
| 2D K-groups Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py` |
| 3D Per-group Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py` |
| Grouped GEMM (2D-3D, 2D-2D) | `benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py` |
| Unified Roofline Analysis | `benchmarks/prototype/moe_training/mxfp8/roofline_unified.py` |

## Benchmark: single MoE layer forward + backward pass

| Model | total_M | N | K | bf16 time (ms) | mxfp8 time (ms) | speedup |
Expand Down
Loading