diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bd2560..0c23ef3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,58 @@ # Changelog +## [Unreleased] + +### Fixed +- **MCP cold-start resilience** — three compounding failure modes that left + the MCP server permanently hung or crashed on boot: + 1. **Sync handlers** — all 9 `@mcp.tool()` handlers were sync `def`, so one + blocked handler froze every concurrent JSON-RPC request on FastMCP's + single event-loop thread. Converted to `async def` with engine calls + wrapped in `asyncio.to_thread`. The `@tracked` telemetry decorator is + now async-aware (detects coroutine functions and wraps accordingly). + 2. **Reranker preload hangs** — `CrossEncoder(...)` in `_preload_models` + blocked indefinitely on a corrupt HuggingFace cache, a stalled + download, or a Windows Defender ASR denial of the sentencepiece shim. + Added a 30s watchdog (`TRUEMEMORY_RERANKER_TIMEOUT_SEC` override); on + timeout, the reranker is marked degraded and rerank entrypoints fall + back to original-ordering results. `_set_reranker` also short-circuits + when degraded so search calls don't block on the stalled load's lock. + The degraded state surfaces in the F06 health payload — operators see + it in `truememory_stats` instead of digging through logs. + 3. **`os.WNOHANG` is POSIX-only** — `_reap_children` called + `os.waitpid(-1, os.WNOHANG)`, crashing every Windows user's backlog + drainer with `AttributeError` on every boot. Guarded with `hasattr`. + +- **Engine concurrent-store throughput** — `add()` now pre-computes content + and separation embeddings BEFORE acquiring `_write_lock`. Previously the + lock was held during both `model.encode()` calls (~10–50 ms each), + serializing concurrent stores. PyTorch releases the GIL inside `.encode()`, + so concurrent stores now overlap on inference; they only contend at the + INSERTs (μs). + +- **`pytest` collection on Windows** — four `@pytest.mark.skipif` decorators + in `tests/ingest/test_robustness_fixes.py` referenced `os.geteuid()` at + module import time. `geteuid` is POSIX-only; pytest collection crashed on + Windows with `AttributeError`. Guarded with `not hasattr(os, "geteuid") + or os.geteuid() == 0` — skips on Windows AND on POSIX root (both cases + where `chmod` permission tests can't enforce read-only). + +### Added +- `TRUEMEMORY_RERANKER_TIMEOUT_SEC` env var (default 30 s, minimum 1 s) + bounds the reranker preload watchdog. Values ≤ 0 fall back to the default + with a warning (the legitimate "skip preload" path is + `TRUEMEMORY_LAZY_MODELS=1`, not `TIMEOUT_SEC=0`). +- `reranker.is_degraded()` / `reranker.mark_degraded(reason)` — public API + for runtime degraded-state coordination between the MCP server's watchdog + and the rerank entrypoints. +- `tests/test_cold_start_resilience.py` — 14 regression locks: WNOHANG + guard, degraded-flag lifecycle, watchdog timeout + fast-load + non-regression, `_set_reranker` short-circuit, health-payload wiring, + timeout-parser validation. +- `tests/test_concurrent_store_hang.py` — 3 regression locks for the + parallel-store hang: engine.add() concurrency, MCP handler async shape, + asyncio.gather end-to-end. + ## [0.6.8] — 2026-05-11 ### Fixed diff --git a/tests/ingest/test_robustness_fixes.py b/tests/ingest/test_robustness_fixes.py index a390f05..8b3c2af 100644 --- a/tests/ingest/test_robustness_fixes.py +++ b/tests/ingest/test_robustness_fixes.py @@ -110,7 +110,10 @@ def _run_cli(args: list[str], env: dict | None = None) -> subprocess.CompletedPr # --------------------------------------------------------------------------- -@pytest.mark.skipif(os.geteuid() == 0, reason="root bypasses chmod 000") +@pytest.mark.skipif( + not hasattr(os, "geteuid") or os.geteuid() == 0, + reason="POSIX-only: chmod 000 only enforces unreadability on POSIX as non-root", +) def test_bug1_unreadable_file_returns_empty_not_fake_content(caplog): """ A file that exists but can't be read (chmod 000) must NOT be silently @@ -223,7 +226,10 @@ def test_bug2_sqlite_operational_error_is_caught_and_traced(): # --------------------------------------------------------------------------- -@pytest.mark.skipif(os.geteuid() == 0, reason="root bypasses chmod 555") +@pytest.mark.skipif( + not hasattr(os, "geteuid") or os.geteuid() == 0, + reason="POSIX-only: chmod 555 only enforces read-only on POSIX as non-root", +) def test_bug3_save_trace_does_not_raise_on_unwritable_dir(caplog): """ ``save_trace`` should log a warning and return ``False`` when its @@ -271,7 +277,10 @@ def test_bug3_save_trace_returns_true_on_success(): # --------------------------------------------------------------------------- -@pytest.mark.skipif(os.geteuid() == 0, reason="root bypasses chmod 555") +@pytest.mark.skipif( + not hasattr(os, "geteuid") or os.geteuid() == 0, + reason="POSIX-only: chmod 555 only enforces read-only on POSIX as non-root", +) def test_bug4_cli_exits_4_when_db_dir_not_writable(tmp_path): """ When the DB parent directory isn't writable, the CLI must exit with code 4 @@ -306,7 +315,10 @@ def test_bug4_cli_exits_4_when_db_dir_not_writable(tmp_path): os.chmod(locked, 0o755) -@pytest.mark.skipif(os.geteuid() == 0, reason="root bypasses chmod 555") +@pytest.mark.skipif( + not hasattr(os, "geteuid") or os.geteuid() == 0, + reason="POSIX-only: chmod 555 only enforces read-only on POSIX as non-root", +) def test_bug4_cli_exits_4_when_trace_dir_not_writable(tmp_path): """Same preflight but for the --trace target.""" transcript = tmp_path / "transcript.txt" diff --git a/tests/test_cold_start_resilience.py b/tests/test_cold_start_resilience.py new file mode 100644 index 0000000..7e09b9a --- /dev/null +++ b/tests/test_cold_start_resilience.py @@ -0,0 +1,350 @@ +"""Regression locks for MCP cold-start resilience. + +Three failure modes this file pins down, all of which would otherwise leave +the MCP server permanently hung or crashed at boot: + +1. **Windows os.WNOHANG missing** — `_reap_children` calls `os.waitpid(-1, + os.WNOHANG)`. WNOHANG is POSIX-only; on Windows the `os` module has no + `WNOHANG` attribute and the backlog drainer thread crashes with + AttributeError. Guard: hasattr check at top of `_reap_children`. + +2. **Reranker preload stalls forever** — corrupt HuggingFace cache, blocked + download, or Windows Defender ASR denying a sentencepiece shim leaves + `CrossEncoder(...)` blocked indefinitely. Guard: watchdog thread with + TRUEMEMORY_RERANKER_TIMEOUT_SEC (default 30s); on timeout, calls + `reranker.mark_degraded(...)`. + +3. **Rerank entrypoints block on stalled load** — once preload is hung, + every `engine.search()` calls `rerank_with_modality_fusion()` which + calls `rerank()` which calls `get_reranker()` which blocks on the same + stalled load. Guard: rerank functions check `reranker.is_degraded()` + and return original ordering instead. +""" +from __future__ import annotations + +import os +import threading +import time + +import pytest + + +# --------------------------------------------------------------------------- +# Bug #1: os.WNOHANG missing on Windows +# --------------------------------------------------------------------------- + + +def test_reap_children_no_crash_when_wnohang_missing(monkeypatch): + """On Windows, os has no WNOHANG attribute. _reap_children must return + cleanly instead of raising AttributeError, otherwise _backlog_drainer + crashes on every boot for every Windows user. + """ + from truememory import mcp_server as ms + + # Simulate the Windows environment: remove WNOHANG from os if present. + monkeypatch.delattr(os, "WNOHANG", raising=False) + + # Must not raise. + ms._reap_children() + + +# --------------------------------------------------------------------------- +# Bug #2 + #3: reranker degraded-mode fallback +# --------------------------------------------------------------------------- + + +@pytest.fixture +def _reset_reranker_degraded(): + """Reset the module-level degraded flag around each test so prior tests + don't pollute state.""" + from truememory import reranker as rr + original = rr._load_failed + rr._load_failed = False + yield + rr._load_failed = original + + +def test_is_degraded_starts_false(_reset_reranker_degraded): + from truememory import reranker as rr + assert rr.is_degraded() is False + + +def test_mark_degraded_sets_flag(_reset_reranker_degraded): + from truememory import reranker as rr + rr.mark_degraded("test reason") + assert rr.is_degraded() is True + + +def test_rerank_returns_original_ordering_when_degraded(_reset_reranker_degraded): + """Once degraded, rerank() must NOT call get_reranker — that would + block on the same stalled load that caused the degraded mark. It must + return the candidates in their original order (truncated to top_k). + """ + from truememory import reranker as rr + + candidates = [ + {"content": f"doc {i}", "rrf_score": 1.0 / (i + 1)} + for i in range(5) + ] + rr.mark_degraded("simulated stall") + + out = rr.rerank("query", candidates, top_k=3) + + assert len(out) == 3, "top_k must be honored in degraded mode" + assert [r["content"] for r in out] == ["doc 0", "doc 1", "doc 2"], ( + "Degraded mode must preserve original input ordering — the caller's " + "RRF/vector ranking is the best signal available without a reranker." + ) + # No rerank_score key must appear — proves get_reranker was never called. + for r in out: + assert "rerank_score" not in r + + +def test_rerank_with_modality_fusion_returns_original_when_degraded( + _reset_reranker_degraded, +): + from truememory import reranker as rr + + candidates = [ + {"content": "a", "modality": "conversation", "rrf_score": 0.9}, + {"content": "b", "modality": "episode", "rrf_score": 0.5}, + {"content": "c", "modality": "fact", "rrf_score": 0.3}, + ] + rr.mark_degraded("simulated stall") + + out = rr.rerank_with_modality_fusion("why did X happen", candidates, top_k=2) + + assert len(out) == 2 + assert [r["content"] for r in out] == ["a", "b"] + + +# --------------------------------------------------------------------------- +# Bug #2: preload watchdog marks degraded on timeout +# --------------------------------------------------------------------------- + + +def test_preload_watchdog_marks_degraded_on_timeout( + monkeypatch, _reset_reranker_degraded, +): + """If get_reranker hangs longer than TRUEMEMORY_RERANKER_TIMEOUT_SEC, + the watchdog must call mark_degraded so search calls fall back instead + of blocking forever. + + Strategy: monkey-patch get_reranker with a function that sleeps past the + timeout, set the timeout to a small value, call _preload_models, then + poll is_degraded() until the watchdog fires. + """ + from truememory import mcp_server as ms + from truememory import reranker as rr + + # Force a very small timeout so the test finishes fast. + monkeypatch.setattr(ms, "_RERANKER_LOAD_TIMEOUT_SEC", 1) + + # Replace get_reranker with a hang. + def _hang(*_a, **_k): + time.sleep(30) # well past the 1s timeout + + monkeypatch.setattr(rr, "get_reranker", _hang) + + # Prevent the embedding-model branch from doing real I/O. + monkeypatch.setenv("TRUEMEMORY_LAZY_MODELS", "") # ensure preload runs + + # Stub the embedding loader so it returns immediately. + import truememory.vector_search as vs + monkeypatch.setattr(vs, "get_model", lambda *_a, **_k: None) + monkeypatch.setattr(ms, "_get_memory", lambda: None) + + ms._preload_models() + + # Watchdog must fire within timeout + small margin. + deadline = time.monotonic() + 3.0 + while time.monotonic() < deadline: + if rr.is_degraded(): + break + time.sleep(0.05) + + assert rr.is_degraded(), ( + "Watchdog did not mark reranker degraded within 3s — the stalled " + "preload would have blocked every subsequent search call indefinitely." + ) + + +def test_preload_watchdog_does_not_mark_degraded_on_fast_load( + monkeypatch, _reset_reranker_degraded, +): + """If get_reranker returns quickly, the watchdog must NOT mark degraded. + Otherwise every successful boot would falsely report degraded mode. + """ + from truememory import mcp_server as ms + from truememory import reranker as rr + + monkeypatch.setattr(ms, "_RERANKER_LOAD_TIMEOUT_SEC", 5) + monkeypatch.setattr(rr, "get_reranker", lambda *_a, **_k: None) + + import truememory.vector_search as vs + monkeypatch.setattr(vs, "get_model", lambda *_a, **_k: None) + monkeypatch.setattr(ms, "_get_memory", lambda: None) + + ms._preload_models() + + # Give the watchdog time to either fire (false positive) or finish cleanly. + time.sleep(0.5) + + assert not rr.is_degraded(), ( + "Watchdog fired on a fast load — this would make every boot report " + "degraded mode, defeating the purpose of the fallback." + ) + + +# --------------------------------------------------------------------------- +# Bug #2 follow-on: _set_reranker must short-circuit when degraded +# --------------------------------------------------------------------------- + + +def test_set_reranker_short_circuits_when_degraded( + monkeypatch, _reset_reranker_degraded, +): + """If the reranker is already degraded (watchdog fired), _set_reranker + must NOT call get_reranker. Otherwise every search call here would block + on the same reranker._lock that the stalled preload thread is holding, + defeating the async-handler + watchdog fix by serializing the thread pool. + """ + from truememory import mcp_server as ms + from truememory import reranker as rr + + rr.mark_degraded("simulated preload timeout") + + called = [] + + def _spy(*_a, **_k): + called.append(True) + return None + + monkeypatch.setattr(rr, "get_reranker", _spy) + + ms._set_reranker("any-model") + + assert called == [], ( + "_set_reranker called get_reranker despite degraded mode — search " + "calls will block on the preload thread's lock." + ) + + +# --------------------------------------------------------------------------- +# Bug #2 follow-on: degraded state must surface in F06 health payload +# --------------------------------------------------------------------------- + + +def test_watchdog_writes_to_health_payload_on_timeout( + monkeypatch, _reset_reranker_degraded, +): + """When the watchdog marks degraded, the F06 health payload must reflect + it. Otherwise truememory_stats lies to the operator while search is + silently falling back. The watchdog calls both mark_degraded() AND + _record_reranker_error() so the existing health payload reads it. + """ + from truememory import mcp_server as ms + from truememory import reranker as rr + + # Reset health-payload state. + ms._clear_reranker_error() + + monkeypatch.setattr(ms, "_RERANKER_LOAD_TIMEOUT_SEC", 1) + monkeypatch.setattr(rr, "get_reranker", lambda *_a, **_k: time.sleep(30)) + + import truememory.vector_search as vs + monkeypatch.setattr(vs, "get_model", lambda *_a, **_k: None) + monkeypatch.setattr(ms, "_get_memory", lambda: None) + + ms._preload_models() + + # Wait for watchdog to fire. + deadline = time.monotonic() + 3.0 + while time.monotonic() < deadline: + if rr.is_degraded(): + break + time.sleep(0.05) + + # Give the watchdog's _record_reranker_error call a moment to land. + time.sleep(0.1) + + health = ms._build_health_payload() + assert health["reranker"]["status"] == "degraded", ( + f"Health payload still reports OK after watchdog timeout: {health['reranker']!r}" + ) + assert "preload exceeded" in (health["reranker"]["last_error"] or ""), ( + f"Expected timeout message in health payload, got: " + f"{health['reranker']['last_error']!r}" + ) + + +def test_load_reranker_exception_writes_to_health_payload( + monkeypatch, _reset_reranker_degraded, +): + """If get_reranker raises during preload (not a timeout but an actual + exception like ImportError or a HuggingFace-cache OSError), the exception + path must also write to the health payload, not just mark degraded. + """ + from truememory import mcp_server as ms + from truememory import reranker as rr + + ms._clear_reranker_error() + + monkeypatch.setattr(ms, "_RERANKER_LOAD_TIMEOUT_SEC", 5) + + def _raise(*_a, **_k): + raise RuntimeError("simulated HF cache corruption") + + monkeypatch.setattr(rr, "get_reranker", _raise) + + import truememory.vector_search as vs + monkeypatch.setattr(vs, "get_model", lambda *_a, **_k: None) + monkeypatch.setattr(ms, "_get_memory", lambda: None) + + ms._preload_models() + + # Exception path is synchronous from the thread's perspective; give the + # thread a moment to run the except block. + time.sleep(0.2) + + health = ms._build_health_payload() + assert health["reranker"]["status"] == "degraded" + assert "simulated HF cache corruption" in (health["reranker"]["last_error"] or "") + + +# --------------------------------------------------------------------------- +# Bug #2 follow-on: timeout validation rejects invalid values +# --------------------------------------------------------------------------- + + +def test_parse_reranker_timeout_accepts_positive(): + from truememory.mcp_server import _parse_reranker_timeout + assert _parse_reranker_timeout("60", default=30) == 60 + assert _parse_reranker_timeout("1", default=30) == 1 + + +def test_parse_reranker_timeout_clamps_zero_and_negative(): + """0 and negative values are footgun inputs (a typo like + `TRUEMEMORY_RERANKER_TIMEOUT_SEC=` in a shell script becomes 0). Must + fall back to default — never silently disable the watchdog. The + legitimate "skip preload" path is TRUEMEMORY_LAZY_MODELS=1. + """ + from truememory.mcp_server import _parse_reranker_timeout + assert _parse_reranker_timeout("0", default=30) == 30 + assert _parse_reranker_timeout("-5", default=30) == 30 + + +def test_parse_reranker_timeout_rejects_non_integer(): + """Non-integer values (e.g. a user typing '30s' or 'thirty') must not + crash the import. Fall back to default with a warning. + """ + from truememory.mcp_server import _parse_reranker_timeout + assert _parse_reranker_timeout("30s", default=30) == 30 + assert _parse_reranker_timeout("thirty", default=30) == 30 + assert _parse_reranker_timeout("", default=30) == 30 + + +def test_parse_reranker_timeout_handles_unset(): + from truememory.mcp_server import _parse_reranker_timeout + assert _parse_reranker_timeout(None, default=30) == 30 + assert _parse_reranker_timeout(None, default=45) == 45 diff --git a/tests/test_concurrent_store_hang.py b/tests/test_concurrent_store_hang.py new file mode 100644 index 0000000..421a9c0 --- /dev/null +++ b/tests/test_concurrent_store_hang.py @@ -0,0 +1,160 @@ +"""Regression lock for the parallel-store hang. + +Symptom (before fix): when Claude Code issues 3+ `truememory_store` MCP +calls in a single parallel tool batch, the harness UI hangs 10-15+ seconds +before any response renders, even though each individual store completes +server-side in ~60ms. Root cause was two-layer: + +1. MCP layer: all 8 `@mcp.tool()` handlers were sync `def`, causing + FastMCP to dispatch them via `return fn(**kwargs)` which blocks the + single asyncio event loop thread for the duration of each call. JSON-RPC + requests serialized at the transport layer. + +2. Engine layer: `TrueMemoryEngine.add()` acquired `_write_lock` BEFORE + calling `embed_single()` (~10-50ms of model.encode). Encoding is + thread-safe (PyTorch releases the GIL inside .encode()), so concurrent + stores could have overlapped on inference — but the lock prevented it. + +Fix: +- MCP: tools changed to `async def`, engine calls wrapped in + `await asyncio.to_thread(...)`. +- Engine: embeddings pre-computed OUTSIDE `_write_lock`; lock now only + guards the 3 INSERTs + commit. + +This test exercises the engine half — kicks 3 concurrent `engine.add()` +calls from threads and asserts the total wall-clock is well under what +serialized encodes would take. The MCP half is verified by a separate +asyncio.gather-based test. +""" +from __future__ import annotations + +import asyncio +import inspect +import threading +import time + +import pytest + +from truememory.client import Memory + + +@pytest.fixture +def memory_db(tmp_path, monkeypatch): + """Fresh Memory instance with a per-test DB; force Edge tier for speed.""" + db_path = tmp_path / "concurrent_store.db" + monkeypatch.setenv("TRUEMEMORY_DB", str(db_path)) + monkeypatch.setenv("TRUEMEMORY_EMBED_MODEL", "edge") + m = Memory() + yield m + + +def test_engine_add_releases_lock_during_embed(memory_db): + """Three threads each call m.add(); total wall-clock must be well under + 3 × serialized-encode time. The fix moves embedding OUTSIDE the + write lock so encodes overlap. Edge embeddings are ~5-15ms each; with + parallel encoding, 3 concurrent stores should land near a single-store + cost + small lock-contention overhead. + """ + contents = [ + "concurrent store test fact A " + "x" * 400, + "concurrent store test fact B " + "y" * 400, + "concurrent store test fact C " + "z" * 400, + ] + results: list[dict] = [] + errors: list[BaseException] = [] + lock = threading.Lock() + + def worker(content: str) -> None: + try: + r = memory_db.add(content=content, user_id="test") + with lock: + results.append(r) + except BaseException as e: # noqa: BLE001 — capture all + with lock: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(c,)) for c in contents] + t0 = time.perf_counter() + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + elapsed = time.perf_counter() - t0 + + assert not errors, f"Concurrent add() raised: {errors!r}" + assert len(results) == 3, f"Expected 3 results, got {len(results)}" + + # The hard requirement: all 3 stores complete in well under 30s. The + # original bug was 10-15s harness-perceived hang on 3 concurrent stores. + # After the fix, 3 concurrent Edge stores should finish in < 5s + # (typically 200-800ms, but allow generous headroom for CI variance). + assert elapsed < 5.0, ( + f"3 concurrent stores took {elapsed:.2f}s — the lock-during-embed " + f"regression may have returned. Expected < 5s." + ) + + # Each row must be queryable + have its own ID + ids = {r["id"] for r in results} + assert len(ids) == 3, f"Expected 3 distinct ids, got {ids}" + + +def test_mcp_handlers_are_async(): + """The 6 hot-path MCP tool handlers must be coroutine functions. + Sync handlers serialize concurrent MCP requests at the FastMCP layer. + + truememory_configure stays sync (called once at setup, has complex + mutable state — not on the hot path). + """ + from truememory import mcp_server as ms + + expected_async = [ + "truememory_store", + "truememory_search", + "truememory_search_deep", + "truememory_get", + "truememory_forget", + "truememory_stats", + "truememory_entity_profile", + "truememory_status", + ] + for name in expected_async: + fn = getattr(ms, name) + assert inspect.iscoroutinefunction(fn), ( + f"{name} must be `async def` so FastMCP doesn't block the " + f"event loop. If you change it back to sync, expect the " + f"parallel-store hang to return." + ) + + # Configure intentionally stays sync — assert that too so future refactors + # know it's deliberate. + assert not inspect.iscoroutinefunction(ms.truememory_configure), ( + "truememory_configure is intentionally sync — heavy state mutation, " + "called once per session at setup, not on the hot path." + ) + + +def test_mcp_store_via_asyncio_gather(memory_db, monkeypatch): + """Three concurrent truememory_store coroutines via asyncio.gather must + all complete in well under 5s. Exercises the MCP-layer fix end-to-end. + """ + from truememory import mcp_server as ms + + monkeypatch.setattr(ms, "_memory", memory_db) + + async def _run() -> list[str]: + return await asyncio.gather( + ms.truememory_store(content="gather store A " + "x" * 400, user_id="test"), + ms.truememory_store(content="gather store B " + "y" * 400, user_id="test"), + ms.truememory_store(content="gather store C " + "z" * 400, user_id="test"), + ) + + t0 = time.perf_counter() + results = asyncio.run(_run()) + elapsed = time.perf_counter() - t0 + + assert len(results) == 3 + assert all(isinstance(r, str) for r in results), "All results must be JSON strings" + assert elapsed < 5.0, ( + f"3 gather()-ed truememory_store coroutines took {elapsed:.2f}s — " + f"the async-handler regression may have returned. Expected < 5s." + ) diff --git a/tests/test_health_stats.py b/tests/test_health_stats.py index 5a49d5c..080513d 100644 --- a/tests/test_health_stats.py +++ b/tests/test_health_stats.py @@ -104,7 +104,8 @@ def test_health_vectors_degraded_surfaces_engine_error(server, monkeypatch): def test_truememory_stats_includes_health(server): - result_json = server.truememory_stats() + import asyncio + result_json = asyncio.run(server.truememory_stats()) result = json.loads(result_json) assert "health" in result assert isinstance(result["health"], dict) diff --git a/truememory/engine.py b/truememory/engine.py index 3cd5e95..f803cfb 100644 --- a/truememory/engine.py +++ b/truememory/engine.py @@ -428,6 +428,33 @@ def add( self._ensure_connection() + # Pre-compute both embeddings OUTSIDE the write lock. + # + # Previously embed_single() — which makes two model.encode() calls + # (~10-50ms each) — ran inside _write_lock, serializing all + # concurrent stores through the encoding step. Encoding is + # thread-safe (PyTorch releases the GIL inside .encode()), so + # concurrent stores can encode in parallel; they only need to + # contend for the lock at the actual INSERTs, which are microseconds. + # + # This is the engine half of the parallel-store-hang fix; the MCP + # half is async-ifying the @mcp.tool() handlers in mcp_server.py. + content_blob = None + sep_blob = None + if self._has_vectors: + try: + from truememory.vector_search import ( + _build_sep_text, + get_model, + serialize_f32, + ) + model = get_model() + content_blob = serialize_f32(model.encode([content])[0]) + sep_text = _build_sep_text(sender, recipient, timestamp, content) + sep_blob = serialize_f32(model.encode([sep_text])[0]) + except Exception: + logger.debug("Failed to pre-compute embeddings during add()", exc_info=True) + with self._write_lock: msg = { "content": content, @@ -439,13 +466,20 @@ def add( } new_id = insert_message(self.conn, msg) - # Embed the message for vector search - if self._has_vectors: + # Insert pre-computed vector embeddings (no encoding here — μs-level). + if self._has_vectors and content_blob is not None: try: - from truememory.vector_search import embed_single - embed_single(self.conn, new_id, content) + self.conn.execute( + "INSERT INTO vec_messages(rowid, embedding) VALUES (?, ?)", + (new_id, content_blob), + ) + if sep_blob is not None: + self.conn.execute( + "INSERT INTO vec_messages_sep(rowid, embedding) VALUES (?, ?)", + (new_id, sep_blob), + ) except Exception: - logger.debug("Failed to embed message %s during add()", new_id, exc_info=True) + logger.debug("Failed to insert pre-computed vectors for message %s", new_id, exc_info=True) # Incrementally update entity profile if self._has_personality and sender: diff --git a/truememory/mcp_server.py b/truememory/mcp_server.py index 3aa2de2..0783f6d 100644 --- a/truememory/mcp_server.py +++ b/truememory/mcp_server.py @@ -22,6 +22,7 @@ from __future__ import annotations +import asyncio import gc import json import logging @@ -431,7 +432,21 @@ def _set_reranker(model_name: str): On failure: store the error in ``_reranker_last_error`` so F07's health payload can surface the degradation; log at WARNING once per distinct error to avoid spamming logs on every search call. + + If the reranker has already been marked degraded (preload watchdog timed + out, or a prior load raised), short-circuit immediately. Without this, + every search call here would block on the same ``reranker._lock`` that + the stalled preload thread is holding, defeating the whole point of the + async-handler + watchdog fix by serializing the thread pool instead of + the event loop. """ + try: + from truememory.reranker import is_degraded + if is_degraded(): + return + except ImportError: + pass + try: from truememory.reranker import get_reranker get_reranker(model_name=model_name) @@ -530,7 +545,7 @@ def _run_query(q): @mcp.tool() @_tracked("tool_store") -def truememory_store( +async def truememory_store( content: str, user_id: str = "", metadata: str = "", @@ -556,13 +571,17 @@ def truememory_store( meta = json.loads(metadata) if metadata else None except (json.JSONDecodeError, ValueError): meta = None - result = m.add(content=content, user_id=user_id or None, metadata=meta) + # Run engine.add in a thread so the FastMCP event loop stays free to + # accept concurrent JSON-RPC requests (fixes parallel-store hang). + result = await asyncio.to_thread( + m.add, content=content, user_id=user_id or None, metadata=meta + ) return json.dumps(result, indent=2) @mcp.tool() @_tracked("tool_search") -def truememory_search( +async def truememory_search( query: str, user_id: str = "", limit: int = 10, @@ -595,18 +614,21 @@ def truememory_search( if len(queries) == 1: m = _get_memory() - results = m.search_deep( + results = await asyncio.to_thread( + m.search_deep, queries[0], user_id=uid, limit=_SEARCH_INTERNAL_LIMIT, llm_fn=llm_fn, ) return json.dumps(results[:limit], indent=2) - results = _parallel_search(queries, uid, _SEARCH_INTERNAL_LIMIT, llm_fn, limit) + results = await asyncio.to_thread( + _parallel_search, queries, uid, _SEARCH_INTERNAL_LIMIT, llm_fn, limit + ) return json.dumps(results, indent=2) @mcp.tool() @_tracked("tool_search_deep") -def truememory_search_deep( +async def truememory_search_deep( query: str, user_id: str = "", limit: int = 10, @@ -640,25 +662,28 @@ def truememory_search_deep( if len(queries) == 1: m = _get_memory() - results = m.search_deep( + results = await asyncio.to_thread( + m.search_deep, queries[0], user_id=uid, limit=_DEEP_INTERNAL_LIMIT, llm_fn=llm_fn, ) return json.dumps(results[:limit], indent=2) - results = _parallel_search(queries, uid, _DEEP_INTERNAL_LIMIT, llm_fn, limit) + results = await asyncio.to_thread( + _parallel_search, queries, uid, _DEEP_INTERNAL_LIMIT, llm_fn, limit + ) return json.dumps(results, indent=2) @mcp.tool() @_tracked("tool_get") -def truememory_get(memory_id: int) -> str: +async def truememory_get(memory_id: int) -> str: """Get a specific memory by its ID. Args: memory_id: The integer ID of the memory to retrieve. """ m = _get_memory() - result = m.get(memory_id) + result = await asyncio.to_thread(m.get, memory_id) if result is None: return json.dumps({"error": f"Memory {memory_id} not found"}) return json.dumps(result, indent=2) @@ -666,26 +691,30 @@ def truememory_get(memory_id: int) -> str: @mcp.tool() @_tracked("tool_forget") -def truememory_forget(memory_id: int) -> str: +async def truememory_forget(memory_id: int) -> str: """Delete a memory by its ID. Args: memory_id: The integer ID of the memory to delete. """ m = _get_memory() - deleted = m.delete(memory_id) + deleted = await asyncio.to_thread(m.delete, memory_id) return json.dumps({"deleted": deleted, "memory_id": memory_id}) @mcp.tool() @_tracked("tool_stats") -def truememory_stats() -> str: +async def truememory_stats() -> str: """Get memory system statistics. On first run, returns a welcome message and setup instructions — present these to the user to walk them through choosing Edge, Base, or Pro tier.""" m = _get_memory() - m._engine._ensure_connection() - stats = m.stats() + # Stats touches the DB (ensure_connection + COUNT queries) — run in a + # thread so the event loop stays free for other MCP requests. + def _gather_stats(): + m._engine._ensure_connection() + return m.stats() + stats = await asyncio.to_thread(_gather_stats) config = _load_config() stats["version"] = __version__ stats["tier"] = config.get("tier", "edge") @@ -927,7 +956,7 @@ def truememory_configure( @mcp.tool() @_tracked("tool_status") -def truememory_status(status_id: int = 0) -> str: +async def truememory_status(status_id: int = 0) -> str: """Check the progress of a tier-switch re-embedding operation. Args: @@ -935,18 +964,20 @@ def truememory_status(status_id: int = 0) -> str: a re-embedding was started. Pass 0 (default) to get the most recent rebuild status. """ - try: - from truememory.tier_switch.manager import RebuildManager - manager = RebuildManager.get_instance() - status = manager.get_status(status_id) - return json.dumps(status, indent=2, default=str) - except Exception as e: - return json.dumps({"error": f"{type(e).__name__}: {e}"}) + def _query(): + try: + from truememory.tier_switch.manager import RebuildManager + manager = RebuildManager.get_instance() + status = manager.get_status(status_id) + return json.dumps(status, indent=2, default=str) + except Exception as e: + return json.dumps({"error": f"{type(e).__name__}: {e}"}) + return await asyncio.to_thread(_query) @mcp.tool() @_tracked("tool_entity_profile") -def truememory_entity_profile(entity: str) -> str: +async def truememory_entity_profile(entity: str) -> str: """Get the personality profile for an entity (person). Returns communication style, preferences, traits, and topics @@ -956,16 +987,19 @@ def truememory_entity_profile(entity: str) -> str: entity: Name of the person/entity to look up. """ m = _get_memory() - m._engine._ensure_connection() - try: - from truememory.personality import get_entity_profile - profile = get_entity_profile(m._engine.conn, entity) - if profile: - return json.dumps(profile, indent=2, default=str) - return json.dumps({"error": f"No profile found for '{entity}'"}) - except Exception as e: - return json.dumps({"error": str(e)}) + def _lookup(): + m._engine._ensure_connection() + try: + from truememory.personality import get_entity_profile + profile = get_entity_profile(m._engine.conn, entity) + if profile: + return json.dumps(profile, indent=2, default=str) + return json.dumps({"error": f"No profile found for '{entity}'"}) + except Exception as e: + return json.dumps({"error": str(e)}) + + return await asyncio.to_thread(_lookup) # --------------------------------------------------------------------------- @@ -1025,10 +1059,58 @@ def _touch_search_time() -> None: _idle_timer.start() +_RERANKER_LOAD_TIMEOUT_DEFAULT_SEC = 30 + + +def _parse_reranker_timeout(raw_value: str | None, default: int = 30) -> int: + """Parse the reranker preload timeout env value. + + Clamps non-positive values to ``default`` with a warning so a typo + (``TRUEMEMORY_RERANKER_TIMEOUT_SEC=`` in a shell script → ``0``) can + never disable the safety net. The legitimate "skip preload entirely" + path is ``TRUEMEMORY_LAZY_MODELS=1``, not ``TIMEOUT_SEC=0``. + + Non-integer values fall back to default with a warning. ``None`` + (env var unset) returns ``default`` silently. + """ + if raw_value is None: + return default + try: + value = int(raw_value) + except (TypeError, ValueError): + log.warning( + "TRUEMEMORY_RERANKER_TIMEOUT_SEC=%r is not an integer; using " + "default %ds.", raw_value, default, + ) + return default + if value <= 0: + log.warning( + "TRUEMEMORY_RERANKER_TIMEOUT_SEC=%d is invalid (minimum 1s); " + "using default %ds. To skip preload entirely, set " + "TRUEMEMORY_LAZY_MODELS=1.", value, default, + ) + return default + return value + + +_RERANKER_LOAD_TIMEOUT_SEC = _parse_reranker_timeout( + os.environ.get("TRUEMEMORY_RERANKER_TIMEOUT_SEC"), + _RERANKER_LOAD_TIMEOUT_DEFAULT_SEC, +) + + def _preload_models(): """Pre-load ML models in background threads so the first search is fast. Set TRUEMEMORY_LAZY_MODELS=1 to skip preloading (models load on first search). + + Reranker load is bounded by TRUEMEMORY_RERANKER_TIMEOUT_SEC (default 30s). + If CrossEncoder construction hangs — typically a corrupt HuggingFace cache, + a blocked download, or a Windows Defender ASR rule denying a sentencepiece + shim — the watchdog marks the reranker degraded and search falls back to + non-reranked results instead of blocking every subsequent MCP call. The + degraded state is also written into the F06 health payload so + truememory_stats surfaces it without operators digging through logs. """ if os.environ.get("TRUEMEMORY_LAZY_MODELS", "") == "1": log.info("Model preloading disabled (TRUEMEMORY_LAZY_MODELS=1)") @@ -1049,13 +1131,28 @@ def _load_reranker(): try: from truememory.reranker import get_reranker get_reranker(model_name=_current_reranker()) - except Exception: - pass + except Exception as e: + from truememory.reranker import mark_degraded + reason = f"preload raised {type(e).__name__}: {e}" + mark_degraded(reason) + _record_reranker_error(reason) + + def _watch_reranker(thread: threading.Thread): + thread.join(timeout=_RERANKER_LOAD_TIMEOUT_SEC) + if thread.is_alive(): + from truememory.reranker import mark_degraded + reason = ( + f"preload exceeded {_RERANKER_LOAD_TIMEOUT_SEC}s (override " + "with TRUEMEMORY_RERANKER_TIMEOUT_SEC)" + ) + mark_degraded(reason) + _record_reranker_error(reason) t1 = threading.Thread(target=_load_embedding_model_and_db, daemon=True) t2 = threading.Thread(target=_load_reranker, daemon=True) t1.start() t2.start() + threading.Thread(target=_watch_reranker, args=(t2,), daemon=True).start() # --------------------------------------------------------------------------- @@ -1098,7 +1195,14 @@ def _reap_children() -> None: Without this, Popen'd ingest processes become zombies after they finish, and os.kill(pid, 0) / ps still sees them as alive — permanently blocking spawn gate slots. + + POSIX-only: Windows has no equivalent zombie-process concept (terminated + children release their PID immediately), so os.WNOHANG is not exposed + on win32. Without this guard, the backlog drainer crashes on every + boot for every Windows user. """ + if not hasattr(os, "WNOHANG"): + return try: while True: pid, _ = os.waitpid(-1, os.WNOHANG) diff --git a/truememory/reranker.py b/truememory/reranker.py index fa5a42e..da58a08 100644 --- a/truememory/reranker.py +++ b/truememory/reranker.py @@ -43,6 +43,15 @@ _lock = threading.Lock() _inference_lock = threading.Lock() # Protects concurrent model.predict() calls +# Runtime health flag. Set to True if the preload watchdog in +# mcp_server._preload_models times out or if CrossEncoder construction raises +# during preload. When True, rerank() / rerank_with_fusion() / +# rerank_with_modality_fusion() return the original candidate ordering +# instead of calling get_reranker(), so search stays responsive even when +# the HuggingFace download stalls or the model cache is corrupt. Cleared +# only on process restart. +_load_failed: bool = False + # --------------------------------------------------------------------------- # Tier-aware reranker resolution (v0.4.0 paper §2.0) # --------------------------------------------------------------------------- @@ -147,6 +156,34 @@ def unload_reranker() -> None: _model = None +def is_degraded() -> bool: + """True if reranker preload timed out or raised during startup. + + Callers (rerank entrypoints) use this to short-circuit and return the + original candidate ordering rather than calling get_reranker() — which + would block on the same stalled load that caused the degraded mark. + Reset only on process restart. + """ + return _load_failed + + +def mark_degraded(reason: str) -> None: + """Mark the reranker as degraded so future rerank() calls fall back. + + Called by the preload watchdog in mcp_server._preload_models when + CrossEncoder construction exceeds TRUEMEMORY_RERANKER_TIMEOUT_SEC, or + when the load thread raises. Idempotent — only logs the first time + to avoid spamming. + """ + global _load_failed + if not _load_failed: + log.warning( + "Reranker degraded: %s. Search will return non-reranked results " + "until process restart.", reason, + ) + _load_failed = True + + def get_reranker(model_name: str | None = None, device: str | None = None): """ Lazy-load the cross-encoder reranker (singleton). @@ -255,6 +292,11 @@ def rerank( if len(results) <= 1: return results[:top_k] + if _load_failed: + # Degraded mode: preload timed out or raised. Don't call + # get_reranker — that would block on the same stalled load. + return results[:top_k] + model = get_reranker(model_name=model_name, device=device) # Build (query, content) pairs @@ -303,6 +345,9 @@ def rerank_with_fusion( if not results: return [] + if _load_failed: + return results[:top_k] + reranked = rerank(query, results, top_k=len(results), **kwargs) return _normalize_and_fuse(reranked, rerank_weight, rrf_weight, top_k) @@ -331,6 +376,9 @@ def rerank_with_modality_fusion( if not results: return [] + if _load_failed: + return results[:top_k] + reranked = rerank(query, results, top_k=len(results), **kwargs) question_type = _classify_question_type(query) diff --git a/truememory/telemetry.py b/truememory/telemetry.py index 719f1dc..bca6f2f 100644 --- a/truememory/telemetry.py +++ b/truememory/telemetry.py @@ -20,6 +20,7 @@ from __future__ import annotations +import asyncio import json import os import platform @@ -147,8 +148,38 @@ def identify(email: str, properties: dict | None = None) -> None: def tracked(event_name: str): - """Decorator that emits a telemetry event after a tool function runs.""" + """Decorator that emits a telemetry event after a tool function runs. + + Detects coroutine functions and returns an async wrapper for them so + FastMCP correctly sees them as `async def` and dispatches them on the + event loop. Without this branch, wrapping an `async def` MCP tool in + `@tracked` produces a sync wrapper that returns an unawaited coroutine + object, which silently breaks the tool AND defeats the entire purpose + of making the handler async in the first place. + """ def decorator(fn): + if asyncio.iscoroutinefunction(fn): + @wraps(fn) + async def async_wrapper(*args, **kwargs): + if not _enabled: + return await fn(*args, **kwargs) + start = time.monotonic() + success = True + try: + return await fn(*args, **kwargs) + except Exception: + success = False + raise + finally: + try: + track(event_name, { + "latency_ms": round((time.monotonic() - start) * 1000, 1), + "success": success, + }) + except Exception: + pass + return async_wrapper + @wraps(fn) def wrapper(*args, **kwargs): if not _enabled: