diff --git a/src/agent-client-protocol-core/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol-core/src/jsonrpc/incoming_actor.rs index 6d4375b..0aa2a7a 100644 --- a/src/agent-client-protocol-core/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol-core/src/jsonrpc/incoming_actor.rs @@ -13,11 +13,13 @@ use crate::RoleId; use crate::UntypedMessage; use crate::jsonrpc::ConnectionTo; use crate::jsonrpc::HandleDispatchFrom; +use crate::jsonrpc::OutgoingMessage; use crate::jsonrpc::ReplyMessage; use crate::jsonrpc::Responder; use crate::jsonrpc::ResponseRouter; use crate::jsonrpc::dynamic_handler::DynHandleDispatchFrom; use crate::jsonrpc::dynamic_handler::DynamicHandlerMessage; +use crate::jsonrpc::outgoing_actor::send_raw_message; use crate::role::Role; @@ -93,20 +95,25 @@ pub(super) async fn incoming_protocol_actor( let mut new_pending_messages = vec![]; for pending_message in pending_messages { tracing::trace!(method = pending_message.method(), handler = ?handler.dyn_describe_chain(), "Retrying message"); + let id = pending_message.id(); match handler .dyn_handle_dispatch_from(pending_message, connection.clone()) - .await? + .await { - Handled::Yes => { + Ok(Handled::Yes) => { tracing::trace!("Message handled"); } - Handled::No { + Ok(Handled::No { message: m, retry: _, - } => { + }) => { tracing::trace!(method = m.method(), handler = ?handler.dyn_describe_chain(), "Message not handled"); new_pending_messages.push(m); } + Err(err) => { + tracing::warn!(?err, handler = ?handler.dyn_describe_chain(), "Dynamic handler errored on pending message, reporting back"); + report_handler_error(connection, id, err)?; + } } } pending_messages = new_pending_messages; @@ -253,18 +260,23 @@ async fn dispatch_dispatch( tracing::trace!(handler = ?handler.describe_chain(), "Attempting handler chain"); match handler .handle_dispatch_from(dispatch, connection.clone()) - .await? + .await { - Handled::Yes => { + Ok(Handled::Yes) => { tracing::trace!(?method, ?id, handler = ?handler.describe_chain(), "Handler accepted message"); return Ok(()); } - Handled::No { message: m, retry } => { + Ok(Handled::No { message: m, retry }) => { tracing::trace!(?method, ?id, handler = ?handler.describe_chain(), "Handler declined message"); dispatch = m; retry_any |= retry; } + + Err(err) => { + tracing::warn!(?method, ?id, ?err, handler = ?handler.describe_chain(), "Handler errored, reporting back to remote"); + return report_handler_error(connection, id, err); + } } // Next, apply any dynamic handlers. @@ -272,18 +284,23 @@ async fn dispatch_dispatch( tracing::trace!(handler = ?dynamic_handler.dyn_describe_chain(), "Attempting dynamic handler"); match dynamic_handler .dyn_handle_dispatch_from(dispatch, connection.clone()) - .await? + .await { - Handled::Yes => { + Ok(Handled::Yes) => { tracing::trace!(?method, ?id, handler = ?dynamic_handler.dyn_describe_chain(), "Dynamic handler accepted message"); return Ok(()); } - Handled::No { message: m, retry } => { + Ok(Handled::No { message: m, retry }) => { tracing::trace!(?method, ?id, handler = ?dynamic_handler.dyn_describe_chain(), "Dynamic handler declined message"); retry_any |= retry; dispatch = m; } + + Err(err) => { + tracing::warn!(?method, ?id, ?err, handler = ?dynamic_handler.dyn_describe_chain(), "Dynamic handler errored, reporting back to remote"); + return report_handler_error(connection, id, err); + } } } @@ -291,17 +308,27 @@ async fn dispatch_dispatch( tracing::trace!(role = ?counterpart, "Attempting default handler"); match counterpart .default_handle_dispatch_from(dispatch, connection.clone()) - .await? + .await { - Handled::Yes => { + Ok(Handled::Yes) => { tracing::trace!(?method, handler = "default", "Role accepted message"); return Ok(()); } - Handled::No { message: m, retry } => { + Ok(Handled::No { message: m, retry }) => { tracing::trace!(?method, handler = "default", "Role declined message"); dispatch = m; retry_any |= retry; } + Err(err) => { + tracing::warn!( + ?method, + ?id, + ?err, + handler = "default", + "Default handler errored, reporting back to remote" + ); + return report_handler_error(connection, id, err); + } } // If the message was never handled, check whether the retry flag was set. @@ -330,3 +357,33 @@ async fn dispatch_dispatch( } } } + +/// When a handler returns an error, report it back to the remote side instead +/// of propagating it and tearing down the connection. +/// +/// For requests (which have an id), sends a JSON-RPC error response. +/// For notifications (no id), sends an out-of-band error notification. +/// For responses, forwards the error to the local awaiter. +fn report_handler_error( + connection: &ConnectionTo, + id: Option, + error: crate::Error, +) -> Result<(), crate::Error> { + match id { + Some(id) => { + // Request: send error response with the original request id + let jsonrpc_id = serde_json::from_value(id).unwrap_or(jsonrpcmsg::Id::Null); + send_raw_message( + &connection.message_tx, + OutgoingMessage::Response { + id: jsonrpc_id, + response: Err(error), + }, + ) + } + None => { + // Notification or response without id: send error notification + connection.send_error_notification(error) + } + } +} diff --git a/src/agent-client-protocol-core/src/schema/mod.rs b/src/agent-client-protocol-core/src/schema/mod.rs index 30f6d93..8367031 100644 --- a/src/agent-client-protocol-core/src/schema/mod.rs +++ b/src/agent-client-protocol-core/src/schema/mod.rs @@ -36,7 +36,7 @@ macro_rules! impl_jsonrpc_request { if method != $method { return Err($crate::Error::method_not_found()); } - $crate::util::json_cast(params) + $crate::util::json_cast_params(params) } } @@ -84,7 +84,7 @@ macro_rules! impl_jsonrpc_notification { if method != $method { return Err($crate::Error::method_not_found()); } - $crate::util::json_cast(params) + $crate::util::json_cast_params(params) } } @@ -133,10 +133,10 @@ macro_rules! impl_jsonrpc_request_enum { params: &impl serde::Serialize, ) -> Result { match method { - $( $(#[$meta])* $method => $crate::util::json_cast(params).map(Self::$variant), )* + $( $(#[$meta])* $method => $crate::util::json_cast_params(params).map(Self::$variant), )* _ => { if let Some(custom_method) = method.strip_prefix('_') { - $crate::util::json_cast(params).map( + $crate::util::json_cast_params(params).map( |ext_req: $crate::schema::ExtRequest| { Self::$ext_variant($crate::schema::ExtRequest::new( custom_method.to_string(), @@ -196,10 +196,10 @@ macro_rules! impl_jsonrpc_notification_enum { params: &impl serde::Serialize, ) -> Result { match method { - $( $(#[$meta])* $method => $crate::util::json_cast(params).map(Self::$variant), )* + $( $(#[$meta])* $method => $crate::util::json_cast_params(params).map(Self::$variant), )* _ => { if let Some(custom_method) = method.strip_prefix('_') { - $crate::util::json_cast(params).map( + $crate::util::json_cast_params(params).map( |ext_notif: $crate::schema::ExtNotification| { Self::$ext_variant($crate::schema::ExtNotification::new( custom_method.to_string(), diff --git a/src/agent-client-protocol-core/src/schema/proxy_protocol.rs b/src/agent-client-protocol-core/src/schema/proxy_protocol.rs index 38b7109..aa35dd1 100644 --- a/src/agent-client-protocol-core/src/schema/proxy_protocol.rs +++ b/src/agent-client-protocol-core/src/schema/proxy_protocol.rs @@ -50,7 +50,7 @@ impl JsonRpcMessage for SuccessorMessage { if method != METHOD_SUCCESSOR_MESSAGE { return Err(crate::Error::method_not_found()); } - let outer = crate::util::json_cast::<_, SuccessorMessage>(params)?; + let outer = crate::util::json_cast_params::<_, SuccessorMessage>(params)?; if !M::matches_method(&outer.message.method) { return Err(crate::Error::method_not_found()); } @@ -161,7 +161,7 @@ impl JsonRpcMessage for McpOverAcpMessage { if method != METHOD_MCP_MESSAGE { return Err(crate::Error::method_not_found()); } - let outer = crate::util::json_cast::<_, McpOverAcpMessage>(params)?; + let outer = crate::util::json_cast_params::<_, McpOverAcpMessage>(params)?; if !M::matches_method(&outer.message.method) { return Err(crate::Error::method_not_found()); } diff --git a/src/agent-client-protocol-core/src/util.rs b/src/agent-client-protocol-core/src/util.rs index 310b02e..2428fa1 100644 --- a/src/agent-client-protocol-core/src/util.rs +++ b/src/agent-client-protocol-core/src/util.rs @@ -30,6 +30,33 @@ where Ok(m) } +/// Cast incoming request/notification params into a typed payload. +/// +/// Like [`json_cast`], but deserialization failures become +/// [`Error::invalid_params`](`crate::Error::invalid_params`) (`-32602`) +/// instead of a parse error, which is the correct JSON-RPC error code for +/// malformed method parameters. +pub fn json_cast_params(params: N) -> Result +where + N: serde::Serialize, + M: serde::de::DeserializeOwned, +{ + let json = serde_json::to_value(params).map_err(|e| { + crate::Error::internal_error().data(serde_json::json!({ + "error": e.to_string(), + "phase": "serialization" + })) + })?; + let m = serde_json::from_value(json.clone()).map_err(|e| { + crate::Error::invalid_params().data(serde_json::json!({ + "error": e.to_string(), + "json": json, + "phase": "deserialization" + })) + })?; + Ok(m) +} + /// Creates an internal error with the given message pub fn internal_error(message: impl ToString) -> crate::Error { crate::Error::internal_error().data(message.to_string()) diff --git a/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs b/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs index 11ecbfe..2d7090a 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs @@ -78,7 +78,7 @@ impl JsonRpcMessage for SimpleRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -107,6 +107,39 @@ impl JsonRpcResponse for SimpleResponse { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleNotification { + message: String, +} + +impl JsonRpcMessage for SimpleNotification { + fn matches_method(method: &str) -> bool { + method == "simple_notification" + } + + fn method(&self) -> &'static str { + "simple_notification" + } + + fn to_untyped_message( + &self, + ) -> Result { + agent_client_protocol_core::UntypedMessage::new(self.method(), self) + } + + fn parse_message( + method: &str, + params: &impl serde::Serialize, + ) -> Result { + if !Self::matches_method(method) { + return Err(agent_client_protocol_core::Error::method_not_found()); + } + agent_client_protocol_core::util::json_cast_params(params) + } +} + +impl agent_client_protocol_core::JsonRpcNotification for SimpleNotification {} + // ============================================================================ // Test 1: Invalid JSON (complete line with parse error) // ============================================================================ @@ -282,7 +315,7 @@ impl JsonRpcMessage for ErrorRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -451,3 +484,299 @@ async fn test_missing_required_params() { }) .await; } + +// ============================================================================ +// Test 5: Invalid params returns error but connection stays alive (issue #131) +// ============================================================================ + +#[tokio::test(flavor = "current_thread")] +async fn test_invalid_params_keeps_connection_alive() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, mut client_reader) = tokio::io::duplex(4096); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + + // Register a handler for SimpleRequest (requires "message" field) + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + drop(server.connect_to(server_transport).await); + }); + + // 1) Send a request with WRONG params (missing "message" field) + let bad_request = + b"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"simple_method\",\"params\":{\"wrong_field\":\"hello\"}}\n"; + client_writer.write_all(bad_request).await.unwrap(); + client_writer.flush().await.unwrap(); + + // Read the error response + let mut buffer = vec![0u8; 4096]; + let n = client_reader.read(&mut buffer).await.unwrap(); + let response_str = String::from_utf8_lossy(&buffer[..n]); + let response: serde_json::Value = + serde_json::from_str(response_str.trim()).expect("Response should be valid JSON"); + + // Verify it's an error response with the correct id and error code + assert_eq!(response["id"], 1); + assert!(response["error"].is_object(), "Expected error object"); + assert_eq!( + response["error"]["code"], -32602, + "Expected invalid params (-32602)" + ); + + // 2) Send a VALID request to prove the connection is still alive + let good_request = + b"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"simple_method\",\"params\":{\"message\":\"hello\"}}\n"; + client_writer.write_all(good_request).await.unwrap(); + client_writer.flush().await.unwrap(); + + // Read the success response + let n = client_reader.read(&mut buffer).await.unwrap(); + let response_str = String::from_utf8_lossy(&buffer[..n]); + let response: serde_json::Value = + serde_json::from_str(response_str.trim()).expect("Response should be valid JSON"); + + // Verify it's a success response + assert_eq!(response["id"], 2); + assert_eq!(response["result"]["result"], "echo: hello"); + }) + .await; +} + +// ============================================================================ +// Helpers for raw-wire tests +// ============================================================================ + +async fn read_jsonrpc_response_line( + reader: &mut tokio::io::BufReader, +) -> serde_json::Value { + use tokio::io::AsyncBufReadExt as _; + + let mut line = String::new(); + match tokio::time::timeout( + tokio::time::Duration::from_secs(1), + reader.read_line(&mut line), + ) + .await + { + Ok(Ok(0)) | Err(_) => panic!("timed out waiting for JSON-RPC response"), + Ok(Ok(_)) => serde_json::from_str(line.trim()).expect("response should be valid JSON"), + Ok(Err(e)) => panic!("failed to read JSON-RPC response line: {e}"), + } +} + +// ============================================================================ +// Test 6: Bad request params returns -32602 and connection stays alive (from Ben's branch) +// ============================================================================ + +#[tokio::test(flavor = "current_thread")] +async fn test_bad_request_params_return_invalid_params_and_connection_stays_alive() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(2048); + let (server_writer, client_reader) = tokio::io::duplex(2048); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(err) = server.connect_to(server_transport).await { + panic!("server should stay alive: {err:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":3,"method":"simple_method","params":{"content":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let invalid_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "error": { + "code": -32602, + "data": { + "error": "missing field `message`", + "json": { + "content": "hello" + }, + "phase": "deserialization" + }, + "message": "Invalid params" + }, + "id": 3, + "jsonrpc": "2.0" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&invalid_response).unwrap()); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":4,"method":"simple_method","params":{"message":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let ok_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 4, + "jsonrpc": "2.0", + "result": { + "result": "echo: hello" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap()); + }) + .await; +} + +// ============================================================================ +// Test 7: Bad notification params (from Ben's branch) +// ============================================================================ + +#[tokio::test(flavor = "current_thread")] +async fn test_bad_notification_params_send_error_notification_and_connection_stays_alive() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(2048); + let (server_writer, client_reader) = tokio::io::duplex(2048); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_notification( + async |_notif: SimpleNotification, + _connection: ConnectionTo| { + // If we get here, the notification parsed successfully. + Ok(()) + }, + agent_client_protocol_core::on_receive_notification!(), + ) + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(err) = server.connect_to(server_transport).await { + panic!("server should stay alive: {err:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + // Send a notification with bad params (wrong field name). + // Notifications have no "id", so the server sends an error + // notification (id: null) and keeps the connection alive. + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"simple_notification","params":{"wrong_field":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + // The server sends an error notification (id: null) for the + // malformed notification. + let error_notification = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "error": { + "code": -32602, + "data": { + "error": "missing field `message`", + "json": { + "wrong_field": "hello" + }, + "phase": "deserialization" + }, + "message": "Invalid params" + }, + "jsonrpc": "2.0" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&error_notification).unwrap()); + + // Now send a valid request to prove the connection is still alive. + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":10,"method":"simple_method","params":{"message":"after bad notification"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let ok_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 10, + "jsonrpc": "2.0", + "result": { + "result": "echo: after bad notification" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap()); + }) + .await; +} diff --git a/src/agent-client-protocol-derive/src/lib.rs b/src/agent-client-protocol-derive/src/lib.rs index 58814df..9275858 100644 --- a/src/agent-client-protocol-derive/src/lib.rs +++ b/src/agent-client-protocol-derive/src/lib.rs @@ -89,7 +89,7 @@ pub fn derive_json_rpc_request(input: TokenStream) -> TokenStream { if method != #method { return Err(#krate::Error::method_not_found()); } - #krate::util::json_cast(params) + #krate::util::json_cast_params(params) } } @@ -149,7 +149,7 @@ pub fn derive_json_rpc_notification(input: TokenStream) -> TokenStream { if method != #method { return Err(#krate::Error::method_not_found()); } - #krate::util::json_cast(params) + #krate::util::json_cast_params(params) } } diff --git a/src/agent-client-protocol-test/src/lib.rs b/src/agent-client-protocol-test/src/lib.rs index 843db92..8833d68 100644 --- a/src/agent-client-protocol-test/src/lib.rs +++ b/src/agent-client-protocol-test/src/lib.rs @@ -129,7 +129,7 @@ macro_rules! impl_jr_message { if !Self::matches_method(method) { return Err(crate::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } };