From e099a71b0a4f94662ba191b196fd72880064d426 Mon Sep 17 00:00:00 2001 From: Kunal Mansukhani Date: Sun, 5 Apr 2026 15:33:14 -0700 Subject: [PATCH 1/2] Add FP4 matmul hard challenge Weight-only FP4 E2M1 quantized matmul (W4A16) with group-wise FP16 scales, the kernel powering low-precision LLM inference on Hopper and Blackwell. Two FP4 values are packed per uint8 byte; each contiguous block of group_size weights along K shares one scale. Co-Authored-By: Claude Opus 4.6 (1M context) --- challenges/hard/86_fp4_matmul/challenge.html | 113 ++++++++++++ challenges/hard/86_fp4_matmul/challenge.py | 174 ++++++++++++++++++ .../hard/86_fp4_matmul/starter/starter.cu | 7 + .../86_fp4_matmul/starter/starter.cute.py | 17 ++ .../hard/86_fp4_matmul/starter/starter.jax.py | 17 ++ .../hard/86_fp4_matmul/starter/starter.mojo | 17 ++ .../86_fp4_matmul/starter/starter.pytorch.py | 15 ++ .../86_fp4_matmul/starter/starter.triton.py | 17 ++ 8 files changed, 377 insertions(+) create mode 100644 challenges/hard/86_fp4_matmul/challenge.html create mode 100644 challenges/hard/86_fp4_matmul/challenge.py create mode 100644 challenges/hard/86_fp4_matmul/starter/starter.cu create mode 100644 challenges/hard/86_fp4_matmul/starter/starter.cute.py create mode 100644 challenges/hard/86_fp4_matmul/starter/starter.jax.py create mode 100644 challenges/hard/86_fp4_matmul/starter/starter.mojo create mode 100644 challenges/hard/86_fp4_matmul/starter/starter.pytorch.py create mode 100644 challenges/hard/86_fp4_matmul/starter/starter.triton.py diff --git a/challenges/hard/86_fp4_matmul/challenge.html b/challenges/hard/86_fp4_matmul/challenge.html new file mode 100644 index 00000000..4653167a --- /dev/null +++ b/challenges/hard/86_fp4_matmul/challenge.html @@ -0,0 +1,113 @@ +

+ Implement an FP4 weight-only quantized matrix multiplication, the kernel at the heart of + modern low-precision LLM inference on Hopper and Blackwell GPUs. Given a float16 activation + matrix x of shape M × K and a weight matrix stored in packed + FP4 E2M1 format, compute y = x × WT of shape + M × N, where W is the dequantized float16 weight matrix of + shape N × K. +

+ +

+ FP4 E2M1 format: Each weight is encoded in 4 bits as + [sign | exponent (2 bits) | mantissa (1 bit)], representing one of sixteen values: + {±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}. + The nibble-to-value mapping is: +

+
+  0x0 =  0.0    0x8 = -0.0
+  0x1 =  0.5    0x9 = -0.5
+  0x2 =  1.0    0xA = -1.0
+  0x3 =  1.5    0xB = -1.5
+  0x4 =  2.0    0xC = -2.0
+  0x5 =  3.0    0xD = -3.0
+  0x6 =  4.0    0xE = -4.0
+  0x7 =  6.0    0xF = -6.0
+
+ +

+ Packing: Each byte of w_q stores two FP4 weights. The high + nibble (bits 7–4) holds w[n, 2i] and the low nibble (bits 3–0) holds + w[n, 2i+1]. +

+ +

+ Dequantization: Weights are dequantized group-wise. Each contiguous block of + group_size weights along the K dimension shares one float16 scale: +

+
+W[n, k] = fp4_decode(w_q_nibble[n, k]) * scales[n, k // group_size]
+
+ +

Implementation Requirements

+ + +

Example

+

+ Input (M = 2, N = 4, K = 4, group_size = 2): +

+

+ Activations \(x\) (float16, \(2 \times 4\)): + \[ + \begin{bmatrix} + 1.0 & 0.0 & 1.0 & 0.0 \\ + 0.0 & 1.0 & 0.0 & 1.0 + \end{bmatrix} + \] + Packed weights \(w\_q\) (uint8, \(4 \times 2\)) decoded via the FP4 E2M1 table: + \[ + \begin{bmatrix} + \texttt{0x22} & \texttt{0x22} \\ + \texttt{0x44} & \texttt{0x44} \\ + \texttt{0xAA} & \texttt{0xAA} \\ + \texttt{0x00} & \texttt{0x00} + \end{bmatrix} + \;\Rightarrow\; + W_{\text{fp4}} = + \begin{bmatrix} + 1.0 & 1.0 & 1.0 & 1.0 \\ + 2.0 & 2.0 & 2.0 & 2.0 \\ + -1.0 & -1.0 & -1.0 & -1.0 \\ + 0.0 & 0.0 & 0.0 & 0.0 + \end{bmatrix} + \] + Scales (float16, \(4 \times 2\), all entries 0.5): + \[ + \begin{bmatrix} + 0.5 & 0.5 \\ + 0.5 & 0.5 \\ + 0.5 & 0.5 \\ + 0.5 & 0.5 + \end{bmatrix} + \;\Rightarrow\; + W_{\text{dequant}} = + \begin{bmatrix} + 0.5 & 0.5 & 0.5 & 0.5 \\ + 1.0 & 1.0 & 1.0 & 1.0 \\ + -0.5 & -0.5 & -0.5 & -0.5 \\ + 0.0 & 0.0 & 0.0 & 0.0 + \end{bmatrix} + \] + Output \(y = x \times W^T\) (float16, \(2 \times 4\)): + \[ + \begin{bmatrix} + 1.0 & 2.0 & -1.0 & 0.0 \\ + 1.0 & 2.0 & -1.0 & 0.0 + \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/hard/86_fp4_matmul/challenge.py b/challenges/hard/86_fp4_matmul/challenge.py new file mode 100644 index 00000000..b07f6a8d --- /dev/null +++ b/challenges/hard/86_fp4_matmul/challenge.py @@ -0,0 +1,174 @@ +import ctypes +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + +# OCP FP4 E2M1 lookup table: 4-bit unsigned index -> float value. +# Bit layout: [sign | exp1 exp0 | mantissa]. Sixteen representable values. +FP4_E2M1_TABLE = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="FP4 MatMul", + atol=5e-02, + rtol=5e-02, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + x: torch.Tensor, + w_q: torch.Tensor, + scales: torch.Tensor, + y: torch.Tensor, + M: int, + N: int, + K: int, + group_size: int, + ): + assert x.shape == (M, K) + assert w_q.shape == (N, K // 2) + assert scales.shape == (N, K // group_size) + assert y.shape == (M, N) + assert x.dtype == torch.float16 + assert w_q.dtype == torch.uint8 + assert scales.dtype == torch.float16 + assert y.dtype == torch.float16 + assert x.device.type == "cuda" + assert w_q.device.type == "cuda" + assert scales.device.type == "cuda" + assert y.device.type == "cuda" + + # Decode packed FP4 E2M1 nibbles via lookup table. + # w_q[n, i] holds two FP4 values: w[n, 2*i] in the high nibble (bits 7:4) + # and w[n, 2*i+1] in the low nibble (bits 3:0). + table = torch.tensor(FP4_E2M1_TABLE, device=x.device, dtype=torch.float32) + high = ((w_q >> 4) & 0xF).to(torch.long) # [N, K//2] + low = (w_q & 0xF).to(torch.long) # [N, K//2] + w_high = table[high] # [N, K//2] + w_low = table[low] # [N, K//2] + w_fp4 = torch.stack([w_high, w_low], dim=-1).reshape(N, K) # [N, K] + + # Apply group-wise FP16 scales: each contiguous block of `group_size` + # weights along K shares one scale. + n_groups = K // group_size + w_groups = w_fp4.reshape(N, n_groups, group_size) + scales_f = scales.float().unsqueeze(-1) # [N, n_groups, 1] + w_dequant = (w_groups * scales_f).reshape(N, K) + + y.copy_((x.float() @ w_dequant.T).half()) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "x": (ctypes.POINTER(ctypes.c_uint16), "in"), + "w_q": (ctypes.POINTER(ctypes.c_uint8), "in"), + "scales": (ctypes.POINTER(ctypes.c_uint16), "in"), + "y": (ctypes.POINTER(ctypes.c_uint16), "out"), + "M": (ctypes.c_int, "in"), + "N": (ctypes.c_int, "in"), + "K": (ctypes.c_int, "in"), + "group_size": (ctypes.c_int, "in"), + } + + def _make_test_case(self, M: int, N: int, K: int, group_size: int, zero_x: bool = False): + device = "cuda" + if zero_x: + x = torch.zeros(M, K, device=device, dtype=torch.float16) + else: + x = torch.randn(M, K, device=device, dtype=torch.float16) * 0.5 + w_q = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device) + scales = torch.rand(N, K // group_size, device=device, dtype=torch.float16) * 0.1 + 0.01 + y = torch.empty(M, N, device=device, dtype=torch.float16) + return { + "x": x, + "w_q": w_q, + "scales": scales, + "y": y, + "M": M, + "N": N, + "K": K, + "group_size": group_size, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + M, N, K, group_size = 2, 4, 4, 2 + + x = torch.tensor( + [[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]], + device=device, + dtype=torch.float16, + ) + # Packed FP4 E2M1 weights (high nibble first). + # Row 0: FP4 [1.0,1.0,1.0,1.0] -> nibbles [0x2,0x2,0x2,0x2] -> bytes [0x22,0x22] = [34,34] + # Row 1: FP4 [2.0,2.0,2.0,2.0] -> nibbles [0x4,0x4,0x4,0x4] -> bytes [0x44,0x44] = [68,68] + # Row 2: FP4 [-1,-1,-1,-1] -> nibbles [0xA,0xA,0xA,0xA] -> bytes [0xAA,0xAA] = [170,170] + # Row 3: FP4 [0.0,0.0,0.0,0.0] -> nibbles [0x0,0x0,0x0,0x0] -> bytes [0x00,0x00] = [0,0] + w_q = torch.tensor( + [[34, 34], [68, 68], [170, 170], [0, 0]], + dtype=torch.uint8, + device=device, + ) + scales = torch.full((N, K // group_size), 0.5, device=device, dtype=torch.float16) + y = torch.empty(M, N, device=device, dtype=torch.float16) + + return { + "x": x, + "w_q": w_q, + "scales": scales, + "y": y, + "M": M, + "N": N, + "K": K, + "group_size": group_size, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge cases with tiny shapes. + tests.append(self._make_test_case(1, 2, 4, 2, zero_x=True)) + tests.append(self._make_test_case(2, 4, 4, 2)) + tests.append(self._make_test_case(3, 5, 8, 4)) + + # Power-of-2 shapes. + tests.append(self._make_test_case(16, 16, 32, 16)) + tests.append(self._make_test_case(32, 64, 64, 32)) + tests.append(self._make_test_case(128, 128, 256, 32)) + + # Non-power-of-2 shapes. + tests.append(self._make_test_case(30, 50, 64, 32)) + tests.append(self._make_test_case(100, 200, 128, 32)) + tests.append(self._make_test_case(255, 100, 128, 32)) + + # Realistic LLM inference shape. + tests.append(self._make_test_case(512, 1024, 1024, 32)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # Matches the FP4 matmul shapes reported in AutoKernel community results. + return self._make_test_case(2048, 8192, 3072, 32) diff --git a/challenges/hard/86_fp4_matmul/starter/starter.cu b/challenges/hard/86_fp4_matmul/starter/starter.cu new file mode 100644 index 00000000..b9f592ab --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.cu @@ -0,0 +1,7 @@ +#include +#include +#include + +// x, w_q, scales, y are device pointers +extern "C" void solve(const __half* x, const uint8_t* w_q, const __half* scales, __half* y, int M, + int N, int K, int group_size) {} diff --git a/challenges/hard/86_fp4_matmul/starter/starter.cute.py b/challenges/hard/86_fp4_matmul/starter/starter.cute.py new file mode 100644 index 00000000..09776ca5 --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.cute.py @@ -0,0 +1,17 @@ +import cutlass +import cutlass.cute as cute + + +# x, w_q, scales, y are tensors on the GPU +@cute.jit +def solve( + x: cute.Tensor, + w_q: cute.Tensor, + scales: cute.Tensor, + y: cute.Tensor, + M: cute.Int32, + N: cute.Int32, + K: cute.Int32, + group_size: cute.Int32, +): + pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.jax.py b/challenges/hard/86_fp4_matmul/starter/starter.jax.py new file mode 100644 index 00000000..ee84481d --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.jax.py @@ -0,0 +1,17 @@ +import jax +import jax.numpy as jnp + + +# x, w_q, scales are tensors on GPU +@jax.jit +def solve( + x: jax.Array, + w_q: jax.Array, + scales: jax.Array, + M: int, + N: int, + K: int, + group_size: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.mojo b/challenges/hard/86_fp4_matmul/starter/starter.mojo new file mode 100644 index 00000000..ca1aebb8 --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.mojo @@ -0,0 +1,17 @@ +from std.gpu.host import DeviceContext +from std.memory import UnsafePointer + + +# x, w_q, scales, y are device pointers +@export +def solve( + x: UnsafePointer[Float16, MutExternalOrigin], + w_q: UnsafePointer[UInt8, MutExternalOrigin], + scales: UnsafePointer[Float16, MutExternalOrigin], + y: UnsafePointer[Float16, MutExternalOrigin], + M: Int32, + N: Int32, + K: Int32, + group_size: Int32, +) raises: + pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py b/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py new file mode 100644 index 00000000..82cb31b4 --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py @@ -0,0 +1,15 @@ +import torch + + +# x, w_q, scales, y are tensors on the GPU +def solve( + x: torch.Tensor, + w_q: torch.Tensor, + scales: torch.Tensor, + y: torch.Tensor, + M: int, + N: int, + K: int, + group_size: int, +): + pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.triton.py b/challenges/hard/86_fp4_matmul/starter/starter.triton.py new file mode 100644 index 00000000..07b1eb84 --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.triton.py @@ -0,0 +1,17 @@ +import torch +import triton +import triton.language as tl + + +# x, w_q, scales, y are tensors on the GPU +def solve( + x: torch.Tensor, + w_q: torch.Tensor, + scales: torch.Tensor, + y: torch.Tensor, + M: int, + N: int, + K: int, + group_size: int, +): + pass From ba161de0e7ee757df323e795856a7766a378d173 Mon Sep 17 00:00:00 2001 From: Kunal Mansukhani Date: Sun, 5 Apr 2026 15:47:23 -0700 Subject: [PATCH 2/2] Rewrite FP4 matmul as NVFP4 FP4xFP4 GEMM Restructures the challenge so a submission directly verifies AutoKernel's FP4 matmul claim (Table 5 of the paper): both operands are packed FP4 E2M1 with E4M3 per-block scales and a per-tensor FP32 alpha, matching the NVFP4 layout used by CUTLASS and qutlass. Previous revision was W4A16 weight-only quant, which cannot reach the TF/s regime the paper reports because x was still FP16. Key changes: - Both x and w are packed FP4 uint8 (nibbles); block size = 16. - Scales are raw E4M3 bytes (torch.float8_e4m3fn bit patterns). - Reference is a pure FP32 dequant + matmul oracle. - Performance shape (M=2048, N=18432, K=3072) taken verbatim from the Triton vs CUTLASS row in Table 5 so TF/s is directly comparable. - Tolerances loosened to atol=0.1, rtol=0.05 to admit FP16 accumulation used by tensor-core paths. Co-Authored-By: Claude Opus 4.6 (1M context) --- challenges/hard/86_fp4_matmul/challenge.html | 111 ++++++------ challenges/hard/86_fp4_matmul/challenge.py | 171 +++++++++++------- .../hard/86_fp4_matmul/starter/starter.cu | 6 +- .../86_fp4_matmul/starter/starter.cute.py | 9 +- .../hard/86_fp4_matmul/starter/starter.jax.py | 9 +- .../hard/86_fp4_matmul/starter/starter.mojo | 9 +- .../86_fp4_matmul/starter/starter.pytorch.py | 9 +- .../86_fp4_matmul/starter/starter.triton.py | 9 +- 8 files changed, 187 insertions(+), 146 deletions(-) diff --git a/challenges/hard/86_fp4_matmul/challenge.html b/challenges/hard/86_fp4_matmul/challenge.html index 4653167a..bb29afd2 100644 --- a/challenges/hard/86_fp4_matmul/challenge.html +++ b/challenges/hard/86_fp4_matmul/challenge.html @@ -1,15 +1,17 @@

- Implement an FP4 weight-only quantized matrix multiplication, the kernel at the heart of - modern low-precision LLM inference on Hopper and Blackwell GPUs. Given a float16 activation - matrix x of shape M × K and a weight matrix stored in packed - FP4 E2M1 format, compute y = x × WT of shape - M × N, where W is the dequantized float16 weight matrix of - shape N × K. + Implement an NVFP4 matrix multiplication, the low-precision GEMM that powers + state-of-the-art LLM inference on Hopper and Blackwell GPUs. Both operands are stored in 4-bit + floating point (FP4 E2M1) with per-block FP8 (E4M3) scales along the reduction dimension, plus + a single per-tensor FP32 scale. Given packed activations x_q of shape + M × K, packed weights w_q of shape N × K, + and their respective block scales, compute + y = alpha × (x × wT) of shape M × N + in float16.

- FP4 E2M1 format: Each weight is encoded in 4 bits as - [sign | exponent (2 bits) | mantissa (1 bit)], representing one of sixteen values: + FP4 E2M1 encoding: Each weight is 4 bits + [sign | exp (2 bits) | mantissa (1 bit)] representing one of sixteen values: {±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}. The nibble-to-value mapping is:

@@ -25,89 +27,92 @@

- Packing: Each byte of w_q stores two FP4 weights. The high - nibble (bits 7–4) holds w[n, 2i] and the low nibble (bits 3–0) holds - w[n, 2i+1]. + Packing: Each byte of x_q / w_q stores two FP4 + values. The high nibble (bits 7–4) holds the even-index value and the low nibble + (bits 3–0) holds the odd-index value.

- Dequantization: Weights are dequantized group-wise. Each contiguous block of - group_size weights along the K dimension shares one float16 scale: + Block scales: Each contiguous block of 16 FP4 values along + the K dimension shares one E4M3 (float8) scale. The scale tensors + x_scales and w_scales are passed as raw uint8 bytes holding the + E4M3 bit patterns. Dequantization is:

-W[n, k] = fp4_decode(w_q_nibble[n, k]) * scales[n, k // group_size]
+x[m, k] = fp4_decode(x_q_nibble[m, k]) * e4m3_decode(x_scales[m, k // 16])
+w[n, k] = fp4_decode(w_q_nibble[n, k]) * e4m3_decode(w_scales[n, k // 16])
+y[m, n] = alpha * sum_k x[m, k] * w[n, k]
 

Implementation Requirements

  • Use only native features (external libraries are not permitted)
  • The solve function signature must remain unchanged
  • -
  • The final result must be stored in y
  • +
  • The final result must be stored in y as float16

Example

- Input (M = 2, N = 4, K = 4, group_size = 2): + Input (M = 2, N = 2, K = 16, alpha = 1.0):

- Activations \(x\) (float16, \(2 \times 4\)): + Packed activations \(x\_q\) (uint8, \(2 \times 8\)) and decoded FP4 values (each row has + sixteen values): \[ + x\_q = \begin{bmatrix} - 1.0 & 0.0 & 1.0 & 0.0 \\ - 0.0 & 1.0 & 0.0 & 1.0 - \end{bmatrix} - \] - Packed weights \(w\_q\) (uint8, \(4 \times 2\)) decoded via the FP4 E2M1 table: - \[ - \begin{bmatrix} - \texttt{0x22} & \texttt{0x22} \\ - \texttt{0x44} & \texttt{0x44} \\ - \texttt{0xAA} & \texttt{0xAA} \\ - \texttt{0x00} & \texttt{0x00} + \texttt{0x22} & \cdots & \texttt{0x22} \\ + \texttt{0x11} & \cdots & \texttt{0x11} \end{bmatrix} \;\Rightarrow\; - W_{\text{fp4}} = + x_{\text{fp4}} = \begin{bmatrix} - 1.0 & 1.0 & 1.0 & 1.0 \\ - 2.0 & 2.0 & 2.0 & 2.0 \\ - -1.0 & -1.0 & -1.0 & -1.0 \\ - 0.0 & 0.0 & 0.0 & 0.0 + 1.0 & 1.0 & \cdots & 1.0 \\ + 0.5 & 0.5 & \cdots & 0.5 \end{bmatrix} \] - Scales (float16, \(4 \times 2\), all entries 0.5): + Packed weights \(w\_q\) (uint8, \(2 \times 8\)): \[ + w\_q = \begin{bmatrix} - 0.5 & 0.5 \\ - 0.5 & 0.5 \\ - 0.5 & 0.5 \\ - 0.5 & 0.5 + \texttt{0x44} & \cdots & \texttt{0x44} \\ + \texttt{0xAA} & \cdots & \texttt{0xAA} \end{bmatrix} \;\Rightarrow\; - W_{\text{dequant}} = + w_{\text{fp4}} = \begin{bmatrix} - 0.5 & 0.5 & 0.5 & 0.5 \\ - 1.0 & 1.0 & 1.0 & 1.0 \\ - -0.5 & -0.5 & -0.5 & -0.5 \\ - 0.0 & 0.0 & 0.0 & 0.0 + 2.0 & 2.0 & \cdots & 2.0 \\ + -1.0 & -1.0 & \cdots & -1.0 \end{bmatrix} \] - Output \(y = x \times W^T\) (float16, \(2 \times 4\)): + Block scales (one block per row since K = 16): both + x_scales and w_scales are uint8 \(2 \times 1\) with every byte + equal to 0x38, which is the E4M3 bit pattern for 1.0. The dequantized operands + therefore equal the FP4 values above. +

+

+ Output \(y = \alpha \cdot (x \times w^T)\) (float16, \(2 \times 2\)): \[ \begin{bmatrix} - 1.0 & 2.0 & -1.0 & 0.0 \\ - 1.0 & 2.0 & -1.0 & 0.0 + \sum 1.0 \cdot 2.0 & \sum 1.0 \cdot (-1.0) \\ + \sum 0.5 \cdot 2.0 & \sum 0.5 \cdot (-1.0) + \end{bmatrix} + = + \begin{bmatrix} + 32.0 & -16.0 \\ + 16.0 & -8.0 \end{bmatrix} \]

Constraints

    -
  • 1 ≤ M, N ≤ 8,192
  • -
  • 1 ≤ K ≤ 8,192
  • -
  • K is divisible by 2 and by group_size
  • -
  • group_size ∈ {2, 4, 8, 16, 32}
  • +
  • 1 ≤ M, N ≤ 32,768
  • +
  • 16 ≤ K ≤ 32,768
  • +
  • K is divisible by 16 (the NVFP4 block size)
  • All tensors are stored in row-major order
  • -
  • Input dtype: x and scales are float16; w_q is uint8
  • -
  • Output dtype: y is float16
  • -
  • Performance is measured with M = 2,048, N = 8,192, K = 3,072, group_size = 32
  • +
  • Inputs: x_q, w_q, x_scales, w_scales + are uint8; alpha is float32
  • +
  • Output: y is float16
  • +
  • Performance is measured with M = 2,048, N = 18,432, K = 3,072
diff --git a/challenges/hard/86_fp4_matmul/challenge.py b/challenges/hard/86_fp4_matmul/challenge.py index b07f6a8d..1376cdac 100644 --- a/challenges/hard/86_fp4_matmul/challenge.py +++ b/challenges/hard/86_fp4_matmul/challenge.py @@ -5,7 +5,7 @@ from core.challenge_base import ChallengeBase # OCP FP4 E2M1 lookup table: 4-bit unsigned index -> float value. -# Bit layout: [sign | exp1 exp0 | mantissa]. Sixteen representable values. +# Bit layout: [sign | exp1 exp0 | mantissa]. FP4_E2M1_TABLE = [ 0.0, 0.5, @@ -25,12 +25,27 @@ -6.0, ] +# NVFP4 block size along the reduction dimension. Each block of 16 FP4 +# values shares one E4M3 scale. Matches CUTLASS / qutlass NVFP4 layout. +BLOCK_SIZE = 16 + + +def _decode_fp4_packed(packed: torch.Tensor, rows: int, cols: int) -> torch.Tensor: + """Decode a (rows, cols/2) uint8 tensor of packed FP4 E2M1 nibbles into + a (rows, cols) float32 tensor. High nibble stores the even-index value, + low nibble stores the odd-index value.""" + table = torch.tensor(FP4_E2M1_TABLE, device=packed.device, dtype=torch.float32) + high = ((packed >> 4) & 0xF).to(torch.long) + low = (packed & 0xF).to(torch.long) + decoded = torch.stack([table[high], table[low]], dim=-1).reshape(rows, cols) + return decoded + class Challenge(ChallengeBase): def __init__(self): super().__init__( - name="FP4 MatMul", - atol=5e-02, + name="NVFP4 Matrix Multiplication", + atol=1e-01, rtol=5e-02, num_gpus=1, access_tier="free", @@ -38,137 +53,153 @@ def __init__(self): def reference_impl( self, - x: torch.Tensor, + x_q: torch.Tensor, + x_scales: torch.Tensor, w_q: torch.Tensor, - scales: torch.Tensor, + w_scales: torch.Tensor, + alpha: float, y: torch.Tensor, M: int, N: int, K: int, - group_size: int, ): - assert x.shape == (M, K) + assert K % BLOCK_SIZE == 0, "K must be divisible by 16 (NVFP4 block size)" + assert x_q.shape == (M, K // 2) + assert x_scales.shape == (M, K // BLOCK_SIZE) assert w_q.shape == (N, K // 2) - assert scales.shape == (N, K // group_size) + assert w_scales.shape == (N, K // BLOCK_SIZE) assert y.shape == (M, N) - assert x.dtype == torch.float16 + assert x_q.dtype == torch.uint8 assert w_q.dtype == torch.uint8 - assert scales.dtype == torch.float16 + assert x_scales.dtype == torch.uint8 + assert w_scales.dtype == torch.uint8 assert y.dtype == torch.float16 - assert x.device.type == "cuda" + assert x_q.device.type == "cuda" + assert x_scales.device.type == "cuda" assert w_q.device.type == "cuda" - assert scales.device.type == "cuda" + assert w_scales.device.type == "cuda" assert y.device.type == "cuda" - # Decode packed FP4 E2M1 nibbles via lookup table. - # w_q[n, i] holds two FP4 values: w[n, 2*i] in the high nibble (bits 7:4) - # and w[n, 2*i+1] in the low nibble (bits 3:0). - table = torch.tensor(FP4_E2M1_TABLE, device=x.device, dtype=torch.float32) - high = ((w_q >> 4) & 0xF).to(torch.long) # [N, K//2] - low = (w_q & 0xF).to(torch.long) # [N, K//2] - w_high = table[high] # [N, K//2] - w_low = table[low] # [N, K//2] - w_fp4 = torch.stack([w_high, w_low], dim=-1).reshape(N, K) # [N, K] - - # Apply group-wise FP16 scales: each contiguous block of `group_size` - # weights along K shares one scale. - n_groups = K // group_size - w_groups = w_fp4.reshape(N, n_groups, group_size) - scales_f = scales.float().unsqueeze(-1) # [N, n_groups, 1] - w_dequant = (w_groups * scales_f).reshape(N, K) - - y.copy_((x.float() @ w_dequant.T).half()) + # Decode packed FP4 operands to float32. + x_fp4 = _decode_fp4_packed(x_q, M, K) + w_fp4 = _decode_fp4_packed(w_q, N, K) + + # Decode E4M3 per-block scales to float32. + xs = x_scales.view(torch.float8_e4m3fn).float() # (M, K/16) + ws = w_scales.view(torch.float8_e4m3fn).float() # (N, K/16) + + # Apply per-block scales along the reduction dimension. + n_blocks = K // BLOCK_SIZE + x_dq = (x_fp4.reshape(M, n_blocks, BLOCK_SIZE) * xs.unsqueeze(-1)).reshape(M, K) + w_dq = (w_fp4.reshape(N, n_blocks, BLOCK_SIZE) * ws.unsqueeze(-1)).reshape(N, K) + + # NVFP4 matmul: y = alpha * (x @ w^T), result cast to FP16. + out = float(alpha) * (x_dq @ w_dq.T) + y.copy_(out.half()) def get_solve_signature(self) -> Dict[str, tuple]: return { - "x": (ctypes.POINTER(ctypes.c_uint16), "in"), + "x_q": (ctypes.POINTER(ctypes.c_uint8), "in"), + "x_scales": (ctypes.POINTER(ctypes.c_uint8), "in"), "w_q": (ctypes.POINTER(ctypes.c_uint8), "in"), - "scales": (ctypes.POINTER(ctypes.c_uint16), "in"), + "w_scales": (ctypes.POINTER(ctypes.c_uint8), "in"), + "alpha": (ctypes.c_float, "in"), "y": (ctypes.POINTER(ctypes.c_uint16), "out"), "M": (ctypes.c_int, "in"), "N": (ctypes.c_int, "in"), "K": (ctypes.c_int, "in"), - "group_size": (ctypes.c_int, "in"), } - def _make_test_case(self, M: int, N: int, K: int, group_size: int, zero_x: bool = False): + def _make_test_case(self, M: int, N: int, K: int, zero_x: bool = False, alpha: float = 1.0): + assert K % BLOCK_SIZE == 0, "K must be divisible by 16" device = "cuda" if zero_x: - x = torch.zeros(M, K, device=device, dtype=torch.float16) + x_q = torch.zeros(M, K // 2, dtype=torch.uint8, device=device) else: - x = torch.randn(M, K, device=device, dtype=torch.float16) * 0.5 + x_q = torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device) w_q = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device) - scales = torch.rand(N, K // group_size, device=device, dtype=torch.float16) * 0.1 + 0.01 + + # Positive E4M3 block scales in a representable range. + x_scales_f = torch.rand(M, K // BLOCK_SIZE, device=device) * 1.5 + 0.5 + w_scales_f = torch.rand(N, K // BLOCK_SIZE, device=device) * 1.5 + 0.5 + x_scales = x_scales_f.to(torch.float8_e4m3fn).view(torch.uint8) + w_scales = w_scales_f.to(torch.float8_e4m3fn).view(torch.uint8) + y = torch.empty(M, N, device=device, dtype=torch.float16) return { - "x": x, + "x_q": x_q, + "x_scales": x_scales, "w_q": w_q, - "scales": scales, + "w_scales": w_scales, + "alpha": alpha, "y": y, "M": M, "N": N, "K": K, - "group_size": group_size, } def generate_example_test(self) -> Dict[str, Any]: device = "cuda" - M, N, K, group_size = 2, 4, 4, 2 + M, N, K = 2, 2, 16 - x = torch.tensor( - [[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]], + # x row 0: sixteen FP4 values = 1.0 (nibble 0x2) -> 8 bytes of 0x22 + # x row 1: sixteen FP4 values = 0.5 (nibble 0x1) -> 8 bytes of 0x11 + x_q = torch.tensor( + [[0x22] * 8, [0x11] * 8], + dtype=torch.uint8, device=device, - dtype=torch.float16, ) - # Packed FP4 E2M1 weights (high nibble first). - # Row 0: FP4 [1.0,1.0,1.0,1.0] -> nibbles [0x2,0x2,0x2,0x2] -> bytes [0x22,0x22] = [34,34] - # Row 1: FP4 [2.0,2.0,2.0,2.0] -> nibbles [0x4,0x4,0x4,0x4] -> bytes [0x44,0x44] = [68,68] - # Row 2: FP4 [-1,-1,-1,-1] -> nibbles [0xA,0xA,0xA,0xA] -> bytes [0xAA,0xAA] = [170,170] - # Row 3: FP4 [0.0,0.0,0.0,0.0] -> nibbles [0x0,0x0,0x0,0x0] -> bytes [0x00,0x00] = [0,0] + # w row 0: sixteen FP4 values = 2.0 (nibble 0x4) -> 8 bytes of 0x44 + # w row 1: sixteen FP4 values = -1.0 (nibble 0xA) -> 8 bytes of 0xAA w_q = torch.tensor( - [[34, 34], [68, 68], [170, 170], [0, 0]], + [[0x44] * 8, [0xAA] * 8], dtype=torch.uint8, device=device, ) - scales = torch.full((N, K // group_size), 0.5, device=device, dtype=torch.float16) - y = torch.empty(M, N, device=device, dtype=torch.float16) + # All block scales = E4M3 1.0 = 0x38. One block per row (K=16). + x_scales = torch.full((M, K // BLOCK_SIZE), 0x38, dtype=torch.uint8, device=device) + w_scales = torch.full((N, K // BLOCK_SIZE), 0x38, dtype=torch.uint8, device=device) + y = torch.empty(M, N, device=device, dtype=torch.float16) return { - "x": x, + "x_q": x_q, + "x_scales": x_scales, "w_q": w_q, - "scales": scales, + "w_scales": w_scales, + "alpha": 1.0, "y": y, "M": M, "N": N, "K": K, - "group_size": group_size, } def generate_functional_test(self) -> List[Dict[str, Any]]: torch.manual_seed(42) tests = [] - # Edge cases with tiny shapes. - tests.append(self._make_test_case(1, 2, 4, 2, zero_x=True)) - tests.append(self._make_test_case(2, 4, 4, 2)) - tests.append(self._make_test_case(3, 5, 8, 4)) + # Edge cases with the minimum K = 16 (one block). + tests.append(self._make_test_case(1, 2, 16, zero_x=True)) + tests.append(self._make_test_case(2, 4, 16)) + tests.append(self._make_test_case(3, 5, 32)) # Power-of-2 shapes. - tests.append(self._make_test_case(16, 16, 32, 16)) - tests.append(self._make_test_case(32, 64, 64, 32)) - tests.append(self._make_test_case(128, 128, 256, 32)) + tests.append(self._make_test_case(16, 16, 32)) + tests.append(self._make_test_case(32, 64, 64)) + tests.append(self._make_test_case(128, 128, 256)) - # Non-power-of-2 shapes. - tests.append(self._make_test_case(30, 50, 64, 32)) - tests.append(self._make_test_case(100, 200, 128, 32)) - tests.append(self._make_test_case(255, 100, 128, 32)) + # Non-power-of-2 leading dims with valid K. + tests.append(self._make_test_case(30, 50, 64)) + tests.append(self._make_test_case(100, 200, 128)) + tests.append(self._make_test_case(255, 100, 128, alpha=0.125)) - # Realistic LLM inference shape. - tests.append(self._make_test_case(512, 1024, 1024, 32)) + # Realistic attention-projection shape. + tests.append(self._make_test_case(512, 1024, 1024, alpha=1.0 / 64.0)) return tests def generate_performance_test(self) -> Dict[str, Any]: torch.manual_seed(0) - # Matches the FP4 matmul shapes reported in AutoKernel community results. - return self._make_test_case(2048, 8192, 3072, 32) + # Verbatim from AutoKernel Table 5: the row where Triton hit 2,898 TF/s + # against CUTLASS's 1,777 TF/s (1.63x speedup). A correct FP4 tensor + # core submission at this shape directly validates the paper's claim. + return self._make_test_case(2048, 18432, 3072, alpha=1.0 / 64.0) diff --git a/challenges/hard/86_fp4_matmul/starter/starter.cu b/challenges/hard/86_fp4_matmul/starter/starter.cu index b9f592ab..8faee62a 100644 --- a/challenges/hard/86_fp4_matmul/starter/starter.cu +++ b/challenges/hard/86_fp4_matmul/starter/starter.cu @@ -2,6 +2,6 @@ #include #include -// x, w_q, scales, y are device pointers -extern "C" void solve(const __half* x, const uint8_t* w_q, const __half* scales, __half* y, int M, - int N, int K, int group_size) {} +// x_q, x_scales, w_q, w_scales, y are device pointers +extern "C" void solve(const uint8_t* x_q, const uint8_t* x_scales, const uint8_t* w_q, + const uint8_t* w_scales, float alpha, __half* y, int M, int N, int K) {} diff --git a/challenges/hard/86_fp4_matmul/starter/starter.cute.py b/challenges/hard/86_fp4_matmul/starter/starter.cute.py index 09776ca5..07ca9580 100644 --- a/challenges/hard/86_fp4_matmul/starter/starter.cute.py +++ b/challenges/hard/86_fp4_matmul/starter/starter.cute.py @@ -2,16 +2,17 @@ import cutlass.cute as cute -# x, w_q, scales, y are tensors on the GPU +# x_q, x_scales, w_q, w_scales, y are tensors on the GPU @cute.jit def solve( - x: cute.Tensor, + x_q: cute.Tensor, + x_scales: cute.Tensor, w_q: cute.Tensor, - scales: cute.Tensor, + w_scales: cute.Tensor, + alpha: cute.Float32, y: cute.Tensor, M: cute.Int32, N: cute.Int32, K: cute.Int32, - group_size: cute.Int32, ): pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.jax.py b/challenges/hard/86_fp4_matmul/starter/starter.jax.py index ee84481d..a490abc6 100644 --- a/challenges/hard/86_fp4_matmul/starter/starter.jax.py +++ b/challenges/hard/86_fp4_matmul/starter/starter.jax.py @@ -2,16 +2,17 @@ import jax.numpy as jnp -# x, w_q, scales are tensors on GPU +# x_q, x_scales, w_q, w_scales are tensors on GPU @jax.jit def solve( - x: jax.Array, + x_q: jax.Array, + x_scales: jax.Array, w_q: jax.Array, - scales: jax.Array, + w_scales: jax.Array, + alpha: float, M: int, N: int, K: int, - group_size: int, ) -> jax.Array: # return output tensor directly pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.mojo b/challenges/hard/86_fp4_matmul/starter/starter.mojo index ca1aebb8..8e2d31ee 100644 --- a/challenges/hard/86_fp4_matmul/starter/starter.mojo +++ b/challenges/hard/86_fp4_matmul/starter/starter.mojo @@ -2,16 +2,17 @@ from std.gpu.host import DeviceContext from std.memory import UnsafePointer -# x, w_q, scales, y are device pointers +# x_q, x_scales, w_q, w_scales, y are device pointers @export def solve( - x: UnsafePointer[Float16, MutExternalOrigin], + x_q: UnsafePointer[UInt8, MutExternalOrigin], + x_scales: UnsafePointer[UInt8, MutExternalOrigin], w_q: UnsafePointer[UInt8, MutExternalOrigin], - scales: UnsafePointer[Float16, MutExternalOrigin], + w_scales: UnsafePointer[UInt8, MutExternalOrigin], + alpha: Float32, y: UnsafePointer[Float16, MutExternalOrigin], M: Int32, N: Int32, K: Int32, - group_size: Int32, ) raises: pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py b/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py index 82cb31b4..9ca0eeaf 100644 --- a/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py +++ b/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py @@ -1,15 +1,16 @@ import torch -# x, w_q, scales, y are tensors on the GPU +# x_q, x_scales, w_q, w_scales, y are tensors on the GPU def solve( - x: torch.Tensor, + x_q: torch.Tensor, + x_scales: torch.Tensor, w_q: torch.Tensor, - scales: torch.Tensor, + w_scales: torch.Tensor, + alpha: float, y: torch.Tensor, M: int, N: int, K: int, - group_size: int, ): pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.triton.py b/challenges/hard/86_fp4_matmul/starter/starter.triton.py index 07b1eb84..97d2730a 100644 --- a/challenges/hard/86_fp4_matmul/starter/starter.triton.py +++ b/challenges/hard/86_fp4_matmul/starter/starter.triton.py @@ -3,15 +3,16 @@ import triton.language as tl -# x, w_q, scales, y are tensors on the GPU +# x_q, x_scales, w_q, w_scales, y are tensors on the GPU def solve( - x: torch.Tensor, + x_q: torch.Tensor, + x_scales: torch.Tensor, w_q: torch.Tensor, - scales: torch.Tensor, + w_scales: torch.Tensor, + alpha: float, y: torch.Tensor, M: int, N: int, K: int, - group_size: int, ): pass