Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,11 @@ mxfp8_quantize(const at::Tensor& input, bool rowwise, bool colwise,
if (colwise) {
const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y;
output_colwise = at::empty_strided({rows, cols}, {1, rows}, options_fp8);
// Need scales_colwise to be this shape so the 'col' dim stride is 1,
// for colwise scaling, we can avoid uncoalesced writes to global memory.
// This is because each of the 32 threads in a warp will be computing
// a scale for a different column of 32 input data values, then each writing
// that scale to global memory - so the stride along this `col` dim should be 1
// so writes can be coalesced into a single transaction.
scales_colwise = at::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale);

// Accept uncoalesced global stores for scale tensor, since row major is much for favorable for the subsequent
// per-group blocked format kernel.
// Microbenchmarks show the memory bandwidth utilization is virtually identical to coalesced global stores.
scales_colwise = at::empty({cols, num_row_blocks}, options_scale);
} else {
output_colwise = at::empty({0}, options_fp8);
scales_colwise = at::empty({0}, options_scale);
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/moe_training/kernels/mxfp8/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
mx_block_rearrange_2d_K_groups_cuda, # noqa: F401
mxfp8_quantize_cuda_3d, # noqa: F401
torch_to_blocked_2d_K_groups, # noqa: F401
torch_to_blocked_2d_M_groups, # noqa: F401
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
triton_fp8_rowwise_3d_transpose_rhs,
)
from torchao.prototype.moe_training.kernels.mxfp8 import (
mx_block_rearrange_2d_K_groups_cuda,
mxfp8_quantize_cuda_3d,
triton_mx_block_rearrange_2d_K_groups,
triton_mx_block_rearrange_2d_M_groups,
triton_mx_block_rearrange_per_group_3d,
)
Expand Down Expand Up @@ -449,11 +449,11 @@ def backward(ctx, grad_out: torch.Tensor):

# Convert scales to blocked format for 2d-2d grouped mm
scale_group_offsets = offs // block_size
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
grad_out_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda(
grad_out_t_scales,
scale_group_offsets,
)
A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
A_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda(
A_t_scales,
scale_group_offsets,
)
Expand Down
5 changes: 2 additions & 3 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,10 +1219,9 @@ def _fake_mxfp8_quantize(
(rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device
)

# colwise scales are written in column-major format to avoid uncoalesced global memory accesses
scales_colwise = torch.empty_strided(
# and microb
scales_colwise = torch.empty(
(cols, num_row_blocks),
(1, cols),
dtype=torch.float8_e8m0fnu,
device=x.device,
)
Expand Down
Loading