Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,36 @@ public class FrameworkConfigKeys implements Serializable {
.noDefaultValue()
.description("infer env conda url");

public static final ConfigKey INFER_ENV_HOT_RELOAD_MODEL_PATH = ConfigKeys
.key("geaflow.infer.env.hot.reload.model.path")
.noDefaultValue()
.description("infer env hot reload model path");

public static final ConfigKey INFER_ENV_HOT_RELOAD_MODEL_VERSION_FILE = ConfigKeys
.key("geaflow.infer.env.hot.reload.model.version.file")
.noDefaultValue()
.description("infer env hot reload model version manifest path");

public static final ConfigKey INFER_ENV_HOT_RELOAD_POLL_INTERVAL_SEC = ConfigKeys
.key("geaflow.infer.env.hot.reload.poll.interval.sec")
.defaultValue(1.0)
.description("infer env hot reload poll interval seconds");

public static final ConfigKey INFER_ENV_HOT_RELOAD_BACKOFF_SEC = ConfigKeys
.key("geaflow.infer.env.hot.reload.backoff.sec")
.defaultValue(10.0)
.description("infer env hot reload backoff seconds after failure");

public static final ConfigKey INFER_ENV_HOT_RELOAD_WARMUP_ENABLE = ConfigKeys
.key("geaflow.infer.env.hot.reload.warmup.enable")
.defaultValue(true)
.description("infer env hot reload warmup enable");

public static final ConfigKey INFER_ENV_HOT_RELOAD_ENABLE = ConfigKeys
.key("geaflow.infer.env.hot.reload.enable")
.defaultValue(true)
.description("infer env hot reload enable");

public static final ConfigKey ASP_ENABLE = ConfigKeys
.key("geaflow.iteration.asp.enable")
.defaultValue(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
*/
package org.apache.geaflow.infer;

import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_BACKOFF_SEC;
import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_ENABLE;
import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_MODEL_PATH;
import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_MODEL_VERSION_FILE;
import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_POLL_INTERVAL_SEC;
import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_HOT_RELOAD_WARMUP_ENABLE;
import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME;

import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -90,6 +96,22 @@ private void runInferTask(InferEnvironmentContext inferEnvironmentContext) {
runCommands.add(inferEnvironmentContext.getInferTFClassNameParam(this.userDataTransformClass));
runCommands.add(inferEnvironmentContext.getInferShareMemoryInputParam(receiveQueueKey));
runCommands.add(inferEnvironmentContext.getInferShareMemoryOutputParam(sendQueueKey));

Configuration config = inferEnvironmentContext.getJobConfig();
String modelPath = config.getString(INFER_ENV_HOT_RELOAD_MODEL_PATH, "model.pt");
String modelVersionFile = config.getString(INFER_ENV_HOT_RELOAD_MODEL_VERSION_FILE,
"model.version");
double pollIntervalSec = config.getDouble(INFER_ENV_HOT_RELOAD_POLL_INTERVAL_SEC);
double backoffSec = config.getDouble(INFER_ENV_HOT_RELOAD_BACKOFF_SEC);
boolean warmupEnabled = config.getBoolean(INFER_ENV_HOT_RELOAD_WARMUP_ENABLE);
boolean hotReloadEnabled = config.getBoolean(INFER_ENV_HOT_RELOAD_ENABLE);

runCommands.add(inferEnvironmentContext.getInferModelPathParam(modelPath));
runCommands.add(inferEnvironmentContext.getInferModelVersionFileParam(modelVersionFile));
runCommands.add(inferEnvironmentContext.getInferPollIntervalSecParam(pollIntervalSec));
runCommands.add(inferEnvironmentContext.getInferBackoffSecParam(backoffSec));
runCommands.add(inferEnvironmentContext.getInferWarmupEnabledParam(warmupEnabled));
runCommands.add(inferEnvironmentContext.getInferHotReloadEnabledParam(hotReloadEnabled));
inferTaskRunner.run(runCommands);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ public class InferEnvironmentContext {
// Start infer process parameter.
private static final String TF_CLASSNAME_KEY = "--tfClassName=";

private static final String MODEL_PATH_KEY = "--model_path=";

private static final String MODEL_VERSION_FILE_KEY = "--model_version_file=";

private static final String POLL_INTERVAL_SEC_KEY = "--poll_interval_sec=";

private static final String BACKOFF_SEC_KEY = "--backoff_sec=";

private static final String WARMUP_ENABLED_KEY = "--warmup_enabled=";

private static final String HOT_RELOAD_ENABLED_KEY = "--hot_reload_enabled=";

private static final String SHARE_MEMORY_INPUT_KEY = "--input_queue_shm_id=";

private static final String SHARE_MEMORY_OUTPUT_KEY = "--output_queue_shm_id=";
Expand Down Expand Up @@ -138,6 +150,30 @@ public String getInferShareMemoryOutputParam(String shareMemoryOutputKey) {
return SHARE_MEMORY_OUTPUT_KEY + shareMemoryOutputKey;
}

public String getInferModelPathParam(String modelPath) {
return MODEL_PATH_KEY + modelPath;
}

public String getInferModelVersionFileParam(String modelVersionFile) {
return MODEL_VERSION_FILE_KEY + modelVersionFile;
}

public String getInferPollIntervalSecParam(double pollIntervalSec) {
return POLL_INTERVAL_SEC_KEY + pollIntervalSec;
}

public String getInferBackoffSecParam(double backoffSec) {
return BACKOFF_SEC_KEY + backoffSec;
}

public String getInferWarmupEnabledParam(boolean warmupEnabled) {
return WARMUP_ENABLED_KEY + (warmupEnabled ? "True" : "False");
}

public String getInferHotReloadEnabledParam(boolean hotReloadEnabled) {
return HOT_RELOAD_ENABLED_KEY + (hotReloadEnabled ? "True" : "False");
}

public String getInferScript() {
return inferScript;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.io.FileUtils;
import org.apache.geaflow.common.config.Configuration;
import org.apache.geaflow.common.exception.GeaflowRuntimeException;
Expand All @@ -48,6 +49,9 @@ public class DataExchangeContext implements Closeable {

private final File receiveQueueFile;
private final File sendQueueFile;
private final Thread releaseQueueEndpointHook;
private final AtomicBoolean closed;
private final AtomicBoolean queueEndpointReleased;
private String receivePath;
private String sendPath;

Expand All @@ -62,7 +66,10 @@ public DataExchangeContext(Configuration config) {
int queueCapacity = config.getInteger(INFER_ENV_SHARE_MEMORY_QUEUE_SIZE);
this.receiveQueue = new DataExchangeQueue(receivePath, queueCapacity, true);
this.sendQueue = new DataExchangeQueue(sendPath, queueCapacity, true);
Runtime.getRuntime().addShutdownHook(new Thread(() -> UnSafeUtils.UNSAFE.freeMemory(queueEndpoint)));
this.closed = new AtomicBoolean(false);
this.queueEndpointReleased = new AtomicBoolean(false);
this.releaseQueueEndpointHook = new Thread(this::releaseQueueEndpoint);
Runtime.getRuntime().addShutdownHook(releaseQueueEndpointHook);
}

public String getReceiveQueueKey() {
Expand All @@ -75,6 +82,9 @@ public String getSendQueueKey() {

@Override
public synchronized void close() throws IOException {
if (!closed.compareAndSet(false, true)) {
return;
}
if (receiveQueue != null) {
receiveQueue.close();
}
Expand All @@ -87,8 +97,13 @@ public synchronized void close() throws IOException {
if (sendQueueFile != null) {
sendQueueFile.delete();
}
UnSafeUtils.UNSAFE.freeMemory(this.queueEndpoint);
releaseQueueEndpoint();
FileUtils.deleteQuietly(localDirectory);
try {
Runtime.getRuntime().removeShutdownHook(releaseQueueEndpointHook);
} catch (IllegalStateException ignored) {
// JVM shutdown is in progress, the hook may already be running.
}
}

public DataExchangeQueue getReceiveQueue() {
Expand All @@ -109,4 +124,10 @@ private File createTempFile(String prefix, String suffix) {
throw new GeaflowRuntimeException("create temp file on infer directory failed ", e);
}
}

private void releaseQueueEndpoint() {
if (queueEndpointReleased.compareAndSet(false, true)) {
UnSafeUtils.UNSAFE.freeMemory(queueEndpoint);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

public final class DataExchangeQueue implements Closeable {

private static final AtomicBoolean CLOSED = new AtomicBoolean(false);
private final AtomicBoolean closed = new AtomicBoolean(false);
private final long outputNextAddress;
private final long capacityAddress;
private final long outputAddress;
Expand Down Expand Up @@ -66,11 +66,9 @@ public DataExchangeQueue(String mapKey, int capacity, boolean reset) {

@Override
public synchronized void close() {
CLOSED.set(true);
if (memoryMapper != null) {
if (closed.compareAndSet(false, true) && memoryMapper != null) {
memoryMapper.close();
}
UnSafeUtils.UNSAFE.freeMemory(mapAddress);
}

public long getMemoryMapSize() {
Expand Down Expand Up @@ -133,7 +131,7 @@ public boolean enableFinished() {
}

public synchronized void markFinished() {
if (!CLOSED.get()) {
if (!closed.get()) {
UnSafeUtils.UNSAFE.putOrderedLong(null, endPointAddress, -1);
}
}
Expand Down Expand Up @@ -165,4 +163,4 @@ public static long getNextPointIndex(long v, int capacity) {
}
return Pow2.align(v, capacity);
}
}
}
Loading