Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions challenges/hard/86_fp4_matmul/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
<p>
Implement an <strong>NVFP4</strong> matrix multiplication, the low-precision GEMM that powers
state-of-the-art LLM inference on Hopper and Blackwell GPUs. Both operands are stored in 4-bit
floating point (FP4 E2M1) with per-block FP8 (E4M3) scales along the reduction dimension, plus
a single per-tensor FP32 scale. Given packed activations <code>x_q</code> of shape
<code>M &times; K</code>, packed weights <code>w_q</code> of shape <code>N &times; K</code>,
and their respective block scales, compute
<code>y = alpha &times; (x &times; w<sup>T</sup>)</code> of shape <code>M &times; N</code>
in float16.
</p>

<p>
<strong>FP4 E2M1 encoding:</strong> Each weight is 4 bits
[sign | exp (2 bits) | mantissa (1 bit)] representing one of sixteen values:
<code>{&plusmn;0, &plusmn;0.5, &plusmn;1, &plusmn;1.5, &plusmn;2, &plusmn;3, &plusmn;4, &plusmn;6}</code>.
The nibble-to-value mapping is:
</p>
<pre>
0x0 = 0.0 0x8 = -0.0
0x1 = 0.5 0x9 = -0.5
0x2 = 1.0 0xA = -1.0
0x3 = 1.5 0xB = -1.5
0x4 = 2.0 0xC = -2.0
0x5 = 3.0 0xD = -3.0
0x6 = 4.0 0xE = -4.0
0x7 = 6.0 0xF = -6.0
</pre>

<p>
<strong>Packing:</strong> Each byte of <code>x_q</code> / <code>w_q</code> stores two FP4
values. The high nibble (bits 7&ndash;4) holds the even-index value and the low nibble
(bits 3&ndash;0) holds the odd-index value.
</p>

<p>
<strong>Block scales:</strong> Each contiguous block of <strong>16</strong> FP4 values along
the <code>K</code> dimension shares one E4M3 (float8) scale. The scale tensors
<code>x_scales</code> and <code>w_scales</code> are passed as raw uint8 bytes holding the
E4M3 bit patterns. Dequantization is:
</p>
<pre>
x[m, k] = fp4_decode(x_q_nibble[m, k]) * e4m3_decode(x_scales[m, k // 16])
w[n, k] = fp4_decode(w_q_nibble[n, k]) * e4m3_decode(w_scales[n, k // 16])
y[m, n] = alpha * sum_k x[m, k] * w[n, k]
</pre>

<h2>Implementation Requirements</h2>
<ul>
<li>Use only native features (external libraries are not permitted)</li>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>The final result must be stored in <code>y</code> as float16</li>
</ul>

<h2>Example</h2>
<p>
Input (<code>M</code> = 2, <code>N</code> = 2, <code>K</code> = 16, <code>alpha</code> = 1.0):
</p>
<p>
Packed activations \(x\_q\) (uint8, \(2 \times 8\)) and decoded FP4 values (each row has
sixteen values):
\[
x\_q =
\begin{bmatrix}
\texttt{0x22} & \cdots & \texttt{0x22} \\
\texttt{0x11} & \cdots & \texttt{0x11}
\end{bmatrix}
\;\Rightarrow\;
x_{\text{fp4}} =
\begin{bmatrix}
1.0 & 1.0 & \cdots & 1.0 \\
0.5 & 0.5 & \cdots & 0.5
\end{bmatrix}
\]
Packed weights \(w\_q\) (uint8, \(2 \times 8\)):
\[
w\_q =
\begin{bmatrix}
\texttt{0x44} & \cdots & \texttt{0x44} \\
\texttt{0xAA} & \cdots & \texttt{0xAA}
\end{bmatrix}
\;\Rightarrow\;
w_{\text{fp4}} =
\begin{bmatrix}
2.0 & 2.0 & \cdots & 2.0 \\
-1.0 & -1.0 & \cdots & -1.0
\end{bmatrix}
\]
Block scales (one block per row since <code>K</code> = 16): both
<code>x_scales</code> and <code>w_scales</code> are uint8 \(2 \times 1\) with every byte
equal to <code>0x38</code>, which is the E4M3 bit pattern for 1.0. The dequantized operands
therefore equal the FP4 values above.
</p>
<p>
Output \(y = \alpha \cdot (x \times w^T)\) (float16, \(2 \times 2\)):
\[
\begin{bmatrix}
\sum 1.0 \cdot 2.0 & \sum 1.0 \cdot (-1.0) \\
\sum 0.5 \cdot 2.0 & \sum 0.5 \cdot (-1.0)
\end{bmatrix}
=
\begin{bmatrix}
32.0 & -16.0 \\
16.0 & -8.0
\end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>M</code>, <code>N</code> &le; 32,768</li>
<li>16 &le; <code>K</code> &le; 32,768</li>
<li><code>K</code> is divisible by <strong>16</strong> (the NVFP4 block size)</li>
<li>All tensors are stored in row-major order</li>
<li>Inputs: <code>x_q</code>, <code>w_q</code>, <code>x_scales</code>, <code>w_scales</code>
are <code>uint8</code>; <code>alpha</code> is <code>float32</code></li>
<li>Output: <code>y</code> is <code>float16</code></li>
<li>Performance is measured with <code>M</code> = 2,048, <code>N</code> = 18,432, <code>K</code> = 3,072</li>
</ul>
205 changes: 205 additions & 0 deletions challenges/hard/86_fp4_matmul/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase

# OCP FP4 E2M1 lookup table: 4-bit unsigned index -> float value.
# Bit layout: [sign | exp1 exp0 | mantissa].
FP4_E2M1_TABLE = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]

# NVFP4 block size along the reduction dimension. Each block of 16 FP4
# values shares one E4M3 scale. Matches CUTLASS / qutlass NVFP4 layout.
BLOCK_SIZE = 16


def _decode_fp4_packed(packed: torch.Tensor, rows: int, cols: int) -> torch.Tensor:
"""Decode a (rows, cols/2) uint8 tensor of packed FP4 E2M1 nibbles into
a (rows, cols) float32 tensor. High nibble stores the even-index value,
low nibble stores the odd-index value."""
table = torch.tensor(FP4_E2M1_TABLE, device=packed.device, dtype=torch.float32)
high = ((packed >> 4) & 0xF).to(torch.long)
low = (packed & 0xF).to(torch.long)
decoded = torch.stack([table[high], table[low]], dim=-1).reshape(rows, cols)
return decoded


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="NVFP4 Matrix Multiplication",
atol=1e-01,
rtol=5e-02,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
x_q: torch.Tensor,
x_scales: torch.Tensor,
w_q: torch.Tensor,
w_scales: torch.Tensor,
alpha: float,
y: torch.Tensor,
M: int,
N: int,
K: int,
):
assert K % BLOCK_SIZE == 0, "K must be divisible by 16 (NVFP4 block size)"
assert x_q.shape == (M, K // 2)
assert x_scales.shape == (M, K // BLOCK_SIZE)
assert w_q.shape == (N, K // 2)
assert w_scales.shape == (N, K // BLOCK_SIZE)
assert y.shape == (M, N)
assert x_q.dtype == torch.uint8
assert w_q.dtype == torch.uint8
assert x_scales.dtype == torch.uint8
assert w_scales.dtype == torch.uint8
assert y.dtype == torch.float16
assert x_q.device.type == "cuda"
assert x_scales.device.type == "cuda"
assert w_q.device.type == "cuda"
assert w_scales.device.type == "cuda"
assert y.device.type == "cuda"

# Decode packed FP4 operands to float32.
x_fp4 = _decode_fp4_packed(x_q, M, K)
w_fp4 = _decode_fp4_packed(w_q, N, K)

# Decode E4M3 per-block scales to float32.
xs = x_scales.view(torch.float8_e4m3fn).float() # (M, K/16)
ws = w_scales.view(torch.float8_e4m3fn).float() # (N, K/16)

# Apply per-block scales along the reduction dimension.
n_blocks = K // BLOCK_SIZE
x_dq = (x_fp4.reshape(M, n_blocks, BLOCK_SIZE) * xs.unsqueeze(-1)).reshape(M, K)
w_dq = (w_fp4.reshape(N, n_blocks, BLOCK_SIZE) * ws.unsqueeze(-1)).reshape(N, K)

# NVFP4 matmul: y = alpha * (x @ w^T), result cast to FP16.
out = float(alpha) * (x_dq @ w_dq.T)
y.copy_(out.half())

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"x_q": (ctypes.POINTER(ctypes.c_uint8), "in"),
"x_scales": (ctypes.POINTER(ctypes.c_uint8), "in"),
"w_q": (ctypes.POINTER(ctypes.c_uint8), "in"),
"w_scales": (ctypes.POINTER(ctypes.c_uint8), "in"),
"alpha": (ctypes.c_float, "in"),
"y": (ctypes.POINTER(ctypes.c_uint16), "out"),
"M": (ctypes.c_int, "in"),
"N": (ctypes.c_int, "in"),
"K": (ctypes.c_int, "in"),
}

def _make_test_case(self, M: int, N: int, K: int, zero_x: bool = False, alpha: float = 1.0):
assert K % BLOCK_SIZE == 0, "K must be divisible by 16"
device = "cuda"
if zero_x:
x_q = torch.zeros(M, K // 2, dtype=torch.uint8, device=device)
else:
x_q = torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device)
w_q = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device)

# Positive E4M3 block scales in a representable range.
x_scales_f = torch.rand(M, K // BLOCK_SIZE, device=device) * 1.5 + 0.5
w_scales_f = torch.rand(N, K // BLOCK_SIZE, device=device) * 1.5 + 0.5
x_scales = x_scales_f.to(torch.float8_e4m3fn).view(torch.uint8)
w_scales = w_scales_f.to(torch.float8_e4m3fn).view(torch.uint8)

y = torch.empty(M, N, device=device, dtype=torch.float16)
return {
"x_q": x_q,
"x_scales": x_scales,
"w_q": w_q,
"w_scales": w_scales,
"alpha": alpha,
"y": y,
"M": M,
"N": N,
"K": K,
}

def generate_example_test(self) -> Dict[str, Any]:
device = "cuda"
M, N, K = 2, 2, 16

# x row 0: sixteen FP4 values = 1.0 (nibble 0x2) -> 8 bytes of 0x22
# x row 1: sixteen FP4 values = 0.5 (nibble 0x1) -> 8 bytes of 0x11
x_q = torch.tensor(
[[0x22] * 8, [0x11] * 8],
dtype=torch.uint8,
device=device,
)
# w row 0: sixteen FP4 values = 2.0 (nibble 0x4) -> 8 bytes of 0x44
# w row 1: sixteen FP4 values = -1.0 (nibble 0xA) -> 8 bytes of 0xAA
w_q = torch.tensor(
[[0x44] * 8, [0xAA] * 8],
dtype=torch.uint8,
device=device,
)
# All block scales = E4M3 1.0 = 0x38. One block per row (K=16).
x_scales = torch.full((M, K // BLOCK_SIZE), 0x38, dtype=torch.uint8, device=device)
w_scales = torch.full((N, K // BLOCK_SIZE), 0x38, dtype=torch.uint8, device=device)

y = torch.empty(M, N, device=device, dtype=torch.float16)
return {
"x_q": x_q,
"x_scales": x_scales,
"w_q": w_q,
"w_scales": w_scales,
"alpha": 1.0,
"y": y,
"M": M,
"N": N,
"K": K,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge cases with the minimum K = 16 (one block).
tests.append(self._make_test_case(1, 2, 16, zero_x=True))
tests.append(self._make_test_case(2, 4, 16))
tests.append(self._make_test_case(3, 5, 32))

# Power-of-2 shapes.
tests.append(self._make_test_case(16, 16, 32))
tests.append(self._make_test_case(32, 64, 64))
tests.append(self._make_test_case(128, 128, 256))

# Non-power-of-2 leading dims with valid K.
tests.append(self._make_test_case(30, 50, 64))
tests.append(self._make_test_case(100, 200, 128))
tests.append(self._make_test_case(255, 100, 128, alpha=0.125))

# Realistic attention-projection shape.
tests.append(self._make_test_case(512, 1024, 1024, alpha=1.0 / 64.0))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# Verbatim from AutoKernel Table 5: the row where Triton hit 2,898 TF/s
# against CUTLASS's 1,777 TF/s (1.63x speedup). A correct FP4 tensor
# core submission at this shape directly validates the paper's claim.
return self._make_test_case(2048, 18432, 3072, alpha=1.0 / 64.0)
7 changes: 7 additions & 0 deletions challenges/hard/86_fp4_matmul/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <stdint.h>

// x_q, x_scales, w_q, w_scales, y are device pointers
extern "C" void solve(const uint8_t* x_q, const uint8_t* x_scales, const uint8_t* w_q,
const uint8_t* w_scales, float alpha, __half* y, int M, int N, int K) {}
18 changes: 18 additions & 0 deletions challenges/hard/86_fp4_matmul/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import cutlass
import cutlass.cute as cute


# x_q, x_scales, w_q, w_scales, y are tensors on the GPU
@cute.jit
def solve(
x_q: cute.Tensor,
x_scales: cute.Tensor,
w_q: cute.Tensor,
w_scales: cute.Tensor,
alpha: cute.Float32,
y: cute.Tensor,
M: cute.Int32,
N: cute.Int32,
K: cute.Int32,
):
pass
18 changes: 18 additions & 0 deletions challenges/hard/86_fp4_matmul/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import jax
import jax.numpy as jnp


# x_q, x_scales, w_q, w_scales are tensors on GPU
@jax.jit
def solve(
x_q: jax.Array,
x_scales: jax.Array,
w_q: jax.Array,
w_scales: jax.Array,
alpha: float,
M: int,
N: int,
K: int,
) -> jax.Array:
# return output tensor directly
pass
18 changes: 18 additions & 0 deletions challenges/hard/86_fp4_matmul/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from std.gpu.host import DeviceContext
from std.memory import UnsafePointer


# x_q, x_scales, w_q, w_scales, y are device pointers
@export
def solve(
x_q: UnsafePointer[UInt8, MutExternalOrigin],
x_scales: UnsafePointer[UInt8, MutExternalOrigin],
w_q: UnsafePointer[UInt8, MutExternalOrigin],
w_scales: UnsafePointer[UInt8, MutExternalOrigin],
alpha: Float32,
y: UnsafePointer[Float16, MutExternalOrigin],
M: Int32,
N: Int32,
K: Int32,
) raises:
pass
Loading
Loading