Describe the bug
When you run SDPA with a causal mask, the output at chunk position 0 should only depend on q[0], the past cache, and (k_new[0], v_new[0]). Adding more queries at chunk positions 1..K-1 shouldn't touch it.
In practice it does. Sweeping K from 1 to 12 with everything else held constant, out[..., 0, :] is bit-identical for K=1..8 and then jumps to a fixed non-zero delta the moment K hits 9. The delta is exactly the same for K = 9, 10, 11, 12 — so it's not gradual reduction-order noise, it's one alternative kernel taking over above 8 rows. Reproduces in bfloat16, float16, and float32.
To Reproduce
import mlx.core as mx
import numpy as np
def causal_mask(P, K, dtype, neg=-1e9):
past = mx.zeros((1, 1, K, P), dtype=dtype)
cols, rows = mx.arange(K).reshape(1, K), mx.arange(K).reshape(K, 1)
chunk = mx.where(cols <= rows,
mx.array(0, dtype=dtype),
mx.array(neg, dtype=dtype)).reshape(1, 1, K, K)
return mx.concatenate([past, chunk], axis=3)
def run(K, q, kn, vn, kp, vp, dtype, scale):
k = mx.concatenate([kp, kn[:, :, :K, :]], axis=2)
v = mx.concatenate([vp, vn[:, :, :K, :]], axis=2)
return mx.fast.scaled_dot_product_attention(
q[:, :, :K, :], k, v, scale=scale, mask=causal_mask(kp.shape[2], K, dtype)
)
mx.random.seed(0)
B, H, dk, P, K_max = 1, 20, 64, 1, 12
scale = 1.0 / dk**0.5
for dtype in (mx.bfloat16, mx.float16, mx.float32):
q = mx.random.normal((B, H, K_max, dk)).astype(dtype)
kn = mx.random.normal((B, H, K_max, dk)).astype(dtype)
vn = mx.random.normal((B, H, K_max, dk)).astype(dtype)
kp = mx.random.normal((B, H, P, dk)).astype(dtype)
vp = mx.random.normal((B, H, P, dk)).astype(dtype)
print(f"\n{dtype}")
ref = None
for K in range(1, K_max + 1):
out = run(K, q, kn, vn, kp, vp, dtype, scale)
mx.eval(out)
pos0 = np.array(out[0, :, 0, :].astype(mx.float32))
if ref is None: ref = pos0
d = np.abs(pos0 - ref)
print(f" K={K:>2}: max|Δ|={d.max():.6f} mean|Δ|={d.mean():.6f}")
Output on Apple M5 Pro, MLX 0.31.2:
bfloat16: K=1..8 → 0 K=9..12 → max|Δ|=0.007812 (constant)
float16: K=1..8 → 0 K=9..12 → max|Δ|=0.001953 (constant)
float32: K=1..8 → 0 K=9..12 → max|Δ|=0.002021 (constant)
The fact that |Δ| is exactly constant from K=9 onward — and that it shows up in fp32 too — is what makes this look like a hard dispatch threshold rather than ordinary FP noise.
Also reproduces with a boolean mask (True = attend), so it's not specific to additive-mask handling.
Expected behavior
out[..., 0, :] shouldn't depend on K. Causal mask says positions 1..K-1 are invisible from position 0.
Why it bites
Greedy speculative decoding (and chunked-prefill verifiers in general) rely on this exact invariant: the per-token reference path computes argmax(target_logits) one query at a time; the verifier batches K+1 queries against the same cache to confirm K draft tokens. When the top-2 logit gap is smaller than the K-dependent drift above, the verifier's argmax flips and the chunked path silently disagrees with the per-token path. Out of the implementer's hands.
Environment
- MLX 0.31.2
- Apple M5 Pro GPU
- macOS 26.4 (Darwin 25.4)
- Python 3.12
Related
Describe the bug
When you run SDPA with a causal mask, the output at chunk position 0 should only depend on
q[0], the past cache, and(k_new[0], v_new[0]). Adding more queries at chunk positions 1..K-1 shouldn't touch it.In practice it does. Sweeping K from 1 to 12 with everything else held constant,
out[..., 0, :]is bit-identical for K=1..8 and then jumps to a fixed non-zero delta the moment K hits 9. The delta is exactly the same for K = 9, 10, 11, 12 — so it's not gradual reduction-order noise, it's one alternative kernel taking over above 8 rows. Reproduces in bfloat16, float16, and float32.To Reproduce
Output on Apple M5 Pro, MLX 0.31.2:
The fact that
|Δ|is exactly constant from K=9 onward — and that it shows up in fp32 too — is what makes this look like a hard dispatch threshold rather than ordinary FP noise.Also reproduces with a boolean mask (
True= attend), so it's not specific to additive-mask handling.Expected behavior
out[..., 0, :]shouldn't depend on K. Causal mask says positions 1..K-1 are invisible from position 0.Why it bites
Greedy speculative decoding (and chunked-prefill verifiers in general) rely on this exact invariant: the per-token reference path computes
argmax(target_logits)one query at a time; the verifier batches K+1 queries against the same cache to confirm K draft tokens. When the top-2 logit gap is smaller than the K-dependent drift above, the verifier's argmax flips and the chunked path silently disagrees with the per-token path. Out of the implementer's hands.Environment
Related
mx.fast.scaled_dot_product_attention with sinks produces different results for S=1 vs S=N queries #3452 (closed, retracted): same S=1 vs S=N divergence, originally attributed to
sinks. Closing comment from the reporter:This is the standalone repro for that underlying divergence — no sinks, no model, reproducible in fp32.
MLXNN.RoPE produces row-asymmetric output for byte-identical row inputs at [B>1, H, 1, D] #3496 (open): RoPE row-asymmetry at T=1 — analogous shape of bug on a different op.
[BUG] Result of scaled_dot_product_attention does not match PyTorch when S_Q != S_KV, mask="causal" #2835 (closed):
mask="causal"lower-right alignment — different issue; this report uses an explicit additive (and boolean) mask.