From 61f1b588965d1d0b5a202f585c9522b9e7f170b3 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 17 Jun 2026 18:05:35 -0700 Subject: [PATCH 1/9] up --- Makefile | 12 ++- examples/models/qwen3_5_moe/CMakeLists.txt | 11 ++- examples/models/qwen3_5_moe/CMakePresets.json | 31 +++++++ examples/models/qwen3_5_moe/README.md | 33 ++++++- .../models/qwen3_5_moe/qwen35_moe_engine.cpp | 91 +++++++++++++------ 5 files changed, 147 insertions(+), 31 deletions(-) diff --git a/Makefile b/Makefile index c93085115aa..552bbf89bd7 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal qwen3_5_moe-mlx clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -131,6 +131,7 @@ help: @echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" + @echo " qwen3_5_moe-mlx - Build Qwen3.5 MoE runner with MLX backend" @echo " clean - Clean build artifacts" voxtral-cuda: @@ -467,6 +468,15 @@ qwen3_5_moe-metal: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner" +qwen3_5_moe-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Qwen3.5 MoE runner with MLX..." + cd examples/models/qwen3_5_moe && cmake --workflow --preset qwen3-5-moe-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner" + clean: rm -rf cmake-out \ extension/llm/tokenizers/build \ diff --git a/examples/models/qwen3_5_moe/CMakeLists.txt b/examples/models/qwen3_5_moe/CMakeLists.txt index 6a753a538f1..b7a3dce14fa 100644 --- a/examples/models/qwen3_5_moe/CMakeLists.txt +++ b/examples/models/qwen3_5_moe/CMakeLists.txt @@ -54,9 +54,14 @@ elseif(EXECUTORCH_BUILD_CUDA) list(APPEND link_libraries aoti_cuda_backend) executorch_target_link_options_shared_lib(aoti_cuda_backend) add_compile_definitions(EXECUTORCH_BUILD_CUDA) +elseif(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) + add_compile_definitions(EXECUTORCH_BUILD_MLX) else() message( - FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_METAL=ON" + FATAL_ERROR + "Set EXECUTORCH_BUILD_CUDA=ON, EXECUTORCH_BUILD_METAL=ON, or EXECUTORCH_BUILD_MLX=ON" ) endif() @@ -82,6 +87,10 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options(qwen3_5_moe_worker PRIVATE "LINKER:-s") endif() +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(qwen3_5_moe_runner) +endif() + if(EXECUTORCH_BUILD_CUDA) enable_testing() add_executable( diff --git a/examples/models/qwen3_5_moe/CMakePresets.json b/examples/models/qwen3_5_moe/CMakePresets.json index 36eea8aa3ad..99786f424de 100644 --- a/examples/models/qwen3_5_moe/CMakePresets.json +++ b/examples/models/qwen3_5_moe/CMakePresets.json @@ -36,6 +36,17 @@ "type": "equals", "rhs": "Darwin" } + }, + { + "name": "qwen3-5-moe-mlx", + "displayName": "Qwen3.5 MoE runner (MLX)", + "inherits": ["qwen3-5-moe-base"], + "cacheVariables": {}, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -54,6 +65,12 @@ "displayName": "Build Qwen3.5 MoE runner and worker (Metal)", "configurePreset": "qwen3-5-moe-metal", "targets": ["qwen3_5_moe_runner", "qwen3_5_moe_worker"] + }, + { + "name": "qwen3-5-moe-mlx", + "displayName": "Build Qwen3.5 MoE runner (MLX)", + "configurePreset": "qwen3-5-moe-mlx", + "targets": ["qwen3_5_moe_runner"] } ], "workflowPresets": [ @@ -84,6 +101,20 @@ "name": "qwen3-5-moe-metal" } ] + }, + { + "name": "qwen3-5-moe-mlx", + "displayName": "Configure and build Qwen3.5 MoE runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "qwen3-5-moe-mlx" + }, + { + "type": "build", + "name": "qwen3-5-moe-mlx" + } + ] } ] } diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 4c9b533207b..c275641bfd7 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -261,7 +261,38 @@ python export.py \ | `--qembedding` | (none) | Embedding quantization: `8w` | | `--tiny-test` | off | Build tiny model with random weights for CI testing | -### Run (MLX) +### Build (MLX) + +Like the CUDA/Metal builds, the `make` target builds ExecuTorch core with the +MLX backend and the runner binary. Requires Apple Silicon (Darwin). + +```bash +make qwen3_5_moe-mlx +``` + +This builds ExecuTorch with MLX support, then the runner binary at +`cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner` (with `mlx.metallib` +copied next to it). Unlike CUDA, the MLX `.pte` is self-contained — no `.ptd` +data file is produced or needed. + +### Run (MLX, C++ runner) + +The C++ runner requires a local HuggingFace `tokenizer.json` (the MLX `.pte` and +a `tokenizer.json`; no `--data_path`): + +```bash +cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner \ + --model_path ./qwen35_moe_mlx/model.pte \ + --tokenizer_path ~/models/Qwen3.5-35B-A3B/tokenizer.json \ + --prompt "What is the capital of France?" \ + --max_new_tokens 50 +``` + +The MLX export emits a single dynamic-seq `forward` method; the runner loads and +calls it for both prefill and decode (sampling on host), matching the Python +runner. See the [Run](#run) section above for the full flag list. + +### Run (MLX, Python) ```bash python -m executorch.examples.models.qwen3_5_moe.run \ diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp index 3c5b2eec439..316b4c7e16a 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -19,6 +19,8 @@ #include #include +#include + #ifdef EXECUTORCH_BUILD_CUDA #include #include @@ -39,6 +41,20 @@ using SizesType = executorch::aten::SizesType; namespace { +#ifdef EXECUTORCH_BUILD_MLX +// The MLX export emits a single dynamic-seq `forward` method that handles both +// prefill (T>=2) and decode (T=1). Mirror gemma4_31b's MLX runner, which loads +// and calls `forward` for both phases. +constexpr const char* kPrefillMethod = "forward"; +constexpr const char* kDecodeMethod = "forward"; +// Prefill is chunked on MLX to cap peak memory and the compiled prefill shape. +constexpr int64_t kPrefillChunkSize = 1024; +#else +// CUDA/Metal exports emit two separate methods. +constexpr const char* kPrefillMethod = "prefill"; +constexpr const char* kDecodeMethod = "decode"; +#endif + Result read_sampled_token( const executorch::aten::Tensor& output, float temperature) { @@ -98,8 +114,10 @@ Result> build_qwen_module( } #endif - ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("prefill")); - ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("decode")); + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kPrefillMethod)); + if (std::string(kDecodeMethod) != std::string(kPrefillMethod)) { + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kDecodeMethod)); + } return module; } @@ -240,34 +258,51 @@ class Qwen35MoESession : public LLMSession { } stop_.store(false, std::memory_order_relaxed); - std::vector token_data(tokens.begin(), tokens.end()); - std::vector pos_data(T); - for (int64_t i = 0; i < T; ++i) { - pos_data[i] = pos_ + i; - } - auto tokens_tensor = from_blob( - token_data.data(), - {1, static_cast(T)}, - executorch::aten::ScalarType::Long); - auto pos_tensor = from_blob( - pos_data.data(), - {static_cast(T)}, - executorch::aten::ScalarType::Long); - - const char* method = (T >= 2) ? "prefill" : "decode"; - std::vector inputs; - inputs.push_back(tokens_tensor); - inputs.push_back(pos_tensor); + + // On MLX, run prefill in fixed-size chunks (caps peak memory and the + // compiled prefill shape). Other backends prefill the whole prompt in one + // pass. Only the final chunk's sampled token is kept; the recurrence/KV + // state from earlier chunks persists via pos_ advancement. +#ifdef EXECUTORCH_BUILD_MLX + const int64_t chunk_size = kPrefillChunkSize; +#else + const int64_t chunk_size = T; +#endif + + uint64_t sampled_token = 0; + for (int64_t off = 0; off < T; off += chunk_size) { + const int64_t len = std::min(chunk_size, T - off); + std::vector token_data( + tokens.begin() + off, tokens.begin() + off + len); + std::vector pos_data(len); + for (int64_t i = 0; i < len; ++i) { + pos_data[i] = pos_ + i; + } + auto tokens_tensor = from_blob( + token_data.data(), + {1, static_cast(len)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), + {static_cast(len)}, + executorch::aten::ScalarType::Long); + + const char* method = (len >= 2) ? kPrefillMethod : kDecodeMethod; + std::vector inputs; + inputs.push_back(tokens_tensor); + inputs.push_back(pos_tensor); #ifdef EXECUTORCH_BUILD_CUDA - set_temp(first_token_temp); - inputs.push_back(EValue(temp_tensor_)); + set_temp(first_token_temp); + inputs.push_back(EValue(temp_tensor_)); #endif - auto sampled = - run_locked(method, inputs, first_token_temp, /*sync_after=*/true); - ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); - pending_ = sampled.get(); + auto sampled = + run_locked(method, inputs, first_token_temp, /*sync_after=*/true); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); + sampled_token = sampled.get(); + pos_ += len; + } + pending_ = sampled_token; prev_decode_token_.reset(); - pos_ += T; return Error::Ok; } @@ -334,7 +369,7 @@ class Qwen35MoESession : public LLMSession { inputs.push_back(EValue(temp_tensor_)); #endif auto sampled = - run_locked("decode", inputs, temperature_, /*sync_after=*/false); + run_locked(kDecodeMethod, inputs, temperature_, /*sync_after=*/false); ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); pending_ = sampled.get(); prev_decode_token_ = token; From d30dc293900a552ade668671caa10689f24f2a93 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 18 Jun 2026 11:02:10 -0700 Subject: [PATCH 2/9] up --- .github/workflows/mlx.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index acc6b4840cf..54a128be4e6 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -161,6 +161,23 @@ jobs: fi echo "::endgroup::" + echo "::group::Build Qwen 3.5 MoE MLX C++ runner" + # Validates the MLX C++ runner build wiring (compile + link + metallib). + # The tiny model has no compatible tokenizer (vocab 256, random weights), + # so we don't run C++ inference here — only confirm it builds. + ${CONDA_RUN} make qwen3_5_moe-mlx + RUNNER=cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner + if [ ! -x "$RUNNER" ]; then + echo "Failed: runner not found at $RUNNER" + exit 1 + fi + if [ ! -f "$(dirname "$RUNNER")/mlx.metallib" ]; then + echo "Failed: mlx.metallib not copied next to runner" + exit 1 + fi + echo "Success: built $RUNNER" + echo "::endgroup::" + backend-tester: needs: run-decision if: | From 359bf10fc597e436970127b003d0753a24a62972 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 18 Jun 2026 11:17:37 -0700 Subject: [PATCH 3/9] up --- examples/models/qwen3_5_moe/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/qwen3_5_moe/CMakeLists.txt b/examples/models/qwen3_5_moe/CMakeLists.txt index b7a3dce14fa..726657a3779 100644 --- a/examples/models/qwen3_5_moe/CMakeLists.txt +++ b/examples/models/qwen3_5_moe/CMakeLists.txt @@ -61,7 +61,7 @@ elseif(TARGET mlxdelegate) else() message( FATAL_ERROR - "Set EXECUTORCH_BUILD_CUDA=ON, EXECUTORCH_BUILD_METAL=ON, or EXECUTORCH_BUILD_MLX=ON" + "Set EXECUTORCH_BUILD_CUDA=ON, EXECUTORCH_BUILD_METAL=ON, or EXECUTORCH_BUILD_MLX=ON" ) endif() From 5af5b195f2409e40e881e7654036a7c01b140d78 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 18 Jun 2026 11:47:27 -0700 Subject: [PATCH 4/9] up --- .github/workflows/mlx.yml | 6 ++++++ examples/models/qwen3_5_moe/export.py | 6 ++++++ .../models/qwen3_5_moe/qwen35_moe_engine.cpp | 20 ++++++++++++++++++- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 54a128be4e6..5a4ccbb4952 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -161,6 +161,12 @@ jobs: fi echo "::endgroup::" + echo "::group::Verify chunked == unchunked prefill" + QWEN_TINY_PTE=/tmp/qwen35_moe_mlx_tiny/model.pte \ + ${CONDA_RUN} python -m pytest \ + examples/models/qwen3_5_moe/test_chunked_prefill.py -v + echo "::endgroup::" + echo "::group::Build Qwen 3.5 MoE MLX C++ runner" # Validates the MLX C++ runner build wiring (compile + link + metallib). # The tiny model has no compatible tokenizer (vocab 256, random weights), diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index d7e7d9ca293..566d61e6cfc 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -768,10 +768,16 @@ def _export_mlx(model, config, args): gc.collect() print("Lowering to ExecuTorch with MLX backend...") + # Largest prefill chunk the runner may submit in one forward call. The MLX + # runner chunks long prompts to cap peak memory; bound it by the compiled + # dynamic max (max_seq_len - 1) so a chunk can never exceed what `forward` + # was compiled for. + max_prefill_chunk = min(1024, config.max_seq_len - 1) metadata = { "get_max_seq_len": config.max_seq_len, "get_vocab_size": config.vocab_size, "get_n_layers": config.num_hidden_layers, + "get_max_prefill_chunk": max_prefill_chunk, "use_kv_cache": True, "use_sdpa_with_kv_cache": False, "enable_dynamic_shape": True, diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp index 316b4c7e16a..8c7b2e23bc2 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -55,6 +55,10 @@ constexpr const char* kPrefillMethod = "prefill"; constexpr const char* kDecodeMethod = "decode"; #endif +// Constant method exported by the MLX .pte giving the largest prefill chunk the +// `forward` method was compiled for. Read into the metadata map in create(). +constexpr const char* kMaxPrefillChunk = "get_max_prefill_chunk"; + Result read_sampled_token( const executorch::aten::Tensor& output, float temperature) { @@ -264,7 +268,13 @@ class Qwen35MoESession : public LLMSession { // pass. Only the final chunk's sampled token is kept; the recurrence/KV // state from earlier chunks persists via pos_ advancement. #ifdef EXECUTORCH_BUILD_MLX - const int64_t chunk_size = kPrefillChunkSize; + // Chunk size = compiled max prefill chunk from model metadata, falling back + // to the default if the model didn't export it. Clamp to >= 1. + int64_t chunk_size = kPrefillChunkSize; + if (auto it = metadata_.find(kMaxPrefillChunk); + it != metadata_.end() && it->second > 0) { + chunk_size = it->second; + } #else const int64_t chunk_size = T; #endif @@ -492,6 +502,14 @@ Result> Qwen35MoEEngine::create( ET_LOG(Error, "Qwen35MoEEngine: failed to read metadata"); return metadata_result.error(); } +#ifdef EXECUTORCH_BUILD_MLX + // Surface the compiled max prefill chunk (a constant method get_llm_metadata + // doesn't harvest) into the metadata map so the session can chunk long + // prompts within the shape `forward` was compiled for. + if (auto mpc = meta_module->get(kMaxPrefillChunk); mpc.ok()) { + metadata_result.get()[kMaxPrefillChunk] = mpc->toScalar().to(); + } +#endif auto eos_ids = get_eos_ids(tokenizer.get(), meta_module.get()); // This export's metadata doesn't carry the chat-turn EOS (config.json has no // eos_token_id and the .pte exports no get_eos_ids method), so get_eos_ids() From b808e26481e52f0d5df03b2d189dd32f99bd075c Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 18 Jun 2026 11:58:11 -0700 Subject: [PATCH 5/9] up --- .../models/qwen3_5_moe/qwen35_moe_engine.cpp | 14 +- .../qwen3_5_moe/test_chunked_prefill.py | 121 ++++++++++++++++++ 2 files changed, 130 insertions(+), 5 deletions(-) create mode 100644 examples/models/qwen3_5_moe/test_chunked_prefill.py diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp index 8c7b2e23bc2..713f6211330 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -47,8 +47,6 @@ namespace { // and calls `forward` for both phases. constexpr const char* kPrefillMethod = "forward"; constexpr const char* kDecodeMethod = "forward"; -// Prefill is chunked on MLX to cap peak memory and the compiled prefill shape. -constexpr int64_t kPrefillChunkSize = 1024; #else // CUDA/Metal exports emit two separate methods. constexpr const char* kPrefillMethod = "prefill"; @@ -268,9 +266,15 @@ class Qwen35MoESession : public LLMSession { // pass. Only the final chunk's sampled token is kept; the recurrence/KV // state from earlier chunks persists via pos_ advancement. #ifdef EXECUTORCH_BUILD_MLX - // Chunk size = compiled max prefill chunk from model metadata, falling back - // to the default if the model didn't export it. Clamp to >= 1. - int64_t chunk_size = kPrefillChunkSize; + // Chunk size: default to the compiled max (kMaxSeqLen - 1), overridden by + // the exported get_max_prefill_chunk constant when present (mirrors + // gemma4_31b). Falls back to T (single pass) if no metadata is available at + // all. + int64_t chunk_size = T; + if (auto it = metadata_.find(kMaxSeqLen); + it != metadata_.end() && it->second > 1) { + chunk_size = it->second - 1; + } if (auto it = metadata_.find(kMaxPrefillChunk); it != metadata_.end() && it->second > 0) { chunk_size = it->second; diff --git a/examples/models/qwen3_5_moe/test_chunked_prefill.py b/examples/models/qwen3_5_moe/test_chunked_prefill.py new file mode 100644 index 00000000000..0206b9f4649 --- /dev/null +++ b/examples/models/qwen3_5_moe/test_chunked_prefill.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Chunked-vs-unchunked prefill equivalence test for the MLX qwen3.5 MoE .pte. + +The MLX C++ runner chunks long prompts and carries the recurrent/conv state and +KV cache across chunk boundaries (qwen35_moe_engine.cpp prefill_tokens). Chunk +boundaries are easy to get subtly wrong, so this test asserts that feeding a +prompt as several sequential `forward` calls produces the same final-position +logits (and same greedy first token) as a single `forward` call. + +It runs against an already-exported tiny MLX .pte (no tokenizer needed: random +token ids). Point it at the .pte via the QWEN_TINY_PTE env var, e.g.: + + python -m executorch.examples.models.qwen3_5_moe.export \ + --tiny-test --backend mlx --qlinear 4w --qlinear-group-size 32 \ + --output-dir /tmp/qwen35_moe_mlx_tiny + QWEN_TINY_PTE=/tmp/qwen35_moe_mlx_tiny/model.pte \ + python -m pytest examples/models/qwen3_5_moe/test_chunked_prefill.py -v + +The test skips (rather than fails) when the .pte env var is unset or the MLX +runtime is unavailable, so it is a no-op on non-MLX machines. +""" + +import os +import unittest + +import torch + +PTE_ENV = "QWEN_TINY_PTE" + + +def _load_forward(pte_path): + """Load a fresh program instance so mutable state starts zeroed.""" + from executorch.runtime import Runtime, Verification + + runtime = Runtime.get() + program = runtime.load_program(pte_path, verification=Verification.Minimal) + return program, program.load_method("forward") + + +def _scalar_metadata(program, name, default): + try: + result = program.load_method(name).execute([]) + except Exception: + return default + v = result[0] + return int(v) if isinstance(v, int) else int(v.item()) + + +def _last_logits(outputs): + # forward returns logits shaped (1, T, vocab); take the final position. + return outputs[0][0, -1, :] + + +class TestChunkedPrefill(unittest.TestCase): + def setUp(self): + self.pte_path = os.environ.get(PTE_ENV) + if not self.pte_path: + self.skipTest(f"{PTE_ENV} not set; export a tiny MLX .pte first") + if not os.path.exists(self.pte_path): + self.skipTest(f"{PTE_ENV}={self.pte_path} does not exist") + try: + import executorch.runtime # noqa: F401 + except Exception as e: # pragma: no cover - environment dependent + self.skipTest(f"executorch.runtime unavailable: {e}") + + def test_chunked_prefill_matches_unchunked(self): + # Read shapes from the model's constant methods. + program, _ = _load_forward(self.pte_path) + vocab_size = _scalar_metadata(program, "get_vocab_size", 256) + max_seq_len = _scalar_metadata(program, "get_max_seq_len", 64) + del program + + prompt_len = min(40, max_seq_len - 1) + chunk = 8 + self.assertGreater( + prompt_len, + chunk, + "prompt must exceed chunk size to exercise multiple chunks", + ) + + torch.manual_seed(0) + tokens = torch.randint(0, vocab_size, (1, prompt_len), dtype=torch.long) + + # Unchunked: one forward over the whole prompt (fresh program/state). + _, forward_full = _load_forward(self.pte_path) + pos_full = torch.arange(prompt_len, dtype=torch.long) + logits_full = _last_logits(forward_full.execute([tokens, pos_full])) + + # Chunked: sequential forwards advancing input_pos, carrying state across + # boundaries (fresh program/state). + _, forward_chunk = _load_forward(self.pte_path) + logits_chunk = None + for off in range(0, prompt_len, chunk): + end = min(off + chunk, prompt_len) + chunk_tokens = tokens[:, off:end] + chunk_pos = torch.arange(off, end, dtype=torch.long) + logits_chunk = _last_logits( + forward_chunk.execute([chunk_tokens, chunk_pos]) + ) + + # Same greedy first token, and logits numerically close. + self.assertEqual( + int(torch.argmax(logits_full)), + int(torch.argmax(logits_chunk)), + "chunked prefill produced a different first token than unchunked", + ) + torch.testing.assert_close( + logits_chunk.to(torch.float32), + logits_full.to(torch.float32), + rtol=1e-2, + atol=1e-2, + ) + + +if __name__ == "__main__": + unittest.main() From c806c001de309144837b3cb8d6eaead2b3b9736e Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 18 Jun 2026 13:50:29 -0700 Subject: [PATCH 6/9] up --- examples/models/qwen3_5_moe/CMakePresets.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/models/qwen3_5_moe/CMakePresets.json b/examples/models/qwen3_5_moe/CMakePresets.json index 99786f424de..276c2116148 100644 --- a/examples/models/qwen3_5_moe/CMakePresets.json +++ b/examples/models/qwen3_5_moe/CMakePresets.json @@ -41,7 +41,9 @@ "name": "qwen3-5-moe-mlx", "displayName": "Qwen3.5 MoE runner (MLX)", "inherits": ["qwen3-5-moe-base"], - "cacheVariables": {}, + "cacheVariables": { + "EXECUTORCH_BUILD_MLX": "ON" + }, "condition": { "type": "equals", "lhs": "${hostSystemName}", From 2bb145ab959ee248166f76b26fe88e84dfe19e55 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 18 Jun 2026 16:47:34 -0700 Subject: [PATCH 7/9] up --- .github/workflows/mlx.yml | 6 +- backends/mlx/CMakeLists.txt | 6 +- backends/mlx/runtime/MLXBackend.cpp | 16 ++ backends/mlx/runtime/mlx_mutable_state.cpp | 268 ++++++++++++++++++ backends/mlx/runtime/mlx_mutable_state.h | 190 +++++++++++++ backends/mlx/test/CMakeLists.txt | 19 ++ backends/mlx/test/mlx_mutable_state_test.cpp | 131 +++++++++ examples/models/qwen3_5_moe/CMakeLists.txt | 1 + examples/models/qwen3_5_moe/CMakePresets.json | 4 +- examples/models/qwen3_5_moe/README.md | 57 ++++ .../models/qwen3_5_moe/qwen35_moe_engine.cpp | 60 ++-- .../models/qwen3_5_moe/qwen35_moe_engine.h | 30 +- 12 files changed, 747 insertions(+), 41 deletions(-) create mode 100644 backends/mlx/runtime/mlx_mutable_state.cpp create mode 100644 backends/mlx/runtime/mlx_mutable_state.h create mode 100644 backends/mlx/test/mlx_mutable_state_test.cpp diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 5a4ccbb4952..167ceb7da83 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -66,7 +66,11 @@ jobs: echo "::endgroup::" echo "::group::Build test runners" - ${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) + ${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner mlx_mutable_state_test -j$(( $(sysctl -n hw.ncpu) - 1 )) + echo "::endgroup::" + + echo "::group::Run mutable-state (multi-session) unit test" + ./cmake-out/backends/mlx/test/mlx_mutable_state_test echo "::endgroup::" echo "::group::Run op unit tests" diff --git a/backends/mlx/CMakeLists.txt b/backends/mlx/CMakeLists.txt index 43968d09b5d..acb96fb1ed9 100644 --- a/backends/mlx/CMakeLists.txt +++ b/backends/mlx/CMakeLists.txt @@ -255,8 +255,10 @@ option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION ON ) -set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp +set(_mlx_backend__srcs + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/mlx_mutable_state.cpp ) add_library(mlxdelegate ${_mlx_backend__srcs}) diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 5bd3bf263d1..0dbdec22436 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -9,6 +9,7 @@ #include "MLXExecutor.h" #include "MLXInterpreter.h" #include "MLXLoader.h" +#include "mlx_mutable_state.h" #include #include @@ -277,6 +278,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { eval(handle->constants.tensors); } + // Register the handle with the per-session mutable-state manager. This is + // a no-op unless a multi-session owner is active for this load (see + // mlx_mutable_state.h); single-session execution is unaffected. + mutable_state_note_handle( + handle, &handle->program, &handle->mutable_buffers); + } catch (const std::exception& e) { ET_LOG(Error, "Failed to load MLX program: %s", e.what()); handle->~MLXHandle(); @@ -366,6 +373,14 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { } } + // Select the active session's mutable buffers (KV cache, recurrent/conv + // state) before running. No-op for single-session handles; weights stay + // shared via ExecutionState::constants. + if (Error rebind_err = mutable_state_rebind_for_execute(h, h->state); + rebind_err != Error::Ok) { + return rebind_err; + } + // Run the MLX program (builds lazy computation graph) h->interpreter.run(program, h->state, h->stream); @@ -431,6 +446,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { void destroy(DelegateHandle* handle) const override { std::lock_guard lock(mlx_global_mutex()); if (handle != nullptr) { + mutable_state_forget_handle(handle); auto* mlx_handle = static_cast(handle); mlx_handle->~MLXHandle(); } diff --git a/backends/mlx/runtime/mlx_mutable_state.cpp b/backends/mlx/runtime/mlx_mutable_state.cpp new file mode 100644 index 00000000000..429f3fea5da --- /dev/null +++ b/backends/mlx/runtime/mlx_mutable_state.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "mlx_mutable_state.h" + +#include "MLXExecutor.h" +#include "MLXLoader.h" + +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { + +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +namespace { + +struct HandleInfo { + const MLXProgram* program{nullptr}; + MutableBufferData* default_buffers{nullptr}; +}; + +struct Context { + // Delegate handles associated with this loaded program (one per loaded + // method). Keyed by opaque MLXHandle pointer. + std::unordered_map handles; + // Per-session mutable buffers: token -> (handle -> buffers). Allocated lazily + // on first execute for a given (session, handle). + std::unordered_map> + sessions; + int next_token{0}; +}; + +// Process-global registry. MLX serializes execution via its own global mutex and +// the engine serializes per session, but the registry itself is guarded here so +// context/session lifecycle calls from other threads are safe. +std::mutex& registry_mutex() { + static std::mutex m; + return m; +} + +std::unordered_map& contexts() { + static std::unordered_map c; + return c; +} + +std::unordered_map& handle_ctx() { + static std::unordered_map m; + return m; +} + +MutableStateContext g_next_ctx = 1; // 0 is reserved as invalid. + +// Thread-local load scope and active (ctx, session) selection. +thread_local MutableStateContext tl_loading_ctx = kInvalidMutableContext; +thread_local MutableStateContext tl_active_ctx = kInvalidMutableContext; +thread_local int tl_active_token = kNoMutableSession; + +} // namespace + +namespace detail { + +MutableStateContext mutable_state_create_context() { + std::lock_guard g(registry_mutex()); + MutableStateContext ctx = g_next_ctx++; + if (ctx == kInvalidMutableContext) { + ctx = g_next_ctx++; + } + contexts()[ctx]; + return ctx; +} + +void mutable_state_destroy_context(MutableStateContext ctx) { + std::lock_guard g(registry_mutex()); + auto it = contexts().find(ctx); + if (it == contexts().end()) { + return; + } + for (const auto& kv : it->second.handles) { + handle_ctx().erase(kv.first); + } + contexts().erase(it); +} + +void mutable_state_begin_load(MutableStateContext ctx) { + tl_loading_ctx = ctx; +} + +void mutable_state_end_load() { + tl_loading_ctx = kInvalidMutableContext; +} + +bool mutable_state_available(MutableStateContext ctx) { + if (ctx == kInvalidMutableContext) { + return false; + } + std::lock_guard g(registry_mutex()); + return contexts().count(ctx) != 0; +} + +int64_t mutable_state_bytes_per_session(MutableStateContext ctx) { + std::lock_guard g(registry_mutex()); + auto it = contexts().find(ctx); + if (it == contexts().end()) { + return 0; + } + int64_t total = 0; + for (const auto& kv : it->second.handles) { + const MutableBufferData* bufs = kv.second.default_buffers; + if (bufs == nullptr) { + continue; + } + for (const auto& t : bufs->tensors) { + if (t.has_value()) { + total += static_cast(t->nbytes()); + } + } + } + return total; +} + +Error mutable_state_validate_coverage(MutableStateContext ctx) { + // MLX clones all mutable buffers by tid; there is no FQN coverage to verify. + (void)ctx; + return Error::Ok; +} + +Result mutable_state_create_session(MutableStateContext ctx) { + std::lock_guard g(registry_mutex()); + auto it = contexts().find(ctx); + if (it == contexts().end()) { + ET_LOG(Error, "mutable_state_create_session: unknown context %d", ctx); + return Error::InvalidState; + } + int token = it->second.next_token++; + // Per-handle buffers are allocated lazily on first execute. + it->second.sessions[token]; + return token; +} + +void mutable_state_destroy_session(MutableStateContext ctx, int token) { + std::lock_guard g(registry_mutex()); + auto it = contexts().find(ctx); + if (it == contexts().end()) { + return; + } + it->second.sessions.erase(token); +} + +void mutable_state_set_active(MutableStateContext ctx, int token) { + tl_active_ctx = ctx; + tl_active_token = token; +} + +} // namespace detail + +void mutable_state_note_handle( + const void* handle, + const MLXProgram* program, + MutableBufferData* default_buffers) { + if (tl_loading_ctx == kInvalidMutableContext) { + return; // No multi-session owner active during this load: single-session. + } + std::lock_guard g(registry_mutex()); + auto it = contexts().find(tl_loading_ctx); + if (it == contexts().end()) { + return; + } + it->second.handles[handle] = HandleInfo{program, default_buffers}; + handle_ctx()[handle] = tl_loading_ctx; +} + +void mutable_state_forget_handle(const void* handle) { + std::lock_guard g(registry_mutex()); + auto hit = handle_ctx().find(handle); + if (hit == handle_ctx().end()) { + return; + } + auto cit = contexts().find(hit->second); + if (cit != contexts().end()) { + cit->second.handles.erase(handle); + for (auto& session : cit->second.sessions) { + session.second.erase(handle); + } + } + handle_ctx().erase(hit); +} + +Error mutable_state_rebind_for_execute( + const void* handle, + ExecutionState& state) { + std::lock_guard g(registry_mutex()); + auto hit = handle_ctx().find(handle); + if (hit == handle_ctx().end()) { + // Handle was not loaded under a multi-session owner: keep default buffers. + return Error::Ok; + } + auto cit = contexts().find(hit->second); + if (cit == contexts().end()) { + return Error::Ok; + } + Context& ctx = cit->second; + HandleInfo& info = ctx.handles[handle]; + + const bool active_for_this_ctx = + tl_active_token != kNoMutableSession && tl_active_ctx == hit->second; + + if (!active_for_this_ctx) { + // No session selected. Refuse if sessions exist (running against the default + // buffers here would not isolate state from created sessions). + if (!ctx.sessions.empty()) { + ET_LOG( + Error, + "mutable_state_rebind_for_execute: no active session selected but " + "sessions exist for this program"); + return Error::InvalidState; + } + state.mutable_buffers = info.default_buffers; + return Error::Ok; + } + + auto sit = ctx.sessions.find(tl_active_token); + if (sit == ctx.sessions.end()) { + ET_LOG( + Error, + "mutable_state_rebind_for_execute: unknown session token %d", + tl_active_token); + return Error::InvalidState; + } + + auto& per_handle = sit->second; + auto bit = per_handle.find(handle); + if (bit == per_handle.end()) { + // First execute for this (session, handle): allocate fresh zeroed buffers. + // Constants/weights stay shared (ExecutionState::constants is untouched); + // only the mutable buffers are per-session. + MutableBufferData buffers; + try { + load_mutable_buffers(*info.program, buffers); + } catch (const std::exception& e) { + ET_LOG( + Error, + "mutable_state_rebind_for_execute: failed to allocate session " + "buffers: %s", + e.what()); + return Error::MemoryAllocationFailed; + } + bit = per_handle.emplace(handle, std::move(buffers)).first; + } + // unordered_map keeps element pointers stable across rehash, so this remains + // valid for the duration of the execute. + state.mutable_buffers = &bit->second; + return Error::Ok; +} + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/mlx_mutable_state.h b/backends/mlx/runtime/mlx_mutable_state.h new file mode 100644 index 00000000000..250b70c5a6b --- /dev/null +++ b/backends/mlx/runtime/mlx_mutable_state.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +// MLX-private support for running one loaded MLX program with multiple isolated +// instances of its mutable buffers (KV cache, conv/recurrent state). Callers +// create sessions and execute with one active session selected. +// +// Unlike the CUDA backend, the MLX runtime owns mutable buffers directly in a +// swappable container (ExecutionState::mutable_buffers is a MutableBufferData*), +// so per-session isolation is a pointer swap to a freshly zero-allocated +// MutableBufferData — no FQN registration / constant-repoint hook is needed. + +namespace executorch { +namespace backends { +namespace mlx { + +// Forward declarations (defined in MLXLoader.h / MLXExecutor.h). +struct MLXProgram; +struct MutableBufferData; +struct ExecutionState; + +// Opaque per-loaded-program context id (0 = invalid). +using MutableStateContext = int; +constexpr MutableStateContext kInvalidMutableContext = 0; + +// Sentinel for execution without per-session rebinding. +constexpr int kNoMutableSession = -1; + +// Implementation entry points. Callers should use MutableStateContextOwner. +namespace detail { + +MutableStateContext mutable_state_create_context(); +void mutable_state_destroy_context(MutableStateContext ctx); +void mutable_state_begin_load(MutableStateContext ctx); +void mutable_state_end_load(); +bool mutable_state_available(MutableStateContext ctx); +int64_t mutable_state_bytes_per_session(MutableStateContext ctx); +::executorch::runtime::Error mutable_state_validate_coverage( + MutableStateContext ctx); +::executorch::runtime::Result mutable_state_create_session( + MutableStateContext ctx); +void mutable_state_destroy_session(MutableStateContext ctx, int token); +void mutable_state_set_active(MutableStateContext ctx, int token); + +} // namespace detail + +// Caller-facing owner for one mutable-state context. Mirrors the CUDA backend's +// MutableStateContextOwner so the example engine can use a symmetric API. +class ET_EXPERIMENTAL MutableStateContextOwner final { + class LoadScope final { + public: + explicit LoadScope(MutableStateContext ctx) { + detail::mutable_state_begin_load(ctx); + } + + ~LoadScope() { + detail::mutable_state_end_load(); + } + + LoadScope(const LoadScope&) = delete; + LoadScope& operator=(const LoadScope&) = delete; + }; + + class ActiveSessionScope final { + public: + ActiveSessionScope(MutableStateContext ctx, int token) { + detail::mutable_state_set_active(ctx, token); + } + + ~ActiveSessionScope() { + detail::mutable_state_set_active( + kInvalidMutableContext, kNoMutableSession); + } + + ActiveSessionScope(const ActiveSessionScope&) = delete; + ActiveSessionScope& operator=(const ActiveSessionScope&) = delete; + }; + + public: + MutableStateContextOwner() : ctx_(detail::mutable_state_create_context()) {} + + ~MutableStateContextOwner() { + destroy(); + } + + MutableStateContextOwner(const MutableStateContextOwner&) = delete; + MutableStateContextOwner& operator=(const MutableStateContextOwner&) = delete; + + MutableStateContextOwner(MutableStateContextOwner&& other) noexcept + : ctx_(std::exchange(other.ctx_, kInvalidMutableContext)) {} + + MutableStateContextOwner& operator=( + MutableStateContextOwner&& other) noexcept { + if (this != &other) { + destroy(); + ctx_ = std::exchange(other.ctx_, kInvalidMutableContext); + } + return *this; + } + + MutableStateContext get() const { + return ctx_; + } + + explicit operator bool() const { + return ctx_ != kInvalidMutableContext; + } + + // Associates delegate handles created by `fn` with this context. + template + auto with_load_scope(Fn&& fn) const -> decltype(std::forward(fn)()) { + LoadScope scope(ctx_); + return std::forward(fn)(); + } + + // Selects this context/session while `fn` executes. The caller is responsible + // for serializing execution that touches the same loaded program. + template + auto with_active_session(int token, Fn&& fn) const + -> decltype(std::forward(fn)()) { + ActiveSessionScope scope(ctx_, token); + return std::forward(fn)(); + } + + bool available() const { + return detail::mutable_state_available(ctx_); + } + + int64_t bytes_per_session() const { + return detail::mutable_state_bytes_per_session(ctx_); + } + + ::executorch::runtime::Error validate_coverage() const { + return detail::mutable_state_validate_coverage(ctx_); + } + + ::executorch::runtime::Result create_session() const { + return detail::mutable_state_create_session(ctx_); + } + + void destroy_session(int token) const { + detail::mutable_state_destroy_session(ctx_, token); + } + + private: + void destroy() { + if (ctx_ != kInvalidMutableContext) { + detail::mutable_state_destroy_context(ctx_); + ctx_ = kInvalidMutableContext; + } + } + + MutableStateContext ctx_ = kInvalidMutableContext; +}; + +// --- MLXBackend hooks -------------------------------------------------------- +// +// Called from MLXBackend init/execute/destroy. `handle` is an opaque key (the +// MLXHandle pointer). `program` and `default_buffers` are the handle's own +// program and (init-time) mutable buffers; the manager swaps in per-session +// buffers (or restores the default) by re-pointing `state.mutable_buffers`. + +void mutable_state_note_handle( + const void* handle, + const MLXProgram* program, + MutableBufferData* default_buffers); + +void mutable_state_forget_handle(const void* handle); + +::executorch::runtime::Error mutable_state_rebind_for_execute( + const void* handle, + ExecutionState& state); + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/test/CMakeLists.txt b/backends/mlx/test/CMakeLists.txt index 39024639d1d..c518b2a232d 100644 --- a/backends/mlx/test/CMakeLists.txt +++ b/backends/mlx/test/CMakeLists.txt @@ -69,3 +69,22 @@ if(EXECUTORCH_MLX_ENABLE_SANITIZERS) multi_thread_test_runner PRIVATE ${_mlx_sanitizer_link_options} ) endif() + +# Per-session mutable-state manager unit test (no model/tokenizer needed). +add_executable(mlx_mutable_state_test mlx_mutable_state_test.cpp) +target_include_directories( + mlx_mutable_state_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../runtime +) +target_link_libraries( + mlx_mutable_state_test PRIVATE mlxdelegate mlx executorch_core +) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_compile_options( + mlx_mutable_state_test PRIVATE -fsanitize=address,undefined + -fno-omit-frame-pointer + ) + target_link_options( + mlx_mutable_state_test PRIVATE ${_mlx_sanitizer_link_options} + ) +endif() +add_test(NAME mlx_mutable_state COMMAND mlx_mutable_state_test) diff --git a/backends/mlx/test/mlx_mutable_state_test.cpp b/backends/mlx/test/mlx_mutable_state_test.cpp new file mode 100644 index 00000000000..ef34962c998 --- /dev/null +++ b/backends/mlx/test/mlx_mutable_state_test.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Unit test for the MLX per-session mutable-state manager +// (backends/mlx/runtime/mlx_mutable_state.{h,cpp}). +// +// Verifies that two sessions created on one loaded program get independent +// mutable buffers: writing into session A's buffer does not leak into session +// B's, and A's value persists across a rebind to B and back. This is the MLX +// analogue of the CUDA "no-bleed" guarantee, exercised directly on the manager +// (no model or tokenizer needed). + +#include "MLXExecutor.h" +#include "MLXLoader.h" +#include "mlx_mutable_state.h" + +#include + +#include + +using namespace ::executorch::backends::mlx; + +namespace { + +int g_failures = 0; + +#define CHECK(cond) \ + do { \ + if (!(cond)) { \ + std::printf("FAIL: %s (line %d)\n", #cond, __LINE__); \ + ++g_failures; \ + } \ + } while (0) + +// Build a minimal program with a single 1-element float mutable buffer at tid 0. +MLXProgram make_program() { + MLXProgram program; + program.num_mutable_buffer_tensors = 1; + program.mutable_buffer_map.push_back(SlotVariant{0, SlotType::TensorSlot}); + TensorMeta meta; + meta.shape.push_back(ShapeDim{/*value=*/1}); + meta.scalar_type = ScalarType::Float; + program.tensor_meta.resize(1); + program.tensor_meta[0] = meta; + return program; +} + +float read0(const MutableBufferData& bufs) { + auto arr = bufs.get(Tid{0}); + ::mlx::core::eval(arr); + return arr.item(); +} + +} // namespace + +int main() { + MLXProgram program = make_program(); + + // Handle's default (init-time) mutable buffers. + MutableBufferData default_bufs; + load_mutable_buffers(program, default_bufs); + + int dummy = 0; + const void* handle = &dummy; + + MutableStateContextOwner owner; + CHECK(static_cast(owner)); + + // Associate the handle with the context (as MLXBackend::init would). + owner.with_load_scope( + [&]() { mutable_state_note_handle(handle, &program, &default_bufs); }); + + CHECK(owner.available()); + CHECK(owner.bytes_per_session() == static_cast(sizeof(float))); + + auto tokA = owner.create_session(); + auto tokB = owner.create_session(); + CHECK(tokA.ok()); + CHECK(tokB.ok()); + CHECK(tokA.get() != tokB.get()); + + ExecutionState state; + + // Session A: rebind, then write a marker (7.0) into its buffer. + owner.with_active_session(tokA.get(), [&]() { + auto err = mutable_state_rebind_for_execute(handle, state); + CHECK(err == ::executorch::runtime::Error::Ok); + state.mutable_buffers->set( + Tid{0}, ::mlx::core::full({1}, 7.0f, ::mlx::core::float32)); + return err; + }); + + // Session B: a fresh rebind must see zeros, not A's marker. + owner.with_active_session(tokB.get(), [&]() { + auto err = mutable_state_rebind_for_execute(handle, state); + CHECK(err == ::executorch::runtime::Error::Ok); + CHECK(read0(*state.mutable_buffers) == 0.0f); + return err; + }); + + // Back to session A: the marker must persist (isolation, no bleed). + owner.with_active_session(tokA.get(), [&]() { + auto err = mutable_state_rebind_for_execute(handle, state); + CHECK(err == ::executorch::runtime::Error::Ok); + CHECK(read0(*state.mutable_buffers) == 7.0f); + return err; + }); + + // With sessions present, executing without an active session is refused + // (prevents running against unmanaged/shared state). + { + auto err = mutable_state_rebind_for_execute(handle, state); + CHECK(err == ::executorch::runtime::Error::InvalidState); + } + + owner.destroy_session(tokA.get()); + owner.destroy_session(tokB.get()); + mutable_state_forget_handle(handle); + + if (g_failures == 0) { + std::printf("OK: mlx_mutable_state isolation test passed\n"); + return 0; + } + std::printf("FAILED: %d checks\n", g_failures); + return 1; +} diff --git a/examples/models/qwen3_5_moe/CMakeLists.txt b/examples/models/qwen3_5_moe/CMakeLists.txt index 726657a3779..aeb97f76ab7 100644 --- a/examples/models/qwen3_5_moe/CMakeLists.txt +++ b/examples/models/qwen3_5_moe/CMakeLists.txt @@ -89,6 +89,7 @@ endif() if(TARGET mlxdelegate) executorch_target_copy_mlx_metallib(qwen3_5_moe_runner) + executorch_target_copy_mlx_metallib(qwen3_5_moe_worker) endif() if(EXECUTORCH_BUILD_CUDA) diff --git a/examples/models/qwen3_5_moe/CMakePresets.json b/examples/models/qwen3_5_moe/CMakePresets.json index 276c2116148..6adcb8aa9cb 100644 --- a/examples/models/qwen3_5_moe/CMakePresets.json +++ b/examples/models/qwen3_5_moe/CMakePresets.json @@ -70,9 +70,9 @@ }, { "name": "qwen3-5-moe-mlx", - "displayName": "Build Qwen3.5 MoE runner (MLX)", + "displayName": "Build Qwen3.5 MoE runner and worker (MLX)", "configurePreset": "qwen3-5-moe-mlx", - "targets": ["qwen3_5_moe_runner"] + "targets": ["qwen3_5_moe_runner", "qwen3_5_moe_worker"] } ], "workflowPresets": [ diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index c275641bfd7..65e3d3c38f1 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -302,6 +302,63 @@ python -m executorch.examples.models.qwen3_5_moe.run \ --max-new-tokens 50 ``` +### Serving (MLX, multi-session) + +The MLX worker hosts multiple isolated sessions on **one** weight load, so an +OpenAI-compatible server can serve concurrent conversations without duplicating +the ~weights. `make qwen3_5_moe-mlx` builds both `qwen3_5_moe_runner` and +`qwen3_5_moe_worker` (each with `mlx.metallib` copied alongside). + +Start the server (it auto-locates the worker binary): + +```bash +# tokenizer.json the C++ worker opens (resolve from the HF cache) +TOKENIZER_JSON=$(ls "${HF_HOME:-$HOME/.cache/huggingface}"/hub/models--Qwen--Qwen3.5-35B-A3B/snapshots/*/tokenizer.json | head -n1) + +python -m executorch.examples.models.qwen3_5_moe.serve \ + --model-path ./qwen35_moe_mlx/model.pte \ + --tokenizer-path "$TOKENIZER_JSON" \ + --hf-tokenizer Qwen/Qwen3.5-35B-A3B \ + --max-sessions 4 \ + --host 127.0.0.1 \ + --port 8000 +``` + +- `--tokenizer-path` is the raw `tokenizer.json` **file** the worker loads; + `--hf-tokenizer` (HF id or local dir) supplies the chat template on the Python + side. No `--data-path` (the MLX `.pte` is self-contained). +- `--max-sessions N` caps physical sessions on the single weight load. One slot + is reserved for anonymous requests (requests sent without a session id), so + `N` allows `N-1` concurrently named sessions. + +Query it (OpenAI-compatible) from another terminal. Route each conversation to a +session with the `session_id` header: + +```bash +curl http://127.0.0.1:8000/v1/chat/completions \ + -H "Content-Type: application/json" -H "session_id: alice" \ + -d '{"model":"qwen3.5-moe", + "messages":[{"role":"user","content":"What is the capital of France?"}], + "max_tokens":50,"chat_template_kwargs":{"enable_thinking":false}}' +``` + +Endpoints: `GET /health`, `GET /v1/models`, `POST /v1/chat/completions`, +`DELETE /v1/sessions/{id}` (free a session + its slot), `POST /v1/sessions/{id}/reset`. + +Session/memory semantics on MLX: +- This server uses the standard **stateless** OpenAI contract — send the full + `messages` history each request. `session_id` + warm-resume is a KV-cache reuse + optimization for the shared prefix, not server-side memory. +- Each session adds **one** set of mutable buffers (KV + recurrent/conv state) on + top of the shared weights; per-session cost scales with `max_seq_len`. Weights + are never duplicated. +- KV persists across requests for a live session and is **released on close** + (`DELETE`/reset). Named sessions are not auto-closed — close them to free slots. + MLX's Metal allocator pools freed buffers (so RSS may not shrink immediately), + but they are reused by later sessions, keeping memory bounded. +- Sessions interleave rather than run in parallel (MLX serializes GPU dispatch via + a global mutex). + ### Tiny Model Test For CI or quick pipeline validation (no model download needed): diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp index 713f6211330..6a6f03918b1 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -183,9 +183,9 @@ class Qwen35MoESession : public LLMSession { ::tokenizers::Tokenizer* tokenizer, std::unordered_map metadata, std::unordered_set eos_ids -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , - ::executorch::backends::cuda::MutableStateContextOwner* mutable_state, + MutableStateContextOwner* mutable_state, int session_token #endif ) @@ -195,7 +195,7 @@ class Qwen35MoESession : public LLMSession { tokenizer_(tokenizer), metadata_(std::move(metadata)), eos_ids_(std::move(eos_ids)) -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , mutable_state_(mutable_state), session_token_(session_token) @@ -212,9 +212,8 @@ class Qwen35MoESession : public LLMSession { } ~Qwen35MoESession() override { -#ifdef EXECUTORCH_BUILD_CUDA - if (mutable_state_ != nullptr && - session_token_ != ::executorch::backends::cuda::kNoMutableSession) { +#ifdef QWEN_HAS_MUTABLE_STATE + if (mutable_state_ != nullptr && session_token_ != kNoMutableSession) { mutable_state_->destroy_session(session_token_); } #endif @@ -425,8 +424,8 @@ class Qwen35MoESession : public LLMSession { float temperature, bool sync_after) { std::lock_guard guard(*exec_mutex_); -#ifdef EXECUTORCH_BUILD_CUDA - Result> res = mutable_state_ != nullptr +#ifdef QWEN_HAS_MUTABLE_STATE + auto res = mutable_state_ != nullptr ? mutable_state_->with_active_session( session_token_, [&]() { return module_->execute(method, inputs); }) @@ -465,10 +464,11 @@ class Qwen35MoESession : public LLMSession { int64_t decode_pos_data_[1] = {0}; TensorPtr decode_tokens_; TensorPtr decode_pos_; +#ifdef QWEN_HAS_MUTABLE_STATE + MutableStateContextOwner* mutable_state_ = nullptr; + int session_token_ = kNoMutableSession; +#endif #ifdef EXECUTORCH_BUILD_CUDA - ::executorch::backends::cuda::MutableStateContextOwner* mutable_state_ = - nullptr; - int session_token_ = ::executorch::backends::cuda::kNoMutableSession; float temp_val_ = 1e-6f; TensorPtr temp_tensor_; #endif @@ -529,17 +529,17 @@ Result> Qwen35MoEEngine::create( "not stop at end of turn"); } +#ifdef QWEN_HAS_MUTABLE_STATE + std::unique_ptr mutable_state; +#endif #ifdef EXECUTORCH_BUILD_CUDA - std::unique_ptr<::executorch::backends::cuda::MutableStateContextOwner> - mutable_state; if (config.enable_cuda_graph) { ET_LOG( Info, "Qwen35MoEEngine: CUDA graph requested; per-session rebinding disabled " "and serving capacity clamped to 1 session."); } else { - auto candidate = std::make_unique< - ::executorch::backends::cuda::MutableStateContextOwner>(); + auto candidate = std::make_unique(); if (Error e = register_mutable_fqns(meta_module.get(), *candidate); e == Error::Ok) { mutable_state = std::move(candidate); @@ -550,9 +550,13 @@ Result> Qwen35MoEEngine::create( "serving capacity clamped to 1 session."); } } +#elif defined(EXECUTORCH_BUILD_MLX) + // MLX owns mutable buffers directly and clones them per session; no FQN + // registration or coverage check is required. + mutable_state = std::make_unique(); #endif -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE auto module_res = mutable_state != nullptr ? mutable_state->with_load_scope( [&]() { return build_qwen_module(config); }) @@ -566,16 +570,14 @@ Result> Qwen35MoEEngine::create( std::unique_ptr shared_module = std::move(module_res.get()); bool rebind_available = false; -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE rebind_available = mutable_state != nullptr && mutable_state->available(); - if (rebind_available) { - if (mutable_state->validate_coverage() != Error::Ok) { - ET_LOG( - Error, - "Qwen35MoEEngine: mutable-buffer coverage check failed; disabling " - "multi-session (capacity clamped to 1)."); - rebind_available = false; - } + if (rebind_available && mutable_state->validate_coverage() != Error::Ok) { + ET_LOG( + Error, + "Qwen35MoEEngine: mutable-buffer coverage check failed; disabling " + "multi-session (capacity clamped to 1)."); + rebind_available = false; } if (!rebind_available) { ET_LOG( @@ -592,7 +594,7 @@ Result> Qwen35MoEEngine::create( std::move(eos_ids), std::move(shared_module), rebind_available -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , std::move(mutable_state) #endif @@ -621,7 +623,7 @@ Result> Qwen35MoEEngine::create_session() { } int token = -1; // kNoMutableSession: single-session / no rebind -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE if (rebind_available_) { auto t = mutable_state_->create_session(); if (t.error() != Error::Ok) { @@ -638,7 +640,7 @@ Result> Qwen35MoEEngine::create_session() { tokenizer_.get(), metadata_, eos_ids_ -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , mutable_state_.get(), token @@ -648,7 +650,7 @@ Result> Qwen35MoEEngine::create_session() { LLMServingCapacity Qwen35MoEEngine::serving_capacity() const { LLMServingCapacity cap; // default: 1 session, 0 bytes (unknown) -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE if (rebind_available_) { cap.max_physical_sessions_without_weight_duplication = config_.max_sessions > 1 ? config_.max_sessions : 1; diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.h b/examples/models/qwen3_5_moe/qwen35_moe_engine.h index c7ea53115b8..683e797ea68 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.h +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.h @@ -28,10 +28,28 @@ #ifdef EXECUTORCH_BUILD_CUDA #include +#elif defined(EXECUTORCH_BUILD_MLX) +#include +#endif + +#if defined(EXECUTORCH_BUILD_CUDA) || defined(EXECUTORCH_BUILD_MLX) +#define QWEN_HAS_MUTABLE_STATE 1 #endif namespace executorch::extension::llm { +#if defined(EXECUTORCH_BUILD_CUDA) +using MutableStateContextOwner = + ::executorch::backends::cuda::MutableStateContextOwner; +constexpr int kNoMutableSession = + ::executorch::backends::cuda::kNoMutableSession; +#elif defined(EXECUTORCH_BUILD_MLX) +using MutableStateContextOwner = + ::executorch::backends::mlx::MutableStateContextOwner; +constexpr int kNoMutableSession = + ::executorch::backends::mlx::kNoMutableSession; +#endif + /// Immutable configuration for a Qwen3.5 MoE engine. struct Qwen35MoEConfig { std::string model_path; // .pte @@ -77,10 +95,9 @@ class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { std::unordered_set eos_ids, std::unique_ptr shared_module, bool rebind_available -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , - std::unique_ptr<::executorch::backends::cuda::MutableStateContextOwner> - mutable_state + std::unique_ptr mutable_state #endif ) : config_(std::move(config)), @@ -89,7 +106,7 @@ class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { eos_ids_(std::move(eos_ids)), shared_module_(std::move(shared_module)), rebind_available_(rebind_available) -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , mutable_state_(std::move(mutable_state)) #endif @@ -104,9 +121,8 @@ class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { std::unique_ptr shared_module_; std::mutex exec_mutex_; bool rebind_available_ = false; -#ifdef EXECUTORCH_BUILD_CUDA - std::unique_ptr<::executorch::backends::cuda::MutableStateContextOwner> - mutable_state_; +#ifdef QWEN_HAS_MUTABLE_STATE + std::unique_ptr mutable_state_; #endif std::atomic live_sessions_{0}; }; From 55ad66bb9da4a769f59d38404ef6522a33000b29 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 23 Jun 2026 11:10:12 -0700 Subject: [PATCH 8/9] up --- .../mlx/custom_kernel_ops/gated_delta_rule.py | 37 +++++++++++++++++++ .../test/test_gated_delta_rule.py | 5 +-- backends/mlx/test/CMakeLists.txt | 2 +- .../qwen3_5_moe/mlx_source_transformations.py | 26 ++++++------- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/backends/mlx/custom_kernel_ops/gated_delta_rule.py b/backends/mlx/custom_kernel_ops/gated_delta_rule.py index 423ffd0b034..41eb8ce7b98 100644 --- a/backends/mlx/custom_kernel_ops/gated_delta_rule.py +++ b/backends/mlx/custom_kernel_ops/gated_delta_rule.py @@ -53,6 +53,15 @@ def gated_delta_rule( B, T_len, Hk, Dk = q.shape Hv, Dv = v.shape[-2:] + # The Metal kernel maps each v-head to its k-head group + # (hk_idx = hv_idx / (Hv / Hk)); mirror that here so the eager reference also + # supports Hk != Hv (GQA) instead of relying on broadcasting, which requires + # Hk == Hv. repeat_interleave on the head dim reproduces that index mapping. + if Hk != Hv: + q = q.repeat_interleave(Hv // Hk, dim=2) + k = k.repeat_interleave(Hv // Hk, dim=2) + Hk = Hv + s = state.clone() ys = [] @@ -101,6 +110,7 @@ def gated_delta_rule_fake( IntOrVid, MetalKernelNode, MultiplyNode, + RepeatNode, ScanNode, SubtractNode, SumNode, @@ -450,6 +460,33 @@ def _emit_scan(self, P: MLXProgramBuilder, n: Node) -> Slot: ] ) + # GQA: q/k carry Hk heads but the recurrence state/v have Hv heads. Expand + # q/k to Hv (repeat_interleave on the head axis) so the per-step broadcasts + # match, mirroring the Metal kernel's hk_idx = hv_idx / (Hv / Hk). + Hk = int(self.q_node.meta["val"].shape[-2]) + Hv = int(self.v_node.meta["val"].shape[-2]) + if Hk != Hv: + rep = IntOrVid.from_literal(Hv // Hk) + _, q_exp = P.make_tmp_slot() + P.emit( + RepeatNode( + x=P.slot_to_tid(q_slot), + out=P.slot_to_tid(q_exp), + repeats=rep, + axis=2, + ) + ) + _, k_exp = P.make_tmp_slot() + P.emit( + RepeatNode( + x=P.slot_to_tid(k_slot), + out=P.slot_to_tid(k_exp), + repeats=rep, + axis=2, + ) + ) + q_slot, k_slot = q_exp, k_exp + # Carry needs a writable slot. This is node n's persistent output (the # mutated state), so it must be a node-owned slot — not a temp slot, whose # id is reclaimed on tmp_scope exit and would be read as dead by a later diff --git a/backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py b/backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py index 0a7e6a687f9..dfee111e74b 100644 --- a/backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py +++ b/backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py @@ -96,9 +96,8 @@ def forward( g: torch.Tensor, # [B, T, Hv] beta: torch.Tensor, # [B, T, Hv] ) -> torch.Tensor: - if self.head_repeat > 1: - q = q.repeat_interleave(self.head_repeat, dim=2) - k = k.repeat_interleave(self.head_repeat, dim=2) + # Pass native Hk (no repeat_interleave): the op itself must handle + # GQA head expansion (kernel via hk_idx mapping, scan/eager internally). return torch.ops.mlx.gated_delta_rule( q, k, v, g, beta, self.state, use_custom_kernel=self.use_custom_kernel ) diff --git a/backends/mlx/test/CMakeLists.txt b/backends/mlx/test/CMakeLists.txt index c518b2a232d..2d494652138 100644 --- a/backends/mlx/test/CMakeLists.txt +++ b/backends/mlx/test/CMakeLists.txt @@ -76,7 +76,7 @@ target_include_directories( mlx_mutable_state_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../runtime ) target_link_libraries( - mlx_mutable_state_test PRIVATE mlxdelegate mlx executorch_core + mlx_mutable_state_test PRIVATE mlxdelegate mlx_schema mlx executorch_core ) if(EXECUTORCH_MLX_ENABLE_SANITIZERS) target_compile_options( diff --git a/examples/models/qwen3_5_moe/mlx_source_transformations.py b/examples/models/qwen3_5_moe/mlx_source_transformations.py index 9a49f8a84f6..3c460fc9c54 100644 --- a/examples/models/qwen3_5_moe/mlx_source_transformations.py +++ b/examples/models/qwen3_5_moe/mlx_source_transformations.py @@ -113,12 +113,14 @@ def _full_attention_forward(self, x, input_pos): k, v = self.kv_cache.update(input_pos, k, v) - if self.n_kv_groups > 1: - k = k.repeat_interleave(self.n_kv_groups, dim=1) - v = v.repeat_interleave(self.n_kv_groups, dim=1) - - attn_mask = self.mask[input_pos].unsqueeze(0).unsqueeze(0) - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + y = torch.ops.mlx.custom_sdpa( + q, + k, + v, + start_pos=pos, + dropout_p=0.0, + is_causal=True, + ) y = y.transpose(1, 2).contiguous().view(B, T, -1) @@ -184,10 +186,8 @@ def _exportable_gated_delta_net_forward(self, x, input_pos): k, (self.head_k_dim,), self._qk_rms_weight, eps=1e-6 ) - # head_repeat for k_heads != v_heads - if self.head_repeat > 1: - q = q.repeat_interleave(self.head_repeat, dim=2) - k = k.repeat_interleave(self.head_repeat, dim=2) + # GQA head expansion (k_heads != v_heads) is handled inside + # mlx::gated_delta_rule # Mamba-style gating beta = b.sigmoid() @@ -278,17 +278,13 @@ def _swap_gated_delta_net(model, model_dtype): def _swap_full_attention(model, config): - """FullAttention → mlx::rope custom op + causal mask.""" + """FullAttention → mlx::rope custom op""" rope_theta = config.rope_theta if config else 10000.0 - max_seq_len = config.max_seq_len if config else 4096 count = 0 for _name, module in model.named_modules(): if isinstance(module, FullAttention): module._rope_dims = module.rotary_emb.rotary_dim module._rope_base = rope_theta - mask = torch.full((max_seq_len, max_seq_len), float("-inf")) - mask = torch.triu(mask, diagonal=1) - module.register_buffer("mask", mask) module.forward = types.MethodType(_full_attention_forward, module) count += 1 return count From b05bdae43074b6468a9115530ad25a8f23a9ebd9 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 23 Jun 2026 11:24:55 -0700 Subject: [PATCH 9/9] up --- backends/mlx/runtime/mlx_mutable_state.cpp | 28 ++++++++++++++++++++-- backends/mlx/runtime/mlx_mutable_state.h | 12 +++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/backends/mlx/runtime/mlx_mutable_state.cpp b/backends/mlx/runtime/mlx_mutable_state.cpp index 79a88ea9bfa..2f00d917136 100644 --- a/backends/mlx/runtime/mlx_mutable_state.cpp +++ b/backends/mlx/runtime/mlx_mutable_state.cpp @@ -257,10 +257,34 @@ Error mutable_state_rebind_for_execute( if (ctx.build_error != Error::Ok) { return ctx.build_error; } - HandleInfo& info = ctx.handles[handle]; + // Invariant: a handle present in handle_ctx() is present in ctx.handles. Look + // it up explicitly (not operator[]) so a broken invariant fails loudly + // instead of inserting a {nullptr, nullptr} entry that later null-derefs in + // load_mutable_buffers(*info.program, ...). + auto info_it = ctx.handles.find(handle); + if (info_it == ctx.handles.end()) { + ET_LOG( + Error, + "mutable_state_rebind_for_execute: handle has a context but no " + "registered HandleInfo (invariant broken)"); + return Error::Internal; + } + HandleInfo& info = info_it->second; + const bool has_active_session = tl_active_token != kNoMutableSession; const bool active_for_this_ctx = - tl_active_token != kNoMutableSession && tl_active_ctx == hit->second; + has_active_session && tl_active_ctx == hit->second; + + // A session is active, but for a different context than the one this handle + // belongs to. Falling back to default buffers would silently execute with the + // wrong model/session state, so refuse instead. + if (has_active_session && !active_for_this_ctx) { + ET_LOG( + Error, + "mutable_state_rebind_for_execute: active context mismatch (a session " + "is active for a different loaded program than the one executing)"); + return Error::Internal; + } if (!active_for_this_ctx) { // No session selected. Refuse if sessions exist (running against the diff --git a/backends/mlx/runtime/mlx_mutable_state.h b/backends/mlx/runtime/mlx_mutable_state.h index 92fd24a3fa6..84420812360 100644 --- a/backends/mlx/runtime/mlx_mutable_state.h +++ b/backends/mlx/runtime/mlx_mutable_state.h @@ -129,6 +129,16 @@ class ET_EXPERIMENTAL MutableStateContextOwner final { // Selects this context/session while `fn` executes. The caller is responsible // for serializing execution that touches the same loaded program. + // + // Thread-safety contract: destroy_session()/forget_handle() only take the + // registry mutex, while rebind (under with_active_session) hands execute a + // raw pointer into Context::sessions that is dereferenced after the lock is + // released. The caller must therefore guarantee a session is never destroyed + // while it is the active session mid-execute (the engine upholds this: a + // session's buffers are freed only when its owning LLMSession drops, never + // concurrently with its own execute). Destroying *other* sessions + // concurrently is safe — unordered_map keeps element pointers stable across + // rehash. template auto with_active_session(int token, Fn&& fn) const -> decltype(std::forward(fn)()) { @@ -137,7 +147,7 @@ class ET_EXPERIMENTAL MutableStateContextOwner final { } // True only after this context has been associated with at least one loaded - // MLX backend handle and can create isolated mutable-buffer sessions. + // MLX backend handle can create isolated mutable-buffer sessions. bool available() const { return detail::mutable_state_available(ctx_); }