diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp index 3b49b4c1c0..23b60d9fd7 100644 --- a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp @@ -114,13 +114,11 @@ mxfp8_quantize(const at::Tensor& input, bool rowwise, bool colwise, if (colwise) { const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y; output_colwise = at::empty_strided({rows, cols}, {1, rows}, options_fp8); - // Need scales_colwise to be this shape so the 'col' dim stride is 1, - // for colwise scaling, we can avoid uncoalesced writes to global memory. - // This is because each of the 32 threads in a warp will be computing - // a scale for a different column of 32 input data values, then each writing - // that scale to global memory - so the stride along this `col` dim should be 1 - // so writes can be coalesced into a single transaction. - scales_colwise = at::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale); + + // Accept uncoalesced global stores for scale tensor, since row major is much for favorable for the subsequent + // per-group blocked format kernel. + // Microbenchmarks show the memory bandwidth utilization is virtually identical to coalesced global stores. + scales_colwise = at::empty({cols, num_row_blocks}, options_scale); } else { output_colwise = at::empty({0}, options_fp8); scales_colwise = at::empty({0}, options_scale); diff --git a/torchao/prototype/moe_training/kernels/mxfp8/__init__.py b/torchao/prototype/moe_training/kernels/mxfp8/__init__.py index 69bff9e1a0..d7d41aba1e 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/__init__.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/__init__.py @@ -1,4 +1,5 @@ from torchao.prototype.moe_training.kernels.mxfp8.quant import ( + mx_block_rearrange_2d_K_groups_cuda, # noqa: F401 mxfp8_quantize_cuda_3d, # noqa: F401 torch_to_blocked_2d_K_groups, # noqa: F401 torch_to_blocked_2d_M_groups, # noqa: F401 diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index accc8853cb..98a5fd39e5 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -17,8 +17,8 @@ triton_fp8_rowwise_3d_transpose_rhs, ) from torchao.prototype.moe_training.kernels.mxfp8 import ( + mx_block_rearrange_2d_K_groups_cuda, mxfp8_quantize_cuda_3d, - triton_mx_block_rearrange_2d_K_groups, triton_mx_block_rearrange_2d_M_groups, triton_mx_block_rearrange_per_group_3d, ) @@ -449,11 +449,11 @@ def backward(ctx, grad_out: torch.Tensor): # Convert scales to blocked format for 2d-2d grouped mm scale_group_offsets = offs // block_size - grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups( + grad_out_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda( grad_out_t_scales, scale_group_offsets, ) - A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups( + A_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda( A_t_scales, scale_group_offsets, ) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 236aa3db53..061502060e 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1219,10 +1219,9 @@ def _fake_mxfp8_quantize( (rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device ) - # colwise scales are written in column-major format to avoid uncoalesced global memory accesses - scales_colwise = torch.empty_strided( + # and microb + scales_colwise = torch.empty( (cols, num_row_blocks), - (1, cols), dtype=torch.float8_e8m0fnu, device=x.device, )