Hi, the dot product (mx.inner is at least 3 times slower for 1d compared to this kernel implemented using metal_kernel).
import mlx
import mlx.core as mx
import numpy as np
import time
with open("kernels/dot.msl") as f:
source = f.read()
if source == None or source == "":
raise ValueError("nothing to read in dot.msl")
with open("kernels/reduce_add.msl") as f:
source2 = f.read()
if source2 == None or source2 == "":
raise ValueError("nothing to read in reduce_add.msl")
dot_kernel = mx.fast.metal_kernel(
name="dot",
input_names=["a", "b"],
output_names=["output"],
source=source
)
reduce_add_kernel = mx.fast.metal_kernel(
name="reduce_add",
input_names=["a"],
output_names=["output"],
source=source2
)
def __dot(a: mx.array, b: mx.array) -> float:
if a.shape != b.shape:
raise ValueError("a and b must have same dimension")
if len(a.shape) != 1:
raise ValueError("dot is supported only for 1d array")
n = a.size
tg_size = 512
ceil = (n + tg_size - 1) // tg_size
output = dot_kernel(
inputs=[a, b],
template=[("F32", mx.float32), ("n", n)],
grid=(n, 1, 1),
threadgroup=(tg_size, 1, 1),
output_shapes=[(ceil,)],
output_dtypes=[mx.float32],
init_value=0
)
while ceil != 1:
partial = output[0]
n = partial.size
ceil = (n + tg_size - 1) // tg_size
output = reduce_add_kernel(
inputs=[partial],
template=[("F32", mx.float32), ("n", n)],
grid=(n, 1, 1),
threadgroup=(tg_size, 1, 1),
output_shapes=[(ceil,)],
output_dtypes=[mx.float32],
init_value=0
)
return output[0][0]
def bench(fn, rounds=20, label=""):
for _ in range(3):
r = fn()
mx.eval(r)
times = []
for _ in range(rounds):
mx.eval()
t0 = time.perf_counter()
r = fn()
mx.eval(r)
times.append(time.perf_counter() - t0)
times.sort()
median = times[len(times) // 2]
best = times[0]
worst = times[-1]
print(f"{label}")
print(f"median={median*1000:.3f}ms | min={best*1000:.3f}ms | max={worst*1000:.3f}ms")
return r
a = mx.random.normal(shape=(50_000_000,), dtype=mx.float32)
b = mx.random.normal(shape=(50_000_000,), dtype=mx.float32)
a_np = np.array(a, copy=False)
b_np = np.array(b, copy=False)
c = bench(lambda: __dot(a, b), label="Custom MLX kernel")
ccc = bench(lambda: mx.inner(a, b), label="MLX native")
cc = bench(lambda: np.dot(a_np, b_np), label="NumPy")
print(f"custom : {float(c)}")
print(f"mx.inner : {float(ccc)}")
print(f"numpy : {float(cc)}")
Kernels
// dot_product
uint gid = thread_position_in_grid.x;
uint tid = thread_position_in_threadgroup.x;
uint lane = thread_index_in_simdgroup;
uint simd_id = simdgroup_index_in_threadgroup;
uint tg_size = threads_per_threadgroup.x;
uint warps = (tg_size + 31) / 32;
F32 c = (gid < uint(n)) ? a[gid] : 0.0f;
threadgroup F32 smem[32];
// warp reduce
c = simd_sum(c);
if (lane == 0) smem[simd_id] = c;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid < 32) {
c = (tid < warps) ? smem[tid] : 0.0f;
c = simd_sum(c);
if (tid == 0) output[threadgroup_position_in_grid.x] = c;
}
// reduce_add
uint gid = thread_position_in_grid.x;
uint tid = thread_position_in_threadgroup.x;
uint lane = thread_index_in_simdgroup;
uint simd_id = simdgroup_index_in_threadgroup;
uint tg_size = threads_per_threadgroup.x;
uint warps = (tg_size + 31) / 32;
F32 c = (gid < uint(n)) ? a[gid] * b[gid] : 0.0f;
threadgroup F32 smem[32];
// warp reduce
c = simd_sum(c);
if (lane == 0) smem[simd_id] = c;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid < 32) {
c = (tid < warps) ? smem[tid] : 0.0f;
c = simd_sum(c);
if (tid == 0) output[threadgroup_position_in_grid.x] = c;
}
results
Custom MLX kernel
median=2.547ms | min=2.492ms | max=2.580ms
MLX native
median=18.859ms | min=18.741ms | max=19.403ms
NumPy
median=6.327ms | min=6.278ms | max=6.400ms
custom : -1250.99365234375
mx.inner : -1251.0242919921875
numpy : -1251.072509765625
machine: Macbook Pro M2
can we add this to existing function?
Hi, the dot product (
mx.inneris at least 3 times slower for 1d compared to this kernel implemented usingmetal_kernel).Kernels
results
machine: Macbook Pro M2
can we add this to existing function?