diff --git a/challenges/hard/86_fp4_matmul/challenge.html b/challenges/hard/86_fp4_matmul/challenge.html new file mode 100644 index 00000000..bb29afd2 --- /dev/null +++ b/challenges/hard/86_fp4_matmul/challenge.html @@ -0,0 +1,118 @@ +

+ Implement an NVFP4 matrix multiplication, the low-precision GEMM that powers + state-of-the-art LLM inference on Hopper and Blackwell GPUs. Both operands are stored in 4-bit + floating point (FP4 E2M1) with per-block FP8 (E4M3) scales along the reduction dimension, plus + a single per-tensor FP32 scale. Given packed activations x_q of shape + M × K, packed weights w_q of shape N × K, + and their respective block scales, compute + y = alpha × (x × wT) of shape M × N + in float16. +

+ +

+ FP4 E2M1 encoding: Each weight is 4 bits + [sign | exp (2 bits) | mantissa (1 bit)] representing one of sixteen values: + {±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}. + The nibble-to-value mapping is: +

+
+  0x0 =  0.0    0x8 = -0.0
+  0x1 =  0.5    0x9 = -0.5
+  0x2 =  1.0    0xA = -1.0
+  0x3 =  1.5    0xB = -1.5
+  0x4 =  2.0    0xC = -2.0
+  0x5 =  3.0    0xD = -3.0
+  0x6 =  4.0    0xE = -4.0
+  0x7 =  6.0    0xF = -6.0
+
+ +

+ Packing: Each byte of x_q / w_q stores two FP4 + values. The high nibble (bits 7–4) holds the even-index value and the low nibble + (bits 3–0) holds the odd-index value. +

+ +

+ Block scales: Each contiguous block of 16 FP4 values along + the K dimension shares one E4M3 (float8) scale. The scale tensors + x_scales and w_scales are passed as raw uint8 bytes holding the + E4M3 bit patterns. Dequantization is: +

+
+x[m, k] = fp4_decode(x_q_nibble[m, k]) * e4m3_decode(x_scales[m, k // 16])
+w[n, k] = fp4_decode(w_q_nibble[n, k]) * e4m3_decode(w_scales[n, k // 16])
+y[m, n] = alpha * sum_k x[m, k] * w[n, k]
+
+ +

Implementation Requirements

+ + +

Example

+

+ Input (M = 2, N = 2, K = 16, alpha = 1.0): +

+

+ Packed activations \(x\_q\) (uint8, \(2 \times 8\)) and decoded FP4 values (each row has + sixteen values): + \[ + x\_q = + \begin{bmatrix} + \texttt{0x22} & \cdots & \texttt{0x22} \\ + \texttt{0x11} & \cdots & \texttt{0x11} + \end{bmatrix} + \;\Rightarrow\; + x_{\text{fp4}} = + \begin{bmatrix} + 1.0 & 1.0 & \cdots & 1.0 \\ + 0.5 & 0.5 & \cdots & 0.5 + \end{bmatrix} + \] + Packed weights \(w\_q\) (uint8, \(2 \times 8\)): + \[ + w\_q = + \begin{bmatrix} + \texttt{0x44} & \cdots & \texttt{0x44} \\ + \texttt{0xAA} & \cdots & \texttt{0xAA} + \end{bmatrix} + \;\Rightarrow\; + w_{\text{fp4}} = + \begin{bmatrix} + 2.0 & 2.0 & \cdots & 2.0 \\ + -1.0 & -1.0 & \cdots & -1.0 + \end{bmatrix} + \] + Block scales (one block per row since K = 16): both + x_scales and w_scales are uint8 \(2 \times 1\) with every byte + equal to 0x38, which is the E4M3 bit pattern for 1.0. The dequantized operands + therefore equal the FP4 values above. +

+

+ Output \(y = \alpha \cdot (x \times w^T)\) (float16, \(2 \times 2\)): + \[ + \begin{bmatrix} + \sum 1.0 \cdot 2.0 & \sum 1.0 \cdot (-1.0) \\ + \sum 0.5 \cdot 2.0 & \sum 0.5 \cdot (-1.0) + \end{bmatrix} + = + \begin{bmatrix} + 32.0 & -16.0 \\ + 16.0 & -8.0 + \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/hard/86_fp4_matmul/challenge.py b/challenges/hard/86_fp4_matmul/challenge.py new file mode 100644 index 00000000..1376cdac --- /dev/null +++ b/challenges/hard/86_fp4_matmul/challenge.py @@ -0,0 +1,205 @@ +import ctypes +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + +# OCP FP4 E2M1 lookup table: 4-bit unsigned index -> float value. +# Bit layout: [sign | exp1 exp0 | mantissa]. +FP4_E2M1_TABLE = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + +# NVFP4 block size along the reduction dimension. Each block of 16 FP4 +# values shares one E4M3 scale. Matches CUTLASS / qutlass NVFP4 layout. +BLOCK_SIZE = 16 + + +def _decode_fp4_packed(packed: torch.Tensor, rows: int, cols: int) -> torch.Tensor: + """Decode a (rows, cols/2) uint8 tensor of packed FP4 E2M1 nibbles into + a (rows, cols) float32 tensor. High nibble stores the even-index value, + low nibble stores the odd-index value.""" + table = torch.tensor(FP4_E2M1_TABLE, device=packed.device, dtype=torch.float32) + high = ((packed >> 4) & 0xF).to(torch.long) + low = (packed & 0xF).to(torch.long) + decoded = torch.stack([table[high], table[low]], dim=-1).reshape(rows, cols) + return decoded + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="NVFP4 Matrix Multiplication", + atol=1e-01, + rtol=5e-02, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + x_q: torch.Tensor, + x_scales: torch.Tensor, + w_q: torch.Tensor, + w_scales: torch.Tensor, + alpha: float, + y: torch.Tensor, + M: int, + N: int, + K: int, + ): + assert K % BLOCK_SIZE == 0, "K must be divisible by 16 (NVFP4 block size)" + assert x_q.shape == (M, K // 2) + assert x_scales.shape == (M, K // BLOCK_SIZE) + assert w_q.shape == (N, K // 2) + assert w_scales.shape == (N, K // BLOCK_SIZE) + assert y.shape == (M, N) + assert x_q.dtype == torch.uint8 + assert w_q.dtype == torch.uint8 + assert x_scales.dtype == torch.uint8 + assert w_scales.dtype == torch.uint8 + assert y.dtype == torch.float16 + assert x_q.device.type == "cuda" + assert x_scales.device.type == "cuda" + assert w_q.device.type == "cuda" + assert w_scales.device.type == "cuda" + assert y.device.type == "cuda" + + # Decode packed FP4 operands to float32. + x_fp4 = _decode_fp4_packed(x_q, M, K) + w_fp4 = _decode_fp4_packed(w_q, N, K) + + # Decode E4M3 per-block scales to float32. + xs = x_scales.view(torch.float8_e4m3fn).float() # (M, K/16) + ws = w_scales.view(torch.float8_e4m3fn).float() # (N, K/16) + + # Apply per-block scales along the reduction dimension. + n_blocks = K // BLOCK_SIZE + x_dq = (x_fp4.reshape(M, n_blocks, BLOCK_SIZE) * xs.unsqueeze(-1)).reshape(M, K) + w_dq = (w_fp4.reshape(N, n_blocks, BLOCK_SIZE) * ws.unsqueeze(-1)).reshape(N, K) + + # NVFP4 matmul: y = alpha * (x @ w^T), result cast to FP16. + out = float(alpha) * (x_dq @ w_dq.T) + y.copy_(out.half()) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "x_q": (ctypes.POINTER(ctypes.c_uint8), "in"), + "x_scales": (ctypes.POINTER(ctypes.c_uint8), "in"), + "w_q": (ctypes.POINTER(ctypes.c_uint8), "in"), + "w_scales": (ctypes.POINTER(ctypes.c_uint8), "in"), + "alpha": (ctypes.c_float, "in"), + "y": (ctypes.POINTER(ctypes.c_uint16), "out"), + "M": (ctypes.c_int, "in"), + "N": (ctypes.c_int, "in"), + "K": (ctypes.c_int, "in"), + } + + def _make_test_case(self, M: int, N: int, K: int, zero_x: bool = False, alpha: float = 1.0): + assert K % BLOCK_SIZE == 0, "K must be divisible by 16" + device = "cuda" + if zero_x: + x_q = torch.zeros(M, K // 2, dtype=torch.uint8, device=device) + else: + x_q = torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device) + w_q = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device) + + # Positive E4M3 block scales in a representable range. + x_scales_f = torch.rand(M, K // BLOCK_SIZE, device=device) * 1.5 + 0.5 + w_scales_f = torch.rand(N, K // BLOCK_SIZE, device=device) * 1.5 + 0.5 + x_scales = x_scales_f.to(torch.float8_e4m3fn).view(torch.uint8) + w_scales = w_scales_f.to(torch.float8_e4m3fn).view(torch.uint8) + + y = torch.empty(M, N, device=device, dtype=torch.float16) + return { + "x_q": x_q, + "x_scales": x_scales, + "w_q": w_q, + "w_scales": w_scales, + "alpha": alpha, + "y": y, + "M": M, + "N": N, + "K": K, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + M, N, K = 2, 2, 16 + + # x row 0: sixteen FP4 values = 1.0 (nibble 0x2) -> 8 bytes of 0x22 + # x row 1: sixteen FP4 values = 0.5 (nibble 0x1) -> 8 bytes of 0x11 + x_q = torch.tensor( + [[0x22] * 8, [0x11] * 8], + dtype=torch.uint8, + device=device, + ) + # w row 0: sixteen FP4 values = 2.0 (nibble 0x4) -> 8 bytes of 0x44 + # w row 1: sixteen FP4 values = -1.0 (nibble 0xA) -> 8 bytes of 0xAA + w_q = torch.tensor( + [[0x44] * 8, [0xAA] * 8], + dtype=torch.uint8, + device=device, + ) + # All block scales = E4M3 1.0 = 0x38. One block per row (K=16). + x_scales = torch.full((M, K // BLOCK_SIZE), 0x38, dtype=torch.uint8, device=device) + w_scales = torch.full((N, K // BLOCK_SIZE), 0x38, dtype=torch.uint8, device=device) + + y = torch.empty(M, N, device=device, dtype=torch.float16) + return { + "x_q": x_q, + "x_scales": x_scales, + "w_q": w_q, + "w_scales": w_scales, + "alpha": 1.0, + "y": y, + "M": M, + "N": N, + "K": K, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge cases with the minimum K = 16 (one block). + tests.append(self._make_test_case(1, 2, 16, zero_x=True)) + tests.append(self._make_test_case(2, 4, 16)) + tests.append(self._make_test_case(3, 5, 32)) + + # Power-of-2 shapes. + tests.append(self._make_test_case(16, 16, 32)) + tests.append(self._make_test_case(32, 64, 64)) + tests.append(self._make_test_case(128, 128, 256)) + + # Non-power-of-2 leading dims with valid K. + tests.append(self._make_test_case(30, 50, 64)) + tests.append(self._make_test_case(100, 200, 128)) + tests.append(self._make_test_case(255, 100, 128, alpha=0.125)) + + # Realistic attention-projection shape. + tests.append(self._make_test_case(512, 1024, 1024, alpha=1.0 / 64.0)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # Verbatim from AutoKernel Table 5: the row where Triton hit 2,898 TF/s + # against CUTLASS's 1,777 TF/s (1.63x speedup). A correct FP4 tensor + # core submission at this shape directly validates the paper's claim. + return self._make_test_case(2048, 18432, 3072, alpha=1.0 / 64.0) diff --git a/challenges/hard/86_fp4_matmul/starter/starter.cu b/challenges/hard/86_fp4_matmul/starter/starter.cu new file mode 100644 index 00000000..8faee62a --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.cu @@ -0,0 +1,7 @@ +#include +#include +#include + +// x_q, x_scales, w_q, w_scales, y are device pointers +extern "C" void solve(const uint8_t* x_q, const uint8_t* x_scales, const uint8_t* w_q, + const uint8_t* w_scales, float alpha, __half* y, int M, int N, int K) {} diff --git a/challenges/hard/86_fp4_matmul/starter/starter.cute.py b/challenges/hard/86_fp4_matmul/starter/starter.cute.py new file mode 100644 index 00000000..07ca9580 --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.cute.py @@ -0,0 +1,18 @@ +import cutlass +import cutlass.cute as cute + + +# x_q, x_scales, w_q, w_scales, y are tensors on the GPU +@cute.jit +def solve( + x_q: cute.Tensor, + x_scales: cute.Tensor, + w_q: cute.Tensor, + w_scales: cute.Tensor, + alpha: cute.Float32, + y: cute.Tensor, + M: cute.Int32, + N: cute.Int32, + K: cute.Int32, +): + pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.jax.py b/challenges/hard/86_fp4_matmul/starter/starter.jax.py new file mode 100644 index 00000000..a490abc6 --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.jax.py @@ -0,0 +1,18 @@ +import jax +import jax.numpy as jnp + + +# x_q, x_scales, w_q, w_scales are tensors on GPU +@jax.jit +def solve( + x_q: jax.Array, + x_scales: jax.Array, + w_q: jax.Array, + w_scales: jax.Array, + alpha: float, + M: int, + N: int, + K: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.mojo b/challenges/hard/86_fp4_matmul/starter/starter.mojo new file mode 100644 index 00000000..8e2d31ee --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.mojo @@ -0,0 +1,18 @@ +from std.gpu.host import DeviceContext +from std.memory import UnsafePointer + + +# x_q, x_scales, w_q, w_scales, y are device pointers +@export +def solve( + x_q: UnsafePointer[UInt8, MutExternalOrigin], + x_scales: UnsafePointer[UInt8, MutExternalOrigin], + w_q: UnsafePointer[UInt8, MutExternalOrigin], + w_scales: UnsafePointer[UInt8, MutExternalOrigin], + alpha: Float32, + y: UnsafePointer[Float16, MutExternalOrigin], + M: Int32, + N: Int32, + K: Int32, +) raises: + pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py b/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py new file mode 100644 index 00000000..9ca0eeaf --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.pytorch.py @@ -0,0 +1,16 @@ +import torch + + +# x_q, x_scales, w_q, w_scales, y are tensors on the GPU +def solve( + x_q: torch.Tensor, + x_scales: torch.Tensor, + w_q: torch.Tensor, + w_scales: torch.Tensor, + alpha: float, + y: torch.Tensor, + M: int, + N: int, + K: int, +): + pass diff --git a/challenges/hard/86_fp4_matmul/starter/starter.triton.py b/challenges/hard/86_fp4_matmul/starter/starter.triton.py new file mode 100644 index 00000000..97d2730a --- /dev/null +++ b/challenges/hard/86_fp4_matmul/starter/starter.triton.py @@ -0,0 +1,18 @@ +import torch +import triton +import triton.language as tl + + +# x_q, x_scales, w_q, w_scales, y are tensors on the GPU +def solve( + x_q: torch.Tensor, + x_scales: torch.Tensor, + w_q: torch.Tensor, + w_scales: torch.Tensor, + alpha: float, + y: torch.Tensor, + M: int, + N: int, + K: int, +): + pass