diff --git a/engine/packages/guard-core/src/proxy_service.rs b/engine/packages/guard-core/src/proxy_service.rs index bbbcc71dc0..96b3c216b2 100644 --- a/engine/packages/guard-core/src/proxy_service.rs +++ b/engine/packages/guard-core/src/proxy_service.rs @@ -380,6 +380,7 @@ impl ProxyService { .and_then(|ip_str| ip_str.parse::().ok()) .unwrap_or_else(|| self.remote_addr.ip()); + let is_websocket = hyper_tungstenite::is_upgrade_request(&req); let mut req_ctx = RequestContext::new( self.remote_addr, request_ids.ray_id, @@ -388,6 +389,7 @@ impl ProxyService { path, req.method().clone(), req.headers().clone(), + is_websocket, client_ip, start_time, ); @@ -409,8 +411,6 @@ impl ProxyService { "Request received" ); - let is_websocket = hyper_tungstenite::is_upgrade_request(&req); - // Used for ws error proxying later let mut mock_req_builder = Request::builder() .method(req.method().clone()) diff --git a/engine/packages/guard-core/src/request_context.rs b/engine/packages/guard-core/src/request_context.rs index 196c107dc4..63c9b8a4bd 100644 --- a/engine/packages/guard-core/src/request_context.rs +++ b/engine/packages/guard-core/src/request_context.rs @@ -20,6 +20,7 @@ pub struct RequestContext { pub(crate) path: String, pub(crate) method: Method, pub(crate) headers: HeaderMap, + pub(crate) is_websocket: bool, pub(crate) client_ip: IpAddr, pub(crate) start_time: Instant, @@ -41,6 +42,7 @@ impl RequestContext { path: String, method: Method, headers: HeaderMap, + is_websocket: bool, client_ip: IpAddr, start_time: Instant, ) -> Self { @@ -55,6 +57,7 @@ impl RequestContext { path, method, headers, + is_websocket, client_ip, start_time, @@ -106,6 +109,10 @@ impl RequestContext { &self.headers } + pub fn is_websocket(&self) -> bool { + self.is_websocket + } + pub fn in_flight_request_id(&self) -> Result { self.in_flight_request_id .context("no in flight request id acquired") diff --git a/engine/packages/guard/src/cache/actor.rs b/engine/packages/guard/src/cache/actor.rs deleted file mode 100644 index a9de2eb0c1..0000000000 --- a/engine/packages/guard/src/cache/actor.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, -}; - -use anyhow::Result; -use gas::prelude::*; -use rivet_guard_core::request_context::RequestContext; - -use crate::routing::pegboard_gateway::X_RIVET_ACTOR; - -#[tracing::instrument(skip_all)] -pub fn build_cache_key(req_ctx: &RequestContext, target: &str) -> Result { - // Check target - ensure!(target == "actor", "wrong target"); - - // Find actor to route to - let actor_id_str = req_ctx - .headers() - .get(X_RIVET_ACTOR) - .ok_or_else(|| { - crate::errors::MissingHeader { - header: X_RIVET_ACTOR.to_string(), - } - .build() - })? - .to_str() - .context("invalid x-rivet-actor header")?; - let actor_id = Id::parse(actor_id_str).context("invalid x-rivet-actor header")?; - - // Create a hash using target, actor_id, path, and method - let mut hasher = DefaultHasher::new(); - target.hash(&mut hasher); - actor_id.hash(&mut hasher); - // TODO: Should this include query for cache key? - req_ctx.path().hash(&mut hasher); - req_ctx.method().as_str().hash(&mut hasher); - let hash = hasher.finish(); - - Ok(hash) -} diff --git a/engine/packages/guard/src/cache/mod.rs b/engine/packages/guard/src/cache/mod.rs index cd97aee4ba..d1117f4a88 100644 --- a/engine/packages/guard/src/cache/mod.rs +++ b/engine/packages/guard/src/cache/mod.rs @@ -4,13 +4,14 @@ use std::{ sync::Arc, }; -use anyhow::Result; use gas::prelude::*; use rivet_guard_core::{CacheKeyFn, request_context::RequestContext}; -pub mod actor; +pub mod pegboard_gateway; -use crate::routing::X_RIVET_TARGET; +use crate::routing::{ + SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_TARGET, X_RIVET_TARGET, parse_actor_path, +}; /// Creates the main cache key function that handles all incoming requests #[tracing::instrument(skip_all)] @@ -18,43 +19,52 @@ pub fn create_cache_key_function() -> CacheKeyFn { Arc::new(move |req_ctx| { tracing::debug!("building cache key"); - let target = match read_target(req_ctx.headers()) { - Ok(target) => target, - Err(err) => { - tracing::debug!(?err, "failed parsing target for cache key"); + // MARK: Path-based cache key + // Check for path-based actor routing + if let Some(actor_path_info) = parse_actor_path(req_ctx.path()) { + tracing::debug!("using path-based cache key for actor"); - return Ok(host_path_method_cache_key(req_ctx)); + if let Ok(cache_key) = + pegboard_gateway::build_cache_key_path_based(req_ctx, &actor_path_info) + { + return Ok(cache_key); } - }; - - let cache_key = match actor::build_cache_key(req_ctx, target) { - Ok(key) => Some(key), - Err(err) => { - tracing::debug!(?err, "failed to create actor cache key"); - - None - } - }; + } - // Fallback to hostname + path + method hash if actor did not work - if let Some(cache_key) = cache_key { - Ok(cache_key) + // MARK: Header- & protocol-based cache key (X-Rivet-Target) + // Determine target + let target = if req_ctx.is_websocket() { + // For WebSocket, parse the sec-websocket-protocol header + req_ctx + .headers() + .get(SEC_WEBSOCKET_PROTOCOL) + .and_then(|protocols| protocols.to_str().ok()) + .and_then(|protocols| { + // Parse protocols to find target.{value} + protocols + .split(',') + .map(|p| p.trim()) + .find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET)) + }) } else { - Ok(host_path_method_cache_key(req_ctx)) - } - }) -} + // For HTTP, use the x-rivet-target header + req_ctx + .headers() + .get(X_RIVET_TARGET) + .and_then(|x| x.to_str().ok()) + }; -fn read_target(headers: &hyper::HeaderMap) -> Result<&str> { - // Read target - let target = headers.get(X_RIVET_TARGET).ok_or_else(|| { - crate::errors::MissingHeader { - header: X_RIVET_TARGET.to_string(), + // Check target-based cache functions + if let Some(target) = target { + if let Ok(cache_key) = pegboard_gateway::build_cache_key_target_based(req_ctx, target) { + return Ok(cache_key); + } } - .build() - })?; - Ok(target.to_str()?) + // MARK: Fallback + tracing::debug!("using fallback cache key"); + Ok(host_path_method_cache_key(req_ctx)) + }) } fn host_path_method_cache_key(req_ctx: &RequestContext) -> u64 { diff --git a/engine/packages/guard/src/cache/pegboard_gateway.rs b/engine/packages/guard/src/cache/pegboard_gateway.rs new file mode 100644 index 0000000000..6804f5beb2 --- /dev/null +++ b/engine/packages/guard/src/cache/pegboard_gateway.rs @@ -0,0 +1,100 @@ +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, +}; + +use anyhow::Result; +use gas::prelude::*; +use rivet_guard_core::request_context::RequestContext; + +use crate::routing::{ + ActorPathInfo, SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_ACTOR, pegboard_gateway::X_RIVET_ACTOR, +}; + +/// Build cache key for path-based actor routing +#[tracing::instrument(skip_all)] +pub fn build_cache_key_path_based( + req_ctx: &RequestContext, + actor_path_info: &ActorPathInfo, +) -> Result { + let target = "actor"; + + // Parse actor ID from path + let actor_id = Id::parse(&actor_path_info.actor_id).context("invalid actor id in path")?; + + // Create a hash using actor_id, stripped path, and method + let mut hasher = DefaultHasher::new(); + target.hash(&mut hasher); + actor_id.hash(&mut hasher); + // TODO: Should this exclude query for cache key? + actor_path_info.stripped_path.hash(&mut hasher); + req_ctx.method().as_str().hash(&mut hasher); + let hash = hasher.finish(); + + Ok(hash) +} + +/// Build cache key for target-based actor routing (header or WebSocket protocol) +#[tracing::instrument(skip_all)] +pub fn build_cache_key_target_based(req_ctx: &RequestContext, target: &str) -> Result { + // Check target + ensure!(target == "actor", "wrong target"); + + // Extract actor ID from WebSocket protocol or HTTP headers + let actor_id_str = if req_ctx.is_websocket() { + // For WebSocket, parse the sec-websocket-protocol header + let protocols_header = req_ctx + .headers() + .get(SEC_WEBSOCKET_PROTOCOL) + .and_then(|protocols| protocols.to_str().ok()) + .ok_or_else(|| { + crate::errors::MissingHeader { + header: "sec-websocket-protocol".to_string(), + } + .build() + })?; + + let protocols: Vec<&str> = protocols_header.split(',').map(|p| p.trim()).collect(); + + let actor_id_raw = protocols + .iter() + .find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR)) + .ok_or_else(|| { + crate::errors::MissingHeader { + header: "`rivet_actor.*` protocol in sec-websocket-protocol".to_string(), + } + .build() + })?; + + urlencoding::decode(actor_id_raw) + .context("invalid url encoding in actor id")? + .to_string() + } else { + // For HTTP, use the x-rivet-actor header + req_ctx + .headers() + .get(X_RIVET_ACTOR) + .ok_or_else(|| { + crate::errors::MissingHeader { + header: X_RIVET_ACTOR.to_string(), + } + .build() + })? + .to_str() + .context("invalid x-rivet-actor header")? + .to_string() + }; + + let actor_id = Id::parse(&actor_id_str).context("invalid actor id")?; + + // Create a hash using target, actor_id, path, and method + let mut hasher = DefaultHasher::new(); + target.hash(&mut hasher); + actor_id.hash(&mut hasher); + // TODO: Should this exclude query for cache key? + req_ctx.path().hash(&mut hasher); + req_ctx.method().as_str().hash(&mut hasher); + let hash = hasher.finish(); + + Ok(hash) +} diff --git a/engine/packages/guard/src/routing/mod.rs b/engine/packages/guard/src/routing/mod.rs index e7d7f8a34c..97c80158b7 100644 --- a/engine/packages/guard/src/routing/mod.rs +++ b/engine/packages/guard/src/routing/mod.rs @@ -39,14 +39,6 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) - async move { tracing::debug!(hostname=%req_ctx.hostname(), path=%req_ctx.path(), "Routing request"); - // Check if this is a WebSocket upgrade request - let is_websocket = req_ctx - .headers() - .get("upgrade") - .and_then(|v| v.to_str().ok()) - .map(|v| v.eq_ignore_ascii_case("websocket")) - .unwrap_or(false); - // MARK: Path-based routing // Route actor if let Some(actor_path_info) = parse_actor_path(req_ctx.path()) { @@ -60,7 +52,6 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) - &actor_path_info.actor_id, actor_path_info.token.as_deref(), &actor_path_info.stripped_path, - is_websocket, ) .await? { @@ -81,7 +72,7 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) - // MARK: Header- & protocol-based routing (X-Rivet-Target) // Determine target - let target = if is_websocket { + let target = if req_ctx.is_websocket() { // For WebSocket, parse the sec-websocket-protocol header req_ctx .headers() @@ -112,14 +103,9 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) - return Ok(routing_output); } - if let Some(routing_output) = pegboard_gateway::route_request( - &ctx, - &shared_state, - req_ctx, - target, - is_websocket, - ) - .await? + if let Some(routing_output) = + pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, target) + .await? { metrics::ROUTE_TOTAL.with_label_values(&["gateway"]).inc(); diff --git a/engine/packages/guard/src/routing/pegboard_gateway.rs b/engine/packages/guard/src/routing/pegboard_gateway.rs index 6676e6f55c..7925e0d298 100644 --- a/engine/packages/guard/src/routing/pegboard_gateway.rs +++ b/engine/packages/guard/src/routing/pegboard_gateway.rs @@ -27,7 +27,6 @@ pub async fn route_request_path_based( actor_id_str: &str, token_from_path: Option<&str>, stripped_path: &str, - is_websocket: bool, ) -> Result> { // Parse actor ID let actor_id = Id::parse(actor_id_str).context("invalid actor id in path")?; @@ -35,7 +34,7 @@ pub async fn route_request_path_based( // Prefer token from path, otherwise read headers let token = if let Some(token) = token_from_path { Some(token) - } else if is_websocket { + } else if req_ctx.is_websocket() { // For WebSocket, parse the sec-websocket-protocol header let protocols_header = req_ctx .headers() @@ -75,7 +74,6 @@ pub async fn route_request( shared_state: &SharedState, req_ctx: &RequestContext, target: &str, - is_websocket: bool, ) -> Result> { // Check target if target != "actor" { @@ -83,7 +81,7 @@ pub async fn route_request( } // Extract actor ID and token from WebSocket protocol or HTTP headers - let (actor_id_str, token) = if is_websocket { + let (actor_id_str, token) = if req_ctx.is_websocket() { // For WebSocket, parse the sec-websocket-protocol header let protocols_header = req_ctx .headers() diff --git a/engine/packages/guard/src/routing/runner.rs b/engine/packages/guard/src/routing/runner.rs index 9d7daf736b..d09198ab46 100644 --- a/engine/packages/guard/src/routing/runner.rs +++ b/engine/packages/guard/src/routing/runner.rs @@ -72,19 +72,10 @@ async fn route_runner_internal( tracing::debug!(datacenter = ?current_dc.name, "validated host for datacenter"); - let is_websocket = req_ctx - .headers() - .get("upgrade") - .and_then(|v| v.to_str().ok()) - .map(|v| v.eq_ignore_ascii_case("websocket")) - .unwrap_or(false); - - tracing::debug!(is_websocket, "connection type"); - // Check auth (if enabled) if let Some(auth) = &ctx.config().auth { // Extract token from protocol or header - let token = if is_websocket { + let token = if req_ctx.is_websocket() { req_ctx .headers() .get(SEC_WEBSOCKET_PROTOCOL)