Skip to content

[Performance] qmv kernel: non-linear cost step at M=3 for large MLP shapes #3553

@AirRunner

Description

@AirRunner

mx.quantized_matmul(transpose=True) shows a discontinuous cost increase at M=3 for large asymmetric MLP shapes. M=2 costs nearly the same as M=1, but M=3 costs +28-37% more. The step does not appear for square shapes (5120→5120).

Per-op microbench (group_size=64, bits=4)

import mlx.core as mx
import time, statistics

def bench_qmm(K, N, M, n_iters=200, warmup=30):
    w_f = mx.random.normal((N, K), dtype=mx.float32).astype(mx.float16)
    w_q, scales, biases = mx.quantize(w_f, group_size=64, bits=4)
    x = mx.random.normal((M, K), dtype=mx.float32).astype(mx.float16)
    mx.eval(w_q, scales, biases, x)
    times = []
    for i in range(warmup + n_iters):
        t0 = time.perf_counter()
        out = mx.quantized_matmul(x, w_q, scales, biases, transpose=True, group_size=64, bits=4)
        mx.eval(out)
        if i >= warmup:
            times.append((time.perf_counter() - t0) * 1000)
    return statistics.mean(times)

shapes = [
    ('5120 → 5120  (attn q/o)', 5120, 5120),
    ('5120 → 1024  (attn k/v)', 5120, 1024),
    ('5120 → 13824 (MLP gate/up)', 5120, 13824),
    ('13824 → 5120 (MLP down)', 13824, 5120),
]

for label, K, N in shapes:
    ms = [bench_qmm(K, N, M) for M in [1, 2, 3, 4]]
    print(f'{label}: M=1={ms[0]:.3f}ms M=2={ms[1]:.3f}ms M=3={ms[2]:.3f}ms M=4={ms[3]:.3f}ms  (M3/M1={ms[2]/ms[0]:.2f}x)')

Output on M4 Pro:

5120 → 5120  (attn q/o):    M=1=0.218ms  M=2=0.200ms  M=3=0.195ms  M=4=0.233ms  (M3/M1=0.89x)
5120 → 1024  (attn k/v):    M=1=0.112ms  M=2=0.118ms  M=3=0.127ms  M=4=0.136ms  (M3/M1=1.13x)
5120 → 13824 (MLP gate/up): M=1=0.270ms  M=2=0.272ms  M=3=0.346ms  M=4=0.425ms  (M3/M1=1.28x)
13824 → 5120 (MLP down):    M=1=0.277ms  M=2=0.292ms  M=3=0.379ms  M=4=0.477ms  (M3/M1=1.37x)

The step is shape-dependent: it only appears where output dimension >> input dimension (or vice versa), not on square shapes.

Full model forward (Qwen3.6-27B 4-bit, post-prefill 512 tokens)

M Time Ratio vs M=1
1 68.0ms 1.00x
2 69.3ms 1.02x
3 93.8ms 1.38x
4 121.9ms 1.79x
5 150.2ms 2.21x
6 179.3ms 2.64x

From M=3 onward, cost grows linearly at about +28ms/step. Non-quantized ops (RMSNorm, softmax) are flat across M=1-4, so the step is entirely in the linear projections.

Context

All shapes fall below vector_limit (=10 for M4 Pro at K,N > 4096, from get_qmv_batch_limit using applegpu_g16s), so M=3 dispatches to qmv_fast_impl, not qmm_splitk.

qmv_fast_impl uses grid = (M, ceil(N/8), B): one threadgroup per M row per N-tile. The cause of the discontinuity at M=3 is not clear to me. I tried two things:

  • Fused kernel (process M rows within a single threadgroup to share weight loads): correct output but slower than stock. The M threadgroups run in parallel in the current dispatch, and serializing them loses more than is gained.
  • Lowering vector_limit to route M=3+ to qmm_splitk: significantly worse (M3/M1 goes from 1.3x to 2.7-2.9x), which makes sense since qmm_splitk is designed for larger M.

The step appears systematically across all asymmetric projection shapes.

→ Does this look like a GPU scheduling effect (wave quantization, occupancy change)? Is there a profiling approach that could help narrow it down?

Impact

This affects any workload batching M=3-8 tokens together: small-batch server inference (3 concurrent requests), speculative decoding verify passes (draft_length=2 produces M=3), beam search. The +28% jump at M=3 is abrupt and hard to work around at the application level.

PRs #1861 (faster small-batch qmv) and #3120 (split-K qmm) address adjacent regimes, but M=3-9 on large asymmetric shapes seems to remain unaddressed.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions