Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/acp/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self._tasks.add_error_handler(self._on_task_error)
self._queue = queue or InMemoryMessageQueue()
self._closed = False
self._disconnected = False
self._sender = (sender_factory or self._default_sender_factory)(self._writer, self._tasks)
if listening:
self._recv_task = self._tasks.create(
Expand Down Expand Up @@ -132,6 +133,7 @@ def add_observer(self, observer: StreamObserver) -> None:
self._observers.append(observer)

async def send_request(self, method: str, params: JsonValue | None = None) -> Any:
self._raise_if_unavailable()
request_id = self._next_request_id
self._next_request_id += 1
future = self._state.register_outgoing(request_id, method)
Expand All @@ -141,6 +143,7 @@ async def send_request(self, method: str, params: JsonValue | None = None) -> An
return await future

async def send_notification(self, method: str, params: JsonValue | None = None) -> None:
self._raise_if_unavailable()
payload = {"jsonrpc": "2.0", "method": method, "params": params}
await self._sender.send(payload)
self._notify_observers(StreamDirection.OUTGOING, payload)
Expand All @@ -160,6 +163,7 @@ async def _receive_loop(self) -> None:
await self._process_message(message)
except asyncio.CancelledError:
return
self._disconnect()

async def _process_message(self, message: dict[str, Any]) -> None:
method = message.get("method")
Expand Down Expand Up @@ -262,7 +266,7 @@ async def _handle_response(self, message: dict[str, Any]) -> None:

def _on_receive_error(self, task: asyncio.Task[Any], exc: BaseException) -> None:
logging.exception("Receive loop failed", exc_info=exc)
self._state.reject_all_outgoing(exc)
self._disconnect()

def _on_task_error(self, task: asyncio.Task[Any], exc: BaseException) -> None:
logging.exception("Background task failed", exc_info=exc)
Expand All @@ -285,3 +289,13 @@ def _default_dispatcher_factory(

def _default_sender_factory(self, writer: asyncio.StreamWriter, supervisor: TaskSupervisor) -> MessageSender:
return MessageSender(writer, supervisor)

def _disconnect(self) -> None:
if self._disconnected:
return
self._disconnected = True
self._state.reject_all_outgoing(ConnectionError("Connection closed"))

def _raise_if_unavailable(self) -> None:
if self._disconnected or self._closed:
raise ConnectionError("Connection closed")
30 changes: 30 additions & 0 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
update_agent_message_text,
update_tool_call,
)
from acp.connection import Connection
from acp.core import AgentSideConnection, ClientSideConnection
from acp.schema import (
AgentMessageChunk,
Expand Down Expand Up @@ -199,6 +200,35 @@ async def read_one(i: int):
assert res.content == f"Content {i}"


@pytest.mark.asyncio
async def test_pending_request_fails_when_remote_sends_eof(server):
conn = Connection(lambda method, params, is_notification: None, server.client_writer, server.client_reader)
request = asyncio.create_task(conn.send_request("ping", {"value": 1}))

await asyncio.sleep(0.05)
server.server_writer.close()
await server.server_writer.wait_closed()

with pytest.raises(ConnectionError, match="Connection closed"):
await asyncio.wait_for(request, timeout=1.0)

await conn.close()


@pytest.mark.asyncio
async def test_new_requests_fail_fast_after_remote_eof(server):
conn = Connection(lambda method, params, is_notification: None, server.client_writer, server.client_reader)

server.server_writer.close()
await server.server_writer.wait_closed()
await asyncio.sleep(0.05)

with pytest.raises(ConnectionError, match="Connection closed"):
await asyncio.wait_for(conn.send_request("ping", {"value": 1}), timeout=1.0)

await conn.close()


@pytest.mark.asyncio
async def test_invalid_params_results_in_error_response(connect, server):
# Only start agent-side (server) so we can inject raw request from client socket
Expand Down
Loading