Skip to content

[BUG] mx.fast.scaled_dot_product_attention: position-0 output changes at K=8 → K=9, all dtypes #3573

@JacobLinCool

Description

@JacobLinCool

Describe the bug

When you run SDPA with a causal mask, the output at chunk position 0 should only depend on q[0], the past cache, and (k_new[0], v_new[0]). Adding more queries at chunk positions 1..K-1 shouldn't touch it.

In practice it does. Sweeping K from 1 to 12 with everything else held constant, out[..., 0, :] is bit-identical for K=1..8 and then jumps to a fixed non-zero delta the moment K hits 9. The delta is exactly the same for K = 9, 10, 11, 12 — so it's not gradual reduction-order noise, it's one alternative kernel taking over above 8 rows. Reproduces in bfloat16, float16, and float32.

To Reproduce

import mlx.core as mx
import numpy as np


def causal_mask(P, K, dtype, neg=-1e9):
    past = mx.zeros((1, 1, K, P), dtype=dtype)
    cols, rows = mx.arange(K).reshape(1, K), mx.arange(K).reshape(K, 1)
    chunk = mx.where(cols <= rows,
                     mx.array(0, dtype=dtype),
                     mx.array(neg, dtype=dtype)).reshape(1, 1, K, K)
    return mx.concatenate([past, chunk], axis=3)


def run(K, q, kn, vn, kp, vp, dtype, scale):
    k = mx.concatenate([kp, kn[:, :, :K, :]], axis=2)
    v = mx.concatenate([vp, vn[:, :, :K, :]], axis=2)
    return mx.fast.scaled_dot_product_attention(
        q[:, :, :K, :], k, v, scale=scale, mask=causal_mask(kp.shape[2], K, dtype)
    )


mx.random.seed(0)
B, H, dk, P, K_max = 1, 20, 64, 1, 12
scale = 1.0 / dk**0.5

for dtype in (mx.bfloat16, mx.float16, mx.float32):
    q  = mx.random.normal((B, H, K_max, dk)).astype(dtype)
    kn = mx.random.normal((B, H, K_max, dk)).astype(dtype)
    vn = mx.random.normal((B, H, K_max, dk)).astype(dtype)
    kp = mx.random.normal((B, H, P, dk)).astype(dtype)
    vp = mx.random.normal((B, H, P, dk)).astype(dtype)
    print(f"\n{dtype}")
    ref = None
    for K in range(1, K_max + 1):
        out = run(K, q, kn, vn, kp, vp, dtype, scale)
        mx.eval(out)
        pos0 = np.array(out[0, :, 0, :].astype(mx.float32))
        if ref is None: ref = pos0
        d = np.abs(pos0 - ref)
        print(f"  K={K:>2}: max|Δ|={d.max():.6f}  mean|Δ|={d.mean():.6f}")

Output on Apple M5 Pro, MLX 0.31.2:

bfloat16:  K=1..8  →  0           K=9..12 →  max|Δ|=0.007812  (constant)
float16:   K=1..8  →  0           K=9..12 →  max|Δ|=0.001953  (constant)
float32:   K=1..8  →  0           K=9..12 →  max|Δ|=0.002021  (constant)

The fact that |Δ| is exactly constant from K=9 onward — and that it shows up in fp32 too — is what makes this look like a hard dispatch threshold rather than ordinary FP noise.

Also reproduces with a boolean mask (True = attend), so it's not specific to additive-mask handling.

Expected behavior

out[..., 0, :] shouldn't depend on K. Causal mask says positions 1..K-1 are invisible from position 0.

Why it bites

Greedy speculative decoding (and chunked-prefill verifiers in general) rely on this exact invariant: the per-token reference path computes argmax(target_logits) one query at a time; the verifier batches K+1 queries against the same cache to confirm K draft tokens. When the top-2 logit gap is smaller than the K-dependent drift above, the verifier's argmax flips and the chunked path silently disagrees with the per-token path. Out of the implementer's hands.

Environment

  • MLX 0.31.2
  • Apple M5 Pro GPU
  • macOS 26.4 (Darwin 25.4)
  • Python 3.12

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions