Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ jobs:
echo "::endgroup::"

echo "::group::Build test runners"
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner mlx_mutable_state_test -j$(( $(sysctl -n hw.ncpu) - 1 ))
echo "::endgroup::"

echo "::group::Run mutable-state (multi-session) unit test"
./cmake-out/backends/mlx/test/mlx_mutable_state_test
echo "::endgroup::"

echo "::group::Run op unit tests"
Expand Down
6 changes: 4 additions & 2 deletions backends/mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,10 @@ option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
ON
)

set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
set(_mlx_backend__srcs
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/mlx_mutable_state.cpp
)

add_library(mlxdelegate ${_mlx_backend__srcs})
Expand Down
37 changes: 37 additions & 0 deletions backends/mlx/custom_kernel_ops/gated_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def gated_delta_rule(
B, T_len, Hk, Dk = q.shape
Hv, Dv = v.shape[-2:]

# The Metal kernel maps each v-head to its k-head group
# (hk_idx = hv_idx / (Hv / Hk)); mirror that here so the eager reference also
# supports Hk != Hv (GQA) instead of relying on broadcasting, which requires
# Hk == Hv. repeat_interleave on the head dim reproduces that index mapping.
if Hk != Hv:
q = q.repeat_interleave(Hv // Hk, dim=2)
k = k.repeat_interleave(Hv // Hk, dim=2)
Hk = Hv

s = state.clone()

ys = []
Expand Down Expand Up @@ -101,6 +110,7 @@ def gated_delta_rule_fake(
IntOrVid,
MetalKernelNode,
MultiplyNode,
RepeatNode,
ScanNode,
SubtractNode,
SumNode,
Expand Down Expand Up @@ -450,6 +460,33 @@ def _emit_scan(self, P: MLXProgramBuilder, n: Node) -> Slot:
]
)

# GQA: q/k carry Hk heads but the recurrence state/v have Hv heads. Expand
# q/k to Hv (repeat_interleave on the head axis) so the per-step broadcasts
# match, mirroring the Metal kernel's hk_idx = hv_idx / (Hv / Hk).
Hk = int(self.q_node.meta["val"].shape[-2])
Hv = int(self.v_node.meta["val"].shape[-2])
if Hk != Hv:
rep = IntOrVid.from_literal(Hv // Hk)
_, q_exp = P.make_tmp_slot()
P.emit(
RepeatNode(
x=P.slot_to_tid(q_slot),
out=P.slot_to_tid(q_exp),
repeats=rep,
axis=2,
)
)
_, k_exp = P.make_tmp_slot()
P.emit(
RepeatNode(
x=P.slot_to_tid(k_slot),
out=P.slot_to_tid(k_exp),
repeats=rep,
axis=2,
)
)
q_slot, k_slot = q_exp, k_exp

# Carry needs a writable slot. This is node n's persistent output (the
# mutated state), so it must be a node-owned slot — not a temp slot, whose
# id is reclaimed on tmp_scope exit and would be read as dead by a later
Expand Down
5 changes: 2 additions & 3 deletions backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def forward(
g: torch.Tensor, # [B, T, Hv]
beta: torch.Tensor, # [B, T, Hv]
) -> torch.Tensor:
if self.head_repeat > 1:
q = q.repeat_interleave(self.head_repeat, dim=2)
k = k.repeat_interleave(self.head_repeat, dim=2)
# Pass native Hk (no repeat_interleave): the op itself must handle
# GQA head expansion (kernel via hk_idx mapping, scan/eager internally).
return torch.ops.mlx.gated_delta_rule(
q, k, v, g, beta, self.state, use_custom_kernel=self.use_custom_kernel
)
Expand Down
16 changes: 16 additions & 0 deletions backends/mlx/runtime/MLXBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "MLXExecutor.h"
#include "MLXInterpreter.h"
#include "MLXLoader.h"
#include "mlx_mutable_state.h"

#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
Expand Down Expand Up @@ -277,6 +278,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
eval(handle->constants.tensors);
}

// Register the handle with the per-session mutable-state manager. This is
// a no-op unless a multi-session owner is active for this load (see
// mlx_mutable_state.h); single-session execution is unaffected.
mutable_state_note_handle(
handle, &handle->program, &handle->mutable_buffers);

} catch (const std::exception& e) {
ET_LOG(Error, "Failed to load MLX program: %s", e.what());
handle->~MLXHandle();
Expand Down Expand Up @@ -366,6 +373,14 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
}
}

// Select the active session's mutable buffers (KV cache, recurrent/conv
// state) before running. No-op for single-session handles; weights stay
// shared via ExecutionState::constants.
if (Error rebind_err = mutable_state_rebind_for_execute(h, h->state);
rebind_err != Error::Ok) {
return rebind_err;
}

// Run the MLX program (builds lazy computation graph)
h->interpreter.run(program, h->state, h->stream);

Expand Down Expand Up @@ -431,6 +446,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
void destroy(DelegateHandle* handle) const override {
std::lock_guard<std::mutex> lock(mlx_global_mutex());
if (handle != nullptr) {
mutable_state_forget_handle(handle);
auto* mlx_handle = static_cast<MLXHandle*>(handle);
mlx_handle->~MLXHandle();
}
Expand Down
Loading
Loading