Skip to content

Slow 1D dot product mx.inner compared to custom kernel #3533

@SuriyaaMM

Description

@SuriyaaMM

Hi, the dot product (mx.inner is at least 3 times slower for 1d compared to this kernel implemented using metal_kernel).

import mlx
import mlx.core as mx
import numpy as np
import time

with open("kernels/dot.msl") as f:
    source = f.read()

if source == None or source == "":
    raise ValueError("nothing to read in dot.msl")

with open("kernels/reduce_add.msl") as f:
    source2 = f.read()

if source2 == None or source2 == "":
    raise ValueError("nothing to read in reduce_add.msl")

dot_kernel = mx.fast.metal_kernel(
    name="dot",
    input_names=["a", "b"],
    output_names=["output"],
    source=source
)
reduce_add_kernel = mx.fast.metal_kernel(
    name="reduce_add",
    input_names=["a"],
    output_names=["output"],
    source=source2
)

def __dot(a: mx.array, b: mx.array) -> float:
    if a.shape != b.shape:
        raise ValueError("a and b must have same dimension")
    if len(a.shape) != 1:
        raise ValueError("dot is supported only for 1d array")
    
    n = a.size
    tg_size = 512
    ceil = (n + tg_size - 1) // tg_size
    
    output = dot_kernel(
        inputs=[a, b],
        template=[("F32", mx.float32), ("n", n)],
        grid=(n, 1, 1),
        threadgroup=(tg_size, 1, 1),
        output_shapes=[(ceil,)],
        output_dtypes=[mx.float32],
        init_value=0
    )

    while ceil != 1:
        partial = output[0]  
        n = partial.size
        ceil = (n + tg_size - 1) // tg_size

        output = reduce_add_kernel(
            inputs=[partial], 
            template=[("F32", mx.float32), ("n", n)],
            grid=(n, 1, 1),
            threadgroup=(tg_size, 1, 1),
            output_shapes=[(ceil,)],
            output_dtypes=[mx.float32],
            init_value=0
        )

    return output[0][0]  

def bench(fn, rounds=20, label=""):
    for _ in range(3):
        r = fn()
        mx.eval(r)

    times = []
    for _ in range(rounds):
        mx.eval()  
        t0 = time.perf_counter()
        r = fn()
        mx.eval(r) 
        times.append(time.perf_counter() - t0)

    times.sort()
    median = times[len(times) // 2]
    best = times[0]
    worst = times[-1]
    print(f"{label}")
    print(f"median={median*1000:.3f}ms | min={best*1000:.3f}ms | max={worst*1000:.3f}ms")
    return r

a = mx.random.normal(shape=(50_000_000,), dtype=mx.float32)
b = mx.random.normal(shape=(50_000_000,), dtype=mx.float32)

a_np = np.array(a, copy=False)
b_np = np.array(b, copy=False)

c = bench(lambda: __dot(a, b), label="Custom MLX kernel")
ccc = bench(lambda: mx.inner(a, b), label="MLX native")
cc = bench(lambda: np.dot(a_np, b_np), label="NumPy")

print(f"custom : {float(c)}")
print(f"mx.inner : {float(ccc)}")
print(f"numpy : {float(cc)}")

Kernels

// dot_product
uint gid = thread_position_in_grid.x;
uint tid = thread_position_in_threadgroup.x;
uint lane = thread_index_in_simdgroup;
uint simd_id = simdgroup_index_in_threadgroup;
uint tg_size = threads_per_threadgroup.x;
uint warps = (tg_size + 31) / 32;
    
F32 c = (gid < uint(n)) ? a[gid] : 0.0f;
    
threadgroup F32 smem[32];
    
// warp reduce
c = simd_sum(c);
    
if (lane == 0) smem[simd_id] = c;
    
threadgroup_barrier(mem_flags::mem_threadgroup);
    
if (tid < 32) {
    c = (tid < warps) ? smem[tid] : 0.0f;
    c = simd_sum(c);
    if (tid == 0) output[threadgroup_position_in_grid.x] = c;
}
// reduce_add
uint gid = thread_position_in_grid.x;
uint tid = thread_position_in_threadgroup.x;
uint lane = thread_index_in_simdgroup;
uint simd_id = simdgroup_index_in_threadgroup;
uint tg_size = threads_per_threadgroup.x;
uint warps = (tg_size + 31) / 32;
    
F32 c = (gid < uint(n)) ? a[gid] * b[gid] : 0.0f;
    
threadgroup F32 smem[32];
    
// warp reduce
c = simd_sum(c);
    
if (lane == 0) smem[simd_id] = c;
    
threadgroup_barrier(mem_flags::mem_threadgroup);
    
if (tid < 32) {
    c = (tid < warps) ? smem[tid] : 0.0f;
    c = simd_sum(c);
    if (tid == 0) output[threadgroup_position_in_grid.x] = c;
}

results

Custom MLX kernel
median=2.547ms | min=2.492ms | max=2.580ms
MLX native
median=18.859ms | min=18.741ms | max=19.403ms
NumPy
median=6.327ms | min=6.278ms | max=6.400ms
custom : -1250.99365234375
mx.inner : -1251.0242919921875
numpy : -1251.072509765625

machine: Macbook Pro M2

can we add this to existing function?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions