Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions engine/packages/guard-core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ impl ProxyService {
.and_then(|ip_str| ip_str.parse::<std::net::IpAddr>().ok())
.unwrap_or_else(|| self.remote_addr.ip());

let is_websocket = hyper_tungstenite::is_upgrade_request(&req);
let mut req_ctx = RequestContext::new(
self.remote_addr,
request_ids.ray_id,
Expand All @@ -388,6 +389,7 @@ impl ProxyService {
path,
req.method().clone(),
req.headers().clone(),
is_websocket,
client_ip,
start_time,
);
Expand All @@ -409,8 +411,6 @@ impl ProxyService {
"Request received"
);

let is_websocket = hyper_tungstenite::is_upgrade_request(&req);

// Used for ws error proxying later
let mut mock_req_builder = Request::builder()
.method(req.method().clone())
Expand Down
7 changes: 7 additions & 0 deletions engine/packages/guard-core/src/request_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct RequestContext {
pub(crate) path: String,
pub(crate) method: Method,
pub(crate) headers: HeaderMap,
pub(crate) is_websocket: bool,
pub(crate) client_ip: IpAddr,
pub(crate) start_time: Instant,

Expand All @@ -41,6 +42,7 @@ impl RequestContext {
path: String,
method: Method,
headers: HeaderMap,
is_websocket: bool,
client_ip: IpAddr,
start_time: Instant,
) -> Self {
Expand All @@ -55,6 +57,7 @@ impl RequestContext {
path,
method,
headers,
is_websocket,
client_ip,
start_time,

Expand Down Expand Up @@ -106,6 +109,10 @@ impl RequestContext {
&self.headers
}

pub fn is_websocket(&self) -> bool {
self.is_websocket
}

pub fn in_flight_request_id(&self) -> Result<protocol::RequestId> {
self.in_flight_request_id
.context("no in flight request id acquired")
Expand Down
41 changes: 0 additions & 41 deletions engine/packages/guard/src/cache/actor.rs

This file was deleted.

76 changes: 43 additions & 33 deletions engine/packages/guard/src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,67 @@ use std::{
sync::Arc,
};

use anyhow::Result;
use gas::prelude::*;
use rivet_guard_core::{CacheKeyFn, request_context::RequestContext};

pub mod actor;
pub mod pegboard_gateway;

use crate::routing::X_RIVET_TARGET;
use crate::routing::{
SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_TARGET, X_RIVET_TARGET, parse_actor_path,
};

/// Creates the main cache key function that handles all incoming requests
#[tracing::instrument(skip_all)]
pub fn create_cache_key_function() -> CacheKeyFn {
Arc::new(move |req_ctx| {
tracing::debug!("building cache key");

let target = match read_target(req_ctx.headers()) {
Ok(target) => target,
Err(err) => {
tracing::debug!(?err, "failed parsing target for cache key");
// MARK: Path-based cache key
// Check for path-based actor routing
if let Some(actor_path_info) = parse_actor_path(req_ctx.path()) {
tracing::debug!("using path-based cache key for actor");

return Ok(host_path_method_cache_key(req_ctx));
if let Ok(cache_key) =
pegboard_gateway::build_cache_key_path_based(req_ctx, &actor_path_info)
{
return Ok(cache_key);
}
};

let cache_key = match actor::build_cache_key(req_ctx, target) {
Ok(key) => Some(key),
Err(err) => {
tracing::debug!(?err, "failed to create actor cache key");

None
}
};
}

// Fallback to hostname + path + method hash if actor did not work
if let Some(cache_key) = cache_key {
Ok(cache_key)
// MARK: Header- & protocol-based cache key (X-Rivet-Target)
// Determine target
let target = if req_ctx.is_websocket() {
// For WebSocket, parse the sec-websocket-protocol header
req_ctx
.headers()
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|protocols| protocols.to_str().ok())
.and_then(|protocols| {
// Parse protocols to find target.{value}
protocols
.split(',')
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET))
})
} else {
Ok(host_path_method_cache_key(req_ctx))
}
})
}
// For HTTP, use the x-rivet-target header
req_ctx
.headers()
.get(X_RIVET_TARGET)
.and_then(|x| x.to_str().ok())
};

fn read_target(headers: &hyper::HeaderMap) -> Result<&str> {
// Read target
let target = headers.get(X_RIVET_TARGET).ok_or_else(|| {
crate::errors::MissingHeader {
header: X_RIVET_TARGET.to_string(),
// Check target-based cache functions
if let Some(target) = target {
if let Ok(cache_key) = pegboard_gateway::build_cache_key_target_based(req_ctx, target) {
return Ok(cache_key);
}
}
.build()
})?;

Ok(target.to_str()?)
// MARK: Fallback
tracing::debug!("using fallback cache key");
Ok(host_path_method_cache_key(req_ctx))
})
}

fn host_path_method_cache_key(req_ctx: &RequestContext) -> u64 {
Expand Down
100 changes: 100 additions & 0 deletions engine/packages/guard/src/cache/pegboard_gateway.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};

use anyhow::Result;
use gas::prelude::*;
use rivet_guard_core::request_context::RequestContext;

use crate::routing::{
ActorPathInfo, SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_ACTOR, pegboard_gateway::X_RIVET_ACTOR,
};

/// Build cache key for path-based actor routing
#[tracing::instrument(skip_all)]
pub fn build_cache_key_path_based(
req_ctx: &RequestContext,
actor_path_info: &ActorPathInfo,
) -> Result<u64> {
let target = "actor";

// Parse actor ID from path
let actor_id = Id::parse(&actor_path_info.actor_id).context("invalid actor id in path")?;

// Create a hash using actor_id, stripped path, and method
let mut hasher = DefaultHasher::new();
target.hash(&mut hasher);
actor_id.hash(&mut hasher);
// TODO: Should this exclude query for cache key?
actor_path_info.stripped_path.hash(&mut hasher);
req_ctx.method().as_str().hash(&mut hasher);
let hash = hasher.finish();

Ok(hash)
}

/// Build cache key for target-based actor routing (header or WebSocket protocol)
#[tracing::instrument(skip_all)]
pub fn build_cache_key_target_based(req_ctx: &RequestContext, target: &str) -> Result<u64> {
// Check target
ensure!(target == "actor", "wrong target");

// Extract actor ID from WebSocket protocol or HTTP headers
let actor_id_str = if req_ctx.is_websocket() {
// For WebSocket, parse the sec-websocket-protocol header
let protocols_header = req_ctx
.headers()
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|protocols| protocols.to_str().ok())
.ok_or_else(|| {
crate::errors::MissingHeader {
header: "sec-websocket-protocol".to_string(),
}
.build()
})?;

let protocols: Vec<&str> = protocols_header.split(',').map(|p| p.trim()).collect();

let actor_id_raw = protocols
.iter()
.find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR))
.ok_or_else(|| {
crate::errors::MissingHeader {
header: "`rivet_actor.*` protocol in sec-websocket-protocol".to_string(),
}
.build()
})?;

urlencoding::decode(actor_id_raw)
.context("invalid url encoding in actor id")?
.to_string()
} else {
// For HTTP, use the x-rivet-actor header
req_ctx
.headers()
.get(X_RIVET_ACTOR)
.ok_or_else(|| {
crate::errors::MissingHeader {
header: X_RIVET_ACTOR.to_string(),
}
.build()
})?
.to_str()
.context("invalid x-rivet-actor header")?
.to_string()
};

let actor_id = Id::parse(&actor_id_str).context("invalid actor id")?;

// Create a hash using target, actor_id, path, and method
let mut hasher = DefaultHasher::new();
target.hash(&mut hasher);
actor_id.hash(&mut hasher);
// TODO: Should this exclude query for cache key?
req_ctx.path().hash(&mut hasher);
req_ctx.method().as_str().hash(&mut hasher);
let hash = hasher.finish();

Ok(hash)
}
22 changes: 4 additions & 18 deletions engine/packages/guard/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
async move {
tracing::debug!(hostname=%req_ctx.hostname(), path=%req_ctx.path(), "Routing request");

// Check if this is a WebSocket upgrade request
let is_websocket = req_ctx
.headers()
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);

// MARK: Path-based routing
// Route actor
if let Some(actor_path_info) = parse_actor_path(req_ctx.path()) {
Expand All @@ -60,7 +52,6 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
&actor_path_info.actor_id,
actor_path_info.token.as_deref(),
&actor_path_info.stripped_path,
is_websocket,
)
.await?
{
Expand All @@ -81,7 +72,7 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -

// MARK: Header- & protocol-based routing (X-Rivet-Target)
// Determine target
let target = if is_websocket {
let target = if req_ctx.is_websocket() {
// For WebSocket, parse the sec-websocket-protocol header
req_ctx
.headers()
Expand Down Expand Up @@ -112,14 +103,9 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
return Ok(routing_output);
}

if let Some(routing_output) = pegboard_gateway::route_request(
&ctx,
&shared_state,
req_ctx,
target,
is_websocket,
)
.await?
if let Some(routing_output) =
pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, target)
.await?
{
metrics::ROUTE_TOTAL.with_label_values(&["gateway"]).inc();

Expand Down
6 changes: 2 additions & 4 deletions engine/packages/guard/src/routing/pegboard_gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ pub async fn route_request_path_based(
actor_id_str: &str,
token_from_path: Option<&str>,
stripped_path: &str,
is_websocket: bool,
) -> Result<Option<RoutingOutput>> {
// Parse actor ID
let actor_id = Id::parse(actor_id_str).context("invalid actor id in path")?;

// Prefer token from path, otherwise read headers
let token = if let Some(token) = token_from_path {
Some(token)
} else if is_websocket {
} else if req_ctx.is_websocket() {
// For WebSocket, parse the sec-websocket-protocol header
let protocols_header = req_ctx
.headers()
Expand Down Expand Up @@ -75,15 +74,14 @@ pub async fn route_request(
shared_state: &SharedState,
req_ctx: &RequestContext,
target: &str,
is_websocket: bool,
) -> Result<Option<RoutingOutput>> {
// Check target
if target != "actor" {
return Ok(None);
}

// Extract actor ID and token from WebSocket protocol or HTTP headers
let (actor_id_str, token) = if is_websocket {
let (actor_id_str, token) = if req_ctx.is_websocket() {
// For WebSocket, parse the sec-websocket-protocol header
let protocols_header = req_ctx
.headers()
Expand Down
Loading
Loading