Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""add_agent_protocol_field

Revision ID: a064de6df78e
Revises: 4a9b7787ccd7
Create Date: 2026-04-07 15:38:00.000000

"""
from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "a064de6df78e"
down_revision: Union[str, None] = "4a9b7787ccd7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column(
"agents",
sa.Column("protocol", sa.String(), nullable=False, server_default="acp"),
)


def downgrade() -> None:
op.drop_column("agents", "protocol")
1 change: 1 addition & 0 deletions agentex/src/adapters/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class AgentORM(BaseORM):
acp_url = Column(String, nullable=True) # URL of the agent's ACP server
# TODO: make this a SQLAlchemyEnum rather than a string
acp_type = Column(String, nullable=False, server_default="async")
protocol = Column(String, nullable=False, server_default="acp")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
Expand Down
1 change: 1 addition & 0 deletions agentex/src/api/routes/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ async def register_agent(
acp_type=request.acp_type,
registration_metadata=request.registration_metadata,
agent_input_type=request.agent_input_type,
protocol=request.protocol,
)
await authorization_service.grant(
AgentexResource.agent(agent_entity.id),
Expand Down
16 changes: 10 additions & 6 deletions agentex/src/api/routes/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from fastapi import APIRouter, Response

from src.api.schemas.authorization_types import (
AgentexResourceType,
AuthorizedOperationType,
)
from src.api.schemas.checkpoints import (
BlobResponse,
CheckpointListItem,
Expand All @@ -14,10 +18,6 @@
PutWritesRequest,
WriteResponse,
)
from src.api.schemas.authorization_types import (
AgentexResourceType,
AuthorizedOperationType,
)
from src.domain.use_cases.checkpoints_use_case import DCheckpointsUseCase
from src.utils.authorization_shortcuts import DAuthorizedBodyId
from src.utils.logging import make_logger
Expand Down Expand Up @@ -95,7 +95,9 @@ async def put_checkpoint(
request: PutCheckpointRequest,
checkpoints_use_case: DCheckpointsUseCase,
_authorized_task_id: DAuthorizedBodyId(
AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id"
AgentexResourceType.task,
AuthorizedOperationType.execute,
field_name="thread_id",
),
) -> PutCheckpointResponse:
blobs = [
Expand Down Expand Up @@ -133,7 +135,9 @@ async def put_writes(
request: PutWritesRequest,
checkpoints_use_case: DCheckpointsUseCase,
_authorized_task_id: DAuthorizedBodyId(
AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id"
AgentexResourceType.task,
AuthorizedOperationType.execute,
field_name="thread_id",
),
) -> Response:
writes = [
Expand Down
4 changes: 2 additions & 2 deletions agentex/src/api/routes/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def _handle_deployment_sync_rpc(
method=request.method,
params=request.params,
request_headers=request_headers,
acp_url_override=acp_url,
service_url_override=acp_url,
)

if isinstance(result_entity, AsyncIterator):
Expand Down Expand Up @@ -231,7 +231,7 @@ async def rpc_response_generator():
method=request.method,
params=request.params,
request_headers=request_headers,
acp_url_override=acp_url,
service_url_override=acp_url,
)

if not isinstance(result_entity_async_iterator, AsyncIterator):
Expand Down
12 changes: 12 additions & 0 deletions agentex/src/api/schemas/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class AgentInputType(str, Enum):
JSON = "json"


class AgentProtocol(str, Enum):
ACP = "acp"


class Agent(BaseModel):
id: str = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The unique name of the agent.")
Expand Down Expand Up @@ -66,6 +70,10 @@ class Agent(BaseModel):
agent_input_type: AgentInputType | None = Field(
default=None, description="The type of input the agent expects."
)
protocol: AgentProtocol = Field(
AgentProtocol.ACP,
description="The communication protocol used by this agent.",
)
production_deployment_id: str | None = Field(
default=None, description="ID of the current production deployment."
)
Expand Down Expand Up @@ -95,6 +103,10 @@ class RegisterAgentRequest(BaseModel):
agent_input_type: AgentInputType | None = Field(
default=None, description="The type of input the agent expects."
)
protocol: AgentProtocol = Field(
AgentProtocol.ACP,
description="The communication protocol used by this agent.",
)


class RegisterAgentResponse(Agent):
Expand Down
8 changes: 8 additions & 0 deletions agentex/src/domain/entities/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class AgentInputType(str, Enum):
JSON = "json"


class AgentProtocol(str, Enum):
ACP = "acp"


class AgentEntity(BaseModel):
id: str = Field(..., description="The unique identifier of the agent.")
docker_image: str | None = Field(
Expand Down Expand Up @@ -65,6 +69,10 @@ class AgentEntity(BaseModel):
agent_input_type: AgentInputType | None = Field(
None, description="The type of input the agent expects."
)
protocol: AgentProtocol = Field(
AgentProtocol.ACP,
description="The communication protocol used by this agent.",
)
production_deployment_id: str | None = Field(
None, description="ID of the current production deployment."
)
2 changes: 2 additions & 0 deletions agentex/src/domain/entities/agents_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class CancelTaskParams(BaseModel):
task: TaskEntity = Field(..., description="The task that was cancelled")


# Deprecated: canonical source is AgentACPService.get_allowed_methods().
# Kept for backward compatibility with existing tests.
ACP_TYPE_TO_ALLOWED_RPC_METHODS = {
ACPType.SYNC: [AgentRPCMethod.MESSAGE_SEND, AgentRPCMethod.TASK_CREATE],
ACPType.AGENTIC: [
Expand Down
4 changes: 1 addition & 3 deletions agentex/src/domain/repositories/checkpoint_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,7 @@ async def list_checkpoints(
self.async_ro_session_maker() as session,
async_sql_exception_handler(),
):
query = select(CheckpointORM).where(
CheckpointORM.thread_id == thread_id
)
query = select(CheckpointORM).where(CheckpointORM.thread_id == thread_id)

if checkpoint_ns is not None:
query = query.where(CheckpointORM.checkpoint_ns == checkpoint_ns)
Expand Down
75 changes: 62 additions & 13 deletions agentex/src/domain/services/agent_acp_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel

from src.adapters.http.adapter_httpx import DHttpxGateway
from src.domain.entities.agents import AgentEntity
from src.domain.entities.agents import ACPType, AgentEntity
from src.domain.entities.agents_rpc import (
AgentRPCMethod,
CancelTaskParams,
Expand Down Expand Up @@ -39,6 +39,7 @@
from src.domain.mixins.task_messages.task_message_mixin import TaskMessageMixin
from src.domain.repositories.agent_api_key_repository import DAgentAPIKeyRepository
from src.domain.repositories.agent_repository import DAgentRepository
from src.domain.services.agent_protocol_gateway import AgentProtocolGateway
from src.utils.logging import ctx_var_request_id, make_logger

logger = make_logger(__name__)
Expand Down Expand Up @@ -107,7 +108,7 @@ def filter_request_headers(headers: dict[str, str] | None) -> dict[str, str]:
}


class AgentACPService(TaskMessageMixin):
class AgentACPService(AgentProtocolGateway, TaskMessageMixin):
"""
Client service for communicating with downstream ACP servers.
Handles JSON-RPC 2.0 communication with agent ACP servers.
Expand Down Expand Up @@ -279,7 +280,7 @@ async def create_task(
self,
agent: AgentEntity,
task: TaskEntity,
acp_url: str,
service_url: str,
params: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a new task"""
Expand All @@ -290,7 +291,7 @@ async def create_task(
)
headers = await self.get_headers(agent)
return await self._call_jsonrpc(
url=acp_url,
url=service_url,
method=AgentRPCMethod.TASK_CREATE,
params=params,
request_id=f"{AgentRPCMethod.TASK_CREATE}-{task.id}", # Use create-specific request ID
Expand All @@ -302,7 +303,7 @@ async def send_message(
agent: AgentEntity,
task: TaskEntity,
content: TaskMessageContentEntity,
acp_url: str,
service_url: str,
) -> TaskMessageContentEntity:
"""Send a message to a running task"""
params = SendMessageParams(
Expand All @@ -319,7 +320,7 @@ async def send_message(
f"Agent {agent.id} already processing message send for task {task.id}"
)
result = await self._call_jsonrpc(
url=acp_url,
url=service_url,
method=AgentRPCMethod.MESSAGE_SEND,
params=params,
request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}", # Use message-specific request ID
Expand All @@ -333,7 +334,7 @@ async def send_message_stream(
agent: AgentEntity,
task: TaskEntity,
content: TaskMessageContentEntity,
acp_url: str,
service_url: str,
) -> AsyncIterator[TaskMessageUpdateEntity]:
"""Send a message to a running task and stream the response"""
params = SendMessageParams(
Expand All @@ -357,7 +358,7 @@ async def send_message_stream(
f"Agent {agent.id} already processing message send for task {task.id}"
)
async for chunk in self._call_jsonrpc_stream(
url=acp_url,
url=service_url,
method=AgentRPCMethod.MESSAGE_SEND,
params=params,
request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}",
Expand All @@ -366,7 +367,7 @@ async def send_message_stream(
yield self._parse_task_message_update(chunk)
else:
async for chunk in self._call_jsonrpc_stream(
url=acp_url,
url=service_url,
method=AgentRPCMethod.MESSAGE_SEND,
params=params,
request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}",
Expand All @@ -375,13 +376,13 @@ async def send_message_stream(
yield self._parse_task_message_update(chunk)

async def cancel_task(
self, agent: AgentEntity, task: TaskEntity, acp_url: str
self, agent: AgentEntity, task: TaskEntity, service_url: str
) -> dict[str, Any]:
"""Cancel a running task"""
params = CancelTaskParams(agent=agent, task=task)
headers = await self.get_headers(agent)
return await self._call_jsonrpc(
url=acp_url,
url=service_url,
method=AgentRPCMethod.TASK_CANCEL,
params=params,
request_id=f"{AgentRPCMethod.TASK_CANCEL}-{task.id}", # Use cancel-specific request ID
Expand All @@ -393,7 +394,7 @@ async def send_event(
agent: AgentEntity,
event: EventEntity,
task: TaskEntity,
acp_url: str,
service_url: str,
request_headers: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Send an event to a running task"""
Expand All @@ -418,12 +419,60 @@ async def send_event(
headers.update(auth_headers)

return await self._call_jsonrpc(
url=acp_url,
url=service_url,
method=AgentRPCMethod.EVENT_SEND,
params=params,
request_id=f"{AgentRPCMethod.EVENT_SEND}-{task.id}", # Use event-specific request ID
default_headers=headers,
)

async def check_health(
self,
agent_id: str,
service_url: str,
) -> bool:
"""Check if the agent server is healthy via its /healthz endpoint."""
try:
response = await self._http_gateway.async_call(
method="GET",
url=f"{service_url}/healthz",
timeout=5,
)
if response.get("status") != "healthy":
logger.error(
f"Agent {agent_id} returned non-healthy status: {response.get('status')}"
)
return False
response_agent_id = response.get("agent_id")
if response_agent_id and response_agent_id != agent_id:
logger.error(
f"Agent {agent_id} returned unexpected agent ID: {response_agent_id}"
)
return False
return True
except Exception as e:
logger.error(f"Failed to check health of agent {agent_id}: {e}")
return False

# ACP-specific: maps agent type to allowed RPC methods
ACP_ALLOWED_METHODS: dict[ACPType, list[AgentRPCMethod]] = {
ACPType.SYNC: [AgentRPCMethod.MESSAGE_SEND, AgentRPCMethod.TASK_CREATE],
ACPType.AGENTIC: [
AgentRPCMethod.TASK_CREATE,
AgentRPCMethod.TASK_CANCEL,
AgentRPCMethod.EVENT_SEND,
],
ACPType.ASYNC: [
AgentRPCMethod.TASK_CREATE,
AgentRPCMethod.TASK_CANCEL,
AgentRPCMethod.EVENT_SEND,
],
}

def get_allowed_methods(self, acp_type: ACPType) -> list[AgentRPCMethod]:
"""Return the list of RPC methods allowed for the given ACP type."""
return self.ACP_ALLOWED_METHODS.get(acp_type, [])


DAgentACPService = Annotated[AgentACPService, Depends(AgentACPService)]
DAgentProtocolGateway = Annotated[AgentProtocolGateway, Depends(AgentACPService)]
Loading
Loading