From 1edb86e25a642bd769cccaf3b0935a13051e1a7d Mon Sep 17 00:00:00 2001 From: aotenjou Date: Wed, 15 Apr 2026 12:06:21 +0800 Subject: [PATCH 1/6] feat(infer): add manifest-based model hot reload for python infer runtime Implement blue-green hot swap in TorchInferSession with throttled version polling, async single-flight loading, warmup-before-switch, and rollback with backoff so workers can safely adopt newly published models without request interruption. --- .../config/keys/FrameworkConfigKeys.java | 30 ++ .../apache/geaflow/infer/InferContext.java | 22 ++ .../infer/InferEnvironmentContext.java | 36 +++ .../infer/inferRuntime/inferSession.py | 264 ++++++++++++++++-- .../infer/inferRuntime/infer_server.py | 38 ++- .../test_infer_session_hot_reload.py | 192 +++++++++++++ 6 files changed, 562 insertions(+), 20 deletions(-) create mode 100644 geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/test_infer_session_hot_reload.py diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java index 441370ab5..c826ae476 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java @@ -153,6 +153,36 @@ public class FrameworkConfigKeys implements Serializable { .noDefaultValue() .description("infer env conda url"); + public static final ConfigKey INFER_ENV_HOT_RELOAD_MODEL_PATH = ConfigKeys + .key("geaflow.infer.env.hot.reload.model.path") + .noDefaultValue() + .description("infer env hot reload model path"); + + public static final ConfigKey INFER_ENV_HOT_RELOAD_MODEL_VERSION_FILE = ConfigKeys + .key("geaflow.infer.env.hot.reload.model.version.file") + .noDefaultValue() + .description("infer env hot reload model version manifest path"); + + public static final ConfigKey INFER_ENV_HOT_RELOAD_POLL_INTERVAL_SEC = ConfigKeys + .key("geaflow.infer.env.hot.reload.poll.interval.sec") + .defaultValue(1.0) + .description("infer env hot reload poll interval seconds"); + + public static final ConfigKey INFER_ENV_HOT_RELOAD_BACKOFF_SEC = ConfigKeys + .key("geaflow.infer.env.hot.reload.backoff.sec") + .defaultValue(10.0) + .description("infer env hot reload backoff seconds after failure"); + + public static final ConfigKey INFER_ENV_HOT_RELOAD_WARMUP_ENABLE = ConfigKeys + .key("geaflow.infer.env.hot.reload.warmup.enable") + .defaultValue(true) + .description("infer env hot reload warmup enable"); + + public static final ConfigKey INFER_ENV_HOT_RELOAD_ENABLE = ConfigKeys + .key("geaflow.infer.env.hot.reload.enable") + .defaultValue(true) + .description("infer env hot reload enable"); + public static final ConfigKey ASP_ENABLE = ConfigKeys .key("geaflow.iteration.asp.enable") .defaultValue(false) diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java index 0289c1985..05d196ad0 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java @@ -18,6 +18,12 @@ */ package org.apache.geaflow.infer; +import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_BACKOFF_SEC; +import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_ENABLE; +import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_MODEL_PATH; +import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_MODEL_VERSION_FILE; +import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_POLL_INTERVAL_SEC; +import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_WARMUP_ENABLE; import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME; import com.google.common.base.Preconditions; @@ -90,6 +96,22 @@ private void runInferTask(InferEnvironmentContext inferEnvironmentContext) { runCommands.add(inferEnvironmentContext.getInferTFClassNameParam(this.userDataTransformClass)); runCommands.add(inferEnvironmentContext.getInferShareMemoryInputParam(receiveQueueKey)); runCommands.add(inferEnvironmentContext.getInferShareMemoryOutputParam(sendQueueKey)); + + Configuration config = inferEnvironmentContext.getJobConfig(); + String modelPath = config.getString(INFER_ENV_HOT_RELOAD_MODEL_PATH, "model.pt"); + String modelVersionFile = config.getString(INFER_ENV_HOT_RELOAD_MODEL_VERSION_FILE, + "model.version"); + double pollIntervalSec = config.getDouble(INFER_ENV_HOT_RELOAD_POLL_INTERVAL_SEC); + double backoffSec = config.getDouble(INFER_ENV_HOT_RELOAD_BACKOFF_SEC); + boolean warmupEnabled = config.getBoolean(INFER_ENV_HOT_RELOAD_WARMUP_ENABLE); + boolean hotReloadEnabled = config.getBoolean(INFER_ENV_HOT_RELOAD_ENABLE); + + runCommands.add(inferEnvironmentContext.getInferModelPathParam(modelPath)); + runCommands.add(inferEnvironmentContext.getInferModelVersionFileParam(modelVersionFile)); + runCommands.add(inferEnvironmentContext.getInferPollIntervalSecParam(pollIntervalSec)); + runCommands.add(inferEnvironmentContext.getInferBackoffSecParam(backoffSec)); + runCommands.add(inferEnvironmentContext.getInferWarmupEnabledParam(warmupEnabled)); + runCommands.add(inferEnvironmentContext.getInferHotReloadEnabledParam(hotReloadEnabled)); inferTaskRunner.run(runCommands); } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java index 569b19ada..9b84dfff3 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java @@ -40,6 +40,18 @@ public class InferEnvironmentContext { // Start infer process parameter. private static final String TF_CLASSNAME_KEY = "--tfClassName="; + private static final String MODEL_PATH_KEY = "--model_path="; + + private static final String MODEL_VERSION_FILE_KEY = "--model_version_file="; + + private static final String POLL_INTERVAL_SEC_KEY = "--poll_interval_sec="; + + private static final String BACKOFF_SEC_KEY = "--backoff_sec="; + + private static final String WARMUP_ENABLED_KEY = "--warmup_enabled="; + + private static final String HOT_RELOAD_ENABLED_KEY = "--hot_reload_enabled="; + private static final String SHARE_MEMORY_INPUT_KEY = "--input_queue_shm_id="; private static final String SHARE_MEMORY_OUTPUT_KEY = "--output_queue_shm_id="; @@ -138,6 +150,30 @@ public String getInferShareMemoryOutputParam(String shareMemoryOutputKey) { return SHARE_MEMORY_OUTPUT_KEY + shareMemoryOutputKey; } + public String getInferModelPathParam(String modelPath) { + return MODEL_PATH_KEY + modelPath; + } + + public String getInferModelVersionFileParam(String modelVersionFile) { + return MODEL_VERSION_FILE_KEY + modelVersionFile; + } + + public String getInferPollIntervalSecParam(double pollIntervalSec) { + return POLL_INTERVAL_SEC_KEY + pollIntervalSec; + } + + public String getInferBackoffSecParam(double backoffSec) { + return BACKOFF_SEC_KEY + backoffSec; + } + + public String getInferWarmupEnabledParam(boolean warmupEnabled) { + return WARMUP_ENABLED_KEY + warmupEnabled; + } + + public String getInferHotReloadEnabledParam(boolean hotReloadEnabled) { + return HOT_RELOAD_ENABLED_KEY + hotReloadEnabled; + } + public String getInferScript() { return inferScript; } diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py index 63ef72ccc..c7bca47f0 100644 --- a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py +++ b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py @@ -15,26 +15,256 @@ # specific language governing permissions and limitations # under the License. +import copy +import logging import os -import torch -torch.set_num_threads(1) - -# class TorchInferSession(object): -# def __init__(self, transform_class) -> None: -# self._transform = transform_class -# self._model_path = os.getcwd() + "/model.pt" -# self._model = transform_class.load_model(self._model_path) -# -# def run(self, *inputs): -# feature = self._transform.transform_pre(*inputs) -# res = self._model(*feature) -# return self._transform.transform_post(res) +import threading +import time +from dataclasses import dataclass + +try: + import torch + torch.set_num_threads(1) +except ImportError: + torch = None + + +@dataclass +class _ModelSlot: + transform: object + model: object + version: str + class TorchInferSession(object): - def __init__(self, transform_class) -> None: - self._transform = transform_class + + def __init__(self, transform_class, hot_reload_options=None) -> None: + self._template_transform = transform_class + self._transform_cls = transform_class.__class__ + self._logger = logging.getLogger(__name__) + + options = hot_reload_options or {} + model_path = options.get("model_path") + if model_path: + self._model_path = model_path + else: + self._model_path = os.path.join(os.getcwd(), "model.pt") + self._model_path = os.path.abspath(self._model_path) + + model_root = os.path.dirname(self._model_path) + self._version_file = os.path.abspath( + options.get("model_version_file") or os.path.join(model_root, "model.version") + ) + + self._poll_interval_sec = max(0.1, float(options.get("poll_interval_sec", 1.0))) + self._backoff_sec = max(0.1, float(options.get("backoff_sec", 10.0))) + self._warmup_enabled = bool(options.get("warmup_enabled", True)) + self._hot_reload_enabled = bool(options.get("hot_reload_enabled", True)) + + self._state_lock = threading.Lock() + self._loader_thread = None + self._loading_version = None + self._last_failed_version = None + self._next_retry_ts = 0.0 + self._next_check_ts = 0.0 + self._standby_slot = None + + init_version = self._read_manifest_version() or "bootstrap" + init_slot = self._build_slot(init_version, reuse_template=True) + self._active_slot = init_slot + + self._logger.info( + "infer hot reload initialized active_version=%s model_path=%s version_file=%s " + "poll_interval_sec=%.3f backoff_sec=%.3f warmup_enabled=%s hot_reload_enabled=%s", + init_slot.version, + self._model_path, + self._version_file, + self._poll_interval_sec, + self._backoff_sec, + self._warmup_enabled, + self._hot_reload_enabled, + ) def run(self, *inputs): - a,b = self._transform.transform_pre(*inputs) - return self._transform.transform_post(a) + self.maybe_reload() + active_slot = self._get_active_slot() + return self._run_with_slot(active_slot, *inputs) + + def maybe_reload(self): + if not self._hot_reload_enabled: + return + + now = time.monotonic() + if now < self._next_check_ts: + return + + with self._state_lock: + now = time.monotonic() + if now < self._next_check_ts: + return + self._next_check_ts = now + self._poll_interval_sec + + active_version = self._active_slot.version + candidate_version = self._read_manifest_version() + if not candidate_version or candidate_version == active_version: + return + if self._loading_version == candidate_version: + return + + if self._last_failed_version == candidate_version and now < self._next_retry_ts: + return + + if self._loader_thread is not None and self._loader_thread.is_alive(): + return + + self._loading_version = candidate_version + self._loader_thread = threading.Thread( + target=self._load_and_swap, + args=(candidate_version,), + name="infer-hot-reload", + daemon=True, + ) + self._loader_thread.start() + self._logger.info( + "infer hot reload scheduled candidate_version=%s active_version=%s", + candidate_version, + active_version, + ) + + def _get_active_slot(self): + with self._state_lock: + return self._active_slot + + def _load_and_swap(self, candidate_version): + start_ts = time.monotonic() + warmup_ms = 0 + try: + standby_slot = self._build_slot(candidate_version) + load_done_ts = time.monotonic() + warmup_ms = self._warmup_slot(standby_slot) + + with self._state_lock: + old_version = self._active_slot.version + self._standby_slot = standby_slot + self._active_slot = standby_slot + self._standby_slot = None + self._loading_version = None + self._last_failed_version = None + self._next_retry_ts = 0.0 + + load_ms = int((load_done_ts - start_ts) * 1000) + self._logger.info( + "infer hot reload switched switch_success=true candidate_version=%s active_version=%s " + "load_ms=%s warmup_ms=%s", + candidate_version, + old_version, + load_ms, + warmup_ms, + ) + except Exception: + fail_ts = time.monotonic() + with self._state_lock: + self._standby_slot = None + self._loading_version = None + self._last_failed_version = candidate_version + self._next_retry_ts = fail_ts + self._backoff_sec + active_version = self._active_slot.version + + self._logger.exception( + "infer hot reload failed switch_success=false candidate_version=%s active_version=%s " + "next_retry_sec=%.3f warmup_ms=%s", + candidate_version, + active_version, + self._backoff_sec, + warmup_ms, + ) + + def _build_slot(self, version, reuse_template=False): + transform = self._build_transform(reuse_template) + model = self._load_model(transform) + return _ModelSlot(transform=transform, model=model, version=version) + + def _build_transform(self, reuse_template): + if reuse_template: + return self._template_transform + try: + return self._transform_cls() + except Exception: + return copy.deepcopy(self._template_transform) + + def _load_model(self, transform): + if not hasattr(transform, "load_model"): + return getattr(transform, "model", None) + + load_result = transform.load_model(self._model_path) + if callable(load_result): + return load_result + model = getattr(transform, "model", None) + if callable(model): + return model + return None + + def _warmup_slot(self, slot): + if not self._warmup_enabled: + return 0 + + warmup_inputs = self._get_warmup_inputs(slot.transform) + if warmup_inputs is None: + return 0 + + start_ts = time.monotonic() + self._run_with_slot(slot, *warmup_inputs) + return int((time.monotonic() - start_ts) * 1000) + + def _get_warmup_inputs(self, transform): + warmup_func = getattr(transform, "get_warmup_inputs", None) + if not callable(warmup_func): + return None + + inputs = warmup_func() + if inputs is None: + return None + if isinstance(inputs, tuple): + return inputs + if isinstance(inputs, list): + return tuple(inputs) + return (inputs,) + + def _run_with_slot(self, slot, *inputs): + pre_result = slot.transform.transform_pre(*inputs) + if slot.model is not None: + model_inputs = self._extract_model_inputs(pre_result) + model_result = self._invoke_model(slot.model, model_inputs) + return slot.transform.transform_post(model_result) + return slot.transform.transform_post(self._extract_post_inputs(pre_result)) + + def _extract_post_inputs(self, pre_result): + if isinstance(pre_result, (tuple, list)) and len(pre_result) == 2: + return pre_result[0] + return pre_result + + def _extract_model_inputs(self, pre_result): + if isinstance(pre_result, (tuple, list)) and len(pre_result) == 2: + return pre_result[1] + return pre_result + + def _invoke_model(self, model, model_inputs): + if isinstance(model_inputs, tuple): + return model(*model_inputs) + if isinstance(model_inputs, list): + return model(*model_inputs) + return model(model_inputs) + def _read_manifest_version(self): + try: + with open(self._version_file, "r", encoding="utf-8") as version_file: + version = version_file.read().strip() + return version or None + except FileNotFoundError: + return None + except Exception: + self._logger.exception( + "infer hot reload read manifest failed version_file=%s", + self._version_file, + ) + return None diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/infer_server.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/infer_server.py index 91d109d9f..c9f726cec 100644 --- a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/infer_server.py +++ b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/infer_server.py @@ -54,9 +54,20 @@ def get_user_define_class(class_name): raise ValueError("class name = {} not found".format(class_name)) -def start_infer_process(class_name, output_queue_shm_id, input_queue_shm_id): +def start_infer_process(class_name, output_queue_shm_id, input_queue_shm_id, + model_path=None, model_version_file=None, + poll_interval_sec=1.0, backoff_sec=10.0, + warmup_enabled=True, hot_reload_enabled=True): transform_class = get_user_define_class(class_name) - infer_session = TorchInferSession(transform_class) + hot_reload_options = { + "model_path": model_path, + "model_version_file": model_version_file, + "poll_interval_sec": poll_interval_sec, + "backoff_sec": backoff_sec, + "warmup_enabled": warmup_enabled, + "hot_reload_enabled": hot_reload_enabled, + } + infer_session = TorchInferSession(transform_class, hot_reload_options) input_size = transform_class.input_size data_exchange = PicklerDataBridger(input_queue_shm_id, output_queue_shm_id, input_size) check_thread = check_ppid('check_process', True) @@ -80,6 +91,10 @@ def start_infer_process(class_name, output_queue_shm_id, input_queue_shm_id): sys.exit(0) +def _str2bool(value): + return str(value).lower() in ("1", "true", "t", "yes", "y", "on") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tfClassName", type=str, @@ -89,6 +104,23 @@ def start_infer_process(class_name, output_queue_shm_id, input_queue_shm_id): "id") parser.add_argument("--output_queue_shm_id", type=str, help="output queue share memory id") + parser.add_argument("--model_path", type=str, + default=os.path.join(os.getcwd(), "model.pt"), + help="model file path") + parser.add_argument("--model_version_file", type=str, + default=os.path.join(os.getcwd(), "model.version"), + help="manifest file path") + parser.add_argument("--poll_interval_sec", type=float, default=1.0, + help="manifest poll interval in seconds") + parser.add_argument("--backoff_sec", type=float, default=10.0, + help="reload backoff in seconds after failure") + parser.add_argument("--warmup_enabled", type=_str2bool, default=True, + help="enable dummy warmup before switching") + parser.add_argument("--hot_reload_enabled", type=_str2bool, default=True, + help="enable model hot reload") args = parser.parse_args() start_infer_process(args.tfClassName, args.output_queue_shm_id, - args.input_queue_shm_id) + args.input_queue_shm_id, args.model_path, + args.model_version_file, args.poll_interval_sec, + args.backoff_sec, args.warmup_enabled, + args.hot_reload_enabled) diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/test_infer_session_hot_reload.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/test_infer_session_hot_reload.py new file mode 100644 index 000000000..97158dbbd --- /dev/null +++ b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/test_infer_session_hot_reload.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import tempfile +import threading +import time +import unittest + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +if CURRENT_DIR not in sys.path: + sys.path.insert(0, CURRENT_DIR) + +from inferSession import TorchInferSession + + +def _atomic_publish(file_path, content): + temp_path = file_path + ".tmp" + with open(temp_path, "w", encoding="utf-8") as write_file: + write_file.write(content) + os.replace(temp_path, file_path) + + +class ReloadableTransform(object): + load_count = 0 + fail_versions = set() + load_sleep_sec = 0.0 + + def __init__(self): + self.input_size = 1 + self._version = "" + + def load_model(self, model_path): + type(self).load_count += 1 + if type(self).load_sleep_sec > 0: + time.sleep(type(self).load_sleep_sec) + with open(model_path, "r", encoding="utf-8") as read_file: + self._version = read_file.read().strip() + if self._version in type(self).fail_versions: + raise RuntimeError("failed to load model version {}".format(self._version)) + + version = self._version + + def _call(data): + return "{}:{}".format(version, data) + + return _call + + def transform_pre(self, *args): + value = args[0] + if value == "slow": + time.sleep(0.2) + return value + + def transform_post(self, value): + return value + + def get_warmup_inputs(self): + return ("warmup",) + + +class TorchInferSessionHotReloadTest(unittest.TestCase): + + def setUp(self): + ReloadableTransform.load_count = 0 + ReloadableTransform.fail_versions = set() + ReloadableTransform.load_sleep_sec = 0.0 + self.temp_dir = tempfile.TemporaryDirectory() + self.model_path = os.path.join(self.temp_dir.name, "model.pt") + self.manifest_path = os.path.join(self.temp_dir.name, "model.version") + _atomic_publish(self.model_path, "v1") + _atomic_publish(self.manifest_path, "v1") + + def tearDown(self): + self.temp_dir.cleanup() + + def _build_session(self, backoff_sec=0.3): + transform = ReloadableTransform() + return TorchInferSession( + transform, + { + "model_path": self.model_path, + "model_version_file": self.manifest_path, + "poll_interval_sec": 0.1, + "backoff_sec": backoff_sec, + "warmup_enabled": True, + "hot_reload_enabled": True, + }, + ) + + def _wait_until(self, predicate, timeout_sec=3.0): + deadline = time.time() + timeout_sec + while time.time() < deadline: + if predicate(): + return True + time.sleep(0.02) + return False + + def test_manifest_change_triggers_reload(self): + session = self._build_session() + self.assertEqual("v1:request", session.run("request")) + + _atomic_publish(self.model_path, "v2") + _atomic_publish(self.manifest_path, "v2") + + switched = self._wait_until(lambda: session.run("request") == "v2:request") + self.assertTrue(switched) + + def test_single_flight_reload(self): + ReloadableTransform.load_sleep_sec = 0.2 + session = self._build_session() + self.assertEqual(1, ReloadableTransform.load_count) + + _atomic_publish(self.model_path, "v2") + _atomic_publish(self.manifest_path, "v2") + + calls = [] + + def _trigger_reload(): + session.maybe_reload() + calls.append(1) + + threads = [threading.Thread(target=_trigger_reload) for _ in range(6)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + self.assertEqual(6, len(calls)) + loaded = self._wait_until(lambda: ReloadableTransform.load_count >= 2) + self.assertTrue(loaded) + time.sleep(0.25) + self.assertEqual(2, ReloadableTransform.load_count) + + def test_switch_does_not_affect_inflight_request(self): + session = self._build_session() + + _atomic_publish(self.model_path, "v2") + _atomic_publish(self.manifest_path, "v2") + + result_holder = {} + + def _run_slow_request(): + result_holder["result"] = session.run("slow") + + slow_thread = threading.Thread(target=_run_slow_request) + slow_thread.start() + slow_thread.join() + + self.assertEqual("v1:slow", result_holder["result"]) + switched = self._wait_until(lambda: session.run("request") == "v2:request") + self.assertTrue(switched) + + def test_failed_reload_keeps_active_and_backoff(self): + session = self._build_session(backoff_sec=0.4) + + ReloadableTransform.fail_versions = {"bad"} + _atomic_publish(self.model_path, "bad") + _atomic_publish(self.manifest_path, "bad") + + failed = self._wait_until(lambda: session.run("request") == "v1:request") + self.assertTrue(failed) + first_failed_count = ReloadableTransform.load_count + self.assertGreaterEqual(first_failed_count, 2) + + time.sleep(0.1) + session.run("request") + self.assertEqual(first_failed_count, ReloadableTransform.load_count) + + time.sleep(0.45) + session.run("request") + retried = self._wait_until(lambda: ReloadableTransform.load_count > first_failed_count) + self.assertTrue(retried) + + +if __name__ == "__main__": + unittest.main() From 3f452d334f0e0092594f194ba58eba81af728b09 Mon Sep 17 00:00:00 2001 From: aotenjou Date: Thu, 16 Apr 2026 08:37:00 +0800 Subject: [PATCH 2/6] chore: modified param and test. --- .../infer/inferRuntime/infer_server.py | 15 +- .../test_infer_session_hot_reload.py | 192 ------------------ 2 files changed, 7 insertions(+), 200 deletions(-) delete mode 100644 geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/test_infer_session_hot_reload.py diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/infer_server.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/infer_server.py index c9f726cec..238aab605 100644 --- a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/infer_server.py +++ b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/infer_server.py @@ -91,10 +91,6 @@ def start_infer_process(class_name, output_queue_shm_id, input_queue_shm_id, sys.exit(0) -def _str2bool(value): - return str(value).lower() in ("1", "true", "t", "yes", "y", "on") - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tfClassName", type=str, @@ -114,13 +110,16 @@ def _str2bool(value): help="manifest poll interval in seconds") parser.add_argument("--backoff_sec", type=float, default=10.0, help="reload backoff in seconds after failure") - parser.add_argument("--warmup_enabled", type=_str2bool, default=True, + parser.add_argument("--warmup_enabled", choices=["True", "False"], + default="True", help="enable dummy warmup before switching") - parser.add_argument("--hot_reload_enabled", type=_str2bool, default=True, + parser.add_argument("--hot_reload_enabled", choices=["True", "False"], + default="True", help="enable model hot reload") args = parser.parse_args() start_infer_process(args.tfClassName, args.output_queue_shm_id, args.input_queue_shm_id, args.model_path, args.model_version_file, args.poll_interval_sec, - args.backoff_sec, args.warmup_enabled, - args.hot_reload_enabled) + args.backoff_sec, + args.warmup_enabled == "True", + args.hot_reload_enabled == "True") diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/test_infer_session_hot_reload.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/test_infer_session_hot_reload.py deleted file mode 100644 index 97158dbbd..000000000 --- a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/test_infer_session_hot_reload.py +++ /dev/null @@ -1,192 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -import sys -import tempfile -import threading -import time -import unittest - -CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -if CURRENT_DIR not in sys.path: - sys.path.insert(0, CURRENT_DIR) - -from inferSession import TorchInferSession - - -def _atomic_publish(file_path, content): - temp_path = file_path + ".tmp" - with open(temp_path, "w", encoding="utf-8") as write_file: - write_file.write(content) - os.replace(temp_path, file_path) - - -class ReloadableTransform(object): - load_count = 0 - fail_versions = set() - load_sleep_sec = 0.0 - - def __init__(self): - self.input_size = 1 - self._version = "" - - def load_model(self, model_path): - type(self).load_count += 1 - if type(self).load_sleep_sec > 0: - time.sleep(type(self).load_sleep_sec) - with open(model_path, "r", encoding="utf-8") as read_file: - self._version = read_file.read().strip() - if self._version in type(self).fail_versions: - raise RuntimeError("failed to load model version {}".format(self._version)) - - version = self._version - - def _call(data): - return "{}:{}".format(version, data) - - return _call - - def transform_pre(self, *args): - value = args[0] - if value == "slow": - time.sleep(0.2) - return value - - def transform_post(self, value): - return value - - def get_warmup_inputs(self): - return ("warmup",) - - -class TorchInferSessionHotReloadTest(unittest.TestCase): - - def setUp(self): - ReloadableTransform.load_count = 0 - ReloadableTransform.fail_versions = set() - ReloadableTransform.load_sleep_sec = 0.0 - self.temp_dir = tempfile.TemporaryDirectory() - self.model_path = os.path.join(self.temp_dir.name, "model.pt") - self.manifest_path = os.path.join(self.temp_dir.name, "model.version") - _atomic_publish(self.model_path, "v1") - _atomic_publish(self.manifest_path, "v1") - - def tearDown(self): - self.temp_dir.cleanup() - - def _build_session(self, backoff_sec=0.3): - transform = ReloadableTransform() - return TorchInferSession( - transform, - { - "model_path": self.model_path, - "model_version_file": self.manifest_path, - "poll_interval_sec": 0.1, - "backoff_sec": backoff_sec, - "warmup_enabled": True, - "hot_reload_enabled": True, - }, - ) - - def _wait_until(self, predicate, timeout_sec=3.0): - deadline = time.time() + timeout_sec - while time.time() < deadline: - if predicate(): - return True - time.sleep(0.02) - return False - - def test_manifest_change_triggers_reload(self): - session = self._build_session() - self.assertEqual("v1:request", session.run("request")) - - _atomic_publish(self.model_path, "v2") - _atomic_publish(self.manifest_path, "v2") - - switched = self._wait_until(lambda: session.run("request") == "v2:request") - self.assertTrue(switched) - - def test_single_flight_reload(self): - ReloadableTransform.load_sleep_sec = 0.2 - session = self._build_session() - self.assertEqual(1, ReloadableTransform.load_count) - - _atomic_publish(self.model_path, "v2") - _atomic_publish(self.manifest_path, "v2") - - calls = [] - - def _trigger_reload(): - session.maybe_reload() - calls.append(1) - - threads = [threading.Thread(target=_trigger_reload) for _ in range(6)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - self.assertEqual(6, len(calls)) - loaded = self._wait_until(lambda: ReloadableTransform.load_count >= 2) - self.assertTrue(loaded) - time.sleep(0.25) - self.assertEqual(2, ReloadableTransform.load_count) - - def test_switch_does_not_affect_inflight_request(self): - session = self._build_session() - - _atomic_publish(self.model_path, "v2") - _atomic_publish(self.manifest_path, "v2") - - result_holder = {} - - def _run_slow_request(): - result_holder["result"] = session.run("slow") - - slow_thread = threading.Thread(target=_run_slow_request) - slow_thread.start() - slow_thread.join() - - self.assertEqual("v1:slow", result_holder["result"]) - switched = self._wait_until(lambda: session.run("request") == "v2:request") - self.assertTrue(switched) - - def test_failed_reload_keeps_active_and_backoff(self): - session = self._build_session(backoff_sec=0.4) - - ReloadableTransform.fail_versions = {"bad"} - _atomic_publish(self.model_path, "bad") - _atomic_publish(self.manifest_path, "bad") - - failed = self._wait_until(lambda: session.run("request") == "v1:request") - self.assertTrue(failed) - first_failed_count = ReloadableTransform.load_count - self.assertGreaterEqual(first_failed_count, 2) - - time.sleep(0.1) - session.run("request") - self.assertEqual(first_failed_count, ReloadableTransform.load_count) - - time.sleep(0.45) - session.run("request") - retried = self._wait_until(lambda: ReloadableTransform.load_count > first_failed_count) - self.assertTrue(retried) - - -if __name__ == "__main__": - unittest.main() From 3bb8e18208645ea05d179cf559bcc1d4c93dc405 Mon Sep 17 00:00:00 2001 From: aotenjou Date: Thu, 16 Apr 2026 23:19:20 +0800 Subject: [PATCH 3/6] fix(infer): make data exchange shutdown idempotent --- .../infer/exchange/DataExchangeContext.java | 25 +++++++++++++++++-- .../infer/exchange/DataExchangeQueue.java | 10 +++----- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java index 417e72703..9959b3f35 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java @@ -24,6 +24,7 @@ import java.io.Closeable; import java.io.File; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -48,6 +49,9 @@ public class DataExchangeContext implements Closeable { private final File receiveQueueFile; private final File sendQueueFile; + private final Thread releaseQueueEndpointHook; + private final AtomicBoolean closed; + private final AtomicBoolean queueEndpointReleased; private String receivePath; private String sendPath; @@ -62,7 +66,10 @@ public DataExchangeContext(Configuration config) { int queueCapacity = config.getInteger(INFER_ENV_SHARE_MEMORY_QUEUE_SIZE); this.receiveQueue = new DataExchangeQueue(receivePath, queueCapacity, true); this.sendQueue = new DataExchangeQueue(sendPath, queueCapacity, true); - Runtime.getRuntime().addShutdownHook(new Thread(() -> UnSafeUtils.UNSAFE.freeMemory(queueEndpoint))); + this.closed = new AtomicBoolean(false); + this.queueEndpointReleased = new AtomicBoolean(false); + this.releaseQueueEndpointHook = new Thread(this::releaseQueueEndpoint); + Runtime.getRuntime().addShutdownHook(releaseQueueEndpointHook); } public String getReceiveQueueKey() { @@ -75,6 +82,9 @@ public String getSendQueueKey() { @Override public synchronized void close() throws IOException { + if (!closed.compareAndSet(false, true)) { + return; + } if (receiveQueue != null) { receiveQueue.close(); } @@ -87,8 +97,13 @@ public synchronized void close() throws IOException { if (sendQueueFile != null) { sendQueueFile.delete(); } - UnSafeUtils.UNSAFE.freeMemory(this.queueEndpoint); + releaseQueueEndpoint(); FileUtils.deleteQuietly(localDirectory); + try { + Runtime.getRuntime().removeShutdownHook(releaseQueueEndpointHook); + } catch (IllegalStateException ignored) { + // JVM shutdown is in progress, the hook may already be running. + } } public DataExchangeQueue getReceiveQueue() { @@ -109,4 +124,10 @@ private File createTempFile(String prefix, String suffix) { throw new GeaflowRuntimeException("create temp file on infer directory failed ", e); } } + + private void releaseQueueEndpoint() { + if (queueEndpointReleased.compareAndSet(false, true)) { + UnSafeUtils.UNSAFE.freeMemory(queueEndpoint); + } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java index 29057f60e..e1c3be29d 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java @@ -26,7 +26,7 @@ public final class DataExchangeQueue implements Closeable { - private static final AtomicBoolean CLOSED = new AtomicBoolean(false); + private final AtomicBoolean closed = new AtomicBoolean(false); private final long outputNextAddress; private final long capacityAddress; private final long outputAddress; @@ -66,11 +66,9 @@ public DataExchangeQueue(String mapKey, int capacity, boolean reset) { @Override public synchronized void close() { - CLOSED.set(true); - if (memoryMapper != null) { + if (closed.compareAndSet(false, true) && memoryMapper != null) { memoryMapper.close(); } - UnSafeUtils.UNSAFE.freeMemory(mapAddress); } public long getMemoryMapSize() { @@ -133,7 +131,7 @@ public boolean enableFinished() { } public synchronized void markFinished() { - if (!CLOSED.get()) { + if (!closed.get()) { UnSafeUtils.UNSAFE.putOrderedLong(null, endPointAddress, -1); } } @@ -165,4 +163,4 @@ public static long getNextPointIndex(long v, int capacity) { } return Pow2.align(v, capacity); } -} \ No newline at end of file +} From 9b59ca0374382345f5c40dda7ba92aa0c5755908 Mon Sep 17 00:00:00 2001 From: aotenjou Date: Thu, 16 Apr 2026 23:19:32 +0800 Subject: [PATCH 4/6] fix(infer): emit strict python boolean args --- .../org/apache/geaflow/infer/InferEnvironmentContext.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java index 9b84dfff3..b6b668cd1 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java @@ -167,11 +167,11 @@ public String getInferBackoffSecParam(double backoffSec) { } public String getInferWarmupEnabledParam(boolean warmupEnabled) { - return WARMUP_ENABLED_KEY + warmupEnabled; + return WARMUP_ENABLED_KEY + (warmupEnabled ? "True" : "False"); } public String getInferHotReloadEnabledParam(boolean hotReloadEnabled) { - return HOT_RELOAD_ENABLED_KEY + hotReloadEnabled; + return HOT_RELOAD_ENABLED_KEY + (hotReloadEnabled ? "True" : "False"); } public String getInferScript() { From f6daca38102ac902672987ee840cf667a90c37c9 Mon Sep 17 00:00:00 2001 From: aotenjou Date: Thu, 16 Apr 2026 23:19:46 +0800 Subject: [PATCH 5/6] feat(infer): honor configured hot reload timing --- .../src/main/resources/infer/inferRuntime/inferSession.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py index c7bca47f0..2a1076865 100644 --- a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py +++ b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py @@ -56,8 +56,8 @@ def __init__(self, transform_class, hot_reload_options=None) -> None: options.get("model_version_file") or os.path.join(model_root, "model.version") ) - self._poll_interval_sec = max(0.1, float(options.get("poll_interval_sec", 1.0))) - self._backoff_sec = max(0.1, float(options.get("backoff_sec", 10.0))) + self._poll_interval_sec = float(options.get("poll_interval_sec", 1.0)) + self._backoff_sec = float(options.get("backoff_sec", 10.0)) self._warmup_enabled = bool(options.get("warmup_enabled", True)) self._hot_reload_enabled = bool(options.get("hot_reload_enabled", True)) From b928adfde46c64443f41b40199fac3a5a8b67657 Mon Sep 17 00:00:00 2001 From: aotenjou Date: Sat, 18 Apr 2026 17:40:59 +0800 Subject: [PATCH 6/6] feat(infer): double-buffer hot reload --- .../infer/inferRuntime/inferSession.py | 151 +++++--- .../geaflow/infer/InferHotReloadTest.java | 360 ++++++++++++++++++ 2 files changed, 467 insertions(+), 44 deletions(-) create mode 100644 geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferHotReloadTest.java diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py index 2a1076865..78e25a2aa 100644 --- a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py +++ b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py @@ -1,3 +1,4 @@ + # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,6 +16,7 @@ # specific language governing permissions and limitations # under the License. +import atexit import copy import logging import os @@ -56,27 +58,28 @@ def __init__(self, transform_class, hot_reload_options=None) -> None: options.get("model_version_file") or os.path.join(model_root, "model.version") ) - self._poll_interval_sec = float(options.get("poll_interval_sec", 1.0)) - self._backoff_sec = float(options.get("backoff_sec", 10.0)) + self._poll_interval_sec = max(float(options.get("poll_interval_sec", 1.0)), 0.01) + self._backoff_sec = max(float(options.get("backoff_sec", 10.0)), 0.0) self._warmup_enabled = bool(options.get("warmup_enabled", True)) self._hot_reload_enabled = bool(options.get("hot_reload_enabled", True)) self._state_lock = threading.Lock() - self._loader_thread = None + self._watcher_stop_event = threading.Event() + self._watcher_thread = None self._loading_version = None self._last_failed_version = None self._next_retry_ts = 0.0 - self._next_check_ts = 0.0 - self._standby_slot = None + self._closed = False - init_version = self._read_manifest_version() or "bootstrap" - init_slot = self._build_slot(init_version, reuse_template=True) - self._active_slot = init_slot + init_version = self._read_reload_fingerprint() or "bootstrap" + self.model_active = self._build_slot(init_version, reuse_template=True) + self.model_standby = self._build_slot(init_version) self._logger.info( - "infer hot reload initialized active_version=%s model_path=%s version_file=%s " + "infer hot reload initialized active_version=%s standby_version=%s model_path=%s version_file=%s " "poll_interval_sec=%.3f backoff_sec=%.3f warmup_enabled=%s hot_reload_enabled=%s", - init_slot.version, + self.model_active.version, + self.model_standby.version if self.model_standby is not None else None, self._model_path, self._version_file, self._poll_interval_sec, @@ -84,70 +87,103 @@ def __init__(self, transform_class, hot_reload_options=None) -> None: self._warmup_enabled, self._hot_reload_enabled, ) + if self._hot_reload_enabled: + self._watcher_thread = threading.Thread( + target=self._watch_reload_loop, + name="infer-hot-reload-watcher", + daemon=True, + ) + self._watcher_thread.start() + atexit.register(self.close) def run(self, *inputs): - self.maybe_reload() + self._ensure_open() active_slot = self._get_active_slot() return self._run_with_slot(active_slot, *inputs) + def close(self): + with self._state_lock: + if self._closed: + return + self._closed = True + self._watcher_stop_event.set() + watcher_thread = self._watcher_thread + if watcher_thread is not None and watcher_thread.is_alive() and watcher_thread is not threading.current_thread(): + watcher_thread.join(timeout=1.0) + with self._state_lock: + self._watcher_thread = None + self._loading_version = None + self._last_failed_version = None + self._next_retry_ts = 0.0 + self.model_standby = None + self.model_active = None + + def _ensure_open(self): + with self._state_lock: + if self._closed: + raise RuntimeError("infer session already closed") + + def _watch_reload_loop(self): + while not self._watcher_stop_event.wait(self._poll_interval_sec): + try: + self.maybe_reload() + except Exception: + self._logger.exception("infer hot reload watcher loop failed") + def maybe_reload(self): - if not self._hot_reload_enabled: + if not self._hot_reload_enabled or self._closed: return now = time.monotonic() - if now < self._next_check_ts: + candidate_version = self._read_reload_fingerprint() + if not candidate_version: return with self._state_lock: - now = time.monotonic() - if now < self._next_check_ts: + if self._closed: return - self._next_check_ts = now + self._poll_interval_sec - - active_version = self._active_slot.version - candidate_version = self._read_manifest_version() - if not candidate_version or candidate_version == active_version: + if self._loading_version is not None: return - if self._loading_version == candidate_version: + active_slot = self.model_active + standby_slot = self.model_standby + active_version = active_slot.version if active_slot is not None else None + standby_version = standby_slot.version if standby_slot is not None else None + if active_slot is None or standby_slot is None: return - - if self._last_failed_version == candidate_version and now < self._next_retry_ts: + if candidate_version == active_version: return - - if self._loader_thread is not None and self._loader_thread.is_alive(): + if self._last_failed_version == candidate_version and now < self._next_retry_ts: return - self._loading_version = candidate_version - self._loader_thread = threading.Thread( - target=self._load_and_swap, - args=(candidate_version,), - name="infer-hot-reload", - daemon=True, - ) - self._loader_thread.start() self._logger.info( - "infer hot reload scheduled candidate_version=%s active_version=%s", + "infer hot reload loading standby candidate_version=%s active_version=%s standby_version=%s", candidate_version, active_version, + standby_version, ) + self._load_and_swap(candidate_version, active_slot, standby_slot) + def _get_active_slot(self): with self._state_lock: - return self._active_slot + return self.model_active - def _load_and_swap(self, candidate_version): + def _load_and_swap(self, candidate_version, active_slot, standby_slot): start_ts = time.monotonic() warmup_ms = 0 try: - standby_slot = self._build_slot(candidate_version) + self._load_slot(standby_slot) load_done_ts = time.monotonic() warmup_ms = self._warmup_slot(standby_slot) with self._state_lock: - old_version = self._active_slot.version - self._standby_slot = standby_slot - self._active_slot = standby_slot - self._standby_slot = None + if self._closed: + self._loading_version = None + return + old_version = active_slot.version if active_slot is not None else None + standby_slot.version = candidate_version + self.model_active = standby_slot + self.model_standby = active_slot self._loading_version = None self._last_failed_version = None self._next_retry_ts = 0.0 @@ -164,11 +200,10 @@ def _load_and_swap(self, candidate_version): except Exception: fail_ts = time.monotonic() with self._state_lock: - self._standby_slot = None self._loading_version = None self._last_failed_version = candidate_version self._next_retry_ts = fail_ts + self._backoff_sec - active_version = self._active_slot.version + active_version = self.model_active.version if self.model_active is not None else None self._logger.exception( "infer hot reload failed switch_success=false candidate_version=%s active_version=%s " @@ -184,6 +219,9 @@ def _build_slot(self, version, reuse_template=False): model = self._load_model(transform) return _ModelSlot(transform=transform, model=model, version=version) + def _load_slot(self, slot): + slot.model = self._load_model(slot.transform) + def _build_transform(self, reuse_template): if reuse_template: return self._template_transform @@ -210,7 +248,7 @@ def _warmup_slot(self, slot): warmup_inputs = self._get_warmup_inputs(slot.transform) if warmup_inputs is None: - return 0 + raise RuntimeError("infer hot reload warmup inputs missing") start_ts = time.monotonic() self._run_with_slot(slot, *warmup_inputs) @@ -255,6 +293,31 @@ def _invoke_model(self, model, model_inputs): return model(*model_inputs) return model(model_inputs) + def _read_reload_fingerprint(self): + model_signature = self._file_signature(self._model_path) + version_signature = self._file_signature(self._version_file) + if model_signature is None and version_signature is None: + return None + + fingerprint_parts = [ + "model={}".format(model_signature or "missing"), + "manifest={}".format(version_signature or "missing"), + ] + manifest_version = self._read_manifest_version() + if manifest_version: + fingerprint_parts.append("token={}".format(manifest_version)) + return "|".join(fingerprint_parts) + + def _file_signature(self, path): + try: + stat_result = os.stat(path) + except FileNotFoundError: + return None + except Exception: + self._logger.exception("infer hot reload stat failed path=%s", path) + return None + return "{}:{}".format(int(stat_result.st_mtime_ns), stat_result.st_size) + def _read_manifest_version(self): try: with open(self._version_file, "r", encoding="utf-8") as version_file: diff --git a/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferHotReloadTest.java b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferHotReloadTest.java new file mode 100644 index 000000000..048d8dec0 --- /dev/null +++ b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferHotReloadTest.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.testng.SkipException; +import org.testng.annotations.Test; + +public class InferHotReloadTest { + + @Test + public void testDoubleBufferHotReloadSwapsAfterWarmup() throws Exception { + ensurePythonAvailable(); + Path tempDir = Files.createTempDirectory("infer-hot-reload"); + try { + copyInferSessionResource(tempDir); + writeTorchStub(tempDir); + Files.write(tempDir.resolve("runner.py"), buildHappyPathRunnerScript().getBytes(UTF_8)); + + Map values = parseOutput(runPython(tempDir)); + String initActiveId = values.get("INIT_ACTIVE_ID"); + String initStandbyId = values.get("INIT_STANDBY_ID"); + String initActiveTransformId = values.get("INIT_ACTIVE_TRANSFORM_ID"); + String initStandbyTransformId = values.get("INIT_STANDBY_TRANSFORM_ID"); + + assertEquals(values.get("INITIAL"), "v1"); + assertEquals(values.get("BEFORE_SWAP"), "v1,v1,v1,v1"); + assertEquals(values.get("AFTER_SWAP"), "v2,v2,v2"); + assertEquals(values.get("LOAD_EVENTS"), + initActiveTransformId + "@v1," + initStandbyTransformId + "@v1," + + initStandbyTransformId + "@v2"); + assertEquals(values.get("WARMUP_EVENTS"), initStandbyTransformId + "@v2"); + assertTrue(!initActiveId.equals(initStandbyId), values.toString()); + assertEquals(values.get("FINAL_ACTIVE_ID"), initStandbyId); + assertEquals(values.get("FINAL_STANDBY_ID"), initActiveId); + assertEquals(values.get("FINAL_ACTIVE_RUN"), "v2"); + assertEquals(values.get("FINAL_STANDBY_RUN"), "v1"); + } finally { + deleteRecursively(tempDir); + } + } + + @Test + public void testMissingWarmupKeepsActiveSlot() throws Exception { + ensurePythonAvailable(); + Path tempDir = Files.createTempDirectory("infer-hot-reload"); + try { + copyInferSessionResource(tempDir); + writeTorchStub(tempDir); + Files.write(tempDir.resolve("runner.py"), buildMissingWarmupRunnerScript().getBytes(UTF_8)); + + Map values = parseOutput(runPython(tempDir)); + String initActiveId = values.get("INIT_ACTIVE_ID"); + String initStandbyId = values.get("INIT_STANDBY_ID"); + String initActiveTransformId = values.get("INIT_ACTIVE_TRANSFORM_ID"); + String initStandbyTransformId = values.get("INIT_STANDBY_TRANSFORM_ID"); + + assertEquals(values.get("INITIAL"), "v1"); + assertEquals(values.get("AFTER_FAILURE"), "v1,v1,v1,v1"); + assertEquals(values.get("LOAD_EVENTS"), + initActiveTransformId + "@v1," + initStandbyTransformId + "@v1," + + initStandbyTransformId + "@v2"); + assertTrue(values.get("WARMUP_EVENTS").isEmpty(), values.toString()); + assertTrue(!initActiveId.equals(initStandbyId), values.toString()); + assertEquals(values.get("FINAL_ACTIVE_ID"), initActiveId); + assertEquals(values.get("FINAL_STANDBY_ID"), initStandbyId); + assertEquals(values.get("FINAL_ACTIVE_RUN"), "v1"); + } finally { + deleteRecursively(tempDir); + } + } + + private String runPython(Path tempDir) throws Exception { + Process process = new ProcessBuilder("python3", "runner.py") + .directory(tempDir.toFile()) + .redirectErrorStream(true) + .start(); + if (!process.waitFor(30, TimeUnit.SECONDS)) { + process.destroyForcibly(); + throw new RuntimeException("python hot reload test timed out"); + } + + byte[] outputBytes; + try (InputStream inputStream = process.getInputStream()) { + outputBytes = inputStream.readAllBytes(); + } + String output = new String(outputBytes, UTF_8).trim(); + assertEquals(process.exitValue(), 0, output); + return output; + } + + private Map parseOutput(String output) { + Map values = new HashMap<>(); + for (String line : output.split("\\R")) { + int delimiterIndex = line.indexOf('='); + if (delimiterIndex <= 0) { + continue; + } + values.put(line.substring(0, delimiterIndex), line.substring(delimiterIndex + 1)); + } + return values; + } + + private String buildHappyPathRunnerScript() { + StringBuilder script = new StringBuilder(); + appendPythonLine(script, "import os"); + appendPythonLine(script, "import threading"); + appendPythonLine(script, "import time"); + appendPythonLine(script, "from inferSession import TorchInferSession"); + appendPythonLine(script, ""); + appendPythonLine(script, "class Transform(object):"); + appendPythonLine(script, " input_size = 1"); + appendPythonLine(script, " load_events = []"); + appendPythonLine(script, " warmup_events = []"); + appendPythonLine(script, " lock = threading.Lock()"); + appendPythonLine(script, " warmup_started = threading.Event()"); + appendPythonLine(script, " warmup_release = threading.Event()"); + appendPythonLine(script, ""); + appendPythonLine(script, " def __init__(self):"); + appendPythonLine(script, " self.version = 'unset'"); + appendPythonLine(script, ""); + appendPythonLine(script, " def load_model(self, model_path):"); + appendPythonLine(script, " with open(model_path, 'r', encoding='utf-8') as model_file:"); + appendPythonLine(script, " self.version = model_file.read().strip()"); + appendPythonLine(script, " with self.__class__.lock:"); + appendPythonLine(script, " self.__class__.load_events.append('{}@{}'.format(id(self), self.version))"); + appendPythonLine(script, " if self.version == 'v2':"); + appendPythonLine(script, " time.sleep(0.2)"); + appendPythonLine(script, " return None"); + appendPythonLine(script, ""); + appendPythonLine(script, " def get_warmup_inputs(self):"); + appendPythonLine(script, " return ('warm',)"); + appendPythonLine(script, ""); + appendPythonLine(script, " def transform_pre(self, value):"); + appendPythonLine(script, " if value == 'warm':"); + appendPythonLine(script, " with self.__class__.lock:"); + appendPythonLine(script, " self.__class__.warmup_events.append('{}@{}'.format(id(self), self.version))"); + appendPythonLine(script, " self.__class__.warmup_started.set()"); + appendPythonLine(script, " if not self.__class__.warmup_release.wait(5):"); + appendPythonLine(script, " raise RuntimeError('warmup release timed out')"); + appendPythonLine(script, " return self.version"); + appendPythonLine(script, ""); + appendPythonLine(script, " def transform_post(self, value):"); + appendPythonLine(script, " return value"); + appendPythonLine(script, ""); + appendPythonLine(script, "def write_model(version):"); + appendPythonLine(script, " with open('model.pt', 'w', encoding='utf-8') as model_file:"); + appendPythonLine(script, " model_file.write(version)"); + appendPythonLine(script, " with open('model.version', 'w', encoding='utf-8') as version_file:"); + appendPythonLine(script, " version_file.write(version)"); + appendPythonLine(script, " os.utime('model.pt', None)"); + appendPythonLine(script, " os.utime('model.version', None)"); + appendPythonLine(script, ""); + appendPythonLine(script, "write_model('v1')"); + appendPythonLine(script, "session = TorchInferSession(Transform(), {"); + appendPythonLine(script, " 'model_path': os.path.join(os.getcwd(), 'model.pt'),"); + appendPythonLine(script, " 'model_version_file': os.path.join(os.getcwd(), 'model.version'),"); + appendPythonLine(script, " 'poll_interval_sec': 0.05,"); + appendPythonLine(script, " 'backoff_sec': 1.0,"); + appendPythonLine(script, " 'warmup_enabled': True,"); + appendPythonLine(script, " 'hot_reload_enabled': True,"); + appendPythonLine(script, "})"); + appendPythonLine(script, ""); + appendPythonLine(script, "initial = session.run('req')"); + appendPythonLine(script, "init_active_id = str(id(session.model_active))"); + appendPythonLine(script, "init_standby_id = str(id(session.model_standby))"); + appendPythonLine(script, "init_active_transform_id = str(id(session.model_active.transform))"); + appendPythonLine(script, "init_standby_transform_id = str(id(session.model_standby.transform))"); + appendPythonLine(script, "write_model('v2')"); + appendPythonLine(script, "if not Transform.warmup_started.wait(5):"); + appendPythonLine(script, " raise RuntimeError('warmup did not start')"); + appendPythonLine(script, "before_swap = []"); + appendPythonLine(script, "for _ in range(4):"); + appendPythonLine(script, " before_swap.append(session.run('req'))"); + appendPythonLine(script, " time.sleep(0.05)"); + appendPythonLine(script, "Transform.warmup_release.set()"); + appendPythonLine(script, "time.sleep(0.8)"); + appendPythonLine(script, "after_swap = []"); + appendPythonLine(script, "for _ in range(3):"); + appendPythonLine(script, " after_swap.append(session.run('req'))"); + appendPythonLine(script, " time.sleep(0.05)"); + appendPythonLine(script, "final_active_id = str(id(session.model_active))"); + appendPythonLine(script, "final_standby_id = str(id(session.model_standby))"); + appendPythonLine(script, "standby_run = session._run_with_slot(session.model_standby, 'req')"); + appendPythonLine(script, "final_active_run = session.run('req')"); + appendPythonLine(script, "print('INITIAL=' + initial)"); + appendPythonLine(script, "print('BEFORE_SWAP=' + ','.join(before_swap))"); + appendPythonLine(script, "print('AFTER_SWAP=' + ','.join(after_swap))"); + appendPythonLine(script, "print('LOAD_EVENTS=' + ','.join(Transform.load_events))"); + appendPythonLine(script, "print('WARMUP_EVENTS=' + ','.join(Transform.warmup_events))"); + appendPythonLine(script, "print('INIT_ACTIVE_ID=' + init_active_id)"); + appendPythonLine(script, "print('INIT_STANDBY_ID=' + init_standby_id)"); + appendPythonLine(script, "print('INIT_ACTIVE_TRANSFORM_ID=' + init_active_transform_id)"); + appendPythonLine(script, "print('INIT_STANDBY_TRANSFORM_ID=' + init_standby_transform_id)"); + appendPythonLine(script, "print('FINAL_ACTIVE_ID=' + final_active_id)"); + appendPythonLine(script, "print('FINAL_STANDBY_ID=' + final_standby_id)"); + appendPythonLine(script, "print('FINAL_ACTIVE_RUN=' + final_active_run)"); + appendPythonLine(script, "print('FINAL_STANDBY_RUN=' + standby_run)"); + appendPythonLine(script, "session.close()"); + return script.toString(); + } + + private String buildMissingWarmupRunnerScript() { + StringBuilder script = new StringBuilder(); + appendPythonLine(script, "import os"); + appendPythonLine(script, "import threading"); + appendPythonLine(script, "import time"); + appendPythonLine(script, "from inferSession import TorchInferSession"); + appendPythonLine(script, ""); + appendPythonLine(script, "class Transform(object):"); + appendPythonLine(script, " input_size = 1"); + appendPythonLine(script, " load_events = []"); + appendPythonLine(script, " warmup_events = []"); + appendPythonLine(script, " lock = threading.Lock()"); + appendPythonLine(script, ""); + appendPythonLine(script, " def __init__(self):"); + appendPythonLine(script, " self.version = 'unset'"); + appendPythonLine(script, ""); + appendPythonLine(script, " def load_model(self, model_path):"); + appendPythonLine(script, " with open(model_path, 'r', encoding='utf-8') as model_file:"); + appendPythonLine(script, " self.version = model_file.read().strip()"); + appendPythonLine(script, " with self.__class__.lock:"); + appendPythonLine(script, " self.__class__.load_events.append('{}@{}'.format(id(self), self.version))"); + appendPythonLine(script, " if self.version == 'v2':"); + appendPythonLine(script, " time.sleep(0.2)"); + appendPythonLine(script, " return None"); + appendPythonLine(script, ""); + appendPythonLine(script, " def transform_pre(self, value):"); + appendPythonLine(script, " return self.version"); + appendPythonLine(script, ""); + appendPythonLine(script, " def transform_post(self, value):"); + appendPythonLine(script, " return value"); + appendPythonLine(script, ""); + appendPythonLine(script, "def write_model(version):"); + appendPythonLine(script, " with open('model.pt', 'w', encoding='utf-8') as model_file:"); + appendPythonLine(script, " model_file.write(version)"); + appendPythonLine(script, " with open('model.version', 'w', encoding='utf-8') as version_file:"); + appendPythonLine(script, " version_file.write(version)"); + appendPythonLine(script, " os.utime('model.pt', None)"); + appendPythonLine(script, " os.utime('model.version', None)"); + appendPythonLine(script, ""); + appendPythonLine(script, "write_model('v1')"); + appendPythonLine(script, "session = TorchInferSession(Transform(), {"); + appendPythonLine(script, " 'model_path': os.path.join(os.getcwd(), 'model.pt'),"); + appendPythonLine(script, " 'model_version_file': os.path.join(os.getcwd(), 'model.version'),"); + appendPythonLine(script, " 'poll_interval_sec': 0.05,"); + appendPythonLine(script, " 'backoff_sec': 1.0,"); + appendPythonLine(script, " 'warmup_enabled': True,"); + appendPythonLine(script, " 'hot_reload_enabled': True,"); + appendPythonLine(script, "})"); + appendPythonLine(script, ""); + appendPythonLine(script, "initial = session.run('req')"); + appendPythonLine(script, "init_active_id = str(id(session.model_active))"); + appendPythonLine(script, "init_standby_id = str(id(session.model_standby))"); + appendPythonLine(script, "init_active_transform_id = str(id(session.model_active.transform))"); + appendPythonLine(script, "init_standby_transform_id = str(id(session.model_standby.transform))"); + appendPythonLine(script, "write_model('v2')"); + appendPythonLine(script, "time.sleep(0.8)"); + appendPythonLine(script, "after_failure = []"); + appendPythonLine(script, "for _ in range(4):"); + appendPythonLine(script, " after_failure.append(session.run('req'))"); + appendPythonLine(script, " time.sleep(0.05)"); + appendPythonLine(script, "final_active_id = str(id(session.model_active))"); + appendPythonLine(script, "final_standby_id = str(id(session.model_standby))"); + appendPythonLine(script, "final_active_run = session.run('req')"); + appendPythonLine(script, "print('INITIAL=' + initial)"); + appendPythonLine(script, "print('AFTER_FAILURE=' + ','.join(after_failure))"); + appendPythonLine(script, "print('LOAD_EVENTS=' + ','.join(Transform.load_events))"); + appendPythonLine(script, "print('WARMUP_EVENTS=' + ','.join(Transform.warmup_events))"); + appendPythonLine(script, "print('INIT_ACTIVE_ID=' + init_active_id)"); + appendPythonLine(script, "print('INIT_STANDBY_ID=' + init_standby_id)"); + appendPythonLine(script, "print('INIT_ACTIVE_TRANSFORM_ID=' + init_active_transform_id)"); + appendPythonLine(script, "print('INIT_STANDBY_TRANSFORM_ID=' + init_standby_transform_id)"); + appendPythonLine(script, "print('FINAL_ACTIVE_ID=' + final_active_id)"); + appendPythonLine(script, "print('FINAL_STANDBY_ID=' + final_standby_id)"); + appendPythonLine(script, "print('FINAL_ACTIVE_RUN=' + final_active_run)"); + appendPythonLine(script, "session.close()"); + return script.toString(); + } + + private void appendPythonLine(StringBuilder script, String line) { + script.append(line).append(System.lineSeparator()); + } + + private void ensurePythonAvailable() throws Exception { + try { + Process process = new ProcessBuilder("python3", "--version") + .redirectErrorStream(true) + .start(); + process.getInputStream().readAllBytes(); + if (process.waitFor() != 0) { + throw new SkipException("python3 is unavailable for infer hot reload test"); + } + } catch (IOException e) { + throw new SkipException("python3 is unavailable for infer hot reload test", e); + } + } + + private void copyInferSessionResource(Path tempDir) throws IOException { + try (InputStream inputStream = InferHotReloadTest.class + .getResourceAsStream("/infer/inferRuntime/inferSession.py")) { + if (inputStream == null) { + throw new IOException("cannot load inferSession.py from classpath"); + } + Files.copy(inputStream, tempDir.resolve("inferSession.py")); + } + } + + private void writeTorchStub(Path tempDir) throws IOException { + Files.write(tempDir.resolve("torch.py"), Arrays.asList( + "def set_num_threads(_):", + " return None" + ), UTF_8); + } + + private void deleteRecursively(Path root) throws IOException { + if (root == null || !Files.exists(root)) { + return; + } + try (Stream stream = Files.walk(root)) { + List paths = stream.sorted(Comparator.reverseOrder()) + .collect(Collectors.toList()); + for (Path path : paths) { + Files.deleteIfExists(path); + } + } + } +}