diff --git a/README.md b/README.md index 9d94cc5..aed6bc4 100644 --- a/README.md +++ b/README.md @@ -27,11 +27,11 @@ flowchart TD PeerA2A --> PeerRuntime end - External -->|message/send,\nmessage:stream| Ingress + External -->|SendMessage,\nSendStreamingMessage| Ingress Ingress -->|tool call| OpenCode OpenCode -->|model/tool result events| Ingress Ingress -->|a2a_call| Outbound - Outbound -->|message/send,\nmessage:stream| PeerA2A + Outbound -->|SendMessage,\nSendStreamingMessage| PeerA2A PeerA2A -->|tool result| Outbound PeerRuntime -->|task session\nexecution| PeerA2A ``` @@ -94,11 +94,10 @@ curl http://127.0.0.1:8000/.well-known/agent-card.json ## A2A Protocol Support -- Default protocol line: `0.3` -- Declared supported protocol lines: `0.3`, `1.0` -- `0.3` is the stable interoperability baseline for the current runtime surface. -- `1.0` currently covers version negotiation plus protocol-aware JSON-RPC and REST error shaping, while transport payloads, enums, pagination, signatures, and interface-level protocol declarations still follow the shipped SDK baseline. -- The detailed compatibility matrix and machine-readable support boundary are documented in [`docs/guide.md`](docs/guide.md) and [`docs/compatibility.md`](docs/compatibility.md). +- Supported A2A protocol line: `1.0` +- The runtime is now v1-only across HTTP+JSON, JSON-RPC, Agent Card discovery, and protocol-aware error contracts. +- Legacy `0.3` method aliases and payload shapes are rejected instead of being normalized at runtime. +- The detailed runtime contract and machine-readable support boundary are documented in [`docs/guide.md`](docs/guide.md) and [`docs/compatibility.md`](docs/compatibility.md). ## Peering Node / Outbound Access diff --git a/docs/compatibility.md b/docs/compatibility.md index 38ef7f0..f09ef53 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -1,15 +1,14 @@ # Compatibility Guide -This document explains the compatibility promises `opencode-a2a` currently tries to uphold for A2A consumers, operators, and maintainers. +This document defines the compatibility promises `opencode-a2a` currently upholds for A2A consumers, operators, and maintainers. ## Runtime Support - Python versions: 3.11, 3.12, 3.13 -- A2A SDK line: `0.3.x` -- Default advertised protocol line: `0.3` -- Declared supported protocol lines: `0.3`, `1.0` +- A2A SDK line: `1.x.y` +- Supported A2A protocol line: `1.0` -The repository pins the SDK version in `pyproject.toml`. Upgrade the SDK deliberately rather than relying on floating dependency resolution. +The repository currently pins one concrete SDK release in `pyproject.toml` within that v1 line. Upgrade the SDK deliberately rather than relying on floating dependency resolution. ## Contract Honesty @@ -21,11 +20,11 @@ Machine-readable discovery surfaces must reflect actual runtime behavior: - JSON-RPC wire contract - compatibility profile -If runtime support is not actually implemented, do not publish it as a supported machine-readable capability. +If runtime support is not implemented, do not publish it as a supported machine-readable capability. Consumer guidance: -- Treat the core A2A send / stream / task methods as the portable baseline. +- Treat the v1 core A2A methods (`SendMessage`, `SendStreamingMessage`, `GetTask`, `CancelTask`, `SubscribeToTask`) as the portable baseline. - Treat `urn:a2a:*` entries in this repository as shared repo-family conventions, not as a claim that they are part of the A2A core baseline. - Treat `opencode.*` methods and `metadata.opencode.*` fields as provider-private OpenCode control and discovery surfaces layered on top of the portable A2A baseline. - Treat [extension-specifications.md](./extension-specifications.md) as the stable URI/spec index, not as the main usage guide. @@ -45,7 +44,8 @@ External TCK runs and local conformance experiments are investigation inputs. Th This repository still ships as an alpha project. Within that alpha line, these declared surfaces should not drift silently: - core A2A send / stream / task methods -- version negotiation and protocol-aware error shaping +- v1-only request/response payloads and enum values +- v1 protocol-aware JSON-RPC and REST error shaping - shared session-binding metadata - shared model-selection metadata - shared streaming metadata @@ -56,9 +56,9 @@ Changes to those surfaces should be treated as compatibility-sensitive and shoul Service-level behavior layered on top of those core methods should also be declared explicitly when interoperability depends on it. Current examples: -- `tasks/resubscribe` replay-once behavior for terminal updates +- `SubscribeToTask` replay-once behavior for terminal updates - first-terminal-state-wins task persistence policy -- task-scoped `acceptedOutputModes` negotiation persistence across send / stream / get / resubscribe +- task-scoped `acceptedOutputModes` negotiation persistence across send / stream / get / subscribe - request-body rejection behavior for oversized transport payloads ## Deployment Profile @@ -98,7 +98,7 @@ The default SQLite-first profile is intended for local or controlled single-inst - `opencode.sessions.shell` is compatibility-sensitive as a deployment-conditional shell snapshot surface. It should not silently widen into a general interactive shell API. - `opencode.workspaces.*` and `opencode.worktrees.*` are boundary-sensitive and should remain explicitly provider-private, operator-scoped, and deployment-conditional where applicable. - Interrupt callback and recovery methods are compatibility-sensitive because clients may depend on request ID lifecycle, expiry semantics, and identity scoping. -- Agent Card media modes and `acceptedOutputModes` handling are compatibility-sensitive. Changes to declared chat modes, to task-scoped negotiation persistence, or to `DataPart` -> `TextPart` downgrade behavior should be treated as wire-level changes. +- Agent Card media modes and `acceptedOutputModes` handling are compatibility-sensitive. Changes to declared chat modes, to task-scoped negotiation persistence, or to output filtering of structured tool payloads should be treated as wire-level changes. - Agent Card and OpenAPI publication of `protocol_compatibility`, `service_behaviors`, and runtime feature toggles is compatibility-sensitive discoverability surface. ## Extension Boundary Governance @@ -142,6 +142,6 @@ This repository does not currently promise: - hard multi-tenant isolation inside one instance - generic provider-auth orchestration on behalf of OpenCode -- a claim that all declared `1.0` protocol surfaces are fully implemented beyond the documented compatibility matrix +- compatibility with legacy A2A `0.3` method aliases or payload shapes Those areas may evolve later, but they should not be implied by current machine-readable discovery output. diff --git a/docs/conformance-triage.md b/docs/conformance-triage.md index c55544d..4a1811b 100644 --- a/docs/conformance-triage.md +++ b/docs/conformance-triage.md @@ -1,89 +1,32 @@ # External Conformance Triage -This document records the first local `./scripts/conformance.sh mandatory` run against the official `a2aproject/a2a-tck` using the repository's dummy-backed SUT. +This document summarizes the current interpretation rules for external TCK runs after the repository's A2A v1 migration. -## Standards Used For Triage +## Current Runtime Baseline -- `a2a-sdk==0.3.25` as installed in this repository: - - `AgentCard` uses `additionalInterfaces`, not `supportedInterfaces`. - - JSON-RPC request models use `message/send`, `tasks/get`, `tasks/cancel`, and `agent/getAuthenticatedExtendedCard`. - - The installed SDK does not expose a JSON-RPC `ListTasks` request model. -- A2A v0.3.0 specification: - - JSON-RPC methods use the `{category}/{action}` pattern such as `message/send` and `tasks/get`. - - Transport declarations use `preferredTransport` plus `additionalInterfaces`. - - The method mapping table lists `tasks/list` as gRPC/REST only. -- Repository compatibility policy: - - `A2A-Version` negotiation supports both `0.3` and `1.0`. - - Payloads still follow the shipped `0.3` SDK baseline. - - `1.0` compatibility is currently documented as partial rather than complete. +- `opencode-a2a` now targets the `a2a-sdk 1.x.y` line +- the runtime is v1-only +- canonical JSON-RPC core methods are `SendMessage`, `SendStreamingMessage`, `GetTask`, `CancelTask`, and `SubscribeToTask` +- legacy `0.3` aliases and payload shapes are intentionally rejected rather than normalized -## Classification Labels +## How To Read TCK Failures -- `TCK issue`: the failing expectation conflicts with `a2a-sdk==0.3.25` and the v0.3.0 baseline used by this repository. -- `TCK issue; also a repo v1.0 gap`: the exact failure is caused by a TCK mismatch, but the same area would still need extra work for stronger `1.0` compatibility. -- `TCK issue / local experiment artifact`: the failure comes from an aggressive heuristic or from local dummy-run characteristics and should not be treated as a runtime protocol bug. +When a TCK run fails, classify the result before changing the runtime: -## Per-Test Triage +- `Runtime gap` + - the failure reproduces against the current v1-only runtime and contradicts the repository's declared machine-readable contract +- `TCK assumption mismatch` + - the failure depends on method names, payload shapes, or schema expectations that do not match the current A2A v1 SDK/runtime contract +- `Local experiment artifact` + - the failure depends on dummy-backed local behavior, environment heuristics, or unrelated tooling/setup issues -- `tests/mandatory/authentication/test_auth_compliance_v030.py::test_security_scheme_structure_compliance`: `TCK issue`. The TCK expects each `securitySchemes` entry to be wrapped as `{httpAuthSecurityScheme: {...}}`, but `a2a-sdk==0.3.25` exposes the flattened OpenAPI-shaped object with fields like `type`, `scheme`, `description`, and `bearerFormat`. -- `tests/mandatory/authentication/test_auth_enforcement.py::test_authentication_scheme_consistency`: `TCK issue`. Same root cause as the previous test: the TCK validates a non-SDK wrapper shape instead of the installed SDK schema. -- `tests/mandatory/jsonrpc/test_a2a_error_codes_enhanced.py::test_push_notification_not_supported_error_32003_enhanced`: `TCK issue`. The failure is a TCK helper bug: `transport_create_task_push_notification_config()` is called with the wrong positional signature before the runtime behavior is even exercised. -- `tests/mandatory/jsonrpc/test_json_rpc_compliance.py::test_rejects_invalid_json_rpc_requests[invalid_request4--32602]`: `TCK issue`. The test sends JSON-RPC method `SendMessage`; under the v0.3.0 / SDK 0.3.25 baseline the correct method is `message/send`, so the runtime correctly returns `-32601` for an unknown method instead of `-32602`. -- `tests/mandatory/jsonrpc/test_json_rpc_compliance.py::test_rejects_invalid_params`: `TCK issue`. Same method-name mismatch as above; with the correct `message/send` method the runtime returns `-32602` for invalid parameters. -- `tests/mandatory/jsonrpc/test_protocol_violations.py::test_duplicate_request_ids`: `TCK issue`. The first request already fails because the TCK uses `SendMessage` instead of `message/send`, so the duplicate-ID assertion never reaches the actual duplicate-ID behavior. -- `tests/mandatory/protocol/test_a2a_v030_new_methods.py::TestMethodMappingCompliance::test_core_method_mapping_compliance`: `TCK issue; also a repo v1.0 gap`. The JSON-RPC client uses PascalCase methods (`SendMessage`, `GetTask`, `CancelTask`) that do not match the v0.3.0 JSON-RPC mapping, but the repository also does not currently provide PascalCase aliases even when `A2A-Version: 1.0` is negotiated. -- `tests/mandatory/protocol/test_message_send_method.py::test_message_send_valid_text`: `TCK issue; also a repo v1.0 gap`. The failing request uses `SendMessage` over JSON-RPC; the repository correctly supports `message/send` for the current SDK baseline, but not the PascalCase alias. -- `tests/mandatory/protocol/test_message_send_method.py::test_message_send_invalid_params`: `TCK issue; also a repo v1.0 gap`. Direct cause is the same PascalCase JSON-RPC method mismatch. -- `tests/mandatory/protocol/test_message_send_method.py::test_message_send_continue_task`: `TCK issue; also a repo v1.0 gap`. Direct cause is again `SendMessage` instead of `message/send`. -- `tests/mandatory/protocol/test_state_transitions.py::test_task_history_length`: `TCK issue; also a repo v1.0 gap`. Task creation fails only because the TCK uses `SendMessage` on JSON-RPC. -- `tests/mandatory/protocol/test_tasks_cancel_method.py::test_tasks_cancel_valid`: `TCK issue; also a repo v1.0 gap`. The fixture cannot create a task because the TCK uses `SendMessage`; the runtime's `tasks/cancel` behavior is not the direct failing cause in this run. -- `tests/mandatory/protocol/test_tasks_cancel_method.py::test_tasks_cancel_nonexistent`: `TCK issue; also a repo v1.0 gap`. The TCK calls JSON-RPC `CancelTask`; under the v0.3.0 baseline the method is `tasks/cancel`. With the correct method, the runtime returns `Task not found` / `-32001`. -- `tests/mandatory/protocol/test_tasks_get_method.py::test_tasks_get_valid`: `TCK issue; also a repo v1.0 gap`. The task-creation fixture fails first because the TCK uses `SendMessage`. -- `tests/mandatory/protocol/test_tasks_get_method.py::test_tasks_get_with_history_length`: `TCK issue; also a repo v1.0 gap`. Same fixture failure via `SendMessage`. -- `tests/mandatory/protocol/test_tasks_get_method.py::test_tasks_get_nonexistent`: `TCK issue; also a repo v1.0 gap`. The TCK calls JSON-RPC `GetTask`; under the v0.3.0 baseline the method is `tasks/get`. With the correct method, the runtime returns `Task not found` / `-32001`. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestBasicListing::test_list_all_tasks`: `TCK issue; also a repo v1.0 gap`. The test suite uses JSON-RPC `ListTasks`, which is outside the `a2a-sdk==0.3.25` JSON-RPC surface and outside the v0.3.0 JSON-RPC mapping. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestBasicListing::test_list_tasks_empty_when_none_exist`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestBasicListing::test_list_tasks_validates_required_fields`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestBasicListing::test_list_tasks_sorted_by_timestamp_descending`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestFiltering::test_filter_by_context_id`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestFiltering::test_filter_by_status`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestFiltering::test_filter_by_last_updated_after`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestFiltering::test_combined_filters`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestPagination::test_default_page_size`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestPagination::test_custom_page_size`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestPagination::test_page_token_navigation`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestPagination::test_last_page_detection`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestPagination::test_total_size_accuracy`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestHistoryLimiting::test_history_length_zero`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestHistoryLimiting::test_history_length_custom`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestHistoryLimiting::test_history_length_exceeds_actual`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestArtifactInclusion::test_artifacts_excluded_by_default`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestArtifactInclusion::test_artifacts_included_when_requested`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestEdgeCasesAndErrors::test_invalid_page_token_error`: `TCK issue; also a repo v1.0 gap`. The assertion expects JSON-RPC param validation on `ListTasks`, but the direct failure is still that `ListTasks` is not a supported JSON-RPC method in the current SDK baseline. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestEdgeCasesAndErrors::test_invalid_status_error`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestEdgeCasesAndErrors::test_negative_page_size_error`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestEdgeCasesAndErrors::test_zero_page_size_error`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestEdgeCasesAndErrors::test_out_of_range_page_size_error`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestEdgeCasesAndErrors::test_default_page_size_is_50`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestEdgeCasesAndErrors::test_negative_history_length_error`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/protocol/test_tasks_list_method.py::TestEdgeCasesAndErrors::test_invalid_timestamp_error`: `TCK issue; also a repo v1.0 gap`. Same JSON-RPC `ListTasks` mismatch. -- `tests/mandatory/security/test_agent_card_security.py::test_public_agent_card_access_control`: `TCK issue`. The TCK requires `supportedInterfaces`, but `a2a-sdk==0.3.25` and the v0.3.0 specification use `additionalInterfaces`. -- `tests/mandatory/security/test_agent_card_security.py::test_sensitive_information_protection`: `TCK issue / local experiment artifact`. The failure is driven by heuristic keyword scanning (`token`, `private`, `127.0.0.1`, non-standard port) against a local dummy-backed run. That is not a reliable indicator of protocol non-compliance. -- `tests/mandatory/security/test_agent_card_security.py::test_security_scheme_consistency`: `TCK issue`. Same schema mismatch as the earlier authentication tests: the TCK expects wrapped security scheme objects instead of the installed SDK shape. -- `tests/mandatory/transport/test_multi_transport_equivalence.py::test_message_sending_equivalence`: `TCK issue; also a repo v1.0 gap`. The transport client uses JSON-RPC `SendMessage`; under the v0.3.0 baseline the method is `message/send`, but stronger `1.0` compatibility would still require additional alias handling. -- `tests/mandatory/transport/test_multi_transport_equivalence.py::test_concurrent_operation_equivalence`: `TCK issue; also a repo v1.0 gap`. Same direct cause as the previous test: the JSON-RPC client sends `SendMessage`. +## Current Guidance -## Adjacent Repository Gaps Found During Triage +- Re-run conformance against the current runtime before using any historical triage note. +- Treat Agent Card, authenticated extended card, OpenAPI, and runtime tests as the repository's declared source of truth. +- Do not reopen removed `0.3` compatibility behavior just to satisfy an outdated TCK assumption. +- If a TCK gap is real, document it against the current v1 contract with the exact request/response payloads that failed. -These did not directly cause the exact failed node IDs above, but they are real repository-side gaps revealed during follow-up probes: +## Historical Note -- `A2A-Version: 1.0` still returns `-32601` for JSON-RPC `SendMessage` and `GetExtendedAgentCard`. That means current `1.0` support is still limited to negotiation and error-shaping rather than full method-surface compatibility. -- `GET /v1/tasks` currently returns `500 NotImplementedError` in a local probe, even though the route exists and repository docs describe the SDK-owned REST surface as including task listing. That behavior should be treated as a repository issue independent from the TCK's incorrect JSON-RPC `ListTasks` expectation. - -## Summary - -For the exact 47 failed/error cases in the first mandatory run: - -- No failure is a clean `a2a-sdk==0.3.25` / v0.3.0 conformance bug in the current runtime. -- Most failures come from TCK method/schema assumptions that do not match the shipped SDK baseline. -- Several failures also highlight future repository work if stronger `1.0` compatibility becomes a goal. +Earlier repository-local triage notes were written before the v1 migration and described a mixed `0.3` / partial `1.0` state. Those notes are no longer normative and were removed to avoid stale guidance. diff --git a/docs/conformance.md b/docs/conformance.md index 87dc3c6..b252469 100644 --- a/docs/conformance.md +++ b/docs/conformance.md @@ -59,8 +59,8 @@ Each run keeps the following artifacts in the selected output directory: When a TCK run fails, inspect the raw report before changing the runtime: - Some failures may point to real runtime gaps. -- Some failures may come from TCK assumptions that do not match `a2a-sdk==0.3.25`. -- Some failures may come from A2A v0.3 versus v1.0 naming or schema drift. +- Some failures may come from TCK assumptions that do not match the current `a2a-sdk 1.x.y` contract. +- Some failures may come from local dummy-backed experiment behavior rather than a wire-level runtime defect. The experiment is useful only if those categories stay separate during triage. diff --git a/docs/extension-specifications.md b/docs/extension-specifications.md index dc5e102..c99cabc 100644 --- a/docs/extension-specifications.md +++ b/docs/extension-specifications.md @@ -89,7 +89,7 @@ URI: `https://github.com/Intelligent-Internet/opencode-a2a/blob/main/docs/extens URI: `https://github.com/Intelligent-Internet/opencode-a2a/blob/main/docs/extension-specifications.md#a2a-compatibility-profile-v1` - Scope: compatibility profile describing core baselines, extension retention, and service behaviors -- Includes machine-readable protocol compatibility summary for the currently declared `0.3` / `1.0` support boundary +- Includes machine-readable protocol compatibility summary for the current v1-only runtime boundary - Public Agent Card: capability declaration only - Authenticated extended card: full compatibility profile payload - Transport: Agent Card extension params diff --git a/docs/guide.md b/docs/guide.md index db5b730..9dd70f5 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -7,15 +7,16 @@ This guide covers configuration, authentication, API behavior, streaming re-subs - The service supports both transports: - HTTP+JSON (REST endpoints such as `/v1/message:send`) - JSON-RPC (`POST /`) -- Agent Card keeps `preferredTransport=HTTP+JSON` and also exposes JSON-RPC in `additional_interfaces`. +- Agent Card exposes both HTTP+JSON and JSON-RPC through `supportedInterfaces`. - The public Agent Card is intentionally slimmed to the minimum discovery surface; per-extension disclosure policy is defined in [`extension-specifications.md`](./extension-specifications.md). - Detailed provider-private contracts are served through the authenticated extended card endpoint `/agent/authenticatedExtendedCard`. - Agent Card responses emit weak `ETag` and `Cache-Control`; clients should revalidate cached cards instead of repeatedly fetching full payloads. - Global HTTP gzip compression is enabled for eligible non-streaming HTTP responses larger than `A2A_HTTP_GZIP_MINIMUM_SIZE` bytes when clients send `Accept-Encoding: gzip`; the default threshold is `8192`, so the main benefit currently lands on larger responses such as the authenticated extended card. - The current A2A prose specification may refer to `AgentCard.capabilities.extendedAgentCard`, but the official JSON schema and SDK types use the top-level `supportsAuthenticatedExtendedCard` field. This service follows the shipped schema/SDK surface. - Payload schema is transport-specific and should not be mixed: - - REST send payload usually uses `message.content` and role values like `ROLE_USER` - - JSON-RPC `message/send` payload uses `params.message.parts` and role values `user` / `agent` + - REST and JSON-RPC both use v1 `message.parts` payloads and enum values such as `ROLE_USER` + - JSON-RPC uses canonical PascalCase core methods such as `SendMessage` and `SubscribeToTask` + - legacy `message.content`, lowercase roles, `{kind: ...}` wrappers, and `message/send` aliases are rejected ## Runtime Environment Variables @@ -69,7 +70,7 @@ Key variables to understand protocol behavior: - `opencode.worktrees.create` - `opencode.worktrees.remove` - `opencode.worktrees.reset` -- Runtime authentication also applies to `/health`; the public unauthenticated discovery surface remains `/.well-known/agent-card.json` and `/.well-known/agent.json`. +- Runtime authentication also applies to `/health`; the public unauthenticated discovery surface remains `/.well-known/agent-card.json`. - The authenticated extended card endpoint `/agent/authenticatedExtendedCard` accepts the same configured bearer/basic auth modes. - The same outbound client flags are also honored by the server-side embedded A2A client used for peer calls and `a2a_call` tool execution: - `A2A_CLIENT_TIMEOUT_SECONDS` @@ -94,11 +95,11 @@ Current client facade API: - `A2AClient.send()` / `A2AClient.send_message()` - `A2AClient.get_task()` - `A2AClient.cancel_task()` -- `A2AClient.resubscribe_task()` +- `A2AClient.subscribe_to_task()` Server-side outbound peer calls read outbound credentials from environment variables. Configure `A2A_CLIENT_BEARER_TOKEN` or `A2A_CLIENT_BASIC_AUTH` when the remote agent protects its runtime surface. CLI outbound calls follow the same environment-only model. -`A2AClient.send()` returns the latest response event and keeps the default stream-first behavior. If a peer returns a non-terminal task snapshot and expects follow-up `tasks/get` polling, enable the optional facade fallback with: +`A2AClient.send()` returns the latest response event and keeps the default stream-first behavior. If a peer returns a non-terminal task snapshot and expects follow-up `GetTask` polling, enable the optional facade fallback with: - `A2A_CLIENT_POLLING_FALLBACK_ENABLED=true` - `A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS` @@ -106,7 +107,7 @@ Server-side outbound peer calls read outbound credentials from environment varia - `A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER` - `A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS` -The fallback only applies to `send()`, keeps `send_message()` as a thin event stream wrapper, and stops polling once the task reaches a terminal state or a caller-intervention state such as `input-required` or `auth-required`. +The fallback only applies to `send()`, keeps `send_message()` and `subscribe_to_task()` as thin raw `StreamResponse` wrappers, and stops polling once the task reaches a terminal state or a caller-intervention state such as `input-required` or `auth-required`. Execution-boundary metadata is intentionally declarative deployment metadata: it is published through `RuntimeProfile`, Agent Card, OpenAPI, and `/health`, and should not be interpreted as a live per-request privilege snapshot or a runtime CLI self-inspection result. @@ -193,19 +194,19 @@ If one deployment works while another fails against the same upstream provider, ## Core Behavior -- The service forwards A2A `message:send` to OpenCode session/message calls. +- The service forwards A2A `SendMessage` / `SendStreamingMessage` traffic to OpenCode session/message calls. - Main chat requests may override the upstream model for one request through `metadata.shared.model`. - Provider/model catalog discovery is available through `opencode.providers.list` and `opencode.models.list`. - Main chat requests that explicitly send `configuration.acceptedOutputModes` must stay compatible with the declared chat output modes. - Current main chat requests must continue accepting `text/plain`; requests that only accept `application/json` or other incompatible modes are rejected before execution starts. - `application/json` is additive structured-output support for incremental `tool_call` payloads. It does not guarantee that ordinary assistant prose can always be losslessly represented as JSON, so consumers that expect normal chat text should keep accepting `text/plain`. - When a client accepts `text/plain` but not `application/json`, structured `tool_call` payloads are downgraded to compact JSON text instead of being silently dropped. -- Accepted output-mode negotiation is persisted as task-scoped metadata so later `tasks/get` and `tasks/resubscribe` reads keep the same filtered response contract as the original `message/send` or `message:stream` request. -- Main chat input supports structured A2A `parts` passthrough: - - `TextPart` is forwarded as an OpenCode text part. - - `FilePart(FileWithBytes)` is forwarded as a `file` part with a `data:` URL. - - `FilePart(FileWithUri)` is forwarded as a `file` part with the original URI. - - `DataPart` is currently rejected explicitly; it is not silently downgraded. +- Accepted output-mode negotiation is persisted as task-scoped metadata so later `GetTask` and `SubscribeToTask` reads keep the same filtered response contract as the original send/stream request. +- Main chat input supports v1 `message.parts[]` passthrough: + - `{"text": ...}` is forwarded as an OpenCode text part. + - `{"raw": ..., "mediaType": ..., "filename": ...}` is forwarded as a `file` part with a `data:` URL. + - `{"url": ..., "mediaType": ..., "filename": ...}` is forwarded as a `file` part with the original URI. + - structured data-only input parts are rejected explicitly; they are not silently downgraded. - Task state defaults to `completed` for successful turns. - The deployment profile is single-tenant and shared-workspace. For detailed isolation principles and security boundaries, see [SECURITY.md](../SECURITY.md). @@ -220,7 +221,7 @@ If one deployment works while another fails against the same upstream provider, - A final snapshot is emitted only when streaming chunks did not already produce the same final text. - Stream routing is schema-first: the service classifies chunks primarily by OpenCode `part.type` and `part_id` state rather than inline text markers. - `message.part.delta` and `message.part.updated` are merged per `part_id`; out-of-order deltas are buffered and replayed when the corresponding `part.updated` arrives. -- Structured `tool` parts are emitted as `tool_call` blocks backed by `DataPart(data={...})`, while `text` and `reasoning` continue to use `TextPart`. +- Structured `tool` parts are emitted as `tool_call` blocks using structured v1 part payloads, while `text` and `reasoning` continue to use text parts. - `tool_call` block payloads are normalized structured objects that may expose fields such as `call_id`, `tool`, `status`, `title`, `subtitle`, `input`, `output`, and `error`. - If `application/json` is not accepted but `text/plain` is still accepted, those `tool_call` blocks are downgraded to stable compact JSON text so text-only clients retain the same observable state transitions. - When a request restricts `acceptedOutputModes`, the stream applies the same output filtering before persistence so later task snapshots do not re-expose filtered structured blocks. @@ -280,38 +281,38 @@ Unsupported method contract: - Error data fields: - `type=METHOD_NOT_SUPPORTED` - `method` - - `supported_methods` - - `protocol_version` + - `supportedMethods` + - `protocolVersion` Consumer guidance: - Discover custom JSON-RPC methods from Agent Card / OpenAPI before calling them. -- Treat `supported_methods` in `error.data` as the runtime truth for the current deployment, especially when a deployment-conditional method is disabled. +- Treat `supportedMethods` in `error.data` as the runtime truth for the current deployment, especially when a deployment-conditional method is disabled. ## Protocol Version Negotiation - The runtime accepts `A2A-Version` from either the HTTP header or the query parameter of A2A transport requests. -- If both are omitted, the runtime falls back to the configured default protocol version. -- Current defaults declare `default_protocol_version=0.3` and `supported_protocol_versions=["0.3", "1.0"]`. +- If both are omitted, the runtime uses the fixed v1 protocol version `1.0`. +- Machine-readable discovery still declares `default_protocol_version=1.0` and `supported_protocol_versions=["1.0"]`, but those values are runtime constants rather than operator-configurable settings. - Unsupported or invalid versions are rejected before request routing: - JSON-RPC returns a unified `VERSION_NOT_SUPPORTED` error envelope. - REST returns HTTP `400` with the same contract fields. -- Error shaping now follows the negotiated major line: - - `0.3` keeps the existing legacy `error.data={...}` and flat REST error payloads. - - `1.0` keeps standard JSON-RPC error codes for standard failures, but moves A2A-specific JSON-RPC errors to `google.rpc.ErrorInfo`-style `error.data[]` details and REST errors to AIP-193 `error.details[]`. -- The current transport payloads still follow the SDK-owned request/response shapes; version negotiation is introduced first so later issues can evolve error and payload compatibility without scattering version checks across handlers. +- Error shaping follows the v1 contract: + - JSON-RPC keeps standard JSON-RPC error codes for standard failures and uses `google.rpc.ErrorInfo`-style `error.data[]` details for A2A-specific failures. + - REST uses AIP-193 style `error.details[]`. +- The runtime does not normalize legacy `0.3` method aliases or payload shapes. Current compatibility matrix: -| Area | `0.3` | `1.0` | Current note | -| --- | --- | --- | --- | -| Version negotiation | Supported | Supported | The runtime accepts `A2A-Version` and routes requests before handler dispatch. | -| Agent Card / interface version discovery | Default card protocol only | Partial | The service publishes `default_protocol_version` and `supported_protocol_versions`, but `AgentInterface.protocolVersion` cannot yet be declared with `a2a-sdk==0.3.25`. | -| Transport payloads and enums | Supported | Partial | Request/response payloads, enums, and schema details still follow the SDK-owned `0.3` baseline. | -| Error model | Supported | Partial | `0.3` keeps legacy `error.data={...}` / flat REST payloads; `1.0` uses protocol-aware JSON-RPC details and AIP-193-style REST errors. | -| Pagination and list semantics | Supported | Partial | Cursor/list behavior is stable, but the declared shape still follows the `0.3` SDK baseline. | -| Push notification surfaces | Unsupported | Unsupported | SDK-owned task push-notification routes are still exposed, but this runtime does not enable push sender/config-store support. REST routes return HTTP `501`, while JSON-RPC methods remain unsupported via SDK-owned error envelopes. | -| Signatures and authenticated data | Supported | Partial | Security schemes and authenticated extended card discovery follow the shipped SDK schema rather than a dedicated `1.0` compatibility layer. | +| Area | `1.0` | Current note | +| --- | --- | --- | +| Version negotiation | Supported | The runtime accepts `A2A-Version` and routes requests before handler dispatch. | +| Agent Card / interface version discovery | Supported | Agent Card publishes v1 `supportedInterfaces` entries for HTTP+JSON and JSON-RPC. | +| Transport payloads and enums | Supported | Request/response payloads, enums, and schema details follow the current SDK-owned v1 baseline. | +| Error model | Supported | JSON-RPC and REST both use the v1 protocol-aware error shapes. | +| Pagination and list semantics | Supported | Cursor/list behavior follows the current SDK baseline. | +| Push notification surfaces | Unsupported | SDK-owned task push-notification routes are still exposed, but this runtime does not enable push sender/config-store support. REST routes return HTTP `501`, while JSON-RPC methods remain unsupported via SDK-owned error envelopes. | +| Signatures and authenticated data | Supported | Security schemes and authenticated extended card discovery follow the shipped SDK schema. | ## Compatibility Profile @@ -330,8 +331,7 @@ Current profile shape: - `default_protocol_version` - `supported_protocol_versions` - `protocol_compatibility` - - `versions["0.3"].status=supported` - - `versions["1.0"].status=partial` + - `versions["1.0"].status=supported` - `versions[*].supported_features[]` - `versions[*].known_gaps[]` - Deployment semantics are declared under `deployment`: @@ -396,23 +396,19 @@ curl -sS http://127.0.0.1:8000/ \ -d '{ "jsonrpc": "2.0", "id": "req-1", - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "msg-multipart-1", - "role": "user", + "role": "ROLE_USER", "parts": [ { - "kind": "text", "text": "Please summarize this file." }, { - "kind": "file", - "file": { - "name": "report.pdf", - "mimeType": "application/pdf", - "uri": "file:///workspace/report.pdf" - } + "url": "file:///workspace/report.pdf", + "filename": "report.pdf", + "mediaType": "application/pdf" } ] } @@ -420,10 +416,10 @@ curl -sS http://127.0.0.1:8000/ \ }' ``` -Current compatibility note: +Current input note: -- `TextPart` and `FilePart` are supported. -- `DataPart` input is not supported and is rejected with an explicit error. +- text parts and file/url/raw parts are supported. +- structured data-only input parts are not supported and are rejected with an explicit error. ## Extension Capability Overview @@ -574,7 +570,7 @@ Consumer guidance: Minimal stream semantics summary: - `text`, `reasoning`, and `tool_call` are emitted as canonical block types -- `text` and `reasoning` blocks use `TextPart`, while `tool_call` uses `DataPart` +- `text` and `reasoning` blocks use text parts, while `tool_call` uses structured v1 part payloads - `message_id` and `event_id` preserve stable timeline identity where possible - `sequence` is the per-request canonical stream sequence - final task/status metadata may repeat normalized usage and interrupt context even after the streaming phase ends @@ -586,7 +582,7 @@ This service exposes OpenCode session read, mutation, and control methods via A2 - Trigger: call extension methods through A2A JSON-RPC - Auth: same runtime auth as the main endpoint (`Bearer` or configured `Basic`) - Privacy guard: when `A2A_LOG_PAYLOADS=true`, request/response bodies are still suppressed for `method=opencode.sessions.*` -- Endpoint discovery: prefer `additional_interfaces[]` with `transport=jsonrpc` from Agent Card +- Endpoint discovery: prefer `supportedInterfaces[]` with `protocolBinding=JSONRPC` from Agent Card - The runtime still delegates SDK-owned JSON-RPC methods such as `agent/getAuthenticatedExtendedCard` and `tasks/pushNotificationConfig/*` to the base A2A implementation; they are not OpenCode-specific extensions. - Push notification config methods remain effectively unsupported in the current runtime because no push config store or push sender is configured; REST routes return HTTP `501`, while JSON-RPC methods stay on SDK-owned unsupported error handling. - Notification behavior: for `opencode.sessions.*`, requests without `id` return HTTP `204 No Content` @@ -595,8 +591,8 @@ This service exposes OpenCode session read, mutation, and control methods via A2 - `opencode.sessions.list` / `opencode.sessions.children` => A2A `Task[]` - `opencode.sessions.get` => A2A `Task` - `opencode.sessions.todo` / `opencode.sessions.diff` => provider-private summaries in `result.items` - - `opencode.sessions.messages.list` => A2A `Message[]` - - `opencode.sessions.messages.get` => A2A `Message` + - `opencode.sessions.messages.list` => adapter-normalized A2A `Message` projections + - `opencode.sessions.messages.get` => adapter-normalized A2A `Message` projection - `opencode.sessions.fork` / `opencode.sessions.share` / `opencode.sessions.unshare` => provider-private session summary in `result.item` - `opencode.sessions.summarize` => provider-private completion result in `result.ok` plus `result.session_id` - `opencode.sessions.revert` / `opencode.sessions.unrevert` => provider-private session summary in `result.item` @@ -678,7 +674,9 @@ curl -sS http://127.0.0.1:8000/ \ Message history responses include: -- `result.items`: normalized A2A `Message[]` +- `result.items`: adapter-normalized A2A `Message[]` +- `role`: canonical v1 enum values such as `ROLE_USER` / `ROLE_AGENT` +- `parts`: current projection is text-focused; text parts are aggregated into a single `Part(text=...)` rather than preserving arbitrary upstream part structure - `result.next_cursor`: opaque cursor for the next older page, or `null` when no older page is available ### Session Get / Children / Todo / Diff / Message Get @@ -687,7 +685,7 @@ Message history responses include: - `opencode.sessions.children` => read child sessions and map them to A2A `Task[]` - `opencode.sessions.todo` => read provider-private todo summaries - `opencode.sessions.diff` => read provider-private diff summaries; optional `message_id` -- `opencode.sessions.messages.get` => read one message and map it to A2A `Message` +- `opencode.sessions.messages.get` => read one message and map it to the same adapter-normalized A2A `Message` projection Example (`opencode.sessions.messages.get`): @@ -1224,28 +1222,28 @@ curl -sS http://127.0.0.1:8000/ \ -d '{ "jsonrpc": "2.0", "id": 101, - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "msg-1", - "role": "user", - "parts": [{"kind": "text", "text": "Explain what this repository does."}] + "role": "ROLE_USER", + "parts": [{"text": "Explain what this repository does."}] } } }' ``` -## Streaming Re-Subscription (`subscribe`) +## Streaming Re-Subscription (`SubscribeToTask`) If an SSE connection drops, use `GET /v1/tasks/{task_id}:subscribe` to re-subscribe while the task is still non-terminal. -## Cancellation Semantics (`tasks/cancel`) +## Cancellation Semantics (`CancelTask`) - The service first marks the A2A task as `canceled` and keeps cancel requests responsive. - For running tasks, the service attempts upstream OpenCode `POST /session/{sessionID}/abort` to stop generation. -- Upstream interruption is best-effort: if upstream returns 404, network errors, or other HTTP errors, A2A cancellation still completes with `TaskState.canceled`. -- Idempotency contract: repeated `tasks/cancel` on an already `canceled` task returns the current terminal task state without error. -- Terminal subscribe contract: calling `subscribe` on a terminal task replays one terminal `Task` snapshot and then closes the stream. +- Upstream interruption is best-effort: if upstream returns 404, network errors, or other HTTP errors, A2A cancellation still completes with `TaskState.TASK_STATE_CANCELED`. +- Idempotency contract: repeated `CancelTask` on an already `canceled` task returns the current terminal task state without error. +- Terminal subscribe contract: calling `SubscribeToTask` or `GET /v1/tasks/{task_id}:subscribe` on a terminal task replays one terminal `Task` snapshot and then closes the stream. - These two semantics are also declared as machine-readable `service_behaviors` in the compatibility profile and wire contract extensions. - At `A2A_LOG_LEVEL=DEBUG`, the service emits lightweight metric log records (`logger=opencode_a2a.execution.executor`): diff --git a/src/opencode_a2a/a2a_protocol.py b/src/opencode_a2a/a2a_protocol.py index 6a19005..4d72320 100644 --- a/src/opencode_a2a/a2a_protocol.py +++ b/src/opencode_a2a/a2a_protocol.py @@ -1,32 +1,8 @@ from __future__ import annotations -from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH +from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH as SDK_AGENT_CARD_WELL_KNOWN_PATH -PREV_AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json" +AGENT_CARD_WELL_KNOWN_PATH = SDK_AGENT_CARD_WELL_KNOWN_PATH EXTENDED_AGENT_CARD_PATH = "/agent/authenticatedExtendedCard" - -V1_JSONRPC_METHOD_TO_LEGACY_METHOD: dict[str, str] = { - "CancelTask": "tasks/cancel", - "CreateTaskPushNotificationConfig": "tasks/pushNotificationConfig/set", - "DeleteTaskPushNotificationConfig": "tasks/pushNotificationConfig/delete", - "GetExtendedAgentCard": "agent/getAuthenticatedExtendedCard", - "GetTask": "tasks/get", - "GetTaskPushNotificationConfig": "tasks/pushNotificationConfig/get", - "ListTasks": "tasks/list", - "ListTaskPushNotificationConfigs": "tasks/pushNotificationConfig/list", - "SendMessage": "message/send", - "SendStreamingMessage": "message/stream", - "SubscribeToTask": "tasks/resubscribe", -} - -LEGACY_JSONRPC_METHOD_TO_V1_METHOD = { - legacy: method for method, legacy in V1_JSONRPC_METHOD_TO_LEGACY_METHOD.items() -} - -__all__ = [ - "AGENT_CARD_WELL_KNOWN_PATH", - "EXTENDED_AGENT_CARD_PATH", - "LEGACY_JSONRPC_METHOD_TO_V1_METHOD", - "PREV_AGENT_CARD_WELL_KNOWN_PATH", - "V1_JSONRPC_METHOD_TO_LEGACY_METHOD", -] +CORE_JSONRPC_METHODS = tuple(JsonRpcDispatcher.METHOD_TO_MODEL) diff --git a/src/opencode_a2a/a2a_utils.py b/src/opencode_a2a/a2a_utils.py index d0dff67..b4d43ba 100644 --- a/src/opencode_a2a/a2a_utils.py +++ b/src/opencode_a2a/a2a_utils.py @@ -1,11 +1,9 @@ from __future__ import annotations -import json from collections.abc import Mapping, Sequence from typing import Any, TypeVar, cast -from a2a.types import Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent -from google.protobuf.json_format import MessageToDict +from a2a.types import Artifact, Message, Part, TaskArtifactUpdateEvent, TaskStatusUpdateEvent from google.protobuf.message import Message as ProtoMessage from google.protobuf.struct_pb2 import ListValue, Struct, Value @@ -18,76 +16,8 @@ def clone_proto(message: ProtoT) -> ProtoT: return cloned -def proto_to_dict( - message: ProtoMessage, - *, - preserving_proto_field_name: bool = False, -) -> dict[str, Any]: - return cast( - dict[str, Any], - MessageToDict( - message, - preserving_proto_field_name=preserving_proto_field_name, - ), - ) - - def proto_equals(left: ProtoMessage, right: ProtoMessage) -> bool: - return proto_to_dict(left, preserving_proto_field_name=True) == proto_to_dict( - right, - preserving_proto_field_name=True, - ) - - -def make_text_part( - text: str, - *, - metadata: Mapping[str, Any] | None = None, - filename: str | None = None, - media_type: str | None = None, -) -> Part: - part = Part(text=text) - if metadata: - part.metadata.update(dict(metadata)) - if filename: - part.filename = filename - if media_type: - part.media_type = media_type - return part - - -def make_raw_part( - raw: bytes, - *, - filename: str | None = None, - media_type: str | None = None, - metadata: Mapping[str, Any] | None = None, -) -> Part: - part = Part(raw=raw) - if metadata: - part.metadata.update(dict(metadata)) - if filename: - part.filename = filename - if media_type: - part.media_type = media_type - return part - - -def make_url_part( - url: str, - *, - filename: str | None = None, - media_type: str | None = None, - metadata: Mapping[str, Any] | None = None, -) -> Part: - part = Part(url=url) - if metadata: - part.metadata.update(dict(metadata)) - if filename: - part.filename = filename - if media_type: - part.media_type = media_type - return part + return bool(left == right) def _to_proto_value(value: Any) -> Value: @@ -131,53 +61,6 @@ def make_data_part( return part -def part_is_text(part: Part) -> bool: - return cast(bool, part.HasField("text")) - - -def part_is_data(part: Part) -> bool: - return cast(bool, part.HasField("data")) - - -def part_is_file(part: Part) -> bool: - return cast(bool, part.HasField("raw")) or cast(bool, part.HasField("url")) - - -def part_kind(part: Part) -> str | None: - if part_is_text(part): - return "text" - if part_is_data(part): - return "data" - if part_is_file(part): - return "file" - return None - - -def part_text(part: Part) -> str | None: - if part.HasField("text"): - return part.text - return None - - -def part_data_to_python(part: Part) -> Any: - if not part.HasField("data"): - return None - return MessageToDict(part.data) - - -def part_text_fallback(part: Part) -> str | None: - if part.HasField("text"): - return part.text - if part.HasField("data"): - return json.dumps( - part_data_to_python(part), - ensure_ascii=True, - sort_keys=True, - separators=(",", ":"), - ) - return None - - def replace_message_parts(message: Message, parts: Sequence[Part]) -> Message: updated = clone_proto(message) del updated.parts[:] @@ -192,37 +75,6 @@ def replace_artifact_parts(artifact: Artifact, parts: Sequence[Part]) -> Artifac return updated -def replace_task_status_message(task: Task, message: Message | None) -> Task: - updated = clone_proto(task) - if message is None: - updated.status.ClearField("message") - else: - updated.status.message.CopyFrom(message) - return updated - - -def replace_task_history(task: Task, history: Sequence[Message]) -> Task: - updated = clone_proto(task) - del updated.history[:] - updated.history.extend(history) - return updated - - -def replace_task_artifacts(task: Task, artifacts: Sequence[Artifact]) -> Task: - updated = clone_proto(task) - del updated.artifacts[:] - updated.artifacts.extend(artifacts) - return updated - - -def replace_task_metadata(task: Task, metadata: Mapping[str, Any] | None) -> Task: - updated = clone_proto(task) - updated.ClearField("metadata") - if metadata: - updated.metadata.update(dict(metadata)) - return updated - - def replace_status_event_message( event: TaskStatusUpdateEvent, message: Message | None, diff --git a/src/opencode_a2a/cli.py b/src/opencode_a2a/cli.py index 0ae6e60..6193a73 100644 --- a/src/opencode_a2a/cli.py +++ b/src/opencode_a2a/cli.py @@ -6,10 +6,9 @@ import sys from collections.abc import Sequence -from a2a.types import Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent +from a2a.types import TaskState from . import __version__ -from .a2a_utils import part_text from .client import A2AClient, load_settings from .server.application import main as serve_main @@ -20,23 +19,24 @@ async def run_call(agent_url: str, text: str) -> int: try: async for event in client.send_message(text): - if isinstance(event, tuple): - _, update = event - if isinstance(update, TaskArtifactUpdateEvent): - artifact = update.artifact - if artifact and artifact.parts: - for part in artifact.parts: - text_val = part_text(part) - if isinstance(text_val, str): - print(text_val, end="", flush=True) - elif isinstance(update, TaskStatusUpdateEvent): - if update.status and update.status.state == TaskState.TASK_STATE_FAILED: - print(f"\n[Failed] {update.status.message or ''}") - elif isinstance(event, Message): - for part in event.parts: - text_val = part_text(part) + if event.HasField("message"): + for part in event.message.parts: + text_val = part.text if part.HasField("text") else None if isinstance(text_val, str): print(text_val, end="", flush=True) + elif event.HasField("artifact_update"): + artifact = event.artifact_update.artifact + if artifact and artifact.parts: + for part in artifact.parts: + text_val = part.text if part.HasField("text") else None + if isinstance(text_val, str): + print(text_val, end="", flush=True) + elif ( + event.HasField("status_update") + and event.status_update.status + and event.status_update.status.state == TaskState.TASK_STATE_FAILED + ): + print(f"\n[Failed] {event.status_update.status.message or ''}") print() # New line after completion except Exception as exc: print(f"\n[Error] {exc}", file=sys.stderr) diff --git a/src/opencode_a2a/client/agent_card.py b/src/opencode_a2a/client/agent_card.py index 8e659e4..c96cad5 100644 --- a/src/opencode_a2a/client/agent_card.py +++ b/src/opencode_a2a/client/agent_card.py @@ -11,7 +11,6 @@ from ..a2a_protocol import ( AGENT_CARD_WELL_KNOWN_PATH, EXTENDED_AGENT_CARD_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, ) from ..trace_context import current_trace_headers from .request_context import build_default_headers @@ -26,7 +25,6 @@ def normalize_agent_card_endpoint(agent_url: str) -> tuple[str, str]: normalized_no_leading = path.rstrip("/").lstrip("/") candidate_paths = ( AGENT_CARD_WELL_KNOWN_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, EXTENDED_AGENT_CARD_PATH, ) @@ -77,10 +75,3 @@ def build_resolver_http_kwargs( if default_headers: http_kwargs["headers"] = default_headers return http_kwargs - - -__all__ = [ - "build_agent_card_resolver", - "build_resolver_http_kwargs", - "normalize_agent_card_endpoint", -] diff --git a/src/opencode_a2a/client/client.py b/src/opencode_a2a/client/client.py index 9d702b5..873408a 100644 --- a/src/opencode_a2a/client/client.py +++ b/src/opencode_a2a/client/client.py @@ -19,18 +19,16 @@ CancelTaskRequest, GetTaskRequest, Message, + Part, Role, SendMessageConfiguration, SendMessageRequest, StreamResponse, SubscribeToTaskRequest, Task, - TaskArtifactUpdateEvent, - TaskStatusUpdateEvent, ) from a2a.utils.errors import A2AError -from ..a2a_utils import make_text_part from ..invocation import call_with_supported_kwargs from .agent_card import build_agent_card_resolver, build_resolver_http_kwargs from .config import A2AClientSettings, load_settings @@ -42,7 +40,18 @@ from .polling import PollingFallbackPolicy from .request_context import build_call_context, split_request_metadata -ClientFactory = None + +def _merge_requested_extensions( + explicit_extensions: list[str] | None, + metadata_extensions: tuple[str, ...] | None, +) -> tuple[str, ...] | None: + merged: list[str] = [] + for value in list(explicit_extensions or []) + list(metadata_extensions or ()): + if isinstance(value, str): + normalized = value.strip() + if normalized and normalized not in merged: + merged.append(normalized) + return tuple(merged) or None class A2AClient: @@ -120,19 +129,18 @@ async def send_message( message_id: str | None = None, metadata: Mapping[str, Any] | None = None, extensions: list[str] | None = None, - ) -> AsyncIterator[ - Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None - ]: - """Send one user message and stream protocol events.""" + ) -> AsyncIterator[StreamResponse]: + """Send one user message and stream raw protocol events.""" await self._acquire_operation() try: client = await self._ensure_client() - request_metadata, extra_headers = split_request_metadata(metadata) + request_metadata, extra_headers, metadata_extensions = split_request_metadata(metadata) + requested_extensions = _merge_requested_extensions(extensions, metadata_extensions) call_context = build_call_context( self._settings.bearer_token, extra_headers, + requested_extensions, self._settings.basic_auth, - self._settings.protocol_version, ) try: async for event in call_with_supported_kwargs( @@ -143,8 +151,7 @@ async def send_message( message_id=message_id or str(uuid4()), context_id=context_id, task_id=task_id, - parts=[make_text_part(text)], - extensions=list(extensions or []), + parts=[Part(text=text)], ), configuration=SendMessageConfiguration(), metadata=request_metadata or {}, @@ -153,9 +160,9 @@ async def send_message( call_context=call_context, request_metadata=request_metadata, ): - yield self._adapt_stream_response(event) + yield event except (A2AError, SDKClientError, httpx.TimeoutException, httpx.TransportError) as exc: - raise map_operation_error("message/send", exc) from exc + raise map_operation_error("SendMessage", exc) from exc finally: await self._release_operation() @@ -168,16 +175,14 @@ async def send( message_id: str | None = None, metadata: Mapping[str, Any] | None = None, extensions: list[str] | None = None, - ) -> Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None: + ) -> StreamResponse | None: """Send a message and return the latest response event. - When polling fallback is enabled, a non-terminal `(Task, None)` result may - be followed by bounded `tasks/get` polling until a terminal task snapshot + When polling fallback is enabled, a non-terminal task snapshot may be + followed by bounded `GetTask` polling until a terminal task snapshot is observed. """ - last_event: ( - Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None - ) = None + last_event: StreamResponse | None = None async for event in self.send_message( text, context_id=context_id, @@ -190,13 +195,11 @@ async def send( if not self._should_poll_after_send(last_event): return last_event terminal_task = await self._poll_task_until_terminal( - cast( - tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None], - last_event, - )[0], + cast(StreamResponse, last_event).task, metadata=metadata, + extensions=extensions, ) - return (terminal_task, None) + return StreamResponse(task=terminal_task) async def get_task( self, @@ -204,17 +207,19 @@ async def get_task( *, history_length: int | None = None, metadata: Mapping[str, Any] | None = None, + extensions: list[str] | None = None, ) -> Task: """Fetch one task by id.""" await self._acquire_operation() try: client = await self._ensure_client() - request_metadata, extra_headers = split_request_metadata(metadata) + request_metadata, extra_headers, metadata_extensions = split_request_metadata(metadata) + requested_extensions = _merge_requested_extensions(extensions, metadata_extensions) call_context = build_call_context( self._settings.bearer_token, extra_headers, + requested_extensions, self._settings.basic_auth, - self._settings.protocol_version, ) try: return cast( @@ -228,7 +233,7 @@ async def get_task( ), ) except (A2AError, SDKClientError, httpx.TimeoutException, httpx.TransportError) as exc: - raise map_operation_error("tasks/get", exc) from exc + raise map_operation_error("GetTask", exc) from exc finally: await self._release_operation() @@ -237,17 +242,19 @@ async def cancel_task( task_id: str, *, metadata: Mapping[str, Any] | None = None, + extensions: list[str] | None = None, ) -> Task: """Cancel one task by id.""" await self._acquire_operation() try: client = await self._ensure_client() - request_metadata, extra_headers = split_request_metadata(metadata) + request_metadata, extra_headers, metadata_extensions = split_request_metadata(metadata) + requested_extensions = _merge_requested_extensions(extensions, metadata_extensions) call_context = build_call_context( self._settings.bearer_token, extra_headers, + requested_extensions, self._settings.basic_auth, - self._settings.protocol_version, ) try: return cast( @@ -261,48 +268,40 @@ async def cancel_task( ), ) except (A2AError, SDKClientError, httpx.TimeoutException, httpx.TransportError) as exc: - raise map_operation_error("tasks/cancel", exc) from exc + raise map_operation_error("CancelTask", exc) from exc finally: await self._release_operation() - async def resubscribe_task( + async def subscribe_to_task( self, task_id: str, *, metadata: Mapping[str, Any] | None = None, - ) -> AsyncIterator[tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None]]: - """Resubscribe to task updates.""" + extensions: list[str] | None = None, + ) -> AsyncIterator[StreamResponse]: + """Subscribe to task updates.""" await self._acquire_operation() try: client = await self._ensure_client() - request_metadata, extra_headers = split_request_metadata(metadata) + request_metadata, extra_headers, metadata_extensions = split_request_metadata(metadata) + requested_extensions = _merge_requested_extensions(extensions, metadata_extensions) call_context = build_call_context( self._settings.bearer_token, extra_headers, + requested_extensions, self._settings.basic_auth, - self._settings.protocol_version, ) try: - subscribe = getattr(client, "subscribe", None) - if subscribe is None: - subscribe = cast(Any, client).resubscribe async for event in call_with_supported_kwargs( - subscribe, + client.subscribe, SubscribeToTaskRequest(id=task_id), context=call_context, call_context=call_context, request_metadata=request_metadata, ): - adapted = self._adapt_stream_response(event) - if isinstance(adapted, tuple): - yield adapted - elif adapted is not None and not isinstance(adapted, Message): - yield cast( - tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None], - adapted, - ) + yield event except (A2AError, SDKClientError, httpx.TimeoutException, httpx.TransportError) as exc: - raise map_operation_error("tasks/resubscribe", exc) from exc + raise map_operation_error("SubscribeToTask", exc) from exc finally: await self._release_operation() @@ -320,13 +319,6 @@ async def _build_client(self) -> Client: supported_protocol_bindings=list(self._settings.supported_transports), use_client_preference=self._settings.use_client_preference, ) - factory_cls = globals().get("ClientFactory") - if factory_cls is not None: - card = await self.get_agent_card() - factory = factory_cls(config) - client = cast(Client, factory.create(card, interceptors=None)) - self._client = client - return client try: client = await create_client( self.agent_url, @@ -361,22 +353,22 @@ async def _release_operation(self) -> None: def _should_poll_after_send( self, - event: Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None, + event: StreamResponse | None, ) -> bool: if not self._polling_fallback_policy.enabled: return False - if event is None or isinstance(event, Message) or not isinstance(event, tuple): + if event is None or not event.HasField("task"): return False - task, update = event - if update is not None: + if not event.task.HasField("status"): return False - return self._polling_fallback_policy.should_poll_state(task.status.state) + return self._polling_fallback_policy.should_poll_state(event.task.status.state) async def _poll_task_until_terminal( self, task: Task, *, metadata: Mapping[str, Any] | None = None, + extensions: list[str] | None = None, ) -> Task: deadline = self._current_time() + self._polling_fallback_policy.timeout_seconds interval = self._polling_fallback_policy.initial_interval_seconds @@ -396,7 +388,11 @@ async def _poll_task_until_terminal( ) await self._sleep(min(interval, remaining)) - current_task = await self.get_task(current_task.id, metadata=metadata) + current_task = await self.get_task( + current_task.id, + metadata=metadata, + extensions=extensions, + ) interval = self._polling_fallback_policy.next_interval_seconds(interval) def _current_time(self) -> float: @@ -404,36 +400,3 @@ def _current_time(self) -> float: async def _sleep(self, delay_seconds: float) -> None: await asyncio.sleep(delay_seconds) - - def _adapt_stream_response( - self, - response: StreamResponse, - ) -> Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None: - if not hasattr(response, "HasField"): - return cast( - Message - | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] - | None, - response, - ) - if response.HasField("message"): - return response.message - if response.HasField("task"): - return (response.task, None) - if response.HasField("status_update"): - task = Task( - id=response.status_update.task_id, - context_id=response.status_update.context_id, - status=response.status_update.status, - ) - return (task, response.status_update) - if response.HasField("artifact_update"): - task = Task( - id=response.artifact_update.task_id, - context_id=response.artifact_update.context_id, - ) - return (task, response.artifact_update) - return None - - -__all__ = ["A2AClient"] diff --git a/src/opencode_a2a/client/config.py b/src/opencode_a2a/client/config.py index 3b9e70a..891d519 100644 --- a/src/opencode_a2a/client/config.py +++ b/src/opencode_a2a/client/config.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from typing import Any -from ..protocol_versions import normalize_protocol_version from .auth import validate_basic_auth from .polling import PollingFallbackPolicy, validate_polling_fallback_policy @@ -71,13 +70,6 @@ def _coerce_optional_str(name: str, value: Any) -> str | None: raise ValueError(f"{name} must be a string, got {value!r}") -def _coerce_optional_protocol_version(name: str, value: Any) -> str | None: - normalized = _coerce_optional_str(name, value) - if normalized is None: - return None - return normalize_protocol_version(normalized) - - def _normalize_transport(value: str) -> str: normalized = value.strip().lower() if normalized in {"jsonrpc", "json-rpc", "json_rpc"}: @@ -118,7 +110,6 @@ class A2AClientSettings: card_fetch_timeout: float = 5.0 bearer_token: str | None = None basic_auth: str | None = None - protocol_version: str | None = None supported_transports: tuple[str, ...] = ( "JSONRPC", "HTTP+JSON", @@ -181,19 +172,6 @@ def load_settings(raw_settings: Any) -> A2AClientSettings: ) if basic_auth is not None: validate_basic_auth(basic_auth) - protocol_version = _coerce_optional_protocol_version( - "A2A_CLIENT_PROTOCOL_VERSION", - _read_setting( - raw_settings, - keys=( - "A2A_CLIENT_PROTOCOL_VERSION", - "a2a_client_protocol_version", - "A2A_PROTOCOL_VERSION", - "a2a_protocol_version", - ), - default=None, - ), - ) supported_transports = _parse_transports( _read_setting( raw_settings, @@ -282,7 +260,6 @@ def load_settings(raw_settings: Any) -> A2AClientSettings: card_fetch_timeout=card_fetch_timeout, bearer_token=bearer_token, basic_auth=basic_auth, - protocol_version=protocol_version, supported_transports=supported_transports, polling_fallback_enabled=polling_fallback_enabled, polling_fallback_initial_interval_seconds=polling_fallback_initial_interval_seconds, @@ -290,6 +267,3 @@ def load_settings(raw_settings: Any) -> A2AClientSettings: polling_fallback_backoff_multiplier=polling_fallback_backoff_multiplier, polling_fallback_timeout_seconds=polling_fallback_timeout_seconds, ) - - -__all__ = ["A2AClientSettings", "load_settings"] diff --git a/src/opencode_a2a/client/error_mapping.py b/src/opencode_a2a/client/error_mapping.py index a99f819..52bac2c 100644 --- a/src/opencode_a2a/client/error_mapping.py +++ b/src/opencode_a2a/client/error_mapping.py @@ -259,14 +259,3 @@ def map_agent_card_error( if isinstance(exc, (httpx.TimeoutException, httpx.TransportError)): return map_transport_error("agent-card/fetch", exc) return A2AAgentUnavailableError("Remote A2A peer is unreachable for agent-card/fetch") - - -__all__ = [ - "map_agent_card_error", - "map_a2a_error", - "map_client_error", - "map_http_error", - "map_jsonrpc_error", - "map_operation_error", - "map_transport_error", -] diff --git a/src/opencode_a2a/client/errors.py b/src/opencode_a2a/client/errors.py index d096c66..cf928d7 100644 --- a/src/opencode_a2a/client/errors.py +++ b/src/opencode_a2a/client/errors.py @@ -71,16 +71,3 @@ def __init__( self.code = rpc_code self.http_status = http_status self.data = data - - -__all__ = [ - "A2AClientError", - "A2AAgentUnavailableError", - "A2AAuthenticationError", - "A2APermissionDeniedError", - "A2AClientResetRequiredError", - "A2ATimeoutError", - "A2AUnsupportedBindingError", - "A2AUnsupportedOperationError", - "A2APeerProtocolError", -] diff --git a/src/opencode_a2a/client/payload_text.py b/src/opencode_a2a/client/payload_text.py index 67cce29..7501be1 100644 --- a/src/opencode_a2a/client/payload_text.py +++ b/src/opencode_a2a/client/payload_text.py @@ -5,11 +5,10 @@ from collections.abc import Mapping from typing import Any -from a2a.types import Message, Part +from a2a.types import Message, Part, StreamResponse +from google.protobuf.json_format import MessageToDict from google.protobuf.message import Message as ProtoMessage -from ..a2a_utils import part_text, proto_to_dict - def extract_text(payload: Any) -> str | None: def extract_from_iterable(items: Any) -> str | None: @@ -27,7 +26,7 @@ def extract_from_parts(parts: Any) -> str | None: collected: list[str] = [] for part in parts: if isinstance(part, Part): - text_value = part_text(part) + text_value = part.text if part.HasField("text") else None if text_value: collected.append(text_value) continue @@ -97,6 +96,17 @@ def extract_from_mapping(payload_map: Mapping[str, Any]) -> str | None: if isinstance(payload, Message): return extract_from_parts(payload.parts) + if isinstance(payload, StreamResponse): + if payload.HasField("artifact_update"): + return extract_text(payload.artifact_update.artifact) + if payload.HasField("status_update"): + return extract_text(payload.status_update.status) + if payload.HasField("message"): + return extract_text(payload.message) + if payload.HasField("task"): + return extract_text(payload.task) + return None + if isinstance(payload, str): return payload.strip() or None @@ -155,7 +165,7 @@ def extract_from_mapping(payload_map: Mapping[str, Any]) -> str | None: mapping_payload: Mapping[str, Any] | None = None if isinstance(payload, ProtoMessage): - payload_dict = proto_to_dict(payload) + payload_dict = MessageToDict(payload) if isinstance(payload_dict, Mapping): mapping_payload = payload_dict elif hasattr(payload, "model_dump") and callable(payload.model_dump): @@ -175,6 +185,3 @@ def extract_from_mapping(payload_map: Mapping[str, Any]) -> str | None: return mapped_text return None - - -__all__ = ["extract_text"] diff --git a/src/opencode_a2a/client/polling.py b/src/opencode_a2a/client/polling.py index 35cd262..4e37549 100644 --- a/src/opencode_a2a/client/polling.py +++ b/src/opencode_a2a/client/polling.py @@ -57,6 +57,3 @@ def validate_polling_fallback_policy(policy: PollingFallbackPolicy) -> None: "A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS must be greater than or " "equal to A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS" ) - - -__all__ = ["PollingFallbackPolicy", "validate_polling_fallback_policy"] diff --git a/src/opencode_a2a/client/request_context.py b/src/opencode_a2a/client/request_context.py index 8e5702c..570b685 100644 --- a/src/opencode_a2a/client/request_context.py +++ b/src/opencode_a2a/client/request_context.py @@ -6,8 +6,9 @@ from typing import Any from a2a.client.client import ClientCallContext +from a2a.client.service_parameters import ServiceParametersFactory, with_a2a_extensions -from ..protocol_versions import normalize_protocol_version +from ..protocol_versions import A2A_PROTOCOL_VERSION from ..trace_context import current_trace_headers from .auth import encode_basic_auth @@ -15,66 +16,79 @@ def build_default_headers( bearer_token: str | None, basic_auth: str | None = None, - protocol_version: str | None = None, ) -> dict[str, str]: - headers: dict[str, str] = {} + headers: dict[str, str] = {"A2A-Version": A2A_PROTOCOL_VERSION} if bearer_token: headers["Authorization"] = f"Bearer {bearer_token}" elif basic_auth: headers["Authorization"] = f"Basic {encode_basic_auth(basic_auth)}" - if protocol_version: - headers["A2A-Version"] = normalize_protocol_version(protocol_version) return headers def split_request_metadata( metadata: Mapping[str, Any] | None, -) -> tuple[dict[str, Any] | None, dict[str, str] | None]: +) -> tuple[dict[str, Any] | None, dict[str, str] | None, tuple[str, ...] | None]: request_metadata: dict[str, Any] = {} extra_headers: dict[str, str] = {} - for key, value in dict(metadata or {}).items(): - if isinstance(key, str) and key.lower() == "authorization": + requested_extensions: list[str] = [] + for key, value in (metadata or {}).items(): + normalized_key = key.lower() + if normalized_key == "authorization": if value is not None: - extra_headers["Authorization"] = str(value) + if not isinstance(value, str): + raise ValueError("Authorization metadata header must be a string.") + extra_headers["Authorization"] = value continue - if isinstance(key, str) and key.lower() == "a2a-version": + if normalized_key == "a2a-version": + raise ValueError("A2A-Version is fixed to 1.0 and must not be overridden.") + if normalized_key == "traceparent": if value is not None: - extra_headers["A2A-Version"] = normalize_protocol_version(str(value)) + if not isinstance(value, str): + raise ValueError("traceparent metadata header must be a string.") + extra_headers["traceparent"] = value continue - if isinstance(key, str) and key.lower() == "traceparent": + if normalized_key == "tracestate": if value is not None: - extra_headers["traceparent"] = str(value) + if not isinstance(value, str): + raise ValueError("tracestate metadata header must be a string.") + extra_headers["tracestate"] = value continue - if isinstance(key, str) and key.lower() == "tracestate": - if value is not None: - extra_headers["tracestate"] = str(value) + if normalized_key == "a2a-extensions": + if value is None: + continue + if not isinstance(value, str): + raise ValueError("A2A-Extensions metadata header must be a string.") + requested_extensions.extend(item.strip() for item in value.split(",") if item.strip()) continue request_metadata[key] = value - return request_metadata or None, extra_headers or None + return ( + request_metadata or None, + extra_headers or None, + tuple(requested_extensions) or None, + ) def build_call_context( bearer_token: str | None, extra_headers: Mapping[str, str] | None, + extensions: tuple[str, ...] | None = None, basic_auth: str | None = None, - protocol_version: str | None = None, -) -> ClientCallContext | None: - merged_headers = build_default_headers(bearer_token, basic_auth, protocol_version) +) -> ClientCallContext: + merged_headers = build_default_headers(bearer_token, basic_auth) merged_headers.update(current_trace_headers()) if extra_headers: merged_headers.update(extra_headers) - if not merged_headers: - return None + normalized_extensions = [value for value in (extensions or ()) if value] + service_parameters = None + if normalized_extensions: + service_parameters = ServiceParametersFactory.create_from( + None, + [with_a2a_extensions(normalized_extensions)], + ) return ClientCallContext( state={ "headers": dict(merged_headers), "http_kwargs": {"headers": dict(merged_headers)}, - } + }, + service_parameters=service_parameters, ) - - -__all__ = [ - "build_call_context", - "build_default_headers", - "split_request_metadata", -] diff --git a/src/opencode_a2a/config.py b/src/opencode_a2a/config.py index 3d0bbbe..ddcf664 100644 --- a/src/opencode_a2a/config.py +++ b/src/opencode_a2a/config.py @@ -3,14 +3,10 @@ import json from typing import Annotated, Any, Literal -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, model_validator from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from opencode_a2a import __version__ -from opencode_a2a.protocol_versions import ( - normalize_protocol_version, - normalize_protocol_versions, -) from opencode_a2a.sandbox_policy import SandboxPolicy SandboxMode = Literal[ @@ -176,11 +172,6 @@ class Settings(BaseSettings): a2a_title: str = Field(default="OpenCode A2A", alias="A2A_TITLE") a2a_description: str = Field(default="OpenCode A2A runtime", alias="A2A_DESCRIPTION") a2a_version: str = Field(default=__version__, alias="A2A_VERSION") - a2a_protocol_version: str = Field(default="0.3", alias="A2A_PROTOCOL_VERSION") - a2a_supported_protocol_versions: DeclaredStringList = Field( - default=("0.3", "1.0"), - alias="A2A_SUPPORTED_PROTOCOL_VERSIONS", - ) a2a_log_level: str = Field(default="WARNING", alias="A2A_LOG_LEVEL") a2a_log_payloads: bool = Field(default=False, alias="A2A_LOG_PAYLOADS") a2a_log_body_limit: int = Field(default=0, alias="A2A_LOG_BODY_LIMIT") @@ -270,10 +261,6 @@ class Settings(BaseSettings): ) a2a_client_bearer_token: str | None = Field(default=None, alias="A2A_CLIENT_BEARER_TOKEN") a2a_client_basic_auth: str | None = Field(default=None, alias="A2A_CLIENT_BASIC_AUTH") - a2a_client_protocol_version: str | None = Field( - default=None, - alias="A2A_CLIENT_PROTOCOL_VERSION", - ) a2a_client_cache_ttl_seconds: float = Field( default=900.0, ge=0.0, @@ -288,7 +275,6 @@ class Settings(BaseSettings): default=("JSONRPC", "HTTP+JSON"), alias="A2A_CLIENT_SUPPORTED_TRANSPORTS", ) - # Task store settings a2a_task_store_backend: TaskStoreBackend = Field( default="database", @@ -306,12 +292,6 @@ def _validate_sandbox_policy(self) -> Settings: raise ValueError( "A2A_TASK_STORE_DATABASE_URL is required when A2A_TASK_STORE_BACKEND=database" ) - if self.a2a_protocol_version not in self.a2a_supported_protocol_versions: - supported_display = ", ".join(self.a2a_supported_protocol_versions) - raise ValueError( - "A2A_PROTOCOL_VERSION must be present in A2A_SUPPORTED_PROTOCOL_VERSIONS. " - f"Declared supported versions: {supported_display}" - ) if self.a2a_static_auth_credentials: if not any(credential.enabled for credential in self.a2a_static_auth_credentials): raise ValueError( @@ -320,30 +300,3 @@ def _validate_sandbox_policy(self) -> Settings: else: raise ValueError("Configure runtime authentication via A2A_STATIC_AUTH_CREDENTIALS") return self - - @field_validator("a2a_protocol_version", mode="before") - @classmethod - def _normalize_a2a_protocol_version(cls, value: Any) -> str: - if not isinstance(value, str): - raise TypeError("A2A_PROTOCOL_VERSION must be a string.") - return normalize_protocol_version(value) - - @field_validator("a2a_client_protocol_version", mode="before") - @classmethod - def _normalize_a2a_client_protocol_version(cls, value: Any) -> str | None: - if value is None: - return None - if not isinstance(value, str): - raise TypeError("A2A_CLIENT_PROTOCOL_VERSION must be a string.") - normalized = value.strip() - if not normalized: - return None - return normalize_protocol_version(normalized) - - @field_validator("a2a_supported_protocol_versions") - @classmethod - def _normalize_supported_protocol_versions( - cls, - value: tuple[str, ...], - ) -> tuple[str, ...]: - return normalize_protocol_versions(value) diff --git a/src/opencode_a2a/contracts/extensions.py b/src/opencode_a2a/contracts/extensions.py index fe07b59..67ffb24 100644 --- a/src/opencode_a2a/contracts/extensions.py +++ b/src/opencode_a2a/contracts/extensions.py @@ -3,9 +3,7 @@ from dataclasses import dataclass from typing import Any -from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher - -from ..a2a_protocol import V1_JSONRPC_METHOD_TO_LEGACY_METHOD +from ..a2a_protocol import CORE_JSONRPC_METHODS as DECLARED_CORE_JSONRPC_METHODS from ..profile.runtime import ( SESSION_SHELL_TOGGLE, WORKSPACE_MUTATIONS_TOGGLE, @@ -45,13 +43,6 @@ def _extension_spec_uri(fragment: str) -> str: SERVICE_BEHAVIOR_CLASSIFICATION = "service-level-semantic-enhancement" CANCEL_IDEMPOTENCY_BEHAVIOR = "return_current_terminal_task" TERMINAL_RESUBSCRIBE_BEHAVIOR = "replay_terminal_task_once_then_close" -V1_PARTIAL_COMPATIBILITY_GAPS: tuple[str, ...] = ( - "AgentInterface.protocolVersion cannot be declared with a2a-sdk==0.3.25.", - ( - "Transport payloads, enums, pagination, signatures, and push-notification " - "surfaces still follow the SDK-owned 0.3 baseline." - ), -) @dataclass(frozen=True) @@ -421,10 +412,7 @@ class WorkspaceControlMethodContract: key: SESSION_METHODS[key] for key in SESSION_MUTATION_METHOD_KEYS } -CORE_JSONRPC_METHODS: tuple[str, ...] = tuple( - V1_JSONRPC_METHOD_TO_LEGACY_METHOD.get(method, method) - for method in JsonRpcDispatcher.METHOD_TO_MODEL -) +CORE_JSONRPC_METHODS: tuple[str, ...] = tuple(DECLARED_CORE_JSONRPC_METHODS) CORE_HTTP_ENDPOINTS: tuple[str, ...] = ( "POST /v1/message:send", "POST /v1/message:stream", @@ -955,7 +943,7 @@ def build_model_selection_extension_params( return { "metadata_field": SHARED_MODEL_SELECTION_FIELD, "behavior": "prefer_metadata_model_else_upstream_default", - "applies_to_methods": ["message/send", "message/stream"], + "applies_to_methods": ["SendMessage", "SendStreamingMessage"], "supported_metadata": [ "shared.model.providerID", "shared.model.modelID", @@ -1608,7 +1596,7 @@ def build_compatibility_profile_params( ), ( "Treat protocol_compatibility as the runtime truth for which major line " - "is fully supported versus partially adapted." + "is fully supported by the current deployment." ), ], } @@ -1621,31 +1609,17 @@ def build_protocol_compatibility_params( ) -> dict[str, Any]: declared_supported_versions = list(supported_protocol_versions) versions: dict[str, dict[str, Any]] = { - "0.3": { - "enabled": "0.3" in declared_supported_versions, - "default": default_protocol_version == "0.3", - "status": "supported", - "supported_features": [ - "Default compatibility line for the current deployment.", - "A2A-Version negotiation fallback and explicit 0.3 routing.", - "Legacy JSON-RPC and REST error envelopes.", - ( - "SDK-owned transport payloads, enums, pagination, signatures, and " - "push-notification surfaces." - ), - ], - "known_gaps": [], - }, "1.0": { "enabled": "1.0" in declared_supported_versions, "default": default_protocol_version == "1.0", - "status": "partial", + "status": "supported", "supported_features": [ - "A2A-Version negotiation and request routing.", - "Protocol-aware JSON-RPC error shaping.", - "Protocol-aware REST error shaping.", + "Proto-first transport payloads and enum naming.", + "Canonical A2A v1.0 JSON-RPC method names.", + "Protocol-aware JSON-RPC and REST error shaping.", + "Agent Card and OpenAPI discovery aligned to the v1.0 surface.", ], - "known_gaps": list(V1_PARTIAL_COMPATIBILITY_GAPS), + "known_gaps": [], }, } @@ -1732,7 +1706,7 @@ def build_service_behavior_contract_params() -> dict[str, Any]: return { "classification": SERVICE_BEHAVIOR_CLASSIFICATION, "methods": { - "tasks/cancel": { + "CancelTask": { "baseline": "core", "retention": "stable", "idempotency": { @@ -1743,7 +1717,7 @@ def build_service_behavior_contract_params() -> dict[str, Any]: } }, }, - "tasks/resubscribe": { + "SubscribeToTask": { "baseline": "core", "retention": "stable", "terminal_state_behavior": { diff --git a/src/opencode_a2a/execution/coordinator.py b/src/opencode_a2a/execution/coordinator.py index 36f2079..8aefe6f 100644 --- a/src/opencode_a2a/execution/coordinator.py +++ b/src/opencode_a2a/execution/coordinator.py @@ -14,6 +14,7 @@ from a2a.types import ( Artifact, Message, + Part, Role, Task, TaskState, @@ -21,7 +22,6 @@ TaskStatusUpdateEvent, ) -from ..a2a_utils import make_text_part from ..invocation import call_with_supported_kwargs from ..opencode_upstream_client import UpstreamConcurrencyLimitError, UpstreamContractError from .event_helpers import _enqueue_artifact_update @@ -59,6 +59,8 @@ class PreparedExecution: workspace_id: str | None session_binding_context_id: str allow_structured_output: bool + emit_session_metadata: bool + emit_streaming_metadata: bool def build_session_binding_context_id( @@ -288,6 +290,8 @@ async def _bind_session(self) -> None: workspace_id=self._prepared.workspace_id, terminal_signal=self._stream_terminal_signal, allow_structured_output=self._prepared.allow_structured_output, + emit_session_metadata=self._prepared.emit_session_metadata, + emit_streaming_metadata=self._prepared.emit_streaming_metadata, ) ) @@ -398,7 +402,7 @@ async def _handle_streaming_response( task_id=self._task_id, context_id=self._context_id, artifact_id=self._stream_artifact_id, - part=make_text_part(response_text), + part=Part(text=response_text), append=self._stream_state.emitted_stream_chunk, last_chunk=True, artifact_metadata=_build_stream_artifact_metadata( @@ -407,6 +411,7 @@ async def _handle_streaming_response( message_id=resolved_message_id, event_id=self._stream_state.build_event_id(sequence), sequence=sequence, + include_shared_stream_metadata=self._prepared.emit_streaming_metadata, ), ) @@ -423,6 +428,8 @@ async def _handle_streaming_response( "event_id": f"{self._stream_state.event_id_namespace}:status", "source": "status", }, + include_session_metadata=self._prepared.emit_session_metadata, + include_streaming_metadata=self._prepared.emit_streaming_metadata, ), ) ) @@ -445,7 +452,7 @@ async def _handle_non_streaming_response( artifact = Artifact( artifact_id=str(uuid.uuid4()), name="response", - parts=[make_text_part(response_text)], + parts=[Part(text=response_text)], ) from .request_context import _build_history @@ -459,6 +466,8 @@ async def _handle_non_streaming_response( metadata=_build_output_metadata( session_id=response.session_id, usage=resolved_token_usage, + include_session_metadata=self._prepared.emit_session_metadata, + include_streaming_metadata=self._prepared.emit_streaming_metadata, ), ) task.status.message.CopyFrom(assistant_message) @@ -498,7 +507,7 @@ def build_assistant_message( return Message( message_id=message_id or str(uuid.uuid4()), role=Role.ROLE_AGENT, - parts=[make_text_part(text)], + parts=[Part(text=text)], task_id=task_id, context_id=context_id, ) diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index be255ba..c921347 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -17,6 +17,7 @@ from a2a.server.events.event_queue import EventQueue from a2a.types import ( Message, + Part, Role, Task, TaskState, @@ -24,7 +25,11 @@ TaskStatusUpdateEvent, ) -from ..a2a_utils import make_text_part +from ..contracts.extensions import ( + SESSION_BINDING_EXTENSION_URI, + STREAMING_EXTENSION_URI, +) +from ..extension_negotiation import requested_extensions_from_call_context from ..invocation import call_with_supported_kwargs from ..opencode_upstream_client import OpencodeUpstreamClient from ..output_modes import accepts_output_mode, normalize_accepted_output_modes @@ -45,40 +50,11 @@ _extract_shared_session_id, ) from .session_manager import SessionManager -from .stream_events import ( - BlockType, - _build_progress_identity, - _coerce_number, - _extract_event_session_id, - _extract_interrupt_asked_event, - _extract_interrupt_resolved_event, - _extract_progress_metadata, - _extract_stream_session_id, - _extract_stream_snapshot_text, - _extract_stream_terminal_signal, - _extract_token_usage, - _extract_tool_part_payload, - _extract_upstream_error_from_event, - _extract_upstream_error_from_response, - _normalize_interrupt_question_options, - _normalize_interrupt_questions, - _normalize_role, - _preview_log_value, -) from .stream_runtime import StreamRuntime from .stream_state import ( - _build_output_metadata, - _merge_token_usage, _StreamOutputState, - _TTLCache, ) from .upstream_error_translator import ( - _await_stream_terminal_signal, - _extract_upstream_error_detail, - _format_inband_upstream_error, - _format_stream_terminal_error, - _format_upstream_error, - _resolve_upstream_error_profile, _StreamTerminalSignal, ) @@ -86,37 +62,6 @@ _TEXT_PLAIN_MEDIA_TYPE = "text/plain" _APPLICATION_JSON_MEDIA_TYPE = "application/json" -__all__ = [ - "BlockType", - "_build_output_metadata", - "_build_progress_identity", - "_coerce_number", - "_extract_event_session_id", - "_extract_interrupt_asked_event", - "_extract_interrupt_resolved_event", - "_extract_progress_metadata", - "_extract_stream_session_id", - "_extract_stream_snapshot_text", - "_extract_stream_terminal_signal", - "_extract_token_usage", - "_extract_tool_part_payload", - "_extract_upstream_error_detail", - "_extract_upstream_error_from_event", - "_extract_upstream_error_from_response", - "_format_inband_upstream_error", - "_format_stream_terminal_error", - "_format_upstream_error", - "_merge_token_usage", - "_normalize_interrupt_question_options", - "_normalize_interrupt_questions", - "_normalize_role", - "_preview_log_value", - "_resolve_upstream_error_profile", - "_TTLCache", -] - -_EXPORTED_COMPAT_SYMBOLS = (BlockType, _await_stream_terminal_signal) - def _emit_metric( name: str, @@ -197,6 +142,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non trace_id = call_context.state.get("trace_id") if call_context else None streaming_request = self._should_stream(context) + requested_extensions = requested_extensions_from_call_context(context.call_context) accepted_output_modes = normalize_accepted_output_modes(context.configuration) message_parts = ( getattr(context.message, "parts", None) if context.message is not None else None @@ -291,6 +237,14 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non accepted_output_modes, _APPLICATION_JSON_MEDIA_TYPE, ) + emit_session_metadata = bool( + { + SESSION_BINDING_EXTENSION_URI, + STREAMING_EXTENSION_URI, + } + & set(requested_extensions) + ) + emit_streaming_metadata = STREAMING_EXTENSION_URI in requested_extensions logger.debug( ( @@ -321,6 +275,8 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non workspace_id=workspace_id, session_binding_context_id=session_binding_context_id, allow_structured_output=allow_structured_output, + emit_session_metadata=emit_session_metadata, + emit_streaming_metadata=emit_streaming_metadata, ) coordinator = ExecutionCoordinator( self, @@ -473,7 +429,7 @@ async def _emit_error( error_message = Message( message_id=str(uuid.uuid4()), role=Role.ROLE_AGENT, - parts=[make_text_part(message)], + parts=[Part(text=message)], task_id=task_id, context_id=context_id, ) @@ -491,7 +447,7 @@ async def _emit_error( task_id=task_id, context_id=context_id, artifact_id=f"{task_id}:error", - part=make_text_part(message), + part=Part(text=message), append=False, last_chunk=True, ) @@ -523,7 +479,7 @@ def _should_stream(self, context: RequestContext) -> bool: return True # JSON-RPC transport sets method in call context state. method = call_context.state.get("method") - return method == "message/stream" + return method == "SendStreamingMessage" async def _consume_opencode_stream( self, @@ -540,6 +496,8 @@ async def _consume_opencode_stream( directory: str | None = None, workspace_id: str | None = None, allow_structured_output: bool = True, + emit_session_metadata: bool = True, + emit_streaming_metadata: bool = True, ) -> None: await self._stream_runtime.consume( session_id=session_id, @@ -554,4 +512,6 @@ async def _consume_opencode_stream( directory=directory, workspace_id=workspace_id, allow_structured_output=allow_structured_output, + emit_session_metadata=emit_session_metadata, + emit_streaming_metadata=emit_streaming_metadata, ) diff --git a/src/opencode_a2a/execution/stream_runtime.py b/src/opencode_a2a/execution/stream_runtime.py index 2cd8c91..39dcb49 100644 --- a/src/opencode_a2a/execution/stream_runtime.py +++ b/src/opencode_a2a/execution/stream_runtime.py @@ -9,12 +9,13 @@ from a2a.server.events.event_queue import EventQueue from a2a.types import ( + Part, TaskState, TaskStatus, TaskStatusUpdateEvent, ) -from ..a2a_utils import make_data_part, make_text_part, part_kind, part_text, part_text_fallback +from ..a2a_utils import make_data_part from ..invocation import call_with_supported_kwargs from .event_helpers import _enqueue_artifact_update from .stream_events import ( @@ -76,6 +77,8 @@ async def consume( directory: str | None = None, workspace_id: str | None = None, allow_structured_output: bool = True, + emit_session_metadata: bool = True, + emit_streaming_metadata: bool = True, ) -> None: part_states: dict[str, _StreamPartState] = {} pending_deltas: defaultdict[str, list[_PendingDelta]] = defaultdict(list) @@ -84,23 +87,8 @@ async def consume( async def _emit_chunks(chunks: list[_NormalizedStreamChunk]) -> None: for chunk in chunks: - if not allow_structured_output and part_kind(chunk.part) == "data": - fallback_text = part_text_fallback(chunk.part) - if fallback_text is None: - continue - chunk = _NormalizedStreamChunk( - part=make_text_part(fallback_text), - content_key=fallback_text, - accumulate_content=False, - append=chunk.append, - block_type=chunk.block_type, - internal_source=chunk.internal_source, - shared_source=chunk.shared_source, - message_id=chunk.message_id, - role=chunk.role, - ) resolved_message_id = stream_state.resolve_message_id(chunk.message_id) - chunk_text = part_text(chunk.part) or "" + chunk_text = chunk.part.text if chunk.part.HasField("text") else "" if stream_state.should_drop_initial_user_echo( chunk_text, block_type=chunk.block_type, @@ -131,6 +119,7 @@ async def _emit_chunks(chunks: list[_NormalizedStreamChunk]) -> None: role=chunk.role, event_id=stream_state.build_event_id(sequence), sequence=sequence, + include_shared_stream_metadata=emit_streaming_metadata, ), ) logger.debug( @@ -180,6 +169,8 @@ async def _emit_interrupt_status( "sequence": sequence, }, interrupt=interrupt_metadata, + include_session_metadata=emit_session_metadata, + include_streaming_metadata=emit_streaming_metadata, ), ) ) @@ -208,6 +199,8 @@ async def _emit_progress_status( "sequence": sequence, }, progress=dict(progress), + include_session_metadata=emit_session_metadata, + include_streaming_metadata=emit_streaming_metadata, ), ) ) @@ -223,7 +216,7 @@ def _new_text_chunk( role: str | None, ) -> _NormalizedStreamChunk: return _NormalizedStreamChunk( - part=make_text_part(text), + part=Part(text=text), content_key=text, accumulate_content=True, append=append, @@ -245,8 +238,18 @@ def _new_data_chunk( message_id: str | None, role: str | None, ) -> _NormalizedStreamChunk: + part = make_data_part(data) + if not allow_structured_output: + fallback_text = json.dumps( + data, + ensure_ascii=True, + sort_keys=True, + separators=(",", ":"), + ) + part = Part(text=fallback_text) + content_key = fallback_text return _NormalizedStreamChunk( - part=make_data_part(dict(data)), + part=part, content_key=content_key, accumulate_content=False, append=append, diff --git a/src/opencode_a2a/execution/stream_state.py b/src/opencode_a2a/execution/stream_state.py index 0cbf671..c1ab899 100644 --- a/src/opencode_a2a/execution/stream_state.py +++ b/src/opencode_a2a/execution/stream_state.py @@ -230,7 +230,10 @@ def _build_stream_artifact_metadata( role: str | None = None, event_id: str | None = None, sequence: int | None = None, -) -> dict[str, Any]: + include_shared_stream_metadata: bool = True, +) -> dict[str, Any] | None: + if not include_shared_stream_metadata: + return None stream_meta: dict[str, Any] = { "block_type": block_type.value, "source": shared_source, @@ -255,22 +258,24 @@ def _build_output_metadata( progress: Mapping[str, Any] | None = None, interrupt: Mapping[str, Any] | None = None, opencode_private: Mapping[str, Any] | None = None, + include_session_metadata: bool = True, + include_streaming_metadata: bool = True, ) -> dict[str, Any] | None: metadata: dict[str, Any] = {} shared_meta: dict[str, Any] = {} - if session_id: + if include_session_metadata and session_id: session_meta: dict[str, Any] = {"id": session_id} if session_title is not None: session_meta["title"] = session_title shared_meta["session"] = session_meta - if usage is not None: + if include_streaming_metadata and usage is not None: shared_meta["usage"] = dict(usage) - if stream is not None: + if include_streaming_metadata and stream is not None: shared_meta["stream"] = dict(stream) - if progress is not None: + if include_streaming_metadata and progress is not None: shared_meta["progress"] = dict(progress) - if interrupt is not None: + if include_streaming_metadata and interrupt is not None: shared_meta["interrupt"] = dict(interrupt) if shared_meta: metadata["shared"] = shared_meta diff --git a/src/opencode_a2a/execution/tool_error_mapping.py b/src/opencode_a2a/execution/tool_error_mapping.py index 6e79e15..b8ce481 100644 --- a/src/opencode_a2a/execution/tool_error_mapping.py +++ b/src/opencode_a2a/execution/tool_error_mapping.py @@ -49,7 +49,7 @@ def build_tool_error( def map_a2a_tool_exception(exc: Exception) -> dict[str, Any]: if isinstance(exc, (A2AError, SDKClientError, httpx.TimeoutException, httpx.TransportError)): - return map_a2a_tool_exception(map_operation_error("message/send", exc)) + return map_a2a_tool_exception(map_operation_error("SendMessage", exc)) if isinstance(exc, A2AAuthenticationError): return _build_client_error_payload( exc, @@ -135,6 +135,3 @@ def _build_client_error_meta(exc: A2AClientError) -> dict[str, Any] | None: if exc.code is not None: error_meta["rpc_code"] = exc.code return error_meta or None - - -__all__ = ["build_tool_error", "map_a2a_tool_exception"] diff --git a/src/opencode_a2a/execution/tool_orchestration.py b/src/opencode_a2a/execution/tool_orchestration.py index da1f05e..67b6cf9 100644 --- a/src/opencode_a2a/execution/tool_orchestration.py +++ b/src/opencode_a2a/execution/tool_orchestration.py @@ -4,6 +4,8 @@ import uuid from typing import Any +from a2a.types import StreamResponse, TaskState + from ..client.payload_text import extract_text from .tool_error_mapping import build_tool_error, map_a2a_tool_exception @@ -80,7 +82,7 @@ async def handle_a2a_call_tool( } try: - event = None + event: StreamResponse | None = None result_text = "" async with a2a_client_manager.borrow_client(agent_url) as client: async for current_event in client.send_message(message): @@ -89,8 +91,6 @@ async def handle_a2a_call_tool( if extracted: result_text = merge_streamed_tool_output(result_text, extracted) - from a2a.types import Task - if result_text: return { "call_id": call_id, @@ -98,25 +98,16 @@ async def handle_a2a_call_tool( "output": result_text, } - if isinstance(event, Task): - result_text = "" - if event.status and event.status.message: - for part_obj in event.status.message.parts: - root = getattr(part_obj, "root", part_obj) - text_val = getattr(root, "text", "") - if text_val: - result_text += str(text_val) - return { - "call_id": call_id, - "tool": tool_name, - "output": result_text or "Task completed.", - } - - if isinstance(event, tuple) and len(event) > 0 and isinstance(event[0], Task): + if ( + event is not None + and event.HasField("task") + and event.task.HasField("status") + and event.task.status.state == TaskState.TASK_STATE_COMPLETED + ): return { "call_id": call_id, "tool": tool_name, - "output": "Task completed (streaming).", + "output": "Task completed.", } return { diff --git a/src/opencode_a2a/extension_negotiation.py b/src/opencode_a2a/extension_negotiation.py new file mode 100644 index 0000000..33a144e --- /dev/null +++ b/src/opencode_a2a/extension_negotiation.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from typing import Any + +from a2a.server.context import ServerCallContext +from a2a.types import Artifact, Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent +from google.protobuf.json_format import MessageToDict +from google.protobuf.message import Message as ProtoMessage + +from .a2a_utils import clone_proto +from .contracts.extensions import ( + INTERRUPT_CALLBACK_EXTENSION_URI, + INTERRUPT_CALLBACK_METHODS, + INTERRUPT_RECOVERY_EXTENSION_URI, + INTERRUPT_RECOVERY_METHODS, + MODEL_SELECTION_EXTENSION_URI, + PROVIDER_DISCOVERY_EXTENSION_URI, + PROVIDER_DISCOVERY_METHODS, + SESSION_BINDING_EXTENSION_URI, + SESSION_MANAGEMENT_EXTENSION_URI, + SESSION_METHODS, + STREAMING_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, + WORKSPACE_CONTROL_METHODS, +) + +_STREAMING_SHARED_METADATA_KEYS = frozenset({"stream", "progress", "interrupt", "usage"}) + +JSONRPC_EXTENSION_URI_BY_METHOD: dict[str, str] = { + **{method: SESSION_MANAGEMENT_EXTENSION_URI for method in SESSION_METHODS.values()}, + **{method: PROVIDER_DISCOVERY_EXTENSION_URI for method in PROVIDER_DISCOVERY_METHODS.values()}, + **{method: INTERRUPT_RECOVERY_EXTENSION_URI for method in INTERRUPT_RECOVERY_METHODS.values()}, + **{method: INTERRUPT_CALLBACK_EXTENSION_URI for method in INTERRUPT_CALLBACK_METHODS.values()}, + **{method: WORKSPACE_CONTROL_EXTENSION_URI for method in WORKSPACE_CONTROL_METHODS.values()}, +} + + +@dataclass(frozen=True) +class ExtensionRequirement: + extension_uri: str + field: str + + +def requested_extensions_from_call_context( + call_context: ServerCallContext | None, +) -> frozenset[str]: + if call_context is None: + return frozenset() + return frozenset(value.strip() for value in call_context.requested_extensions if value.strip()) + + +def filter_negotiated_extensions_from_payload( + payload: Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + requested_extensions: Iterable[str], +) -> Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent: + requested = frozenset(value for value in requested_extensions if value) + if isinstance(payload, Task): + updated = clone_proto(payload) + _set_filtered_metadata(updated, requested) + if updated.status.HasField("message"): + _set_filtered_metadata(updated.status.message, requested) + for history_item in updated.history: + _set_filtered_metadata(history_item, requested) + for artifact in updated.artifacts: + _set_filtered_metadata(artifact, requested) + return updated + if isinstance(payload, TaskStatusUpdateEvent): + updated = clone_proto(payload) + _set_filtered_metadata(updated, requested) + if updated.status.HasField("message"): + _set_filtered_metadata(updated.status.message, requested) + return updated + if isinstance(payload, TaskArtifactUpdateEvent): + updated = clone_proto(payload) + _set_filtered_metadata(updated.artifact, requested) + return updated + return payload + + +def _set_filtered_metadata( + proto: Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Artifact | Message, + requested_extensions: frozenset[str], +) -> None: + metadata = getattr(proto, "metadata", None) + metadata_dict: dict[str, Any] | None + if isinstance(metadata, ProtoMessage): + metadata_dict = MessageToDict(metadata, preserving_proto_field_name=True) + elif isinstance(metadata, Mapping): + metadata_dict = dict(metadata) + else: + metadata_dict = None + if not metadata_dict: + proto.ClearField("metadata") + return + filtered_metadata: dict[str, Any] = dict(metadata_dict) + shared_metadata = filtered_metadata.get("shared") + if isinstance(shared_metadata, Mapping): + filtered_shared = dict(shared_metadata) + if ( + SESSION_BINDING_EXTENSION_URI not in requested_extensions + and STREAMING_EXTENSION_URI not in requested_extensions + ): + filtered_shared.pop("session", None) + if MODEL_SELECTION_EXTENSION_URI not in requested_extensions: + filtered_shared.pop("model", None) + if STREAMING_EXTENSION_URI not in requested_extensions: + for key in _STREAMING_SHARED_METADATA_KEYS: + filtered_shared.pop(key, None) + if filtered_shared: + filtered_metadata["shared"] = filtered_shared + else: + filtered_metadata.pop("shared", None) + proto.ClearField("metadata") + if filtered_metadata: + proto.metadata.update(filtered_metadata) diff --git a/src/opencode_a2a/jsonrpc/application.py b/src/opencode_a2a/jsonrpc/application.py index 41d384f..650e92e 100644 --- a/src/opencode_a2a/jsonrpc/application.py +++ b/src/opencode_a2a/jsonrpc/application.py @@ -1,9 +1,8 @@ from __future__ import annotations import logging -from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping +from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import replace -from functools import partial from typing import Any, cast from a2a.server.events import Event @@ -18,30 +17,18 @@ from starlette.requests import Request from starlette.responses import Response -from ..a2a_protocol import LEGACY_JSONRPC_METHOD_TO_V1_METHOD, V1_JSONRPC_METHOD_TO_LEGACY_METHOD +from ..extension_negotiation import ( + requested_extensions_from_call_context, +) from ..opencode_upstream_client import OpencodeUpstreamClient from .dispatch import ( ExtensionHandlerContext, build_extension_method_registry, ) from .error_responses import ( - adapt_jsonrpc_error_for_protocol, + adapt_jsonrpc_error, invalid_params_error, method_not_supported_error, - protocol_uses_v1_error_format, -) -from .methods import ( - SESSION_CONTEXT_PREFIX, - _extract_provider_catalog, - _normalize_model_summaries, - _normalize_permission_reply, - _normalize_provider_summaries, - _parse_question_answers, - _PromptAsyncValidationError, - _validate_command_request_payload, - _validate_prompt_async_format, - _validate_prompt_async_part, - _validate_shell_request_payload, ) from .models import JSONRPCError, JSONRPCRequest @@ -55,87 +42,6 @@ } ) -__all__ = [ - "SESSION_CONTEXT_PREFIX", - "_extract_provider_catalog", - "_normalize_model_summaries", - "_normalize_permission_reply", - "_normalize_provider_summaries", - "_parse_question_answers", - "_PromptAsyncValidationError", - "_validate_command_request_payload", - "_validate_prompt_async_format", - "_validate_prompt_async_part", - "_validate_shell_request_payload", -] - - -def _normalize_core_message_role(value: Any) -> Any: - if not isinstance(value, str): - return value - normalized = value.strip().lower() - if normalized == "user": - return "ROLE_USER" - if normalized == "agent": - return "ROLE_AGENT" - return value - - -def _normalize_core_message_part(part: Any) -> Any: - if not isinstance(part, Mapping): - return part - - normalized = dict(part) - kind = normalized.pop("kind", normalized.pop("type", None)) - if kind in {None, "text", "data"}: - return normalized - - if kind != "file": - return normalized - - file_value = normalized.pop("file", None) - if isinstance(file_value, Mapping): - mapped = dict(normalized) - raw_value = file_value.get("bytes") - url_value = file_value.get("uri") - if isinstance(raw_value, str) and raw_value: - mapped["raw"] = raw_value - elif isinstance(url_value, str) and url_value: - mapped["url"] = url_value - filename = file_value.get("name") - if isinstance(filename, str) and filename.strip(): - mapped["filename"] = filename - media_type = ( - file_value.get("mimeType") or file_value.get("mime_type") or file_value.get("mediaType") - ) - if isinstance(media_type, str) and media_type.strip(): - mapped["mediaType"] = media_type - return mapped - - return normalized - - -def _normalize_core_message_payload(message: Any) -> Any: - if not isinstance(message, Mapping): - return message - - normalized = dict(message) - normalized["role"] = _normalize_core_message_role(normalized.get("role")) - parts = normalized.get("parts") - if isinstance(parts, list): - normalized["parts"] = [_normalize_core_message_part(part) for part in parts] - return normalized - - -def _normalize_core_request_params(method: str, params: Any) -> Any: - if not isinstance(params, Mapping): - return params - - normalized = dict(params) - if method in {"SendMessage", "SendStreamingMessage"}: - normalized["message"] = _normalize_core_message_payload(normalized.get("message")) - return normalized - class OpencodeSessionManagementJSONRPCApplication(JsonRpcDispatcher): """Dispatch OpenCode extension methods on top of the SDK JSON-RPC surface.""" @@ -146,7 +52,6 @@ def __init__( http_handler, upstream_client: OpencodeUpstreamClient, methods: dict[str, str], - protocol_version: str, supported_methods: list[str], directory_resolver: Callable[[str | None], str | None] | None = None, session_claim: Callable[..., Awaitable[bool]] | None = None, @@ -190,7 +95,6 @@ def __init__( self._method_reply_permission = methods["reply_permission"] self._method_reply_question = methods["reply_question"] self._method_reject_question = methods["reject_question"] - self._protocol_version = protocol_version self._supported_methods = list(supported_methods) missing_control_hooks = [ name @@ -245,7 +149,6 @@ def __init__( method_reply_permission=self._method_reply_permission, method_reply_question=self._method_reply_question, method_reject_question=self._method_reject_question, - protocol_version=self._protocol_version, supported_methods=tuple(self._supported_methods), directory_resolver=self._directory_resolver, session_claim=self._session_claim, @@ -271,10 +174,8 @@ def _generate_protocol_error_response( self, request_id: str | int | None, error: JSONRPCError | A2AError, - *, - protocol_version: str, ) -> JSONResponse: - adapted = adapt_jsonrpc_error_for_protocol(protocol_version, error) + adapted = adapt_jsonrpc_error(error) if isinstance(adapted, A2AError): error_payload = { "code": JSON_RPC_ERROR_CODE_MAP.get(type(adapted), -32603), @@ -370,33 +271,12 @@ async def _handle_core_request( request: Request, body: dict[str, Any], base_request: JSONRPCRequest, - *, - protocol_version: str, ) -> Response: - if ( - base_request.method in V1_JSONRPC_METHOD_TO_LEGACY_METHOD - and not protocol_uses_v1_error_format(protocol_version) - ): - if base_request.id is None: - return Response(status_code=204) - return self._generate_protocol_error_response( - base_request.id, - method_not_supported_error( - method=base_request.method, - supported_methods=self._supported_methods, - protocol_version=protocol_version, - ), - protocol_version=protocol_version, - ) - canonical_method = LEGACY_JSONRPC_METHOD_TO_V1_METHOD.get( - base_request.method, - base_request.method, - ) + canonical_method = base_request.method if canonical_method in _PUSH_NOTIFICATION_METHODS: return self._generate_protocol_error_response( base_request.id, UnsupportedOperationError(), - protocol_version=protocol_version, ) if canonical_method == "GetExtendedAgentCard": if base_request.id is None: @@ -408,7 +288,6 @@ async def _handle_core_request( UnsupportedOperationError( message="The agent does not support authenticated extended cards" ), - protocol_version=protocol_version, ) return self._jsonrpc_success_response( base_request.id, @@ -423,20 +302,16 @@ async def _handle_core_request( method_not_supported_error( method=base_request.method, supported_methods=self._supported_methods, - protocol_version=protocol_version, ), - protocol_version=protocol_version, ) try: params = body.get("params", {}) - normalized_params = _normalize_core_request_params(canonical_method, params) - specific_request = ParseDict(normalized_params, model_class()) + specific_request = ParseDict(params, model_class()) except Exception as exc: return self._generate_protocol_error_response( base_request.id, invalid_params_error(str(exc)), - protocol_version=protocol_version, ) call_context = self._context_builder.build(request) @@ -460,16 +335,10 @@ async def _handle_core_request( return self._generate_protocol_error_response( base_request.id, exc, - protocol_version=protocol_version, ) async def handle_requests(self, request: Request) -> Response: request_id: str | int | None = None - negotiated_protocol_version = getattr( - request.state, - "a2a_protocol_version", - self._protocol_version, - ) try: body = await request.json() if isinstance(body, dict): @@ -486,7 +355,26 @@ async def handle_requests(self, request: Request) -> Response: request, body, base_request, - protocol_version=negotiated_protocol_version, + ) + + call_context = self._context_builder.build(request) + requested_extensions = requested_extensions_from_call_context(call_context) + if extension_spec.extension_uri not in requested_extensions: + return self._generate_protocol_error_response( + base_request.id, + UnsupportedOperationError( + message=( + f"Method {base_request.method} requires explicit A2A extension " + "negotiation via the A2A-Extensions header." + ), + data={ + "type": "EXTENSION_NEGOTIATION_REQUIRED", + "method": base_request.method, + "required_extensions": [extension_spec.extension_uri], + "requested_extensions": sorted(requested_extensions), + "header": "A2A-Extensions", + }, + ), ) params = base_request.params or {} @@ -494,15 +382,10 @@ async def handle_requests(self, request: Request) -> Response: return self._generate_protocol_error_response( base_request.id, invalid_params_error("params must be an object"), - protocol_version=negotiated_protocol_version, ) request_context = replace( self._extension_handler_context, - protocol_version=negotiated_protocol_version, - error_response=partial( - self._generate_protocol_error_response, - protocol_version=negotiated_protocol_version, - ), + error_response=self._generate_protocol_error_response, ) return await extension_spec.handler( request_context, diff --git a/src/opencode_a2a/jsonrpc/dispatch.py b/src/opencode_a2a/jsonrpc/dispatch.py index 075dbca..91e9872 100644 --- a/src/opencode_a2a/jsonrpc/dispatch.py +++ b/src/opencode_a2a/jsonrpc/dispatch.py @@ -4,22 +4,25 @@ from dataclasses import dataclass from typing import Any, TypeAlias -from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher from a2a.utils.errors import A2AError from fastapi.responses import JSONResponse from starlette.requests import Request from starlette.responses import Response -from ..a2a_protocol import V1_JSONRPC_METHOD_TO_LEGACY_METHOD +from ..a2a_protocol import CORE_JSONRPC_METHODS as DECLARED_CORE_JSONRPC_METHODS +from ..contracts.extensions import ( + INTERRUPT_CALLBACK_EXTENSION_URI, + INTERRUPT_RECOVERY_EXTENSION_URI, + PROVIDER_DISCOVERY_EXTENSION_URI, + SESSION_MANAGEMENT_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, +) from ..opencode_upstream_client import OpencodeUpstreamClient from .models import JSONRPCError, JSONRPCRequest # Delegate all SDK-owned JSON-RPC methods to the base app, then let the local # extension registry override only the OpenCode-specific methods. -CORE_JSONRPC_METHODS = frozenset( - V1_JSONRPC_METHOD_TO_LEGACY_METHOD.get(method, method) - for method in JsonRpcDispatcher.METHOD_TO_MODEL -) +CORE_JSONRPC_METHODS = frozenset(DECLARED_CORE_JSONRPC_METHODS) ErrorResponseFactory: TypeAlias = Callable[[str | int | None, JSONRPCError | A2AError], Response] SuccessResponseFactory: TypeAlias = Callable[[str | int, Any], JSONResponse] @@ -68,7 +71,6 @@ class ExtensionHandlerContext: method_reply_permission: str method_reply_question: str method_reject_question: str - protocol_version: str supported_methods: tuple[str, ...] directory_resolver: Callable[[str | None], str | None] session_claim: SessionClaimFunc @@ -82,6 +84,7 @@ class ExtensionHandlerContext: class ExtensionMethodSpec: name: str methods: frozenset[str] + extension_uri: str handler: ExtensionHandlerFunc @@ -163,6 +166,7 @@ def build_extension_method_registry( ExtensionMethodSpec( name="session_lifecycle", methods=frozenset(session_item_methods), + extension_uri=SESSION_MANAGEMENT_EXTENSION_URI, handler=handle_session_lifecycle_request, ), ExtensionMethodSpec( @@ -173,6 +177,7 @@ def build_extension_method_registry( context.method_get_session_messages, } ), + extension_uri=SESSION_MANAGEMENT_EXTENSION_URI, handler=handle_session_query_request, ), ExtensionMethodSpec( @@ -183,6 +188,7 @@ def build_extension_method_registry( context.method_list_models, } ), + extension_uri=PROVIDER_DISCOVERY_EXTENSION_URI, handler=handle_provider_discovery_request, ), ExtensionMethodSpec( @@ -193,16 +199,19 @@ def build_extension_method_registry( context.method_list_questions, } ), + extension_uri=INTERRUPT_RECOVERY_EXTENSION_URI, handler=handle_interrupt_query_request, ), ExtensionMethodSpec( name="workspace_control", methods=frozenset(workspace_control_methods), + extension_uri=WORKSPACE_CONTROL_EXTENSION_URI, handler=handle_workspace_control_request, ), ExtensionMethodSpec( name="session_control", methods=frozenset(session_control_methods), + extension_uri=SESSION_MANAGEMENT_EXTENSION_URI, handler=handle_session_control_request, ), ExtensionMethodSpec( @@ -214,6 +223,7 @@ def build_extension_method_registry( context.method_reject_question, } ), + extension_uri=INTERRUPT_CALLBACK_EXTENSION_URI, handler=handle_interrupt_callback_request, ), ) diff --git a/src/opencode_a2a/jsonrpc/error_responses.py b/src/opencode_a2a/jsonrpc/error_responses.py index 0fc328a..df8a736 100644 --- a/src/opencode_a2a/jsonrpc/error_responses.py +++ b/src/opencode_a2a/jsonrpc/error_responses.py @@ -4,9 +4,13 @@ from collections.abc import Mapping from typing import Any -from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, A2AError, InvalidParamsError +from a2a.utils.errors import ( + JSON_RPC_ERROR_CODE_MAP, + A2AError, + InvalidParamsError, +) -from ..protocol_versions import normalize_protocol_version +from ..protocol_versions import A2A_PROTOCOL_VERSION from .models import JSONRPCError A2A_ERROR_DOMAIN = "a2a-protocol.org" @@ -21,12 +25,6 @@ STANDARD_JSONRPC_ERROR_CODES = frozenset(STANDARD_JSONRPC_ERROR_MESSAGES) -def protocol_uses_v1_error_format(protocol_version: str | None) -> bool: - if protocol_version is None: - return False - return normalize_protocol_version(protocol_version).startswith("1.") - - def _to_upper_snake_case(name: str) -> str: normalized: list[str] = [] previous_was_lower = False @@ -110,13 +108,7 @@ def _metadata_from_error(error: object) -> dict[str, Any]: return {str(key): value for key, value in data.items() if key != "type"} -def adapt_jsonrpc_error_for_protocol( - protocol_version: str, - error: JSONRPCError | A2AError, -) -> JSONRPCError | A2AError: - if not protocol_uses_v1_error_format(protocol_version): - return error - +def adapt_jsonrpc_error(error: JSONRPCError | A2AError) -> JSONRPCError | A2AError: reason_source: object = error if isinstance(error, A2AError): root_error = JSONRPCError( @@ -164,17 +156,12 @@ def adapt_jsonrpc_error_for_protocol( def build_http_error_body( *, - protocol_version: str, status_code: int, status: str, message: str, - legacy_payload: dict[str, Any], reason: str | None = None, metadata: Mapping[str, Any] | None = None, ) -> dict[str, Any]: - if not protocol_uses_v1_error_format(protocol_version): - return legacy_payload - details: list[dict[str, Any]] = [] if reason is not None: details.append(_build_error_info_detail(reason=reason, metadata=metadata)) @@ -203,7 +190,6 @@ def method_not_supported_error( *, method: str, supported_methods: list[str], - protocol_version: str, ) -> JSONRPCError: return JSONRPCError( code=-32601, @@ -212,7 +198,7 @@ def method_not_supported_error( "type": "METHOD_NOT_SUPPORTED", "method": method, "supported_methods": supported_methods, - "protocol_version": protocol_version, + "protocol_version": A2A_PROTOCOL_VERSION, }, ) @@ -370,23 +356,3 @@ def upstream_payload_error( if request_id is not None: data["request_id"] = request_id return JSONRPCError(code=code, message="Upstream OpenCode payload mismatch", data=data) - - -__all__ = [ - "A2A_ERROR_DOMAIN", - "GOOGLE_RPC_ERROR_INFO_TYPE", - "adapt_jsonrpc_error_for_protocol", - "authorization_forbidden_error", - "build_http_error_body", - "interrupt_not_found_error", - "interrupt_type_mismatch_error", - "invalid_params_error", - "method_not_supported_error", - "protocol_uses_v1_error_format", - "session_forbidden_error", - "session_not_found_error", - "upstream_http_error", - "upstream_payload_error", - "upstream_unreachable_error", - "version_not_supported_error", -] diff --git a/src/opencode_a2a/jsonrpc/methods.py b/src/opencode_a2a/jsonrpc/methods.py index 251e1ab..fb538c4 100644 --- a/src/opencode_a2a/jsonrpc/methods.py +++ b/src/opencode_a2a/jsonrpc/methods.py @@ -2,9 +2,9 @@ from typing import Any, cast -from a2a.types import Message, Role, Task, TaskState, TaskStatus +from a2a.types import Message, Part, Role, Task, TaskState, TaskStatus +from google.protobuf.json_format import MessageToDict -from ..a2a_utils import make_text_part, proto_to_dict from ..contracts.extensions import ( COMMAND_REQUEST_ALLOWED_FIELDS, PROMPT_ASYNC_REQUEST_ALLOWED_FIELDS, @@ -15,12 +15,6 @@ SESSION_CONTEXT_PREFIX = "ctx:opencode-session:" -def _jsonrpc_role_name(role: Role) -> str: - if role == Role.ROLE_USER: - return "user" - return "agent" - - class _PromptAsyncValidationError(ValueError): def __init__(self, *, field: str, message: str) -> None: super().__init__(message) @@ -350,7 +344,7 @@ def _as_a2a_session_task(session: Any) -> dict[str, Any] | None: status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), metadata={"shared": {"session": {"id": session_id, "title": title}}}, ) - return proto_to_dict(task) + return cast(dict[str, Any], MessageToDict(task)) def _as_a2a_message(session_id: str, item: Any) -> dict[str, Any] | None: @@ -383,13 +377,11 @@ def _as_a2a_message(session_id: str, item: Any) -> dict[str, Any] | None: msg = Message( message_id=message_id, role=role, - parts=[make_text_part(text)], + parts=[Part(text=text)], context_id=context_id, metadata={"shared": {"session": {"id": session_id}}}, ) - message = proto_to_dict(msg) - message["role"] = _jsonrpc_role_name(role) - return message + return cast(dict[str, Any], MessageToDict(msg)) def _extract_raw_items(raw_result: Any, *, kind: str) -> list[Any]: diff --git a/src/opencode_a2a/metadata_access.py b/src/opencode_a2a/metadata_access.py index 9804fe9..6ea65b2 100644 --- a/src/opencode_a2a/metadata_access.py +++ b/src/opencode_a2a/metadata_access.py @@ -3,10 +3,9 @@ from collections.abc import Iterable, Mapping from typing import Any +from google.protobuf.json_format import MessageToDict from google.protobuf.message import Message as ProtoMessage -from .a2a_utils import proto_to_dict - def extract_namespaced_value( source: Mapping[str, Any] | None, @@ -49,7 +48,7 @@ def extract_first_namespaced_string( def _normalize_mapping(value: Any) -> Mapping[str, Any] | None: if isinstance(value, ProtoMessage): - normalized = proto_to_dict(value) + normalized = MessageToDict(value) return normalized if isinstance(normalized, Mapping) else None if isinstance(value, Mapping): try: diff --git a/src/opencode_a2a/output_modes.py b/src/opencode_a2a/output_modes.py index 844b338..b5ef97a 100644 --- a/src/opencode_a2a/output_modes.py +++ b/src/opencode_a2a/output_modes.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json from collections.abc import Collection, Iterable, Mapping from typing import Any, cast @@ -15,23 +16,16 @@ TaskState, TaskStatusUpdateEvent, ) +from google.protobuf.json_format import MessageToDict from google.protobuf.message import Message as ProtoMessage from .a2a_utils import ( clone_proto, - make_text_part, - part_is_data, - part_is_file, - part_is_text, - proto_to_dict, replace_artifact_event_artifact, replace_artifact_parts, replace_message_parts, replace_status_event_message, ) -from .a2a_utils import ( - part_text_fallback as _part_text_fallback, -) OUTPUT_NEGOTIATION_METADATA_KEY = "output_negotiation" OUTPUT_NEGOTIATION_ACCEPTED_OUTPUT_MODES_FIELD = "accepted_output_modes" @@ -82,12 +76,6 @@ def accepts_output_mode( return accepted_output_modes is None or media_type in accepted_output_modes -def part_text_fallback(part: Any) -> str | None: - if isinstance(part, Part): - return _part_text_fallback(part) - return None - - def build_output_negotiation_metadata( accepted_output_modes: Iterable[str] | None, ) -> dict[str, Any] | None: @@ -105,7 +93,7 @@ def build_output_negotiation_metadata( def _normalize_metadata_mapping(metadata: Any) -> dict[str, Any]: if isinstance(metadata, ProtoMessage): - normalized = proto_to_dict(metadata) + normalized = MessageToDict(metadata) return normalized if isinstance(normalized, dict) else {} if isinstance(metadata, Mapping): return dict(metadata) @@ -394,17 +382,28 @@ def _filter_parts( filtered.append(part) continue if accepts_output_mode(accepted_output_modes, _TEXT_PLAIN_MEDIA_TYPE): - fallback_text = part_text_fallback(part) - if fallback_text is not None: - filtered.append(make_text_part(fallback_text)) + if part.HasField("text"): + filtered.append(Part(text=part.text)) + continue + if part.HasField("data"): + filtered.append( + Part( + text=json.dumps( + MessageToDict(part.data), + ensure_ascii=True, + sort_keys=True, + separators=(",", ":"), + ) + ) + ) return filtered def _part_media_type(part: Part) -> str | None: - if part_is_text(part): + if part.HasField("text"): return _TEXT_PLAIN_MEDIA_TYPE - if part_is_data(part): + if part.HasField("data"): return _APPLICATION_JSON_MEDIA_TYPE - if part_is_file(part): + if part.HasField("raw") or part.HasField("url"): return part.media_type or "application/octet-stream" return None diff --git a/src/opencode_a2a/parts/mapping.py b/src/opencode_a2a/parts/mapping.py index ced4a44..c60226f 100644 --- a/src/opencode_a2a/parts/mapping.py +++ b/src/opencode_a2a/parts/mapping.py @@ -2,7 +2,9 @@ import base64 from collections.abc import Sequence -from typing import Any, Literal, TypedDict +from typing import Literal, TypedDict + +from a2a.types import Part class UnsupportedA2AInputError(ValueError): @@ -24,37 +26,32 @@ class OpencodeFileInputPart(TypedDict, total=False): OpencodeInputPart = OpencodeTextInputPart | OpencodeFileInputPart -def extract_text_from_a2a_parts(parts: Any) -> str: - normalized_parts = _normalize_parts(parts) - if normalized_parts is None: +def extract_text_from_a2a_parts(parts: Sequence[Part] | None) -> str: + if not parts: return "" texts: list[str] = [] - for part in normalized_parts: - if _part_kind(part) != "text": - continue - text = _part_text_value(part) - if isinstance(text, str): - texts.append(text) + for part in parts: + if part.HasField("text") and part.text: + texts.append(part.text) return "\n".join(texts).strip() -def summarize_a2a_parts(parts: Any) -> str | None: +def summarize_a2a_parts(parts: Sequence[Part] | None) -> str | None: text = extract_text_from_a2a_parts(parts) if text: return text[:80] - normalized_parts = _normalize_parts(parts) - if normalized_parts is None: + if not parts: return None filenames: list[str] = [] - for part in normalized_parts: - if _part_kind(part) != "file": + for part in parts: + if not _is_file_part(part): continue - name = _part_filename(part) - if isinstance(name, str) and name.strip(): - filenames.append(name.strip()) + name = _normalize_string(part.filename) + if name: + filenames.append(name) else: filenames.append("file") @@ -65,54 +62,50 @@ def summarize_a2a_parts(parts: Any) -> str | None: return ", ".join(filenames[:3])[:80] -def map_a2a_parts_to_opencode_parts(parts: Any) -> list[OpencodeInputPart]: - normalized_parts = _normalize_parts(parts) - if normalized_parts is None: +def map_a2a_parts_to_opencode_parts(parts: Sequence[Part] | None) -> list[OpencodeInputPart]: + if not parts: return [] mapped: list[OpencodeInputPart] = [] - for index, part in enumerate(normalized_parts): - kind = _part_kind(part) - - if kind == "text": - text = _part_text_value(part) - if isinstance(text, str): - mapped.append({"type": "text", "text": text}) + for index, part in enumerate(parts): + if part.HasField("text") and part.text: + mapped.append({"type": "text", "text": part.text}) continue - if kind == "file": + if _is_file_part(part): mapped.append(_map_file_part(part, index=index)) continue - if kind == "data": + if part.HasField("data"): raise UnsupportedA2AInputError( - f"request.parts[{index}] DataPart input is not supported; use TextPart or FilePart." + "request.parts[" + f"{index}" + "] structured data is not supported; use text, raw, or url parts." ) raise UnsupportedA2AInputError( - f"request.parts[{index}] is not supported; only TextPart and FilePart are accepted." + f"request.parts[{index}] is not supported; only text, raw, or url parts are accepted." ) return mapped -def _map_file_part(part: Any, *, index: int) -> OpencodeFileInputPart: - raw_bytes = getattr(part, "raw", None) - url = _normalize_string(getattr(part, "url", None)) - if isinstance(raw_bytes, bytes) and raw_bytes: - mime = _normalize_string(getattr(part, "media_type", None)) or "application/octet-stream" - name = _normalize_string(getattr(part, "filename", None)) +def _map_file_part(part: Part, *, index: int) -> OpencodeFileInputPart: + url = _normalize_string(part.url) if part.HasField("url") else None + if part.HasField("raw") and part.raw: + mime = _normalize_string(part.media_type) or "application/octet-stream" + name = _normalize_string(part.filename) mapped: OpencodeFileInputPart = { "type": "file", - "url": f"data:{mime};base64,{base64.b64encode(raw_bytes).decode('ascii')}", + "url": f"data:{mime};base64,{base64.b64encode(part.raw).decode('ascii')}", "mime": mime, } if name: mapped["filename"] = name return mapped if url: - mime = _normalize_string(getattr(part, "media_type", None)) or "application/octet-stream" - name = _normalize_string(getattr(part, "filename", None)) + mime = _normalize_string(part.media_type) or "application/octet-stream" + name = _normalize_string(part.filename) mapped = { "type": "file", "url": url, @@ -122,108 +115,19 @@ def _map_file_part(part: Any, *, index: int) -> OpencodeFileInputPart: mapped["filename"] = name return mapped - root = _unwrap_part_root(part) - file_value = getattr(root, "file", None) - if file_value is None: - raise UnsupportedA2AInputError( - f"request.parts[{index}] FilePart is missing the file payload." - ) - - mime = ( - _normalize_string( - getattr(file_value, "mime_type", None) or getattr(file_value, "mimeType", None) - ) - or "application/octet-stream" + raise UnsupportedA2AInputError( + f"request.parts[{index}] file input must contain either raw bytes or a url." ) - name = _normalize_string(getattr(file_value, "name", None)) - - bytes_value = _normalize_string(getattr(file_value, "bytes", None)) - if bytes_value: - mapped_from_bytes: OpencodeFileInputPart = { - "type": "file", - "url": f"data:{mime};base64,{bytes_value}", - "mime": mime, - } - if name: - mapped_from_bytes["filename"] = name - return mapped_from_bytes - uri = _normalize_string(getattr(file_value, "uri", None)) - if uri: - mapped = { - "type": "file", - "url": uri, - "mime": mime, - } - if name: - mapped["filename"] = name - return mapped - raise UnsupportedA2AInputError( - f"request.parts[{index}] FilePart must contain either bytes or uri." - ) +def _is_file_part(part: Part) -> bool: + if part.HasField("raw") and part.raw: + return True + return part.HasField("url") and _normalize_string(part.url) is not None -def _unwrap_part_root(part: Any) -> Any: - root = getattr(part, "root", None) - if root is not None: - return root - return part - - -def _part_kind(part: Any) -> str | None: - if isinstance(getattr(part, "text", None), str) and getattr(part, "text", None): - return "text" - if isinstance(getattr(part, "raw", None), bytes) and getattr(part, "raw", None): - return "file" - if _normalize_string(getattr(part, "url", None)): - return "file" - data = getattr(part, "data", None) - which_oneof = getattr(data, "WhichOneof", None) - if callable(which_oneof) and which_oneof("kind") is not None: - return "data" - - root = _unwrap_part_root(part) - kind = getattr(root, "kind", None) - if isinstance(kind, str): - return kind - if isinstance(getattr(root, "text", None), str): - return "text" - if getattr(root, "file", None) is not None: - return "file" - if getattr(root, "data", None) is not None: - return "data" - return None - - -def _part_text_value(part: Any) -> str | None: - text = getattr(part, "text", None) - if isinstance(text, str): - return text - root = _unwrap_part_root(part) - root_text = getattr(root, "text", None) - if isinstance(root_text, str): - return root_text - return None - - -def _part_filename(part: Any) -> str | None: - filename = _normalize_string(getattr(part, "filename", None)) - if filename: - return filename - root = _unwrap_part_root(part) - file_value = getattr(root, "file", None) - return _normalize_string(getattr(file_value, "name", None)) - - -def _normalize_string(value: Any) -> str | None: +def _normalize_string(value: object) -> str | None: if not isinstance(value, str): return None normalized = value.strip() return normalized if normalized else None - - -def _normalize_parts(parts: Any) -> list[Any] | None: - if not isinstance(parts, Sequence) or isinstance(parts, str | bytes | bytearray): - return None - return list(parts) diff --git a/src/opencode_a2a/protocol_versions.py b/src/opencode_a2a/protocol_versions.py index 9991d2e..ec15355 100644 --- a/src/opencode_a2a/protocol_versions.py +++ b/src/opencode_a2a/protocol_versions.py @@ -1,37 +1,23 @@ from __future__ import annotations import re -from collections.abc import Iterable -from dataclasses import dataclass _PROTOCOL_VERSION_PATTERN = re.compile(r"^(?P\d+)\.(?P\d+)(?:\.\d+)?$") +A2A_PROTOCOL_VERSION = "1.0" +A2A_SUPPORTED_PROTOCOL_VERSIONS = (A2A_PROTOCOL_VERSION,) class UnsupportedProtocolVersionError(ValueError): - def __init__( - self, - requested_version: str, - *, - supported_protocol_versions: tuple[str, ...], - default_protocol_version: str, - ) -> None: + def __init__(self, requested_version: str) -> None: self.requested_version = requested_version - self.supported_protocol_versions = supported_protocol_versions - self.default_protocol_version = default_protocol_version - supported_display = ", ".join(supported_protocol_versions) + self.supported_protocol_versions = A2A_SUPPORTED_PROTOCOL_VERSIONS + self.default_protocol_version = A2A_PROTOCOL_VERSION super().__init__( f"Unsupported A2A protocol version {requested_version!r}. " - f"Supported versions: {supported_display}." + f"Supported versions: {A2A_PROTOCOL_VERSION}." ) -@dataclass(frozen=True) -class NegotiatedProtocolVersion: - requested_version: str - negotiated_version: str - explicit: bool - - def normalize_protocol_version(value: str) -> str: normalized = value.strip() if not normalized: @@ -42,62 +28,21 @@ def normalize_protocol_version(value: str) -> str: return f"{match.group('major')}.{match.group('minor')}" -def normalize_protocol_versions(values: Iterable[str]) -> tuple[str, ...]: - normalized_versions: list[str] = [] - seen: set[str] = set() - for value in values: - normalized = normalize_protocol_version(str(value)) - if normalized in seen: - continue - seen.add(normalized) - normalized_versions.append(normalized) - if not normalized_versions: - raise ValueError("At least one supported protocol version must be declared.") - return tuple(normalized_versions) - - def negotiate_protocol_version( *, header_value: str | None, query_value: str | None, - default_protocol_version: str, - supported_protocol_versions: Iterable[str], -) -> NegotiatedProtocolVersion: - normalized_default = normalize_protocol_version(default_protocol_version) - normalized_supported = normalize_protocol_versions(supported_protocol_versions) - +) -> str: raw_header = (header_value or "").strip() raw_query = (query_value or "").strip() - explicit = bool(raw_header or raw_query) - raw_requested = raw_header or raw_query or normalized_default + raw_requested = raw_header or raw_query or A2A_PROTOCOL_VERSION try: normalized_requested = normalize_protocol_version(raw_requested) except ValueError as exc: - raise UnsupportedProtocolVersionError( - raw_requested, - supported_protocol_versions=normalized_supported, - default_protocol_version=normalized_default, - ) from exc - - if normalized_requested not in normalized_supported: - raise UnsupportedProtocolVersionError( - normalized_requested, - supported_protocol_versions=normalized_supported, - default_protocol_version=normalized_default, - ) - - return NegotiatedProtocolVersion( - requested_version=normalized_requested, - negotiated_version=normalized_requested, - explicit=explicit, - ) + raise UnsupportedProtocolVersionError(raw_requested) from exc + if normalized_requested != A2A_PROTOCOL_VERSION: + raise UnsupportedProtocolVersionError(normalized_requested) -__all__ = [ - "NegotiatedProtocolVersion", - "UnsupportedProtocolVersionError", - "negotiate_protocol_version", - "normalize_protocol_version", - "normalize_protocol_versions", -] + return normalized_requested diff --git a/src/opencode_a2a/server/agent_card.py b/src/opencode_a2a/server/agent_card.py index 8df0662..1009891 100644 --- a/src/opencode_a2a/server/agent_card.py +++ b/src/opencode_a2a/server/agent_card.py @@ -40,8 +40,9 @@ build_wire_contract_params, build_workspace_control_extension_params, ) -from ..jsonrpc.application import SESSION_CONTEXT_PREFIX +from ..jsonrpc.methods import SESSION_CONTEXT_PREFIX from ..profile.runtime import RuntimeProfile, build_runtime_profile +from ..protocol_versions import A2A_PROTOCOL_VERSION _CHAT_INPUT_MODES = ["text/plain", "application/octet-stream"] _CHAT_OUTPUT_MODES = ["text/plain", "application/json"] @@ -119,9 +120,9 @@ def _build_agent_card_description( summary = ( "Supports HTTP+JSON and JSON-RPC transports, streaming-first A2A messaging " - "(message/send, message/stream), authenticated extended Agent Card " - "(agent/getAuthenticatedExtendedCard), task APIs (tasks/get, tasks/cancel, " - "tasks/resubscribe; SDK-owned push notification config surfaces remain " + "(SendMessage, SendStreamingMessage), authenticated extended Agent Card " + "(GetExtendedAgentCard), task APIs (GetTask, CancelTask, " + "SubscribeToTask; SDK-owned push notification config surfaces remain " "exposed but currently return unsupported; REST mappings include GET " "/v1/tasks and GET /v1/tasks/{id}:subscribe), shared " "session-binding/model-selection/streaming contracts, provider-private " @@ -211,7 +212,6 @@ def _build_workspace_control_skill_examples(*, capability_snapshot) -> list[str] def _build_agent_extensions( *, - settings: Settings, runtime_profile: RuntimeProfile, include_detailed_contracts: bool, ) -> list[AgentExtension]: @@ -239,16 +239,12 @@ def _build_agent_extensions( runtime_profile=runtime_profile, ) compatibility_profile_params = build_compatibility_profile_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, - supported_protocol_versions=settings.a2a_supported_protocol_versions, - default_protocol_version=settings.a2a_protocol_version, ) wire_contract_params = build_wire_contract_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, - supported_protocol_versions=settings.a2a_supported_protocol_versions, - default_protocol_version=settings.a2a_protocol_version, ) return [ @@ -477,11 +473,11 @@ def _build_agent_skills( id="opencode.chat", name="OpenCode Chat", description=( - "Handle core A2A message/send and message/stream requests by routing " - "TextPart and FilePart inputs to OpenCode sessions with shared session " - "binding and optional request-scoped model selection. Chat clients " - "should continue accepting text/plain responses; application/json is " - "additive structured-output support." + "Handle core A2A SendMessage and SendStreamingMessage requests by routing " + "Part.text, Part.raw, and Part.url inputs to OpenCode sessions with " + "shared session binding and optional request-scoped model selection. " + "Chat clients should continue accepting text/plain responses; " + "application/json is additive structured-output support." ), input_modes=list(_CHAT_INPUT_MODES), output_modes=list(_CHAT_OUTPUT_MODES), @@ -605,12 +601,12 @@ def _build_agent_card( AgentInterface( url=public_url, protocol_binding="HTTP+JSON", - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, ), AgentInterface( url=public_url, protocol_binding="JSONRPC", - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, ), ], default_input_modes=list(_CHAT_INPUT_MODES), @@ -619,7 +615,6 @@ def _build_agent_card( streaming=True, extended_agent_card=True, extensions=_build_agent_extensions( - settings=settings, runtime_profile=runtime_profile, include_detailed_contracts=include_detailed_contracts, ), diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index 164ea12..d305603 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -25,6 +25,7 @@ InternalError, InvalidRequestError, Message, + Part, Role, SendMessageRequest, SendMessageResponse, @@ -48,41 +49,32 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from google.protobuf.json_format import MessageToDict, ParseDict, ParseError +from google.protobuf.message import Message as ProtoMessage from pydantic_settings import BaseSettings from starlette.middleware.gzip import GZipMiddleware from ..a2a_protocol import ( AGENT_CARD_WELL_KNOWN_PATH, EXTENDED_AGENT_CARD_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, ) -from ..a2a_utils import make_text_part from ..config import Settings from ..contracts.extensions import ( - COMPATIBILITY_PROFILE_EXTENSION_URI, - INTERRUPT_CALLBACK_EXTENSION_URI, - INTERRUPT_CALLBACK_METHODS, - INTERRUPT_RECOVERY_EXTENSION_URI, - INTERRUPT_RECOVERY_METHODS, MODEL_SELECTION_EXTENSION_URI, - PROVIDER_DISCOVERY_EXTENSION_URI, - PROVIDER_DISCOVERY_METHODS, SESSION_BINDING_EXTENSION_URI, - SESSION_CONTROL_METHODS, - SESSION_MANAGEMENT_EXTENSION_URI, - SESSION_METHODS, - STREAMING_EXTENSION_URI, - WIRE_CONTRACT_EXTENSION_URI, - WORKSPACE_CONTROL_EXTENSION_URI, - WORKSPACE_CONTROL_METHODS, build_capability_snapshot, ) from ..execution.executor import OpencodeAgentExecutor +from ..extension_negotiation import ( + ExtensionRequirement, + filter_negotiated_extensions_from_payload, + requested_extensions_from_call_context, +) from ..invocation import call_with_supported_kwargs from ..jsonrpc.application import ( OpencodeSessionManagementJSONRPCApplication, ) from ..jsonrpc.error_responses import build_http_error_body +from ..metadata_access import extract_namespaced_value from ..opencode_upstream_client import OpencodeUpstreamClient from ..output_modes import ( NegotiatingResultAggregator, @@ -91,42 +83,25 @@ normalize_accepted_output_modes, ) from ..profile.runtime import build_runtime_profile +from ..protocol_versions import A2A_PROTOCOL_VERSION from ..trace_context import install_log_record_factory from .agent_card import ( _CHAT_OUTPUT_MODES, - _build_agent_card_description, - _build_chat_examples, - _build_session_management_skill_examples, build_agent_card, build_authenticated_extended_agent_card, ) from .client_manager import A2AClientManager from .lifespan import build_lifespan from .middleware import ( - AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL, - PUBLIC_AGENT_CARD_CACHE_CONTROL, build_agent_card_etag, emit_stream_request_metrics, install_runtime_middlewares, ) from .openapi import ( - _build_jsonrpc_extension_openapi_description, - _build_jsonrpc_extension_openapi_examples, - _build_rest_message_openapi_examples, _patch_jsonrpc_openapi_contract, ) from .request_parsing import ( - _decode_payload_preview, - _detect_sensitive_extension_method, - _is_json_content_type, - _looks_like_jsonrpc_envelope, - _looks_like_jsonrpc_message_payload, - _normalize_content_type, - _normalize_v1_jsonrpc_method_alias, - _parse_content_length, _parse_json_body, - _request_body_too_large_response, - _RequestBodyTooLargeError, ) from .rest_tasks import build_list_tasks_route from .state_store import ( @@ -143,12 +118,6 @@ logger = logging.getLogger(__name__) TASK_STORE_ERROR_TYPE = "TASK_STORE_UNAVAILABLE" PUSH_NOTIFICATIONS_UNSUPPORTED_MESSAGE = "Push notifications are not supported by the agent" -_REST_MESSAGE_CONFIGURATION_FIELDS = ( - "acceptedOutputModes", - "historyLength", - "returnImmediately", - "taskPushNotificationConfig", -) def _are_modalities_compatible( @@ -158,31 +127,12 @@ def _are_modalities_compatible( return bool(set(supported_output_modes) & set(accepted_output_modes)) -def _build_rest_legacy_error_payload( - *, - message: str, - reason: str | None = None, - metadata: Mapping[str, Any] | None = None, -) -> dict[str, Any]: - payload: dict[str, Any] = {"error": message} - if reason: - payload["type"] = reason - if metadata: - payload.update(dict(metadata)) - return payload - - def _rest_error_response( *, request: Request, - default_protocol_version: str, error: Exception, ) -> JSONResponse: - protocol_version = getattr( - request.state, - "a2a_protocol_version", - default_protocol_version, - ) + del request logger_fn = logger.exception logger_message = "Unexpected REST message route failure" @@ -201,15 +151,9 @@ def _rest_error_response( logger_fn(logger_message) return JSONResponse( build_http_error_body( - protocol_version=protocol_version, status_code=mapping.http_code, status=mapping.grpc_status, message=message, - legacy_payload=_build_rest_legacy_error_payload( - message=message, - reason=mapping.reason, - metadata=metadata, - ), reason=mapping.reason, metadata=metadata, ), @@ -223,14 +167,9 @@ def _rest_error_response( logger_fn(logger_message) return JSONResponse( build_http_error_body( - protocol_version=protocol_version, status_code=400, status="INVALID_ARGUMENT", message=message, - legacy_payload=_build_rest_legacy_error_payload( - message=message, - reason="INVALID_REQUEST", - ), reason="INVALID_REQUEST", ), status_code=400, @@ -239,172 +178,42 @@ def _rest_error_response( logger_fn(logger_message) return JSONResponse( build_http_error_body( - protocol_version=protocol_version, status_code=500, status="INTERNAL", message="unknown exception", - legacy_payload=_build_rest_legacy_error_payload( - message="unknown exception", - reason="INTERNAL_ERROR", - ), reason="INTERNAL_ERROR", ), status_code=500, ) -def _normalize_rest_content_part( - value: Any, - *, - field: str, -) -> dict[str, Any]: - if not isinstance(value, Mapping): - raise InvalidRequestError(message=f"{field} must be an object.") - - normalized: dict[str, Any] = {} - metadata = value.get("metadata") - if metadata is not None: - normalized["metadata"] = metadata - - text_value = value.get("text") - if isinstance(text_value, str): - normalized["text"] = text_value - return normalized - - if "data" in value: - normalized["data"] = value.get("data") - return normalized - - file_value = value.get("file") - if isinstance(file_value, Mapping): - raw_value = file_value.get("bytes") - url_value = file_value.get("uri") - if isinstance(raw_value, str) and raw_value: - normalized["raw"] = raw_value - elif isinstance(url_value, str) and url_value: - normalized["url"] = url_value - else: - raise InvalidRequestError(message=f"{field}.file must contain uri or bytes.") - filename = file_value.get("name") - if isinstance(filename, str) and filename.strip(): - normalized["filename"] = filename - media_type = ( - file_value.get("mimeType") or file_value.get("mime_type") or file_value.get("mediaType") - ) - if isinstance(media_type, str) and media_type.strip(): - normalized["mediaType"] = media_type - return normalized - - raw_value = value.get("raw") - if isinstance(raw_value, str) and raw_value: - normalized["raw"] = raw_value - filename = value.get("filename") - if isinstance(filename, str) and filename.strip(): - normalized["filename"] = filename - media_type = value.get("mediaType") or value.get("media_type") - if isinstance(media_type, str) and media_type.strip(): - normalized["mediaType"] = media_type - return normalized - - url_value = value.get("url") - if isinstance(url_value, str) and url_value: - normalized["url"] = url_value - filename = value.get("filename") - if isinstance(filename, str) and filename.strip(): - normalized["filename"] = filename - media_type = value.get("mediaType") or value.get("media_type") - if isinstance(media_type, str) and media_type.strip(): - normalized["mediaType"] = media_type - return normalized - - raise InvalidRequestError(message=f"{field} must contain text, data, or file.") - - -def _normalize_rest_send_message_payload(payload: dict[str, Any]) -> dict[str, Any]: - normalized = dict(payload) - message = normalized.get("message") - if not isinstance(message, Mapping): - raise InvalidRequestError(message="message must be an object.") - - normalized_message = dict(message) - content = normalized_message.pop("content", None) - if not isinstance(content, list): - raise InvalidRequestError(message="message.content must be an array.") - normalized_message["parts"] = [ - _normalize_rest_content_part(item, field=f"message.content[{index}]") - for index, item in enumerate(content) - ] - normalized["message"] = normalized_message - - configuration_updates: dict[str, Any] = {} - for field in _REST_MESSAGE_CONFIGURATION_FIELDS: - if field in normalized: - configuration_updates[field] = normalized.pop(field) - if configuration_updates: - configuration = normalized.get("configuration") - if configuration is None: - normalized["configuration"] = configuration_updates - elif isinstance(configuration, Mapping): - merged_configuration = dict(configuration) - merged_configuration.update(configuration_updates) - normalized["configuration"] = merged_configuration - else: - raise InvalidRequestError(message="configuration must be an object.") - - return normalized - - def _parse_rest_send_message_request(body: bytes): payload = _parse_json_body(body) if payload is None: raise InvalidRequestError(message="REST message payload must be a JSON object.") - return ParseDict( - _normalize_rest_send_message_payload(payload), - SendMessageRequest(), - ) - + message = payload.get("message") + if isinstance(message, dict): + if "content" in message: + raise InvalidRequestError( + message="REST message payload must use message.parts, not message.content." + ) + role = message.get("role") + if isinstance(role, str) and role in {"user", "agent"}: + raise InvalidRequestError( + message="REST message payload must use ROLE_* values for message.role." + ) + parts = message.get("parts") + if isinstance(parts, list): + for index, part in enumerate(parts): + if isinstance(part, dict) and ("kind" in part or "type" in part or "file" in part): + raise InvalidRequestError( + message=( + f"message.parts[{index}] must use direct Part fields " + "such as text, raw, url, or data." + ) + ) + return ParseDict(payload, SendMessageRequest()) -__all__ = [ - "_RequestBodyTooLargeError", - "COMPATIBILITY_PROFILE_EXTENSION_URI", - "INTERRUPT_CALLBACK_EXTENSION_URI", - "INTERRUPT_CALLBACK_METHODS", - "INTERRUPT_RECOVERY_EXTENSION_URI", - "INTERRUPT_RECOVERY_METHODS", - "MODEL_SELECTION_EXTENSION_URI", - "PUBLIC_AGENT_CARD_CACHE_CONTROL", - "AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL", - "PROVIDER_DISCOVERY_EXTENSION_URI", - "PROVIDER_DISCOVERY_METHODS", - "SESSION_MANAGEMENT_EXTENSION_URI", - "SESSION_BINDING_EXTENSION_URI", - "SESSION_CONTROL_METHODS", - "SESSION_METHODS", - "STREAMING_EXTENSION_URI", - "WIRE_CONTRACT_EXTENSION_URI", - "WORKSPACE_CONTROL_EXTENSION_URI", - "WORKSPACE_CONTROL_METHODS", - "_build_agent_card_description", - "_build_chat_examples", - "_build_jsonrpc_extension_openapi_description", - "_build_jsonrpc_extension_openapi_examples", - "_build_rest_message_openapi_examples", - "_build_session_management_skill_examples", - "build_authenticated_extended_agent_card", - "_configure_logging", - "_decode_payload_preview", - "_detect_sensitive_extension_method", - "_is_json_content_type", - "_looks_like_jsonrpc_envelope", - "_looks_like_jsonrpc_message_payload", - "_normalize_v1_jsonrpc_method_alias", - "_normalize_content_type", - "_normalize_log_level", - "_parse_content_length", - "_parse_json_body", - "_request_body_too_large_response", - "build_agent_card", -] if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -425,7 +234,7 @@ def __init__( # noqa: PLR0913 self, agent_executor: AgentExecutor, task_store: TaskStore, - agent_card: AgentCard | None = None, + agent_card: AgentCard, queue_manager: Any | None = None, push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, @@ -437,7 +246,7 @@ def __init__( # noqa: PLR0913 super().__init__( agent_executor=agent_executor, task_store=task_store, - agent_card=agent_card or AgentCard(name="opencode-a2a"), + agent_card=agent_card, queue_manager=queue_manager, push_config_store=push_config_store, push_sender=push_sender, @@ -483,7 +292,7 @@ def _task_store_failure_task( error_message = Message( message_id=f"{task_id}:task-store-error", role=Role.ROLE_AGENT, - parts=[make_text_part(message_text)], + parts=[Part(text=message_text)], task_id=task_id, context_id=context_id, ) @@ -510,7 +319,7 @@ def _task_store_failure_events( context_id=context_id, artifact=Artifact( artifact_id=f"{task_id}:error", - parts=[make_text_part(message_text)], + parts=[Part(text=message_text)], ), append=False, last_chunk=True, @@ -537,14 +346,108 @@ def _extract_accepted_output_modes(params) -> list[str] | None: # noqa: ANN001 return list(normalized) if normalized is not None else None @staticmethod - def _apply_task_output_negotiation(task: Task) -> Task: + def _apply_task_output_negotiation(task: Task, context=None) -> Task: negotiated = apply_accepted_output_modes( task, extract_accepted_output_modes_from_metadata(task.metadata), ) - if isinstance(negotiated, Task): - return negotiated - return task + resolved = negotiated if isinstance(negotiated, Task) else task + return cast( + Task, + filter_negotiated_extensions_from_payload( + resolved, + requested_extensions_from_call_context(context), + ), + ) + + @staticmethod + def _validate_shared_extension_negotiation(params, context=None) -> None: # noqa: ANN001 + requested_extensions = requested_extensions_from_call_context(context) + sources: list[dict[str, Any]] = [] + params_metadata = getattr(params, "metadata", None) + if isinstance(params_metadata, ProtoMessage): + params_metadata = MessageToDict(params_metadata, preserving_proto_field_name=True) + elif isinstance(params_metadata, Mapping): + params_metadata = dict(params_metadata) + else: + params_metadata = None + if params_metadata: + sources.append(params_metadata) + message = getattr(params, "message", None) + message_metadata = getattr(message, "metadata", None) + if isinstance(message_metadata, ProtoMessage): + message_metadata = MessageToDict(message_metadata, preserving_proto_field_name=True) + elif isinstance(message_metadata, Mapping): + message_metadata = dict(message_metadata) + else: + message_metadata = None + if message_metadata: + sources.append(message_metadata) + + requirements: list[ExtensionRequirement] = [] + if any( + extract_namespaced_value(source, namespace="shared", path=("session", "id")) is not None + for source in sources + ): + requirements.append( + ExtensionRequirement( + extension_uri=SESSION_BINDING_EXTENSION_URI, + field="metadata.shared.session.id", + ) + ) + if any( + extract_namespaced_value(source, namespace="opencode", path=("directory",)) is not None + for source in sources + ): + requirements.append( + ExtensionRequirement( + extension_uri=SESSION_BINDING_EXTENSION_URI, + field="metadata.opencode.directory", + ) + ) + if any( + extract_namespaced_value(source, namespace="opencode", path=("workspace", "id")) + is not None + for source in sources + ): + requirements.append( + ExtensionRequirement( + extension_uri=SESSION_BINDING_EXTENSION_URI, + field="metadata.opencode.workspace.id", + ) + ) + if any( + extract_namespaced_value(source, namespace="shared", path=("model", "providerID")) + is not None + or extract_namespaced_value(source, namespace="shared", path=("model", "modelID")) + is not None + for source in sources + ): + requirements.append( + ExtensionRequirement( + extension_uri=MODEL_SELECTION_EXTENSION_URI, + field="metadata.shared.model", + ) + ) + missing_requirements = [ + requirement + for requirement in requirements + if requirement.extension_uri not in requested_extensions + ] + if not missing_requirements: + return + raise UnsupportedOperationError( + message="Request requires explicit A2A extension negotiation via A2A-Extensions.", + data={ + "type": "EXTENSION_NEGOTIATION_REQUIRED", + "fields": [requirement.field for requirement in missing_requirements], + "required_extensions": sorted( + {requirement.extension_uri for requirement in missing_requirements} + ), + "requested_extensions": sorted(requested_extensions), + "header": "A2A-Extensions", + }, + ) async def _setup_message_execution(self, params, context=None): # noqa: ANN001 ( @@ -600,7 +503,7 @@ async def on_get_task( task = await self.task_store.get(params.id, context) if not task: raise TaskNotFoundError() - return self._apply_task_output_negotiation(apply_history_length(task, params)) + return self._apply_task_output_negotiation(apply_history_length(task, params), context) except TaskStoreOperationError as exc: raise self._task_store_server_error(exc) from exc @@ -677,10 +580,10 @@ async def on_subscribe_to_task( # Subscribe contract: terminal tasks replay once and then close stream. if task.status.state in TERMINAL_TASK_STATES: - yield self._apply_task_output_negotiation(task) + yield self._apply_task_output_negotiation(task, context) return - yield self._apply_task_output_negotiation(task) + yield self._apply_task_output_negotiation(task, context) task_manager = TaskManager( task_id=task.id, @@ -700,23 +603,16 @@ async def on_subscribe_to_task( extract_accepted_output_modes_from_metadata(getattr(event, "metadata", None)), ) if negotiated is not None: - yield negotiated + yield filter_negotiated_extensions_from_payload( + negotiated, + requested_extensions_from_call_context(context), + ) except TaskStoreOperationError as exc: raise self._task_store_server_error(exc) from exc - async def on_resubscribe_to_task( - self, - params, - context=None, - ): - subscribe_params = params - if not isinstance(params, SubscribeToTaskRequest): - subscribe_params = SubscribeToTaskRequest(id=params.id) - async for event in self.on_subscribe_to_task(subscribe_params, context): - yield event - async def on_message_send_stream(self, params, context=None): self._validate_chat_output_modes(params) + self._validate_shared_extension_negotiation(params, context) ( _task_manager, task_id, @@ -767,6 +663,7 @@ async def on_message_send_stream(self, params, context=None): async def on_message_send(self, params, context=None): self._validate_chat_output_modes(params) + self._validate_shared_extension_negotiation(params, context) ( _task_manager, task_id, @@ -803,7 +700,7 @@ async def push_notification_callback() -> None: self._track_background_task(bg_consume_task) except TaskStoreOperationError as exc: logger.exception( - "Task store operation failed during message/send task_id=%s operation=%s", + "Task store operation failed during SendMessage task_id=%s operation=%s", task_id, exc.operation, ) @@ -882,12 +779,6 @@ def build(self, request: Request) -> ServerCallContext: trace_id = getattr(request.state, "trace_id", None) if trace_id: context.state["trace_id"] = trace_id - negotiated_protocol_version = getattr(request.state, "a2a_protocol_version", None) - if negotiated_protocol_version: - context.state["a2a_protocol_version"] = negotiated_protocol_version - requested_protocol_version = getattr(request.state, "a2a_requested_protocol_version", None) - if requested_protocol_version: - context.state["a2a_requested_protocol_version"] = requested_protocol_version return context @@ -948,7 +839,6 @@ def create_app(settings: Settings) -> FastAPI: http_handler=handler, context_builder=context_builder, upstream_client=upstream_client, - protocol_version=settings.a2a_protocol_version, supported_methods=capability_snapshot.supported_jsonrpc_methods(), directory_resolver=( partial( @@ -1013,7 +903,6 @@ async def _handler(context) -> SendMessageResponse: # noqa: ANN001 except Exception as error: # noqa: BLE001 return _rest_error_response( request=request, - default_protocol_version=settings.a2a_protocol_version, error=error, ) @@ -1029,12 +918,10 @@ async def _handler(context): # noqa: ANN001 except Exception as error: # noqa: BLE001 return _rest_error_response( request=request, - default_protocol_version=settings.a2a_protocol_version, error=error, ) app.add_api_route(AGENT_CARD_WELL_KNOWN_PATH, public_agent_card_route, methods=["GET"]) - app.add_api_route(PREV_AGENT_CARD_WELL_KNOWN_PATH, public_agent_card_route, methods=["GET"]) app.add_api_route("/v1/message:send", rest_message_send_route, methods=["POST"]) app.add_api_route("/v1/message:stream", rest_message_send_stream_route, methods=["POST"]) app.add_api_route("/v1/tasks/{id}:cancel", rest_dispatcher.on_cancel_task, methods=["POST"]) @@ -1053,18 +940,12 @@ async def _handler(context): # noqa: ANN001 app.add_api_route("/v1/tasks/{id}", rest_dispatcher.on_get_task, methods=["GET"]) async def push_notifications_unsupported_route(request: Request) -> JSONResponse: - protocol_version = getattr( - request.state, - "a2a_protocol_version", - settings.a2a_protocol_version, - ) + del request return JSONResponse( build_http_error_body( - protocol_version=protocol_version, status_code=501, status="UNIMPLEMENTED", message=PUSH_NOTIFICATIONS_UNSUPPORTED_MESSAGE, - legacy_payload={"message": PUSH_NOTIFICATIONS_UNSUPPORTED_MESSAGE}, reason="PUSH_NOTIFICATIONS_UNSUPPORTED", ), status_code=501, @@ -1094,7 +975,6 @@ async def push_notifications_unsupported_route(request: Request) -> JSONResponse "/v1/tasks", build_list_tasks_route( task_store=task_store, - default_protocol_version=settings.a2a_protocol_version, ), methods=["GET"], ) @@ -1122,7 +1002,7 @@ async def health_check(): return runtime_profile.health_payload( service="opencode-a2a", version=settings.a2a_version, - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, ) return app diff --git a/src/opencode_a2a/server/client_manager.py b/src/opencode_a2a/server/client_manager.py index f987f93..d496e74 100644 --- a/src/opencode_a2a/server/client_manager.py +++ b/src/opencode_a2a/server/client_manager.py @@ -21,9 +21,6 @@ def __init__(self, settings) -> None: # noqa: ANN001 "A2A_CLIENT_USE_CLIENT_PREFERENCE": settings.a2a_client_use_client_preference, "A2A_CLIENT_BEARER_TOKEN": settings.a2a_client_bearer_token, "A2A_CLIENT_BASIC_AUTH": settings.a2a_client_basic_auth, - "A2A_CLIENT_PROTOCOL_VERSION": ( - settings.a2a_client_protocol_version or settings.a2a_protocol_version - ), "A2A_CLIENT_SUPPORTED_TRANSPORTS": settings.a2a_client_supported_transports, } ) diff --git a/src/opencode_a2a/server/middleware.py b/src/opencode_a2a/server/middleware.py index c38cb8b..a9b39a8 100644 --- a/src/opencode_a2a/server/middleware.py +++ b/src/opencode_a2a/server/middleware.py @@ -8,21 +8,20 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response +from google.protobuf.json_format import MessageToDict from starlette.responses import StreamingResponse from ..a2a_protocol import ( AGENT_CARD_WELL_KNOWN_PATH, EXTENDED_AGENT_CARD_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, ) -from ..a2a_utils import proto_to_dict from ..auth import ( authenticate_static_credential, build_static_auth_credentials, ) from ..execution.metrics import emit_metric from ..jsonrpc.error_responses import ( - adapt_jsonrpc_error_for_protocol, + adapt_jsonrpc_error, build_http_error_body, version_not_supported_error, ) @@ -30,7 +29,6 @@ from ..protocol_versions import ( UnsupportedProtocolVersionError, negotiate_protocol_version, - normalize_protocol_version, ) from ..trace_context import ( TRACEPARENT_HEADER, @@ -44,9 +42,7 @@ _detect_sensitive_extension_method, _is_json_content_type, _looks_like_jsonrpc_envelope, - _looks_like_jsonrpc_message_payload, _normalize_content_type, - _normalize_v1_jsonrpc_method_alias, _parse_content_length, _parse_json_body, _request_body_too_large_response, @@ -82,7 +78,6 @@ def _unauthorized_response() -> JSONResponse: async def bearer_auth(request: Request, call_next): if request.method == "OPTIONS" or request.url.path in { AGENT_CARD_WELL_KNOWN_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, }: return await call_next(request) @@ -110,7 +105,7 @@ async def bearer_auth(request: Request, call_next): def build_agent_card_etag(card) -> str: # noqa: ANN001 - payload = proto_to_dict(card) + payload = MessageToDict(card) content = json.dumps( payload, ensure_ascii=False, @@ -142,31 +137,6 @@ def _extract_jsonrpc_request_id(payload: object) -> str | int | None: return request_id return None - def _error_protocol_version(request: Request) -> str: - negotiated = getattr(request.state, "a2a_protocol_version", None) - if isinstance(negotiated, str) and negotiated.strip(): - return negotiated - raw_value = request.headers.get("A2A-Version") or request.query_params.get("A2A-Version") - if isinstance(raw_value, str) and raw_value.strip(): - try: - return normalize_protocol_version(raw_value) - except ValueError: - return raw_value.strip() - return cast(str, settings.a2a_protocol_version) - - def _uses_v1_jsonrpc_aliases(request: Request) -> bool: - negotiated = getattr(request.state, "a2a_protocol_version", None) - if negotiated == "1.0": - return True - - raw_value = request.headers.get("A2A-Version") or request.query_params.get("A2A-Version") - if not isinstance(raw_value, str) or not raw_value.strip(): - return False - try: - return normalize_protocol_version(raw_value) == "1.0" - except ValueError: - return False - @app.middleware("http") async def bind_trace_context(request: Request, call_next): trace_context = resolve_trace_context( @@ -192,11 +162,9 @@ async def negotiate_a2a_protocol_version(request: Request, call_next): return await call_next(request) try: - negotiated = negotiate_protocol_version( + negotiated_version = negotiate_protocol_version( header_value=request.headers.get("A2A-Version"), query_value=request.query_params.get("A2A-Version"), - default_protocol_version=settings.a2a_protocol_version, - supported_protocol_versions=settings.a2a_supported_protocol_versions, ) except UnsupportedProtocolVersionError as error: if request.url.path == "/" and request.method == "POST": @@ -208,7 +176,6 @@ async def negotiate_a2a_protocol_version(request: Request, call_next): path=request.url.path, method=request.method, error=request_error, - protocol_version=_error_protocol_version(request), ) return JSONResponse( { @@ -216,8 +183,7 @@ async def negotiate_a2a_protocol_version(request: Request, call_next): "id": _extract_jsonrpc_request_id(payload), "error": cast( JSONRPCError, - adapt_jsonrpc_error_for_protocol( - error.requested_version, + adapt_jsonrpc_error( version_not_supported_error( requested_version=error.requested_version, supported_protocol_versions=list( @@ -232,17 +198,9 @@ async def negotiate_a2a_protocol_version(request: Request, call_next): ) return JSONResponse( build_http_error_body( - protocol_version=error.requested_version, status_code=400, status="INVALID_ARGUMENT", message="Unsupported A2A version", - legacy_payload={ - "error": "Unsupported A2A version", - "type": "VERSION_NOT_SUPPORTED", - "requested_version": error.requested_version, - "supported_protocol_versions": list(error.supported_protocol_versions), - "default_protocol_version": error.default_protocol_version, - }, reason="VERSION_NOT_SUPPORTED", metadata={ "requested_version": error.requested_version, @@ -256,11 +214,9 @@ async def negotiate_a2a_protocol_version(request: Request, call_next): if token is not None: _REQUEST_BODY_BYTES.reset(token) - request.state.a2a_protocol_version = negotiated.negotiated_version - request.state.a2a_requested_protocol_version = negotiated.requested_version - request.state.a2a_protocol_version_explicit = negotiated.explicit + request.state.a2a_protocol_version = negotiated_version response = await call_next(request) - response.headers["A2A-Version"] = negotiated.negotiated_version + response.headers["A2A-Version"] = negotiated_version return response async def _get_request_body(request: Request) -> tuple[bytes, Token | None]: @@ -322,10 +278,7 @@ async def cache_agent_card_responses(request: Request, call_next): return await call_next(request) path = request.url.path - is_public_card = path in { - AGENT_CARD_WELL_KNOWN_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, - } + is_public_card = path == AGENT_CARD_WELL_KNOWN_PATH is_extended_card = path == EXTENDED_AGENT_CARD_PATH if not is_public_card and not is_extended_card: return await call_next(request) @@ -379,7 +332,6 @@ async def enforce_request_body_limit(request: Request, call_next): path=request.url.path, method=request.method, error=error, - protocol_version=_error_protocol_version(request), ) finally: if token is not None: @@ -397,26 +349,17 @@ async def guard_rest_payload_shape(request: Request, call_next): try: body, token = await _get_request_body(request) payload = _parse_json_body(body) - if _looks_like_jsonrpc_envelope(payload) or _looks_like_jsonrpc_message_payload( - payload - ): + if _looks_like_jsonrpc_envelope(payload): return JSONResponse( build_http_error_body( - protocol_version=_error_protocol_version(request), status_code=400, status="INVALID_ARGUMENT", message=( - "Invalid HTTP+JSON payload for REST endpoint. " - "Use message.content with ROLE_* role values, or call " - "POST / with method=message/send or method=message/stream." + "Invalid JSON-RPC payload for REST endpoint. " + "Call POST / for JSON-RPC methods such as SendMessage " + "or SendStreamingMessage, or send ProtoJSON " + "SendMessageRequest payloads to the REST endpoint." ), - legacy_payload={ - "error": ( - "Invalid HTTP+JSON payload for REST endpoint. " - "Use message.content with ROLE_* role values, or call " - "POST / with method=message/send or method=message/stream." - ) - }, reason="INVALID_HTTP_JSON_PAYLOAD", metadata={"path": request.url.path}, ), @@ -428,49 +371,8 @@ async def guard_rest_payload_shape(request: Request, call_next): path=request.url.path, method=request.method, error=error, - protocol_version=_error_protocol_version(request), - ) - finally: - if token is not None: - _REQUEST_BODY_BYTES.reset(token) - - @app.middleware("http") - async def normalize_v1_jsonrpc_method_aliases(request: Request, call_next): - token: Token | None = None - rewrite_token: Token | None = None - if ( - request.method != "POST" - or request.url.path != "/" - or not _uses_v1_jsonrpc_aliases(request) - ): - return await call_next(request) - - try: - body, token = await _get_request_body(request) - payload = _parse_json_body(body) - normalized_payload = _normalize_v1_jsonrpc_method_alias( - payload, - protocol_version="1.0", - ) - if normalized_payload is not None and normalized_payload is not payload: - normalized_body = json.dumps( - normalized_payload, - ensure_ascii=False, - separators=(",", ":"), - ).encode("utf-8") - request._body = normalized_body - rewrite_token = _REQUEST_BODY_BYTES.set(normalized_body) - return await call_next(request) - except _RequestBodyTooLargeError as error: - return _request_body_too_large_response( - path=request.url.path, - method=request.method, - error=error, - protocol_version=_error_protocol_version(request), ) finally: - if rewrite_token is not None: - _REQUEST_BODY_BYTES.reset(rewrite_token) if token is not None: _REQUEST_BODY_BYTES.reset(token) @@ -574,7 +476,6 @@ async def log_payloads(request: Request, call_next): path=request.url.path, method=request.method, error=error, - protocol_version=_error_protocol_version(request), ) finally: if token is not None: diff --git a/src/opencode_a2a/server/openapi.py b/src/opencode_a2a/server/openapi.py index bc3e247..8b83c5a 100644 --- a/src/opencode_a2a/server/openapi.py +++ b/src/opencode_a2a/server/openapi.py @@ -25,9 +25,10 @@ build_wire_contract_params, build_workspace_control_extension_params, ) -from ..jsonrpc.application import SESSION_CONTEXT_PREFIX +from ..jsonrpc.methods import SESSION_CONTEXT_PREFIX from ..jsonrpc.models import JSONRPCRequest from ..profile.runtime import RuntimeProfile +from ..protocol_versions import A2A_PROTOCOL_VERSION def _build_jsonrpc_extension_openapi_description( @@ -41,7 +42,7 @@ def _build_jsonrpc_extension_openapi_description( interrupt_methods = ", ".join(sorted(INTERRUPT_CALLBACK_METHODS.values())) return ( "A2A JSON-RPC entrypoint. Supports core A2A methods " - "(message/send, message/stream, tasks/get, tasks/cancel, tasks/resubscribe) " + "(SendMessage, SendStreamingMessage, GetTask, CancelTask, SubscribeToTask) " "plus shared model-selection metadata, OpenCode session/provider extensions, " "interrupt recovery extensions, and shared interrupt callback methods.\n\n" f"OpenCode session read/mutation/control methods: {', '.join(session_methods)}.\n" @@ -67,12 +68,12 @@ def _build_jsonrpc_extension_openapi_examples( "value": { "jsonrpc": "2.0", "id": 101, - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "msg-1", - "role": "user", - "parts": [{"kind": "text", "text": "Explain what this repository does."}], + "role": "ROLE_USER", + "parts": [{"text": "Explain what this repository does."}], } }, }, @@ -82,17 +83,12 @@ def _build_jsonrpc_extension_openapi_examples( "value": { "jsonrpc": "2.0", "id": 102, - "method": "message/stream", + "method": "SendStreamingMessage", "params": { "message": { "messageId": "msg-stream-1", - "role": "user", - "parts": [ - { - "kind": "text", - "text": "Stream the answer and highlight key conclusions.", - } - ], + "role": "ROLE_USER", + "parts": [{"text": "Stream the answer and highlight key conclusions."}], } }, }, @@ -102,12 +98,12 @@ def _build_jsonrpc_extension_openapi_examples( "value": { "jsonrpc": "2.0", "id": 103, - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "msg-model-1", - "role": "user", - "parts": [{"kind": "text", "text": "Answer with the faster model."}], + "role": "ROLE_USER", + "parts": [{"text": "Answer with the faster model."}], }, "metadata": { "shared": { @@ -125,23 +121,17 @@ def _build_jsonrpc_extension_openapi_examples( "value": { "jsonrpc": "2.0", "id": 104, - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "msg-file-1", - "role": "user", + "role": "ROLE_USER", "parts": [ + {"text": "Review the attached file and summarize the main risks."}, { - "kind": "text", - "text": "Review the attached file and summarize the main risks.", - }, - { - "kind": "file", - "file": { - "name": "report.pdf", - "mimeType": "application/pdf", - "uri": "file:///workspace/report.pdf", - }, + "url": "file:///workspace/report.pdf", + "filename": "report.pdf", + "mediaType": "application/pdf", }, ], } @@ -525,24 +515,22 @@ def _build_rest_message_openapi_examples() -> dict[str, Any]: "message": { "messageId": "msg-rest-1", "role": "ROLE_USER", - "content": [{"text": "Explain what this repository does."}], + "parts": [{"text": "Explain what this repository does."}], } }, }, "message_with_file_input": { - "summary": "Send message with FilePart input (HTTP+JSON)", + "summary": "Send message with file input (HTTP+JSON)", "value": { "message": { "messageId": "msg-rest-file-1", "role": "ROLE_USER", - "content": [ + "parts": [ {"text": "Review the attached file and summarize the main risks."}, { - "file": { - "name": "report.pdf", - "mimeType": "application/pdf", - "uri": "file:///workspace/report.pdf", - } + "url": "file:///workspace/report.pdf", + "filename": "report.pdf", + "mediaType": "application/pdf", }, ], } @@ -554,7 +542,7 @@ def _build_rest_message_openapi_examples() -> dict[str, Any]: "message": { "messageId": "msg-rest-continue-1", "role": "ROLE_USER", - "content": [{"text": "Continue previous work and summarize next steps."}], + "parts": [{"text": "Continue previous work and summarize next steps."}], }, "metadata": { "shared": { @@ -569,7 +557,7 @@ def _build_rest_message_openapi_examples() -> dict[str, Any]: "message": { "messageId": "msg-rest-model-1", "role": "ROLE_USER", - "content": [{"text": "Answer with the faster model."}], + "parts": [{"text": "Answer with the faster model."}], }, "metadata": { "shared": { @@ -614,16 +602,12 @@ def _patch_jsonrpc_openapi_contract( runtime_profile=runtime_profile, ) compatibility_profile = build_compatibility_profile_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, - supported_protocol_versions=settings.a2a_supported_protocol_versions, - default_protocol_version=settings.a2a_protocol_version, ) wire_contract = build_wire_contract_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, - supported_protocol_versions=settings.a2a_supported_protocol_versions, - default_protocol_version=settings.a2a_protocol_version, ) capability_snapshot = build_capability_snapshot(runtime_profile=runtime_profile) original_openapi = app.openapi @@ -680,7 +664,8 @@ def custom_openapi() -> dict[str, Any]: "summary": "Send Message (HTTP+JSON)", "description": ( "A2A HTTP+JSON message send endpoint. " - "Use REST payload shape with message.content and ROLE_* roles." + "Use ProtoJSON SendMessageRequest payloads with message.parts " + "and ROLE_* roles." ), "schema_ref": "#/components/schemas/SendMessageRequest", }, @@ -688,7 +673,8 @@ def custom_openapi() -> dict[str, Any]: "summary": "Stream Message (HTTP+JSON)", "description": ( "A2A HTTP+JSON streaming endpoint. " - "Use REST payload shape with message.content and ROLE_* roles." + "Use ProtoJSON SendMessageRequest payloads with message.parts " + "and ROLE_* roles." ), "schema_ref": "#/components/schemas/SendStreamingMessageRequest", }, diff --git a/src/opencode_a2a/server/request_parsing.py b/src/opencode_a2a/server/request_parsing.py index d5dc528..5ac927e 100644 --- a/src/opencode_a2a/server/request_parsing.py +++ b/src/opencode_a2a/server/request_parsing.py @@ -5,7 +5,6 @@ from fastapi.responses import JSONResponse -from ..a2a_protocol import V1_JSONRPC_METHOD_TO_LEGACY_METHOD from ..contracts.extensions import ( INTERRUPT_CALLBACK_METHODS, INTERRUPT_RECOVERY_METHODS, @@ -16,8 +15,6 @@ logger = logging.getLogger(__name__) -_V1_JSONRPC_METHOD_ALIASES = dict(V1_JSONRPC_METHOD_TO_LEGACY_METHOD) - def _parse_json_body(body_bytes: bytes) -> dict | None: try: @@ -75,18 +72,6 @@ def _decode_payload_preview(body: bytes, *, limit: int) -> str: return body.decode("utf-8", errors="replace") -def _looks_like_jsonrpc_message_payload(payload: dict | None) -> bool: - if payload is None: - return False - message = payload.get("message") - if not isinstance(message, dict): - return False - if "parts" in message: - return True - role = message.get("role") - return isinstance(role, str) and role in {"user", "agent"} - - def _looks_like_jsonrpc_envelope(payload: dict | None) -> bool: if payload is None: return False @@ -95,22 +80,6 @@ def _looks_like_jsonrpc_envelope(payload: dict | None) -> bool: return isinstance(method, str) and isinstance(version, str) -def _normalize_v1_jsonrpc_method_alias( - payload: dict | None, *, protocol_version: str -) -> dict | None: - if payload is None or protocol_version != "1.0": - return payload - method = payload.get("method") - if not isinstance(method, str): - return payload - canonical_method = _V1_JSONRPC_METHOD_ALIASES.get(method) - if canonical_method is None or canonical_method == method: - return payload - normalized_payload = dict(payload) - normalized_payload["method"] = canonical_method - return normalized_payload - - class _RequestBodyTooLargeError(Exception): def __init__(self, *, limit: int, actual_size: int) -> None: super().__init__("Request body too large") @@ -123,7 +92,6 @@ def _request_body_too_large_response( path: str, method: str, error: _RequestBodyTooLargeError, - protocol_version: str = "0.3", ) -> JSONResponse: logger.warning( "A2A request %s %s rejected: body_size=%s exceeds max_request_body_bytes=%s", @@ -134,11 +102,9 @@ def _request_body_too_large_response( ) return JSONResponse( build_http_error_body( - protocol_version=protocol_version, status_code=413, status="RESOURCE_EXHAUSTED", message="Request body too large", - legacy_payload={"error": "Request body too large", "max_bytes": error.limit}, reason="REQUEST_BODY_TOO_LARGE", metadata={"max_bytes": error.limit, "actual_size": error.actual_size}, ), diff --git a/src/opencode_a2a/server/rest_tasks.py b/src/opencode_a2a/server/rest_tasks.py index 038aba4..16435e3 100644 --- a/src/opencode_a2a/server/rest_tasks.py +++ b/src/opencode_a2a/server/rest_tasks.py @@ -5,14 +5,18 @@ import logging from dataclasses import dataclass from datetime import UTC, datetime -from typing import cast +from typing import Any, cast +from a2a.extensions.common import HTTP_EXTENSION_HEADER, get_requested_extensions from a2a.server.tasks.task_store import TaskStore from a2a.types import Task, TaskState from fastapi import Request from fastapi.responses import JSONResponse +from google.protobuf.json_format import MessageToDict -from ..a2a_utils import proto_to_dict +from ..extension_negotiation import ( + filter_negotiated_extensions_from_payload, +) from ..jsonrpc.error_responses import build_http_error_body from ..output_modes import ( apply_accepted_output_modes, @@ -66,14 +70,8 @@ def _validation_error(field: str, message: str) -> _ListTasksValidationError: def build_list_tasks_route( *, task_store: TaskStore, - default_protocol_version: str, ): async def list_tasks_route(request: Request) -> JSONResponse: - protocol_version = getattr( - request.state, - "a2a_protocol_version", - default_protocol_version, - ) try: query = _parse_list_tasks_query(request) tasks = await list_stored_tasks(task_store) @@ -81,19 +79,13 @@ async def list_tasks_route(request: Request) -> JSONResponse: return _invalid_argument_response( field=error.field, message=error.message, - protocol_version=protocol_version, ) except TaskStoreOperationError as error: return JSONResponse( build_http_error_body( - protocol_version=protocol_version, status_code=500, status="INTERNAL", message="Task store unavailable while listing tasks.", - legacy_payload={ - "error": "Task store unavailable while listing tasks.", - "operation": error.operation, - }, reason="TASK_STORE_UNAVAILABLE", metadata={"operation": error.operation}, ), @@ -115,6 +107,9 @@ async def list_tasks_route(request: Request) -> JSONResponse: task, history_length=query.history_length, include_artifacts=query.include_artifacts, + requested_extensions=frozenset( + get_requested_extensions(request.headers.getlist(HTTP_EXTENSION_HEADER)) + ), ) for task in page_tasks ], @@ -159,6 +154,7 @@ def _serialize_task( *, history_length: int, include_artifacts: bool, + requested_extensions: frozenset[str], ) -> dict: negotiated = apply_accepted_output_modes( task, @@ -166,8 +162,9 @@ def _serialize_task( ) if isinstance(negotiated, Task): task = negotiated + task = filter_negotiated_extensions_from_payload(task, requested_extensions) - payload = proto_to_dict(task) + payload = cast(dict[str, Any], MessageToDict(task)) history = payload.get("history") if history_length <= 0: @@ -215,9 +212,7 @@ def _parse_list_tasks_query(request: Request) -> _ListTasksQuery: status = None if status_value is not None: try: - normalized_status = status_value.strip().upper() - if normalized_status and not normalized_status.startswith("TASK_STATE_"): - normalized_status = f"TASK_STATE_{normalized_status}" + normalized_status = status_value.strip() status = TaskState.Value(normalized_status) except ValueError as exc: raise _ListTasksValidationError( @@ -347,15 +342,12 @@ def _invalid_argument_response( *, field: str, message: str, - protocol_version: str, ) -> JSONResponse: return JSONResponse( build_http_error_body( - protocol_version=protocol_version, status_code=400, status="INVALID_ARGUMENT", message=message, - legacy_payload={"error": message, "field": field}, reason="INVALID_LIST_TASKS_REQUEST", metadata={"field": field}, ), diff --git a/src/opencode_a2a/server/state_store.py b/src/opencode_a2a/server/state_store.py index 5941088..3c42b48 100644 --- a/src/opencode_a2a/server/state_store.py +++ b/src/opencode_a2a/server/state_store.py @@ -125,7 +125,7 @@ def _initialize_state_store_schema(connection) -> None: # noqa: ANN001 def _pending_claim_expires_at( row: Mapping[str, Any], *, - legacy_ttl_seconds: float, + ttl_seconds: float, ) -> float | None: expires_at = row.get("expires_at") if expires_at is not None: @@ -133,7 +133,7 @@ def _pending_claim_expires_at( updated_at = row.get("updated_at") if updated_at is None: return None - return float(updated_at) + max(0.0, legacy_ttl_seconds) + return float(updated_at) + max(0.0, ttl_seconds) class SessionStateRepository(ABC): @@ -394,7 +394,7 @@ async def get_pending_claim(self, *, session_id: str) -> str | None: return None expires_at = _pending_claim_expires_at( row, - legacy_ttl_seconds=self._pending_claim_ttl_seconds, + ttl_seconds=self._pending_claim_ttl_seconds, ) if expires_at is None or expires_at <= now: await session.execute( diff --git a/src/opencode_a2a/server/task_store.py b/src/opencode_a2a/server/task_store.py index 0b04bcd..8aacb26 100644 --- a/src/opencode_a2a/server/task_store.py +++ b/src/opencode_a2a/server/task_store.py @@ -12,12 +12,13 @@ from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore from a2a.types import ListTasksRequest, ListTasksResponse, Task, TaskState +from google.protobuf.json_format import MessageToDict from sqlalchemy import event, or_, select from sqlalchemy.dialects.postgresql import insert as postgresql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.engine import make_url -from ..a2a_utils import proto_equals, proto_to_dict +from ..a2a_utils import proto_equals from ..config import Settings from ..task_states import TERMINAL_TASK_STATES @@ -462,10 +463,10 @@ def _task_row_values(task: Task, *, owner: str | None) -> dict[str, Any]: "last_updated": ( task.status.timestamp.ToDatetime() if task.status.HasField("timestamp") else None ), - "status": proto_to_dict(task.status), - "artifacts": [proto_to_dict(artifact) for artifact in task.artifacts], - "history": [proto_to_dict(message) for message in task.history], - "metadata": proto_to_dict(task.metadata), + "status": MessageToDict(task.status), + "artifacts": [MessageToDict(artifact) for artifact in task.artifacts], + "history": [MessageToDict(message) for message in task.history], + "metadata": MessageToDict(task.metadata), "protocol_version": "1.0", } diff --git a/src/opencode_a2a/upstream_taxonomy.py b/src/opencode_a2a/upstream_taxonomy.py index a55bc93..22c4011 100644 --- a/src/opencode_a2a/upstream_taxonomy.py +++ b/src/opencode_a2a/upstream_taxonomy.py @@ -86,10 +86,3 @@ def extract_upstream_error_detail(response: httpx.Response | None) -> str | None if text: return text[:512] return None - - -__all__ = [ - "UpstreamHTTPErrorProfile", - "extract_upstream_error_detail", - "resolve_upstream_http_error_profile", -] diff --git a/tests/client/test_agent_card.py b/tests/client/test_agent_card.py index e837515..cd8dc21 100644 --- a/tests/client/test_agent_card.py +++ b/tests/client/test_agent_card.py @@ -5,7 +5,6 @@ import httpx import pytest -from a2a.client.errors import A2AClientHTTPError from opencode_a2a.client.agent_card import ( build_agent_card_resolver, @@ -14,6 +13,7 @@ ) from opencode_a2a.client.error_mapping import map_agent_card_error from opencode_a2a.client.errors import A2AAuthenticationError +from tests.support.fake_client_errors import FakeA2AClientHTTPError @pytest.mark.asyncio @@ -58,7 +58,7 @@ def test_normalize_agent_card_endpoint_requires_absolute_url() -> None: def test_build_resolver_http_kwargs_uses_bearer_token() -> None: assert build_resolver_http_kwargs(bearer_token="peer-token", timeout=7) == { "timeout": 7, - "headers": {"Authorization": "Bearer peer-token"}, + "headers": {"A2A-Version": "1.0", "Authorization": "Bearer peer-token"}, } @@ -71,12 +71,12 @@ def test_build_resolver_http_kwargs_uses_basic_auth() -> None: timeout=7, ) == { "timeout": 7, - "headers": {"Authorization": f"Basic {encoded}"}, + "headers": {"A2A-Version": "1.0", "Authorization": f"Basic {encoded}"}, } def test_map_agent_card_error_http_variant() -> None: - mapped = map_agent_card_error(A2AClientHTTPError(401, "unauthorized")) + mapped = map_agent_card_error(FakeA2AClientHTTPError(401, "unauthorized")) assert isinstance(mapped, A2AAuthenticationError) assert mapped.http_status == 401 diff --git a/tests/client/test_client_config.py b/tests/client/test_client_config.py index 5e9963b..04daaf6 100644 --- a/tests/client/test_client_config.py +++ b/tests/client/test_client_config.py @@ -20,7 +20,6 @@ def test_load_settings_from_mapping() -> None: "A2A_CLIENT_USE_CLIENT_PREFERENCE": "true", "A2A_CLIENT_BEARER_TOKEN": "peer-token", "A2A_CLIENT_BASIC_AUTH": "user:pass", - "A2A_CLIENT_PROTOCOL_VERSION": "1.0.0", "A2A_CLIENT_SUPPORTED_TRANSPORTS": "json-rpc,http-json", "A2A_CLIENT_POLLING_FALLBACK_ENABLED": "true", "A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS": "0.75", @@ -36,7 +35,6 @@ def test_load_settings_from_mapping() -> None: assert settings.use_client_preference is True assert settings.bearer_token == "peer-token" assert settings.basic_auth == "user:pass" - assert settings.protocol_version == "1.0" assert settings.supported_transports == ("JSONRPC", "HTTP+JSON") assert settings.polling_fallback_enabled is True assert settings.polling_fallback_initial_interval_seconds == 0.75 @@ -68,12 +66,6 @@ def test_load_settings_accepts_base64_basic_auth() -> None: assert settings == A2AClientSettings(basic_auth=encoded) -def test_load_settings_can_fallback_to_general_protocol_version() -> None: - settings = load_settings({"A2A_PROTOCOL_VERSION": "0.3.0"}) - - assert settings.protocol_version == "0.3" - - def test_load_settings_invalid_basic_auth_raises() -> None: with pytest.raises(ValueError, match="username:password"): load_settings({"A2A_CLIENT_BASIC_AUTH": "not-basic-auth"}) diff --git a/tests/client/test_client_facade.py b/tests/client/test_client_facade.py index 1a5ed3e..4be5b74 100644 --- a/tests/client/test_client_facade.py +++ b/tests/client/test_client_facade.py @@ -7,8 +7,7 @@ import httpx import pytest from a2a.client import ClientConfig -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError, A2AClientJSONRPCError -from a2a.types import JSONRPCError, JSONRPCErrorResponse, Task, TaskState, TaskStatus +from a2a.types import Message, Part, Role, StreamResponse, Task, TaskState, TaskStatus from opencode_a2a.client import A2AClient from opencode_a2a.client import client as client_module @@ -18,6 +17,12 @@ A2ATimeoutError, A2AUnsupportedOperationError, ) +from opencode_a2a.jsonrpc.models import JSONRPCError, JSONRPCErrorResponse +from tests.support.fake_client_errors import ( + FakeA2AClientHTTPError, + FakeA2AClientJSONError, + FakeA2AClientJSONRPCError, +) class _FakeCardResolver: @@ -33,12 +38,12 @@ async def get_agent_card(self, **_kwargs: object) -> object: class _FakeClient: def __init__( self, - events: list[object] | None = None, + events: list[StreamResponse] | None = None, *, fail: BaseException | None = None, - task_results: list[object] | None = None, + task_results: list[Task] | None = None, task_fail: BaseException | None = None, - ): + ) -> None: self._events = list(events or []) self._fail = fail self._task_results = list(task_results or []) @@ -46,16 +51,18 @@ def __init__( self.send_message_inputs: list[tuple[object, object, object]] = [] self.task_inputs: list[tuple[object, object]] = [] self.cancel_inputs: list[tuple[object, object]] = [] - self.resubscribe_inputs: list[tuple[object, object]] = [] + self.subscribe_inputs: list[tuple[object, object]] = [] - async def send_message(self, message, *args: object, **kwargs: object) -> AsyncIterator[object]: + async def send_message( + self, message, *args: object, **kwargs: object + ) -> AsyncIterator[StreamResponse]: self.send_message_inputs.append((message, args, kwargs)) if self._fail: raise self._fail for event in self._events: yield event - async def get_task(self, params, *args: object, **kwargs: object) -> object: + async def get_task(self, params, *args: object, **kwargs: object) -> Task: self.task_inputs.append((params, kwargs)) if self._task_fail: raise self._task_fail @@ -63,16 +70,18 @@ async def get_task(self, params, *args: object, **kwargs: object) -> object: raise self._fail if self._task_results: return self._task_results.pop(0) - return {"task_id": params.id} + return _task(params.id, TaskState.TASK_STATE_COMPLETED) - async def cancel_task(self, params, *args: object, **kwargs: object) -> object: + async def cancel_task(self, params, *args: object, **kwargs: object) -> Task: self.cancel_inputs.append((params, kwargs)) if self._fail: raise self._fail - return {"task_id": params.id, "status": "canceled"} + return _task(params.id, TaskState.TASK_STATE_CANCELED) - async def resubscribe(self, params, *args: object, **kwargs: object) -> AsyncIterator[object]: - self.resubscribe_inputs.append((params, kwargs)) + async def subscribe( + self, params, *args: object, **kwargs: object + ) -> AsyncIterator[StreamResponse]: + self.subscribe_inputs.append((params, kwargs)) if self._fail: raise self._fail for event in self._events: @@ -87,6 +96,20 @@ def _task(task_id: str, state: TaskState) -> Task: ) +def _stream_message(text: str) -> StreamResponse: + return StreamResponse( + message=Message( + message_id=f"msg-{text}", + role=Role.ROLE_AGENT, + parts=[Part(text=text)], + ) + ) + + +def _stream_task(task_id: str, state: TaskState) -> StreamResponse: + return StreamResponse(task=_task(task_id, state)) + + @pytest.mark.asyncio async def test_get_agent_card_cached_and_reused(monkeypatch: pytest.MonkeyPatch) -> None: resolver = _FakeCardResolver("agent-card") @@ -117,40 +140,28 @@ async def test_build_client_uses_settings_and_transport_config( ) fake_sdk_client = _FakeClient() - factory_calls: dict[str, object] = {} - - class _FakeFactory: - def __init__(self, config: ClientConfig, consumers: list[object] | None = None): - factory_calls["config"] = config - factory_calls["consumers"] = consumers - - def create( - self, - _card: object, - consumers: list[object] | None = None, - interceptors: list[object] | None = None, - extensions: list[str] | None = None, - ) -> _FakeClient: - factory_calls["create_consumers"] = consumers - factory_calls["interceptors"] = interceptors - factory_calls["extensions"] = extensions - return fake_sdk_client - - monkeypatch.setattr(client_module, "ClientFactory", _FakeFactory) - monkeypatch.setattr( - client_module, - "build_agent_card_resolver", - lambda *_args: _FakeCardResolver("agent-card"), - ) + create_calls: dict[str, object] = {} + + async def _fake_create_client(agent, *, client_config, resolver_http_kwargs, **kwargs): # noqa: ANN001 + create_calls["agent"] = agent + create_calls["config"] = client_config + create_calls["resolver_http_kwargs"] = resolver_http_kwargs + create_calls["extra_kwargs"] = kwargs + return fake_sdk_client + + monkeypatch.setattr(client_module, "create_client", _fake_create_client) actual = await client._build_client() - config = factory_calls["config"] + config = create_calls["config"] assert isinstance(config, ClientConfig) assert config.streaming is True assert config.polling is False assert config.use_client_preference is True assert config.supported_protocol_bindings == ["HTTP+JSON"] - assert factory_calls["interceptors"] is None + assert create_calls["resolver_http_kwargs"] == { + "timeout": 3, + "headers": {"A2A-Version": "1.0", "Authorization": "Bearer peer-token"}, + } assert actual is fake_sdk_client @@ -163,32 +174,18 @@ async def test_build_client_enables_sdk_polling_when_polling_fallback_enabled( settings=A2AClientSettings(polling_fallback_enabled=True), ) fake_sdk_client = _FakeClient() - factory_calls: dict[str, object] = {} - - class _FakeFactory: - def __init__(self, config: ClientConfig, consumers: list[object] | None = None): - factory_calls["config"] = config - factory_calls["consumers"] = consumers - - def create( - self, - _card: object, - consumers: list[object] | None = None, - interceptors: list[object] | None = None, - extensions: list[str] | None = None, - ) -> _FakeClient: - return fake_sdk_client - - monkeypatch.setattr(client_module, "ClientFactory", _FakeFactory) - monkeypatch.setattr( - client_module, - "build_agent_card_resolver", - lambda *_args: _FakeCardResolver("agent-card"), - ) + create_calls: dict[str, object] = {} + + async def _fake_create_client(agent, *, client_config, resolver_http_kwargs, **kwargs): # noqa: ANN001 + del agent, resolver_http_kwargs, kwargs + create_calls["config"] = client_config + return fake_sdk_client + + monkeypatch.setattr(client_module, "create_client", _fake_create_client) actual = await client._build_client() - config = factory_calls["config"] + config = create_calls["config"] assert isinstance(config, ClientConfig) assert config.polling is True assert actual is fake_sdk_client @@ -197,10 +194,14 @@ def create( @pytest.mark.asyncio async def test_send_returns_last_event(monkeypatch: pytest.MonkeyPatch) -> None: client = A2AClient("http://agent.example.com") - fake_client = _FakeClient(events=["a", "b", "last"]) + fake_client = _FakeClient( + events=[_stream_message("a"), _stream_message("b"), _stream_message("last")] + ) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) response = await client.send("hello") - assert response == "last" + assert isinstance(response, StreamResponse) + assert response.HasField("message") + assert response.message.parts[0].text == "last" @pytest.mark.asyncio @@ -216,7 +217,7 @@ async def test_send_polling_fallback_returns_terminal_task(monkeypatch: pytest.M ), ) fake_client = _FakeClient( - events=[(_task("task-1", TaskState.TASK_STATE_WORKING), None)], + events=[_stream_task("task-1", TaskState.TASK_STATE_WORKING)], task_results=[ _task("task-1", TaskState.TASK_STATE_WORKING), _task("task-1", TaskState.TASK_STATE_COMPLETED), @@ -232,7 +233,9 @@ async def _fake_sleep(delay: float) -> None: response = await client.send("hello") - assert response == (_task("task-1", TaskState.TASK_STATE_COMPLETED), None) + assert isinstance(response, StreamResponse) + assert response.HasField("task") + assert response.task.status.state == TaskState.TASK_STATE_COMPLETED assert [params.id for params, _kwargs in fake_client.task_inputs] == ["task-1", "task-1"] assert sleep_calls == [0.1, 0.2] @@ -243,13 +246,14 @@ async def test_send_polling_fallback_skips_input_required(monkeypatch: pytest.Mo "http://agent.example.com", settings=A2AClientSettings(polling_fallback_enabled=True), ) - event = (_task("task-1", TaskState.TASK_STATE_INPUT_REQUIRED), None) - fake_client = _FakeClient(events=[event]) + fake_client = _FakeClient(events=[_stream_task("task-1", TaskState.TASK_STATE_INPUT_REQUIRED)]) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) response = await client.send("hello") - assert response == event + assert isinstance(response, StreamResponse) + assert response.HasField("task") + assert response.task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED assert fake_client.task_inputs == [] @@ -266,7 +270,7 @@ async def test_send_polling_fallback_timeout_raises(monkeypatch: pytest.MonkeyPa ), ) fake_client = _FakeClient( - events=[(_task("task-1", TaskState.TASK_STATE_WORKING), None)], + events=[_stream_task("task-1", TaskState.TASK_STATE_WORKING)], task_results=[_task("task-1", TaskState.TASK_STATE_WORKING)], ) now_values = iter([0.0, 0.0, 0.3]) @@ -297,8 +301,8 @@ async def test_send_polling_fallback_maps_get_task_error( ), ) fake_client = _FakeClient( - events=[(_task("task-1", TaskState.TASK_STATE_WORKING), None)], - task_fail=A2AClientHTTPError(404, "gone"), + events=[_stream_task("task-1", TaskState.TASK_STATE_WORKING)], + task_fail=FakeA2AClientHTTPError(404, "gone"), ) async def _fake_sleep(_delay: float) -> None: @@ -307,7 +311,7 @@ async def _fake_sleep(_delay: float) -> None: monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) monkeypatch.setattr(client, "_sleep", _fake_sleep) - with pytest.raises(A2AUnsupportedOperationError, match="does not support tasks/get"): + with pytest.raises(A2AUnsupportedOperationError, match="does not support GetTask"): await client.send("hello") @@ -319,12 +323,13 @@ async def test_send_message_adds_bearer_token_from_settings( "http://agent.example.com", settings=A2AClientSettings(bearer_token="peer-token"), ) - fake_client = _FakeClient(events=["ok"]) + fake_client = _FakeClient(events=[_stream_message("ok")]) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) result = [event async for event in client.send_message("hello")] - assert result == ["ok"] + assert len(result) == 1 + assert result[0].HasField("message") _, _, kwargs = fake_client.send_message_inputs[0] assert kwargs["request_metadata"] is None assert kwargs["context"] is not None @@ -339,12 +344,13 @@ async def test_send_message_adds_basic_auth_from_settings( "http://agent.example.com", settings=A2AClientSettings(basic_auth="user:pass"), ) - fake_client = _FakeClient(events=["ok"]) + fake_client = _FakeClient(events=[_stream_message("ok")]) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) result = [event async for event in client.send_message("hello")] - assert result == ["ok"] + assert len(result) == 1 + assert result[0].HasField("message") _, _, kwargs = fake_client.send_message_inputs[0] assert kwargs["request_metadata"] is None assert kwargs["context"] is not None @@ -361,7 +367,7 @@ async def test_send_message_preserves_explicit_authorization_metadata( "http://agent.example.com", settings=A2AClientSettings(bearer_token="peer-token"), ) - fake_client = _FakeClient(events=["ok"]) + fake_client = _FakeClient(events=[_stream_message("ok")]) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) result = [ @@ -372,18 +378,43 @@ async def test_send_message_preserves_explicit_authorization_metadata( ) ] - assert result == ["ok"] + assert len(result) == 1 + assert result[0].HasField("message") _, _, kwargs = fake_client.send_message_inputs[0] assert kwargs["request_metadata"] == {"trace_id": "trace-1"} assert kwargs["context"].state["headers"]["Authorization"] == "Bearer explicit-token" +@pytest.mark.asyncio +async def test_send_message_negotiates_extensions_via_service_parameters( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient(events=[_stream_message("ok")]) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + + result = [ + event + async for event in client.send_message( + "hello", + metadata={"A2A-Extensions": "https://example.com/ext-b"}, + extensions=["https://example.com/ext-a"], + ) + ] + + assert len(result) == 1 + payload, _, kwargs = fake_client.send_message_inputs[0] + assert kwargs["context"].service_parameters == { + "A2A-Extensions": "https://example.com/ext-a,https://example.com/ext-b" + } + + @pytest.mark.asyncio async def test_send_message_prefers_explicit_authorization_without_default_token( monkeypatch: pytest.MonkeyPatch, ) -> None: client = A2AClient("http://agent.example.com") - fake_client = _FakeClient(events=["ok"]) + fake_client = _FakeClient(events=[_stream_message("ok")]) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) result = [ @@ -393,7 +424,8 @@ async def test_send_message_prefers_explicit_authorization_without_default_token ) ] - assert result == ["ok"] + assert len(result) == 1 + assert result[0].HasField("message") _, _, kwargs = fake_client.send_message_inputs[0] assert kwargs["request_metadata"] is None assert kwargs["context"].state["headers"]["Authorization"] == "Bearer explicit-token" @@ -404,11 +436,11 @@ async def test_send_message_maps_jsonrpc_not_supported( monkeypatch: pytest.MonkeyPatch, ) -> None: rpc_error = JSONRPCErrorResponse( - error=JSONRPCError(code=-32601, message="Unsupported method: message/send"), + error=JSONRPCError(code=-32601, message="Unsupported method: SendMessage"), id="req-1", ) client = A2AClient("http://agent.example.com") - fake_client = _FakeClient(fail=A2AClientJSONRPCError(rpc_error)) + fake_client = _FakeClient(fail=FakeA2AClientJSONRPCError(rpc_error)) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) with pytest.raises( A2AUnsupportedOperationError, @@ -422,7 +454,7 @@ async def test_send_message_maps_jsonrpc_not_supported( async def test_get_agent_card_maps_json_error(monkeypatch: pytest.MonkeyPatch) -> None: class _BrokenResolver: async def get_agent_card(self, **_kwargs: object) -> object: - raise A2AClientJSONError("invalid json") + raise FakeA2AClientJSONError("invalid json") client = A2AClient("http://agent.example.com") monkeypatch.setattr( @@ -462,7 +494,10 @@ async def get_agent_card(self, **kwargs: object) -> object: assert resolver_http_kwargs == { "http_kwargs": { "timeout": 7, - "headers": {"Authorization": f"Basic {b64encode(b'user:pass').decode()}"}, + "headers": { + "A2A-Version": "1.0", + "Authorization": f"Basic {b64encode(b'user:pass').decode()}", + }, } } @@ -503,6 +538,25 @@ async def test_get_task_uses_authorization_header_context( assert kwargs["request_metadata"] == {"trace_id": "trace-1"} +@pytest.mark.asyncio +async def test_get_task_negotiates_extensions_via_service_parameters( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient() + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + + await client.get_task( + "task-id", + extensions=["https://example.com/ext-b", "https://example.com/ext-a"], + ) + + _params, kwargs = fake_client.task_inputs[0] + assert kwargs["context"].service_parameters == { + "A2A-Extensions": "https://example.com/ext-a,https://example.com/ext-b" + } + + @pytest.mark.asyncio async def test_cancel_task_uses_authorization_header_context( monkeypatch: pytest.MonkeyPatch, @@ -526,40 +580,48 @@ async def test_get_task_maps_transport_http_error( monkeypatch: pytest.MonkeyPatch, ) -> None: client = A2AClient("http://agent.example.com") - fake_client = _FakeClient(fail=A2AClientHTTPError(404, "gone")) + fake_client = _FakeClient(fail=FakeA2AClientHTTPError(404, "gone")) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) - with pytest.raises(A2AUnsupportedOperationError, match="does not support tasks/get"): + with pytest.raises(A2AUnsupportedOperationError, match="does not support GetTask"): await client.get_task("task-id") @pytest.mark.asyncio -async def test_resubscribe_forward_events(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_subscribe_to_task_forwards_events(monkeypatch: pytest.MonkeyPatch) -> None: client = A2AClient("http://agent.example.com") - fake_client = _FakeClient(events=[1, 2]) + fake_client = _FakeClient( + events=[ + _stream_task("task-id", TaskState.TASK_STATE_WORKING), + _stream_task("task-id", TaskState.TASK_STATE_COMPLETED), + ] + ) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) - result = [event async for event in client.resubscribe_task("task-id")] - assert result == [1, 2] + result = [event async for event in client.subscribe_to_task("task-id")] + assert [event.task.status.state for event in result] == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] @pytest.mark.asyncio -async def test_resubscribe_uses_authorization_header_context( +async def test_subscribe_to_task_uses_authorization_header_context( monkeypatch: pytest.MonkeyPatch, ) -> None: client = A2AClient("http://agent.example.com") - fake_client = _FakeClient(events=[1]) + fake_client = _FakeClient(events=[_stream_task("task-id", TaskState.TASK_STATE_WORKING)]) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) result = [ event - async for event in client.resubscribe_task( + async for event in client.subscribe_to_task( "task-id", metadata={"authorization": "Bearer explicit-token", "trace_id": "trace-1"}, ) ] - assert result == [1] - params, kwargs = fake_client.resubscribe_inputs[0] + assert [event.task.status.state for event in result] == [TaskState.TASK_STATE_WORKING] + params, kwargs = fake_client.subscribe_inputs[0] assert params.id == "task-id" assert kwargs["context"].state["headers"]["Authorization"] == "Bearer explicit-token" assert kwargs["request_metadata"] == {"trace_id": "trace-1"} diff --git a/tests/client/test_error_mapping.py b/tests/client/test_error_mapping.py index 08827cd..d2039c0 100644 --- a/tests/client/test_error_mapping.py +++ b/tests/client/test_error_mapping.py @@ -1,8 +1,6 @@ from __future__ import annotations import httpx -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError, A2AClientJSONRPCError -from a2a.types import JSONRPCError, JSONRPCErrorResponse from opencode_a2a.client.error_mapping import ( map_agent_card_error, @@ -19,22 +17,28 @@ A2ATimeoutError, A2AUnsupportedOperationError, ) +from opencode_a2a.jsonrpc.models import JSONRPCError, JSONRPCErrorResponse +from tests.support.fake_client_errors import ( + FakeA2AClientHTTPError, + FakeA2AClientJSONError, + FakeA2AClientJSONRPCError, +) def test_map_jsonrpc_error_variants() -> None: - invalid_params_error = A2AClientJSONRPCError( + invalid_params_error = FakeA2AClientJSONRPCError( JSONRPCErrorResponse( error=JSONRPCError(code=-32602, message="bad params"), id="req-1", ) ) - internal_error = A2AClientJSONRPCError( + internal_error = FakeA2AClientJSONRPCError( JSONRPCErrorResponse( error=JSONRPCError(code=-32603, message="internal"), id="req-2", ) ) - generic_error = A2AClientJSONRPCError( + generic_error = FakeA2AClientJSONRPCError( JSONRPCErrorResponse( error=JSONRPCError(code=-32000, message="generic"), id="req-3", @@ -53,11 +57,11 @@ def test_map_jsonrpc_error_variants() -> None: def test_map_http_error_variants() -> None: - auth_failed = map_http_error("message/send", A2AClientHTTPError(401, "denied")) - permission_denied = map_http_error("message/send", A2AClientHTTPError(403, "forbidden")) - unsupported = map_http_error("message/send", A2AClientHTTPError(405, "nope")) - reset = map_http_error("message/send", A2AClientHTTPError(503, "busy")) - unavailable = map_http_error("message/send", A2AClientHTTPError(500, "boom")) + auth_failed = map_http_error("SendMessage", FakeA2AClientHTTPError(401, "denied")) + permission_denied = map_http_error("SendMessage", FakeA2AClientHTTPError(403, "forbidden")) + unsupported = map_http_error("SendMessage", FakeA2AClientHTTPError(405, "nope")) + reset = map_http_error("SendMessage", FakeA2AClientHTTPError(503, "busy")) + unavailable = map_http_error("SendMessage", FakeA2AClientHTTPError(500, "boom")) assert isinstance(auth_failed, A2AAuthenticationError) assert isinstance(permission_denied, A2APermissionDeniedError) @@ -67,15 +71,15 @@ def test_map_http_error_variants() -> None: def test_map_operation_error_transport_and_timeout_variants() -> None: - timeout = map_operation_error("message/send", httpx.ReadTimeout("timed out")) - unavailable = map_operation_error("message/send", httpx.ConnectError("down")) + timeout = map_operation_error("SendMessage", httpx.ReadTimeout("timed out")) + unavailable = map_operation_error("SendMessage", httpx.ConnectError("down")) assert isinstance(timeout, A2ATimeoutError) assert isinstance(unavailable, A2AAgentUnavailableError) def test_map_agent_card_error_json_variant() -> None: - mapped = map_agent_card_error(A2AClientJSONError("invalid json")) + mapped = map_agent_card_error(FakeA2AClientJSONError("invalid json")) assert isinstance(mapped, A2APeerProtocolError) assert mapped.error_code == "invalid_agent_card" diff --git a/tests/client/test_payload_text.py b/tests/client/test_payload_text.py index a49132a..580a234 100644 --- a/tests/client/test_payload_text.py +++ b/tests/client/test_payload_text.py @@ -3,34 +3,30 @@ from a2a.types import ( Artifact, Message, + Part, Role, + StreamResponse, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus, ) -from opencode_a2a.a2a_utils import make_text_part from opencode_a2a.client.payload_text import extract_text def test_extract_text_prefers_stream_artifact_payload() -> None: - task = Task( - id="remote-task", - context_id="remote-context", - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) update = TaskArtifactUpdateEvent( task_id="remote-task", context_id="remote-context", artifact=Artifact( artifact_id="artifact-1", name="response", - parts=[make_text_part("streamed remote text")], + parts=[Part(text="streamed remote text")], ), ) - assert extract_text((task, update)) == "streamed remote text" + assert extract_text(StreamResponse(artifact_update=update)) == "streamed remote text" def test_extract_text_reads_task_status_message() -> None: @@ -42,7 +38,7 @@ def test_extract_text_reads_task_status_message() -> None: message=Message( role=Role.ROLE_AGENT, message_id="m1", - parts=[make_text_part("status message text")], + parts=[Part(text="status message text")], ), ), ) diff --git a/tests/client/test_request_context.py b/tests/client/test_request_context.py index 03ff3e7..1d0ea67 100644 --- a/tests/client/test_request_context.py +++ b/tests/client/test_request_context.py @@ -14,10 +14,10 @@ def test_split_request_metadata_and_default_headers() -> None: - request_metadata, extra_headers = split_request_metadata( + request_metadata, extra_headers, requested_extensions = split_request_metadata( { "authorization": "Bearer explicit-token", - "A2A-Version": "1.0.0", + "A2A-Extensions": "https://example.com/ext-b, https://example.com/ext-a", "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "tracestate": "vendor=value", "trace_id": "trace-1", @@ -27,23 +27,35 @@ def test_split_request_metadata_and_default_headers() -> None: assert request_metadata == {"trace_id": "trace-1"} assert extra_headers == { "Authorization": "Bearer explicit-token", - "A2A-Version": "1.0", "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "tracestate": "vendor=value", } - assert build_default_headers("peer-token") == {"Authorization": "Bearer peer-token"} + assert requested_extensions == ( + "https://example.com/ext-b", + "https://example.com/ext-a", + ) + assert build_default_headers("peer-token") == { + "Authorization": "Bearer peer-token", + "A2A-Version": "1.0", + } def test_build_default_headers_encodes_basic_auth_credentials() -> None: encoded = b64encode(b"user:pass").decode() - assert build_default_headers(None, "user:pass") == {"Authorization": f"Basic {encoded}"} + assert build_default_headers(None, "user:pass") == { + "Authorization": f"Basic {encoded}", + "A2A-Version": "1.0", + } def test_build_default_headers_accepts_pre_encoded_basic_auth() -> None: encoded = b64encode(b"user:pass").decode() - assert build_default_headers(None, encoded) == {"Authorization": f"Basic {encoded}"} + assert build_default_headers(None, encoded) == { + "Authorization": f"Basic {encoded}", + "A2A-Version": "1.0", + } def test_build_default_headers_rejects_invalid_basic_auth() -> None: @@ -53,19 +65,24 @@ def test_build_default_headers_rejects_invalid_basic_auth() -> None: def test_build_default_headers_prefers_bearer_over_basic_auth() -> None: assert build_default_headers("peer-token", "user:pass") == { - "Authorization": "Bearer peer-token" + "Authorization": "Bearer peer-token", + "A2A-Version": "1.0", } -def test_build_default_headers_includes_protocol_version() -> None: - assert build_default_headers("peer-token", protocol_version="1.0.0") == { +def test_build_default_headers_always_include_fixed_protocol_version() -> None: + assert build_default_headers("peer-token") == { "Authorization": "Bearer peer-token", "A2A-Version": "1.0", } -def test_build_call_context_without_headers_returns_none() -> None: - assert build_call_context(None, None) is None +def test_build_call_context_includes_fixed_protocol_version() -> None: + context = build_call_context(None, None) + + assert context is not None + assert context.state["headers"] == {"A2A-Version": "1.0"} + assert context.state["http_kwargs"]["headers"] == {"A2A-Version": "1.0"} def test_build_call_context_includes_current_trace_headers() -> None: @@ -80,6 +97,7 @@ def test_build_call_context_includes_current_trace_headers() -> None: assert context is not None assert context.state["headers"] == { + "A2A-Version": "1.0", "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "tracestate": "vendor=value", } @@ -104,6 +122,7 @@ def test_build_call_context_preserves_explicit_trace_headers_over_current_contex assert context is not None assert context.state["headers"] == { "Authorization": "Bearer peer-token", + "A2A-Version": "1.0", "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "tracestate": "vendor=value", } @@ -115,9 +134,34 @@ def test_build_call_context_carries_default_headers_without_interceptor_layer() assert isinstance(context, ClientCallContext) assert context.state["headers"] == { "Authorization": "Bearer peer-token", + "A2A-Version": "1.0", "X-Trace-Id": "trace-1", } assert context.state["http_kwargs"]["headers"] == { "Authorization": "Bearer peer-token", + "A2A-Version": "1.0", "X-Trace-Id": "trace-1", } + + +def test_build_call_context_merges_extension_service_parameters() -> None: + context = build_call_context( + "peer-token", + {"X-Trace-Id": "trace-1"}, + ("https://example.com/ext-b", "https://example.com/ext-a"), + ) + + assert isinstance(context, ClientCallContext) + assert context.service_parameters == { + "A2A-Extensions": "https://example.com/ext-a,https://example.com/ext-b" + } + + +def test_split_request_metadata_rejects_protocol_version_override() -> None: + with pytest.raises(ValueError, match="must not be overridden"): + split_request_metadata({"A2A-Version": "1.0"}) + + +def test_split_request_metadata_rejects_non_string_extensions_header() -> None: + with pytest.raises(ValueError, match="A2A-Extensions metadata header must be a string"): + split_request_metadata({"A2A-Extensions": ["https://example.com/ext-a"]}) diff --git a/tests/config/test_settings.py b/tests/config/test_settings.py index 1e15c00..8d11e53 100644 --- a/tests/config/test_settings.py +++ b/tests/config/test_settings.py @@ -7,6 +7,10 @@ from opencode_a2a import __version__ from opencode_a2a.config import Settings +from opencode_a2a.protocol_versions import ( + A2A_PROTOCOL_VERSION, + A2A_SUPPORTED_PROTOCOL_VERSIONS, +) def test_settings_missing_required(): @@ -85,31 +89,8 @@ def test_settings_valid(): assert settings.a2a_task_store_backend == "database" assert settings.a2a_task_store_database_url == "sqlite+aiosqlite:///./opencode-a2a.db" assert settings.a2a_version == __version__ - assert settings.a2a_protocol_version == "0.3" - assert settings.a2a_supported_protocol_versions == ("0.3", "1.0") - - -def test_settings_normalize_protocol_versions() -> None: - env = { - "A2A_STATIC_AUTH_CREDENTIALS": json.dumps( - [ - { - "scheme": "bearer", - "token": "test-token", - "principal": "automation", - } - ] - ), - "A2A_PROTOCOL_VERSION": "0.3.0", - "A2A_SUPPORTED_PROTOCOL_VERSIONS": "0.3.0,1.0.0,1.0", - "A2A_CLIENT_PROTOCOL_VERSION": "1.0.0", - } - with mock.patch.dict(os.environ, env, clear=True): - settings = Settings() - - assert settings.a2a_protocol_version == "0.3" - assert settings.a2a_supported_protocol_versions == ("0.3", "1.0") - assert settings.a2a_client_protocol_version == "1.0" + assert A2A_PROTOCOL_VERSION == "1.0" + assert A2A_SUPPORTED_PROTOCOL_VERSIONS == ("1.0",) def test_settings_allow_explicit_memory_backend() -> None: diff --git a/tests/conftest.py b/tests/conftest.py index c76b42a..423b389 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,168 +1,12 @@ from __future__ import annotations import asyncio -import sys -import types from collections.abc import Generator -from dataclasses import dataclass import pytest from sqlalchemy.ext.asyncio import AsyncEngine -def _install_legacy_test_shims() -> None: - import a2a.types as legacy_types - from a2a.client import errors as client_errors - from a2a.server.request_handlers.default_request_handler import LegacyRequestHandler - from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher - from a2a.server.routes.rest_dispatcher import RestDispatcher - from a2a.types import GetTaskRequest, SendMessageConfiguration, SendMessageRequest - from a2a.utils import errors as protocol_errors - from a2a.utils.constants import TransportProtocol - - from opencode_a2a.jsonrpc.models import JSONRPCError, JSONRPCErrorResponse - - class A2AClientHTTPError(client_errors.A2AClientError): - def __init__(self, status_code: int, message: str): - self.status_code = status_code - super().__init__(f"HTTP Error {status_code}: {message}") - - class A2AClientJSONError(client_errors.A2AClientError): - pass - - class A2AClientJSONRPCError(client_errors.A2AClientError): - def __init__(self, response: object): - self.response = response - error = getattr(response, "error", None) - message = getattr(error, "message", "JSON-RPC error") - super().__init__(message) - - class MessageSendConfiguration: - def __new__(cls, *args, **kwargs): - del args - accepted_output_modes = kwargs.pop("acceptedOutputModes", None) - if accepted_output_modes is not None: - kwargs["accepted_output_modes"] = accepted_output_modes - return SendMessageConfiguration(**kwargs) - - @dataclass - class TextPart: - text: str - - @dataclass - class DataPart: - data: object - - @dataclass - class FileWithBytes: - bytes: str - mimeType: str | None = None - name: str | None = None - - @dataclass - class FileWithUri: - uri: str - mimeType: str | None = None - name: str | None = None - - @dataclass - class FilePart: - file: FileWithBytes | FileWithUri - - class ServerError(Exception): - def __init__(self, error: Exception): - self.error = error - super().__init__(str(error)) - - class RESTAdapter: - def __init__(self, *, agent_card, http_handler, context_builder=None): - del agent_card - self._dispatcher = RestDispatcher( - request_handler=http_handler, - context_builder=context_builder, - ) - - def routes(self) -> dict[tuple[str, str], object]: - return { - ("/v1/message:send", "POST"): self._dispatcher.on_message_send, - ("/v1/message:stream", "POST"): self._dispatcher.on_message_send_stream, - ("/v1/tasks/{id}:cancel", "POST"): self._dispatcher.on_cancel_task, - ("/v1/tasks/{id}:subscribe", "GET"): self._dispatcher.on_subscribe_to_task, - ("/v1/tasks/{id}:subscribe", "POST"): self._dispatcher.on_subscribe_to_task, - ("/v1/tasks/{id}", "GET"): self._dispatcher.on_get_task, - ( - "/v1/tasks/{id}/pushNotificationConfigs/{push_id}", - "GET", - ): self._dispatcher.get_push_notification, - ( - "/v1/tasks/{id}/pushNotificationConfigs/{push_id}", - "DELETE", - ): self._dispatcher.delete_push_notification, - ( - "/v1/tasks/{id}/pushNotificationConfigs", - "POST", - ): self._dispatcher.set_push_notification, - ( - "/v1/tasks/{id}/pushNotificationConfigs", - "GET", - ): self._dispatcher.list_push_notifications, - ("/agent/authenticatedExtendedCard", "GET"): ( - self._dispatcher.handle_authenticated_agent_card - ), - } - - client_errors.A2AClientHTTPError = A2AClientHTTPError - client_errors.A2AClientJSONError = A2AClientJSONError - client_errors.A2AClientJSONRPCError = A2AClientJSONRPCError - - legacy_types.A2AError = protocol_errors.A2AError - legacy_types.InvalidParamsError = protocol_errors.InvalidParamsError - legacy_types.UnsupportedOperationError = protocol_errors.UnsupportedOperationError - legacy_types.JSONRPCError = JSONRPCError - legacy_types.JSONRPCErrorResponse = JSONRPCErrorResponse - legacy_types.MessageSendConfiguration = MessageSendConfiguration - legacy_types.MessageSendParams = SendMessageRequest - legacy_types.TaskIdParams = GetTaskRequest - legacy_types.TaskQueryParams = GetTaskRequest - legacy_types.TextPart = TextPart - legacy_types.DataPart = DataPart - legacy_types.FilePart = FilePart - legacy_types.FileWithBytes = FileWithBytes - legacy_types.FileWithUri = FileWithUri - legacy_types.TransportProtocol = TransportProtocol - - protocol_errors.ServerError = ServerError - - apps_module = types.ModuleType("a2a.server.apps") - jsonrpc_module = types.ModuleType("a2a.server.apps.jsonrpc") - fastapi_app_module = types.ModuleType("a2a.server.apps.jsonrpc.fastapi_app") - jsonrpc_app_module = types.ModuleType("a2a.server.apps.jsonrpc.jsonrpc_app") - rest_module = types.ModuleType("a2a.server.apps.rest") - rest_adapter_module = types.ModuleType("a2a.server.apps.rest.rest_adapter") - - fastapi_app_module.A2AFastAPIApplication = JsonRpcDispatcher - fastapi_app_module.A2AFastAPI = JsonRpcDispatcher - jsonrpc_app_module.JSONRPCApplication = JsonRpcDispatcher - jsonrpc_app_module.DefaultCallContextBuilder = object - rest_adapter_module.RESTAdapter = RESTAdapter - - sys.modules.setdefault("a2a.server.apps", apps_module) - sys.modules.setdefault("a2a.server.apps.jsonrpc", jsonrpc_module) - sys.modules.setdefault("a2a.server.apps.jsonrpc.fastapi_app", fastapi_app_module) - sys.modules.setdefault("a2a.server.apps.jsonrpc.jsonrpc_app", jsonrpc_app_module) - sys.modules.setdefault("a2a.server.apps.rest", rest_module) - sys.modules.setdefault("a2a.server.apps.rest.rest_adapter", rest_adapter_module) - - default_handler_module = sys.modules.get("a2a.server.request_handlers.default_request_handler") - if default_handler_module is not None and not hasattr( - default_handler_module, "DefaultRequestHandler" - ): - default_handler_module.DefaultRequestHandler = LegacyRequestHandler - - -_install_legacy_test_shims() - - @pytest.fixture(autouse=True) def dispose_app_database_engines(monkeypatch: pytest.MonkeyPatch) -> Generator[None]: import opencode_a2a.server.application as app_module diff --git a/tests/contracts/test_extension_contract_consistency.py b/tests/contracts/test_extension_contract_consistency.py index dbf13e2..e1a2102 100644 --- a/tests/contracts/test_extension_contract_consistency.py +++ b/tests/contracts/test_extension_contract_consistency.py @@ -2,9 +2,19 @@ import pytest from opencode_a2a.contracts.extensions import ( + COMPATIBILITY_PROFILE_EXTENSION_URI, + INTERRUPT_CALLBACK_EXTENSION_URI, INTERRUPT_CALLBACK_METHODS, + INTERRUPT_RECOVERY_EXTENSION_URI, + MODEL_SELECTION_EXTENSION_URI, + PROVIDER_DISCOVERY_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + SESSION_MANAGEMENT_EXTENSION_URI, SESSION_QUERY_DEFAULT_LIMIT, SESSION_QUERY_MAX_LIMIT, + STREAMING_EXTENSION_URI, + WIRE_CONTRACT_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, build_capability_snapshot, build_compatibility_profile_params, build_interrupt_callback_extension_params, @@ -17,26 +27,16 @@ build_wire_contract_params, build_workspace_control_extension_params, ) -from opencode_a2a.jsonrpc.application import SESSION_CONTEXT_PREFIX +from opencode_a2a.jsonrpc.methods import SESSION_CONTEXT_PREFIX from opencode_a2a.profile.runtime import build_runtime_profile +from opencode_a2a.protocol_versions import A2A_PROTOCOL_VERSION from opencode_a2a.server.agent_card import build_authenticated_extended_agent_card -from opencode_a2a.server.application import ( - COMPATIBILITY_PROFILE_EXTENSION_URI, - INTERRUPT_CALLBACK_EXTENSION_URI, - INTERRUPT_RECOVERY_EXTENSION_URI, - MODEL_SELECTION_EXTENSION_URI, - PROVIDER_DISCOVERY_EXTENSION_URI, - SESSION_BINDING_EXTENSION_URI, - SESSION_MANAGEMENT_EXTENSION_URI, - STREAMING_EXTENSION_URI, - WIRE_CONTRACT_EXTENSION_URI, - WORKSPACE_CONTROL_EXTENSION_URI, - create_app, -) +from opencode_a2a.server.application import create_app from tests.support.helpers import ( DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) from tests.support.helpers import make_settings +from tests.support.session_extensions import _extension_headers def test_extension_ssot_matches_agent_card_contracts() -> None: @@ -81,16 +81,12 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: runtime_profile=runtime_profile, ) expected_compatibility_profile = build_compatibility_profile_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, - supported_protocol_versions=settings.a2a_supported_protocol_versions, - default_protocol_version=settings.a2a_protocol_version, ) expected_wire_contract = build_wire_contract_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, - supported_protocol_versions=settings.a2a_supported_protocol_versions, - default_protocol_version=settings.a2a_protocol_version, ) assert session_binding.params == expected_session_binding, ( @@ -175,16 +171,12 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: runtime_profile=runtime_profile, ) expected_compatibility_profile = build_compatibility_profile_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, - supported_protocol_versions=settings.a2a_supported_protocol_versions, - default_protocol_version=settings.a2a_protocol_version, ) expected_wire_contract = build_wire_contract_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, - supported_protocol_versions=settings.a2a_supported_protocol_versions, - default_protocol_version=settings.a2a_protocol_version, ) assert session_binding == expected_session_binding, ( @@ -280,7 +272,7 @@ async def test_runtime_supported_methods_align_with_capability_snapshot( runtime_profile = build_runtime_profile(settings) capability_snapshot = build_capability_snapshot(runtime_profile=runtime_profile) wire_contract = build_wire_contract_params( - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, runtime_profile=runtime_profile, ) transport = httpx.ASGITransport(app=app) @@ -294,8 +286,8 @@ async def test_runtime_supported_methods_align_with_capability_snapshot( assert response.status_code == 200 error = response.json()["error"] - assert error["data"]["supported_methods"] == capability_snapshot.supported_jsonrpc_methods() - assert error["data"]["supported_methods"] == wire_contract["all_jsonrpc_methods"] + assert error["data"]["supportedMethods"] == capability_snapshot.supported_jsonrpc_methods() + assert error["data"]["supportedMethods"] == wire_contract["all_jsonrpc_methods"] @pytest.mark.asyncio @@ -408,7 +400,7 @@ async def test_extension_notification_contracts_return_204( async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={"jsonrpc": "2.0", "method": method, "params": params}, ) assert response.status_code == 204 diff --git a/tests/execution/test_agent_errors.py b/tests/execution/test_agent_errors.py index b81ec95..8d116d2 100644 --- a/tests/execution/test_agent_errors.py +++ b/tests/execution/test_agent_errors.py @@ -6,8 +6,8 @@ import pytest from a2a.server.events.event_queue import EventQueue from a2a.types import Task, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent +from google.protobuf.json_format import MessageToDict -from opencode_a2a.a2a_utils import proto_to_dict from opencode_a2a.execution.executor import OpencodeAgentExecutor from opencode_a2a.opencode_upstream_client import ( OpencodeMessage, @@ -36,7 +36,7 @@ def _is_terminal_status_event(event: TaskStatusUpdateEvent) -> bool: def _metadata_dict(metadata) -> dict: # noqa: ANN001 if metadata is None: return {} - return proto_to_dict(metadata) + return MessageToDict(metadata) @pytest.mark.asyncio diff --git a/tests/execution/test_agent_helpers.py b/tests/execution/test_agent_helpers.py index 11180fa..2724c9a 100644 --- a/tests/execution/test_agent_helpers.py +++ b/tests/execution/test_agent_helpers.py @@ -6,10 +6,7 @@ import pytest from a2a.types import TaskState -from opencode_a2a.execution.executor import ( - BlockType, - _await_stream_terminal_signal, - _build_output_metadata, +from opencode_a2a.execution.stream_events import ( _build_progress_identity, _coerce_number, _extract_event_session_id, @@ -20,21 +17,28 @@ _extract_stream_snapshot_text, _extract_stream_terminal_signal, _extract_token_usage, - _extract_upstream_error_detail, _extract_upstream_error_from_event, _extract_upstream_error_from_response, - _format_inband_upstream_error, - _format_stream_terminal_error, - _format_upstream_error, - _merge_token_usage, _normalize_interrupt_question_options, _normalize_interrupt_questions, _normalize_role, _preview_log_value, - _resolve_upstream_error_profile, +) +from opencode_a2a.execution.stream_state import ( + BlockType, + _build_output_metadata, + _merge_token_usage, _StreamOutputState, _TTLCache, ) +from opencode_a2a.execution.upstream_error_translator import ( + _await_stream_terminal_signal, + _extract_upstream_error_detail, + _format_inband_upstream_error, + _format_stream_terminal_error, + _format_upstream_error, + _resolve_upstream_error_profile, +) from opencode_a2a.opencode_upstream_client import UpstreamContractError diff --git a/tests/execution/test_metrics.py b/tests/execution/test_metrics.py index 7283055..57aead6 100644 --- a/tests/execution/test_metrics.py +++ b/tests/execution/test_metrics.py @@ -6,9 +6,18 @@ import pytest from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore -from a2a.types import Message, Role, SendMessageRequest, Task, TaskState, TaskStatus +from a2a.types import ( + AgentCapabilities, + AgentCard, + Message, + Part, + Role, + SendMessageRequest, + Task, + TaskState, + TaskStatus, +) -from opencode_a2a.a2a_utils import make_text_part from opencode_a2a.execution.executor import OpencodeAgentExecutor, _StreamOutputState from opencode_a2a.server.application import OpencodeRequestHandler from tests.support.helpers import DummyEventQueue, make_settings @@ -19,11 +28,15 @@ def _make_message_send_params() -> SendMessageRequest: message=Message( message_id="msg-user-1", role=Role.ROLE_USER, - parts=[make_text_part("hello")], + parts=[Part(text="hello")], ) ) +def _agent_card() -> AgentCard: + return AgentCard(name="opencode-a2a", capabilities=AgentCapabilities(streaming=True)) + + @pytest.mark.asyncio async def test_stream_request_metrics_track_total_and_active(caplog) -> None: class _FakeAggregator: @@ -50,7 +63,11 @@ async def _cleanup_producer(self, producer_task, task_id): # noqa: ANN001 except asyncio.CancelledError: pass - handler = _TestHandler(agent_executor=MagicMock(), task_store=InMemoryTaskStore()) + handler = _TestHandler( + agent_executor=MagicMock(), + task_store=InMemoryTaskStore(), + agent_card=_agent_card(), + ) with caplog.at_level(logging.DEBUG, logger="opencode_a2a.execution.executor"): stream = handler.on_message_send_stream(_make_message_send_params()) diff --git a/tests/execution/test_multipart_input.py b/tests/execution/test_multipart_input.py index 566791a..2b38e3d 100644 --- a/tests/execution/test_multipart_input.py +++ b/tests/execution/test_multipart_input.py @@ -1,6 +1,7 @@ import pytest -from a2a.types import DataPart, FilePart, FileWithBytes, FileWithUri, TaskState, TextPart +from a2a.types import Part, TaskState +from opencode_a2a.a2a_utils import make_data_part from opencode_a2a.execution.executor import OpencodeAgentExecutor from opencode_a2a.opencode_upstream_client import OpencodeMessage from tests.support.helpers import DummyEventQueue, make_request_context_with_parts, make_settings @@ -82,14 +83,8 @@ async def test_execute_forwards_text_and_file_parts() -> None: task_id="task-1", context_id="ctx-1", parts=[ - TextPart(text="Describe this file"), - FilePart( - file=FileWithBytes( - bytes="aGVsbG8=", - mimeType="text/plain", - name="note.txt", - ) - ), + Part(text="Describe this file"), + Part(raw=b"hello", filename="note.txt", media_type="text/plain"), ], ) @@ -126,12 +121,10 @@ async def test_execute_accepts_file_only_input() -> None: task_id="task-1", context_id="ctx-1", parts=[ - FilePart( - file=FileWithUri( - uri="file:///tmp/report.pdf", - mimeType="application/pdf", - name="report.pdf", - ) + Part( + url="file:///tmp/report.pdf", + filename="report.pdf", + media_type="application/pdf", ) ], ) @@ -159,7 +152,7 @@ async def test_execute_rejects_data_parts() -> None: context = make_request_context_with_parts( task_id="task-1", context_id="ctx-1", - parts=[DataPart(data={"kind": "json", "value": 1})], + parts=[make_data_part({"kind": "json", "value": 1})], ) await executor.execute(context, queue) @@ -167,7 +160,7 @@ async def test_execute_rejects_data_parts() -> None: assert client.sent_calls == [] task = queue.events[-1] assert task.status.state == TaskState.TASK_STATE_FAILED - assert "DataPart input is not supported" in ( + assert "structured data is not supported" in ( getattr(task.status.message.parts[0], "text", None) or getattr(getattr(task.status.message.parts[0], "root", None), "text", "") ) diff --git a/tests/execution/test_opencode_agent_session_binding.py b/tests/execution/test_opencode_agent_session_binding.py index 708774b..6fb189c 100644 --- a/tests/execution/test_opencode_agent_session_binding.py +++ b/tests/execution/test_opencode_agent_session_binding.py @@ -6,18 +6,16 @@ import httpx import pytest -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONRPCError from a2a.types import ( Artifact, - JSONRPCError, - JSONRPCErrorResponse, + Part, + StreamResponse, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus, ) -from opencode_a2a.a2a_utils import make_text_part from opencode_a2a.client import A2AClient from opencode_a2a.client.errors import ( A2AClientResetRequiredError, @@ -33,9 +31,14 @@ maybe_handle_tools, merge_streamed_tool_output, ) +from opencode_a2a.jsonrpc.models import JSONRPCError, JSONRPCErrorResponse from opencode_a2a.opencode_upstream_client import OpencodeMessage from opencode_a2a.server.client_manager import A2AClientManager from opencode_a2a.trace_context import TraceContext, bind_trace_context +from tests.support.fake_client_errors import ( + FakeA2AClientHTTPError, + FakeA2AClientJSONRPCError, +) from tests.support.helpers import ( DummyChatOpencodeUpstreamClient, DummyEventQueue, @@ -287,8 +290,6 @@ async def test_agent_handles_a2a_call_tool(monkeypatch) -> None: Artifact, Task, TaskArtifactUpdateEvent, - TaskState, - TaskStatus, ) class MockA2AClient: @@ -308,7 +309,7 @@ async def send_message(self, text: str): artifact=Artifact( artifact_id="artifact-1", name="response", - parts=[make_text_part(f"remote response to {text}")], + parts=[Part(text=f"remote response to {text}")], ), ), ) @@ -400,7 +401,7 @@ async def _send_message(_text: str): artifact=Artifact( artifact_id="artifact-1", name="response", - parts=[make_text_part("streamed tool output")], + parts=[Part(text="streamed tool output")], ), ), ) @@ -422,8 +423,6 @@ def borrow_client(self, url: str): Artifact, Task, TaskArtifactUpdateEvent, - TaskState, - TaskStatus, ) client = ToolLoopClient() @@ -493,7 +492,7 @@ def __aiter__(self): return self async def __anext__(self): - raise A2AClientHTTPError(401, "unauthorized") + raise FakeA2AClientHTTPError(401, "unauthorized") class MockA2AClient: def send_message(self, text: str): @@ -542,21 +541,16 @@ async def test_agent_a2a_call_uses_server_side_basic_auth_headers( ) -> None: fake_sdk_client = _FakeOutboundClient( events=[ - ( - Task( - id="remote-task", - context_id="remote-ctx", - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ), - TaskArtifactUpdateEvent( + StreamResponse( + artifact_update=TaskArtifactUpdateEvent( task_id="remote-task", context_id="remote-ctx", artifact=Artifact( artifact_id="artifact-1", name="response", - parts=[make_text_part("remote response")], + parts=[Part(text="remote response")], ), - ), + ) ) ] ) @@ -569,8 +563,6 @@ async def test_agent_a2a_call_uses_server_side_basic_auth_headers( a2a_client_use_client_preference=False, a2a_client_bearer_token=None, a2a_client_basic_auth="user:pass", - a2a_client_protocol_version=None, - a2a_protocol_version="0.3", a2a_client_supported_transports=("JSONRPC", "HTTP+JSON"), a2a_client_cache_ttl_seconds=60.0, a2a_client_cache_maxsize=1, @@ -610,21 +602,16 @@ async def test_agent_a2a_call_propagates_current_trace_headers( ) -> None: fake_sdk_client = _FakeOutboundClient( events=[ - ( - Task( - id="remote-task", - context_id="remote-ctx", - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ), - TaskArtifactUpdateEvent( + StreamResponse( + artifact_update=TaskArtifactUpdateEvent( task_id="remote-task", context_id="remote-ctx", artifact=Artifact( artifact_id="artifact-1", name="response", - parts=[make_text_part("remote response")], + parts=[Part(text="remote response")], ), - ), + ) ) ] ) @@ -637,8 +624,6 @@ async def test_agent_a2a_call_propagates_current_trace_headers( a2a_client_use_client_preference=False, a2a_client_bearer_token=None, a2a_client_basic_auth=None, - a2a_client_protocol_version=None, - a2a_protocol_version="0.3", a2a_client_supported_transports=("JSONRPC", "HTTP+JSON"), a2a_client_cache_ttl_seconds=60.0, a2a_client_cache_maxsize=1, @@ -681,7 +666,7 @@ async def test_agent_a2a_call_propagates_current_trace_headers( def test_map_a2a_tool_exception_protocol_and_unavailable_variants() -> None: - rpc_error = A2AClientJSONRPCError( + rpc_error = FakeA2AClientJSONRPCError( JSONRPCErrorResponse( error=JSONRPCError(code=-32602, message="bad params"), id="req-1", @@ -717,7 +702,7 @@ def test_map_a2a_tool_exception_additional_variants() -> None: class _FakeOutboundClient: - def __init__(self, events: list[object]) -> None: + def __init__(self, events: list[StreamResponse]) -> None: self._events = list(events) self.send_message_inputs: list[tuple[object, object, object]] = [] diff --git a/tests/execution/test_session_ownership.py b/tests/execution/test_session_ownership.py index c5a0158..1ec5100 100644 --- a/tests/execution/test_session_ownership.py +++ b/tests/execution/test_session_ownership.py @@ -6,8 +6,9 @@ from a2a.server.events.event_queue import EventQueue from a2a.types import TaskState -from opencode_a2a.execution.executor import OpencodeAgentExecutor, _TTLCache +from opencode_a2a.execution.executor import OpencodeAgentExecutor from opencode_a2a.execution.session_manager import SessionManager +from opencode_a2a.execution.stream_state import _TTLCache from opencode_a2a.opencode_upstream_client import OpencodeUpstreamClient from opencode_a2a.server.state_store import ( DatabaseSessionStateRepository, diff --git a/tests/execution/test_streaming_output_contract_blocks.py b/tests/execution/test_streaming_output_contract_blocks.py index dfb3d49..cd8bb9a 100644 --- a/tests/execution/test_streaming_output_contract_blocks.py +++ b/tests/execution/test_streaming_output_contract_blocks.py @@ -1,6 +1,5 @@ import pytest -from opencode_a2a.a2a_utils import part_kind from opencode_a2a.execution.executor import ( OpencodeAgentExecutor, ) @@ -120,7 +119,7 @@ async def test_streaming_emits_structured_tool_part_updates() -> None: assert [payload["status"] for payload in payloads] == ["pending", "running", "completed"] assert all(payload["call_id"] == "call-1" for payload in payloads) assert all(payload["tool"] == "bash" for payload in payloads) - assert all(part_kind(ev.artifact.parts[0]) == "data" for ev in tool_updates) + assert all(ev.artifact.parts[0].HasField("data") for ev in tool_updates) @pytest.mark.asyncio @@ -160,7 +159,7 @@ async def test_streaming_downgrades_structured_tool_updates_when_json_output_not tool_updates = [ev for ev in updates if _artifact_stream_meta(ev)["block_type"] == "tool_call"] assert len(tool_updates) == 1 assert _part_text(tool_updates[0]) == '{"call_id":"call-1","status":"running","tool":"bash"}' - assert part_kind(tool_updates[0].artifact.parts[0]) == "text" + assert tool_updates[0].artifact.parts[0].HasField("text") assert any(_artifact_stream_meta(ev)["block_type"] == "text" for ev in updates) diff --git a/tests/execution/test_streaming_output_contract_core.py b/tests/execution/test_streaming_output_contract_core.py index f02485c..d8c82a9 100644 --- a/tests/execution/test_streaming_output_contract_core.py +++ b/tests/execution/test_streaming_output_contract_core.py @@ -3,20 +3,15 @@ import pytest from a2a.types import ( - FilePart, - FileWithUri, + Part, Task, TaskState, TaskStatusUpdateEvent, ) -from opencode_a2a.execution.executor import ( - BlockType, - OpencodeAgentExecutor, - _extract_token_usage, - _extract_tool_part_payload, - _StreamOutputState, -) +from opencode_a2a.execution.executor import OpencodeAgentExecutor +from opencode_a2a.execution.stream_events import _extract_token_usage, _extract_tool_part_payload +from opencode_a2a.execution.stream_state import BlockType, _StreamOutputState from opencode_a2a.task_states import TERMINAL_TASK_STATES from tests.support.helpers import ( DummyEventQueue, @@ -56,12 +51,10 @@ async def test_streaming_accepts_file_input_without_breaking_contract() -> None: task_id="task-1", context_id="ctx-1", parts=[ - FilePart( - file=FileWithUri( - uri="file:///tmp/report.pdf", - mimeType="application/pdf", - name="report.pdf", - ) + Part( + url="file:///tmp/report.pdf", + filename="report.pdf", + media_type="application/pdf", ) ], call_context=SimpleNamespace(state={"a2a_streaming_request": True}), diff --git a/tests/jsonrpc/test_application_dispatch.py b/tests/jsonrpc/test_application_dispatch.py new file mode 100644 index 0000000..75d03f7 --- /dev/null +++ b/tests/jsonrpc/test_application_dispatch.py @@ -0,0 +1,390 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest +from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent +from a2a.utils.errors import InvalidParamsError, UnsupportedOperationError +from fastapi import FastAPI +from fastapi.responses import JSONResponse + +import opencode_a2a.server.application as app_module +from opencode_a2a.contracts.extensions import SESSION_MANAGEMENT_EXTENSION_URI +from opencode_a2a.jsonrpc.models import JSONRPCRequest +from tests.support.helpers import DummySessionQueryOpencodeUpstreamClient, make_settings +from tests.support.session_extensions import _BASE_SETTINGS, _jsonrpc_app + + +def _build_dispatcher(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + app_module, + "OpencodeUpstreamClient", + DummySessionQueryOpencodeUpstreamClient, + ) + app = app_module.create_app(make_settings(test_bearer_token="test-token", **_BASE_SETTINGS)) + return _jsonrpc_app(app) + + +def _request_context() -> SimpleNamespace: + return SimpleNamespace(state={}, tenant="") + + +async def _empty_stream() -> AsyncIterator[TaskStatusUpdateEvent]: + if False: + yield TaskStatusUpdateEvent( + task_id="task-0", + context_id="ctx-0", + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + + +async def _broken_stream() -> AsyncIterator[TaskStatusUpdateEvent]: + yield TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + raise InvalidParamsError(message="bad stream") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("method", "handler_name", "expected"), + [ + ("CancelTask", "_handle_cancel_task", {"ok": "cancel"}), + ("GetTask", "_handle_get_task", {"ok": "get"}), + ("ListTasks", "_handle_list_tasks", {"ok": "list"}), + ( + "CreateTaskPushNotificationConfig", + "_handle_create_task_push_notification_config", + {"ok": "create-push"}, + ), + ( + "GetTaskPushNotificationConfig", + "_handle_get_task_push_notification_config", + {"ok": "get-push"}, + ), + ( + "ListTaskPushNotificationConfigs", + "_handle_list_task_push_notification_configs", + {"ok": "list-push"}, + ), + ("DeleteTaskPushNotificationConfig", "_handle_delete_task_push_notification_config", None), + ("GetExtendedAgentCard", "_handle_get_extended_agent_card", {"ok": "extended-card"}), + ], +) +async def test_process_non_streaming_request_dispatches_sdk_methods( + monkeypatch: pytest.MonkeyPatch, + method: str, + handler_name: str, + expected: dict[str, str] | None, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + handler = AsyncMock(return_value=expected) + monkeypatch.setattr(dispatcher, handler_name, handler, raising=False) + + result = await dispatcher._process_non_streaming_request( + object(), + SimpleNamespace(state={"method": method}), + ) + + assert result == expected + handler.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_non_streaming_request_rejects_unknown_method( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + + with pytest.raises(UnsupportedOperationError, match="Method UnknownMethod is not supported."): + await dispatcher._process_non_streaming_request( + object(), + SimpleNamespace(state={"method": "UnknownMethod"}), + ) + + +@pytest.mark.asyncio +async def test_process_streaming_request_supports_subscribe_and_empty_stream( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + monkeypatch.setattr( + dispatcher.request_handler, + "on_subscribe_to_task", + lambda _request_obj, _context: _empty_stream(), + ) + + wrapped = await dispatcher._process_streaming_request( + 9, + object(), + SimpleNamespace(state={"method": "SubscribeToTask"}), + ) + + assert [item async for item in wrapped] == [] + + +@pytest.mark.asyncio +async def test_process_streaming_request_wraps_stream_errors( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + monkeypatch.setattr( + dispatcher.request_handler, + "on_message_send_stream", + lambda _request_obj, _context: _broken_stream(), + ) + + wrapped = await dispatcher._process_streaming_request( + 10, + object(), + SimpleNamespace(state={"method": "SendStreamingMessage"}), + ) + payloads = [item async for item in wrapped] + + assert payloads[0]["id"] == 10 + assert payloads[0]["result"]["statusUpdate"]["taskId"] == "task-1" + assert payloads[1]["error"]["code"] == -32602 + assert payloads[1]["error"]["message"] == "bad stream" + + +@pytest.mark.asyncio +async def test_process_streaming_request_rejects_non_stream_method( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + + with pytest.raises(UnsupportedOperationError, match="Stream not supported"): + await dispatcher._process_streaming_request( + 11, + object(), + SimpleNamespace(state={"method": "GetTask"}), + ) + + +@pytest.mark.asyncio +async def test_generate_protocol_error_response_supports_a2a_error_payloads( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + monkeypatch.setattr( + "opencode_a2a.jsonrpc.application.adapt_jsonrpc_error", + lambda _error: InvalidParamsError( + message="bad request", + data={"field": "params"}, + ), + ) + + response = dispatcher._generate_protocol_error_response( + 12, + UnsupportedOperationError(), + ) + + assert response.status_code == 200 + assert response.body == ( + b'{"jsonrpc":"2.0","id":12,"error":{"code":-32602,' + b'"message":"bad request","data":{"field":"params"}}}' + ) + + +@pytest.mark.asyncio +async def test_handle_core_request_supports_extended_card_notification_and_missing_card( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + + notification = JSONRPCRequest.model_validate( + {"jsonrpc": "2.0", "method": "GetExtendedAgentCard", "params": {}} + ) + response = await dispatcher._handle_core_request( + MagicMock(), + {"params": {}}, + notification, + ) + assert response.status_code == 204 + + monkeypatch.setattr(dispatcher._http_handler, "extended_agent_card", None, raising=False) + request = JSONRPCRequest.model_validate( + {"jsonrpc": "2.0", "id": 13, "method": "GetExtendedAgentCard", "params": {}} + ) + error_response = await dispatcher._handle_core_request( + MagicMock(), + {"params": {}}, + request, + ) + + assert error_response.status_code == 200 + assert b"The agent does not support authenticated extended cards" in error_response.body + + +@pytest.mark.asyncio +async def test_handle_core_request_returns_204_for_unknown_notification( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + base_request = JSONRPCRequest.model_validate( + {"jsonrpc": "2.0", "method": "NoSuchMethod", "params": {}} + ) + + response = await dispatcher._handle_core_request( + MagicMock(), + {"params": {}}, + base_request, + ) + + assert response.status_code == 204 + + +@pytest.mark.asyncio +async def test_handle_core_request_invalid_params_and_handler_errors( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + monkeypatch.setattr(dispatcher._context_builder, "build", lambda _request: _request_context()) + + base_request = JSONRPCRequest.model_validate( + {"jsonrpc": "2.0", "id": 14, "method": "GetTask", "params": {}} + ) + invalid_response = await dispatcher._handle_core_request( + MagicMock(), + {"params": "bad"}, + base_request, + ) + assert invalid_response.status_code == 200 + assert b'"code":-32602' in invalid_response.body + + monkeypatch.setattr( + dispatcher, + "_process_non_streaming_request", + AsyncMock(side_effect=InvalidParamsError(message="handler failed")), + ) + error_response = await dispatcher._handle_core_request( + MagicMock(), + {"params": {"id": "task-1"}}, + base_request, + ) + assert error_response.status_code == 200 + assert error_response.body == ( + b'{"jsonrpc":"2.0","id":14,"error":{"code":-32602,"message":"Invalid parameters"}}' + ) + + +@pytest.mark.asyncio +async def test_handle_core_request_streaming_and_non_streaming_notifications( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + monkeypatch.setattr(dispatcher._context_builder, "build", lambda _request: _request_context()) + streaming_result = _empty_stream() + create_response = MagicMock(return_value=JSONResponse({"stream": "ok"})) + monkeypatch.setattr( + dispatcher, + "_process_streaming_request", + AsyncMock(return_value=streaming_result), + ) + monkeypatch.setattr(dispatcher, "_create_response", create_response) + + streaming_request = JSONRPCRequest.model_validate( + { + "jsonrpc": "2.0", + "id": 15, + "method": "SendStreamingMessage", + "params": { + "message": { + "messageId": "msg-1", + "role": "ROLE_USER", + "parts": [{"text": "hello"}], + } + }, + } + ) + streaming_response = await dispatcher._handle_core_request( + MagicMock(), + { + "params": { + "message": { + "messageId": "msg-1", + "role": "ROLE_USER", + "parts": [{"text": "hello"}], + } + } + }, + streaming_request, + ) + + assert streaming_response.body == b'{"stream":"ok"}' + create_response.assert_called_once() + + monkeypatch.setattr( + dispatcher, + "_process_non_streaming_request", + AsyncMock(return_value={"ignored": True}), + ) + notification = JSONRPCRequest.model_validate( + {"jsonrpc": "2.0", "method": "GetTask", "params": {"id": "task-1"}} + ) + response = await dispatcher._handle_core_request( + MagicMock(), + {"params": {"id": "task-1"}}, + notification, + ) + + assert response.status_code == 204 + + +@pytest.mark.asyncio +async def test_handle_requests_normalizes_invalid_request_id_and_extension_params( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dispatcher = _build_dispatcher(monkeypatch) + app = FastAPI() + dispatcher.add_routes_to_app(app) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + invalid_id = await client.post( + "/", + json={"jsonrpc": "2.0", "id": {"bad": 1}, "method": "SendMessage", "params": {}}, + ) + invalid_id_payload = invalid_id.json() + assert invalid_id_payload["id"] is None + + class _Request: + def __init__(self) -> None: + self.state = SimpleNamespace() + + async def json(self) -> dict[str, object]: + return { + "jsonrpc": "2.0", + "id": 16, + "method": "opencode.sessions.list", + "params": "bad", + } + + fake_base_request = SimpleNamespace( + id=16, + method="opencode.sessions.list", + params="bad", + ) + monkeypatch.setattr( + "opencode_a2a.jsonrpc.application.JSONRPCRequest.model_validate", + lambda _body: fake_base_request, + ) + monkeypatch.setattr( + dispatcher._context_builder, + "build", + lambda _request: SimpleNamespace( + requested_extensions={SESSION_MANAGEMENT_EXTENSION_URI}, + state={}, + tenant="", + ), + ) + + invalid_extension_response = await dispatcher.handle_requests(_Request()) + invalid_extension_payload = invalid_extension_response.body + assert b'"code":-32602' in invalid_extension_payload + assert b'"message":"Invalid parameters"' in invalid_extension_payload diff --git a/tests/jsonrpc/test_dispatch_registry.py b/tests/jsonrpc/test_dispatch_registry.py index 7405703..21f3bf9 100644 --- a/tests/jsonrpc/test_dispatch_registry.py +++ b/tests/jsonrpc/test_dispatch_registry.py @@ -1,22 +1,14 @@ -import types -from unittest.mock import AsyncMock - import httpx import pytest -from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent, UnsupportedOperationError from fastapi.responses import JSONResponse import opencode_a2a.server.application as app_module -from opencode_a2a.a2a_protocol import V1_JSONRPC_METHOD_TO_LEGACY_METHOD -from opencode_a2a.jsonrpc.application import ( - OpencodeSessionManagementJSONRPCApplication, - _normalize_core_message_part, - _normalize_core_message_payload, - _normalize_core_message_role, - _normalize_core_request_params, -) +from opencode_a2a.a2a_protocol import CORE_JSONRPC_METHODS +from opencode_a2a.contracts.extensions import SESSION_MANAGEMENT_EXTENSION_URI +from opencode_a2a.jsonrpc.application import OpencodeSessionManagementJSONRPCApplication from tests.support.helpers import DummySessionQueryOpencodeUpstreamClient, make_settings -from tests.support.session_extensions import _BASE_SETTINGS, _jsonrpc_app +from tests.support.jsonrpc_error_assertions import assert_v1_error_reason, error_context_detail +from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers, _jsonrpc_app @pytest.mark.asyncio @@ -49,17 +41,14 @@ async def test_extension_registry_tracks_configured_methods(monkeypatch) -> None @pytest.mark.asyncio -async def test_core_jsonrpc_methods_delegate_to_base_app(monkeypatch) -> None: - async def _fake_core_handle(self, request, body, base_request, *, protocol_version): # noqa: ANN001 - del self, request, body, protocol_version - return JSONResponse( - { - "delegated_method": V1_JSONRPC_METHOD_TO_LEGACY_METHOD.get( - base_request.method, - base_request.method, - ) - } - ) +@pytest.mark.parametrize("method", ("SendMessage", "SendStreamingMessage", "GetTask", "CancelTask")) +async def test_core_jsonrpc_methods_delegate_to_base_app( + monkeypatch, + method: str, +) -> None: + async def _fake_core_handle(self, request, body, base_request): # noqa: ANN001 + del self, request, body + return JSONResponse({"delegated_method": base_request.method}) monkeypatch.setattr( OpencodeSessionManagementJSONRPCApplication, @@ -72,26 +61,19 @@ async def _fake_core_handle(self, request, body, base_request, *, protocol_versi async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer test-token"}, - json={"jsonrpc": "2.0", "id": 1, "method": "message/send", "params": {}}, + headers=_extension_headers({"Authorization": "Bearer test-token"}), + json={"jsonrpc": "2.0", "id": 1, "method": method, "params": {}}, ) assert response.status_code == 200 - assert response.json() == {"delegated_method": "message/send"} + assert response.json() == {"delegated_method": method} @pytest.mark.asyncio async def test_sdk_owned_non_chat_jsonrpc_methods_delegate_to_base_app(monkeypatch) -> None: - async def _fake_core_handle(self, request, body, base_request, *, protocol_version): # noqa: ANN001 - del self, request, body, protocol_version - return JSONResponse( - { - "delegated_method": V1_JSONRPC_METHOD_TO_LEGACY_METHOD.get( - base_request.method, - base_request.method, - ) - } - ) + async def _fake_core_handle(self, request, body, base_request): # noqa: ANN001 + del self, request, body + return JSONResponse({"delegated_method": base_request.method}) monkeypatch.setattr( OpencodeSessionManagementJSONRPCApplication, @@ -104,71 +86,26 @@ async def _fake_core_handle(self, request, body, base_request, *, protocol_versi async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer test-token"}, + headers=_extension_headers({"Authorization": "Bearer test-token"}), json={ "jsonrpc": "2.0", "id": 2, - "method": "tasks/pushNotificationConfig/get", + "method": "GetTaskPushNotificationConfig", "params": {}, }, ) assert response.status_code == 200 - assert response.json() == {"delegated_method": "tasks/pushNotificationConfig/get"} + assert response.json() == {"delegated_method": "GetTaskPushNotificationConfig"} -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("alias_method", "canonical_method"), - ( - ("SendMessage", "message/send"), - ("SendStreamingMessage", "message/stream"), - ("GetTask", "tasks/get"), - ("CancelTask", "tasks/cancel"), - ("GetExtendedAgentCard", "agent/getAuthenticatedExtendedCard"), - ("GetTaskPushNotificationConfig", "tasks/pushNotificationConfig/get"), - ("ListTaskPushNotificationConfigs", "tasks/pushNotificationConfig/list"), - ("CreateTaskPushNotificationConfig", "tasks/pushNotificationConfig/set"), - ("DeleteTaskPushNotificationConfig", "tasks/pushNotificationConfig/delete"), - ), -) -async def test_v1_pascalcase_jsonrpc_aliases_delegate_to_canonical_methods( - monkeypatch, - alias_method: str, - canonical_method: str, -) -> None: - async def _fake_core_handle(self, request, body, base_request, *, protocol_version): # noqa: ANN001 - del self, request, body, protocol_version - return JSONResponse( - { - "delegated_method": V1_JSONRPC_METHOD_TO_LEGACY_METHOD.get( - base_request.method, - base_request.method, - ) - } - ) - - monkeypatch.setattr( - OpencodeSessionManagementJSONRPCApplication, - "_handle_core_request", - _fake_core_handle, - ) - app = app_module.create_app(make_settings(test_bearer_token="test-token", **_BASE_SETTINGS)) - - transport = httpx.ASGITransport(app=app) - async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - response = await client.post( - "/", - headers={ - "Authorization": "Bearer test-token", - "A2A-Version": "1.0", - }, - json={"jsonrpc": "2.0", "id": 1, "method": alias_method, "params": {}}, - ) - - assert response.status_code == 200 - assert response.headers["A2A-Version"] == "1.0" - assert response.json() == {"delegated_method": canonical_method} +def test_core_jsonrpc_methods_are_canonical_pascalcase() -> None: + assert "SendMessage" in CORE_JSONRPC_METHODS + assert "SendStreamingMessage" in CORE_JSONRPC_METHODS + assert "GetTask" in CORE_JSONRPC_METHODS + assert "CancelTask" in CORE_JSONRPC_METHODS + assert "message/send" not in CORE_JSONRPC_METHODS + assert "tasks/get" not in CORE_JSONRPC_METHODS @pytest.mark.asyncio @@ -182,8 +119,8 @@ async def test_extension_methods_stay_on_local_registry(monkeypatch) -> None: ) ) - async def _unexpected_delegate(self, request, body, base_request, *, protocol_version): # noqa: ANN001 - del self, request, body, base_request, protocol_version + async def _unexpected_delegate(self, request, body, base_request): # noqa: ANN001 + del self, request, body, base_request raise AssertionError("extension method should not delegate to base JSON-RPC app") monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) @@ -205,7 +142,7 @@ async def _unexpected_delegate(self, request, body, base_request, *, protocol_ve async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer test-token"}, + headers=_extension_headers({"Authorization": "Bearer test-token"}), json={ "jsonrpc": "2.0", "id": 1, @@ -218,153 +155,40 @@ async def _unexpected_delegate(self, request, body, base_request, *, protocol_ve assert response.json()["result"]["items"][0]["id"] == "s-1" -def test_core_request_normalizers_cover_v1_message_shapes() -> None: - assert _normalize_core_message_role(None) is None - assert _normalize_core_message_role("user") == "ROLE_USER" - assert _normalize_core_message_role("agent") == "ROLE_AGENT" - assert _normalize_core_message_role("ROLE_USER") == "ROLE_USER" - - assert _normalize_core_message_part("plain") == "plain" - assert _normalize_core_message_part({"kind": "text", "text": "hello"}) == {"text": "hello"} - assert _normalize_core_message_part({"type": "data", "data": {"step": 1}}) == { - "data": {"step": 1} - } - assert _normalize_core_message_part({"kind": "custom", "value": 1}) == {"value": 1} - assert _normalize_core_message_part( - { - "kind": "file", - "file": {"bytes": "aGVsbG8=", "name": "report.txt", "mimeType": "text/plain"}, - } - ) == { - "raw": "aGVsbG8=", - "filename": "report.txt", - "mediaType": "text/plain", - } - assert _normalize_core_message_part( - { - "kind": "file", - "file": { - "uri": "file:///tmp/report.txt", - "name": "report.txt", - "mediaType": "text/plain", - }, - } - ) == { - "url": "file:///tmp/report.txt", - "filename": "report.txt", - "mediaType": "text/plain", - } - assert _normalize_core_message_part({"kind": "file", "url": "https://example.com"}) == { - "url": "https://example.com" - } - - assert _normalize_core_message_payload("raw-message") == "raw-message" - assert _normalize_core_message_payload( - { - "role": "user", - "parts": [{"kind": "text", "text": "hello"}], - } - ) == { - "role": "ROLE_USER", - "parts": [{"text": "hello"}], - } - assert _normalize_core_request_params("GetTask", {"id": "task-1"}) == {"id": "task-1"} - assert _normalize_core_request_params( - "SendMessage", - { - "message": { - "role": "agent", - "parts": [{"kind": "text", "text": "hello"}], - } - }, - ) == { - "message": { - "role": "ROLE_AGENT", - "parts": [{"text": "hello"}], - } - } - - @pytest.mark.asyncio -async def test_local_core_request_processors_cover_custom_v1_bypass(monkeypatch) -> None: +async def test_extension_methods_require_explicit_a2a_extensions_header(monkeypatch) -> None: monkeypatch.setattr( app_module, "OpencodeUpstreamClient", DummySessionQueryOpencodeUpstreamClient, ) - app = app_module.create_app(make_settings(test_bearer_token="test-token", **_BASE_SETTINGS)) - jsonrpc_app = _jsonrpc_app(app) - - handlers = { - "SendMessage": "_handle_send_message", - "CancelTask": "_handle_cancel_task", - "GetTask": "_handle_get_task", - "ListTasks": "_handle_list_tasks", - "CreateTaskPushNotificationConfig": "_handle_create_task_push_notification_config", - "GetTaskPushNotificationConfig": "_handle_get_task_push_notification_config", - "ListTaskPushNotificationConfigs": "_handle_list_task_push_notification_configs", - "GetExtendedAgentCard": "_handle_get_extended_agent_card", - } - for method, attr in handlers.items(): - mock = AsyncMock(return_value={"method": method}) - monkeypatch.setattr(jsonrpc_app, attr, mock) - result = await jsonrpc_app._process_non_streaming_request( # noqa: SLF001 - object(), - types.SimpleNamespace(state={"method": method}), + app = app_module.create_app( + make_settings( + test_bearer_token="test-token", + a2a_log_payloads=False, + **_BASE_SETTINGS, ) - assert result == {"method": method} - mock.assert_awaited_once() - - delete_mock = AsyncMock(return_value=None) - monkeypatch.setattr(jsonrpc_app, "_handle_delete_task_push_notification_config", delete_mock) - result = await jsonrpc_app._process_non_streaming_request( # noqa: SLF001 - object(), - types.SimpleNamespace(state={"method": "DeleteTaskPushNotificationConfig"}), ) - assert result is None - delete_mock.assert_awaited_once() - - with pytest.raises(UnsupportedOperationError, match="Method MissingMethod is not supported"): - await jsonrpc_app._process_non_streaming_request( # noqa: SLF001 - object(), - types.SimpleNamespace(state={"method": "MissingMethod"}), - ) - async def _stream_then_error(_request_obj, _context): # noqa: ANN001 - yield TaskStatusUpdateEvent( - task_id="task-1", - context_id="ctx-1", - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer test-token"}, + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "opencode.sessions.list", + "params": {"limit": 1}, + }, ) - raise UnsupportedOperationError(message="stream failed") - async def _empty_stream(_request_obj, _context): # noqa: ANN001 - if False: # pragma: no cover - yield None - - jsonrpc_app.request_handler.on_message_send_stream = _stream_then_error - send_stream = await jsonrpc_app._process_streaming_request( # noqa: SLF001 - 7, - object(), - types.SimpleNamespace(state={"method": "SendStreamingMessage"}), - ) - send_items = [item async for item in send_stream] - assert send_items[0]["id"] == 7 - assert send_items[0]["jsonrpc"] == "2.0" - assert "result" in send_items[0] - assert send_items[1]["error"]["code"] == -32004 - - jsonrpc_app.request_handler.on_subscribe_to_task = _empty_stream - subscribe_stream = await jsonrpc_app._process_streaming_request( # noqa: SLF001 - 8, - object(), - types.SimpleNamespace(state={"method": "SubscribeToTask"}), - ) - assert [item async for item in subscribe_stream] == [] - - with pytest.raises(UnsupportedOperationError, match="Stream not supported"): - await jsonrpc_app._process_streaming_request( # noqa: SLF001 - 9, - object(), - types.SimpleNamespace(state={"method": "UnknownStream"}), - ) + assert response.status_code == 200 + error = response.json()["error"] + assert_v1_error_reason(error, reason="EXTENSION_NEGOTIATION_REQUIRED") + context = error_context_detail(error) + assert context is not None + assert context["method"] == "opencode.sessions.list" + assert context["requiredExtensions"] == [SESSION_MANAGEMENT_EXTENSION_URI] # noqa: RUF005 + assert context["requestedExtensions"] == [] + assert context["header"] == "A2A-Extensions" diff --git a/tests/jsonrpc/test_error_responses.py b/tests/jsonrpc/test_error_responses.py index fe61e8b..a7c71ae 100644 --- a/tests/jsonrpc/test_error_responses.py +++ b/tests/jsonrpc/test_error_responses.py @@ -4,7 +4,7 @@ from opencode_a2a.jsonrpc.error_responses import ( GOOGLE_RPC_ERROR_INFO_TYPE, - adapt_jsonrpc_error_for_protocol, + adapt_jsonrpc_error, authorization_forbidden_error, interrupt_not_found_error, interrupt_type_mismatch_error, @@ -22,8 +22,7 @@ def test_jsonrpc_error_mapping_helpers_preserve_business_contract_fields() -> None: unsupported = method_not_supported_error( method="unsupported.method", - supported_methods=["message/send", "tasks/get"], - protocol_version="0.3", + supported_methods=["SendMessage", "GetTask"], ) assert unsupported.code == -32601 assert unsupported.data["type"] == "METHOD_NOT_SUPPORTED" @@ -122,8 +121,8 @@ def test_invalid_error_helper_wraps_a2a_error() -> None: def test_version_not_supported_error_includes_supported_versions() -> None: error = version_not_supported_error( requested_version="2.0", - supported_protocol_versions=["0.3", "1.0"], - default_protocol_version="0.3", + supported_protocol_versions=["1.0"], + default_protocol_version="1.0", ) assert error.code == -32001 @@ -131,36 +130,33 @@ def test_version_not_supported_error_includes_supported_versions() -> None: assert error.data == { "type": "VERSION_NOT_SUPPORTED", "requested_version": "2.0", - "supported_protocol_versions": ["0.3", "1.0"], - "default_protocol_version": "0.3", + "supported_protocol_versions": ["1.0"], + "default_protocol_version": "1.0", } def test_adapt_standard_jsonrpc_error_for_v1_uses_standard_message_and_camel_case_data() -> None: - adapted = adapt_jsonrpc_error_for_protocol( - "1.0", + adapted = adapt_jsonrpc_error( method_not_supported_error( method="unsupported.method", - supported_methods=["message/send", "tasks/get"], - protocol_version="1.0", + supported_methods=["SendMessage", "GetTask"], ), ) assert adapted.message == "Method not found" assert adapted.data == { "method": "unsupported.method", - "supportedMethods": ["message/send", "tasks/get"], + "supportedMethods": ["SendMessage", "GetTask"], "protocolVersion": "1.0", } def test_adapt_a2a_specific_error_for_v1_uses_error_info_details() -> None: - adapted = adapt_jsonrpc_error_for_protocol( - "1.0", + adapted = adapt_jsonrpc_error( version_not_supported_error( requested_version="1.1", - supported_protocol_versions=["0.3", "1.0"], - default_protocol_version="0.3", + supported_protocol_versions=["1.0"], + default_protocol_version="1.0", ), ) @@ -171,23 +167,20 @@ def test_adapt_a2a_specific_error_for_v1_uses_error_info_details() -> None: "domain": "a2a-protocol.org", "metadata": { "requestedVersion": "1.1", - "supportedProtocolVersions": '["0.3","1.0"]', - "defaultProtocolVersion": "0.3", + "supportedProtocolVersions": '["1.0"]', + "defaultProtocolVersion": "1.0", }, } assert adapted.data[1] == { "@type": "type.googleapis.com/opencode_a2a.ErrorContext", "requestedVersion": "1.1", - "supportedProtocolVersions": ["0.3", "1.0"], - "defaultProtocolVersion": "0.3", + "supportedProtocolVersions": ["1.0"], + "defaultProtocolVersion": "1.0", } def test_adapt_a2a_root_error_for_v1_uses_error_type_reason() -> None: - adapted = adapt_jsonrpc_error_for_protocol( - "1.0", - UnsupportedOperationError(), - ) + adapted = adapt_jsonrpc_error(UnsupportedOperationError()) assert adapted.code == -32004 assert adapted.message == "This operation is not supported" diff --git a/tests/jsonrpc/test_jsonrpc_methods.py b/tests/jsonrpc/test_jsonrpc_methods.py index 9bb018a..4d6cebc 100644 --- a/tests/jsonrpc/test_jsonrpc_methods.py +++ b/tests/jsonrpc/test_jsonrpc_methods.py @@ -1,6 +1,6 @@ import pytest -from opencode_a2a.jsonrpc.application import ( +from opencode_a2a.jsonrpc.methods import ( _extract_provider_catalog, _normalize_model_summaries, _normalize_permission_reply, diff --git a/tests/jsonrpc/test_jsonrpc_unsupported_method.py b/tests/jsonrpc/test_jsonrpc_unsupported_method.py index 1ba9e4e..8827182 100644 --- a/tests/jsonrpc/test_jsonrpc_unsupported_method.py +++ b/tests/jsonrpc/test_jsonrpc_unsupported_method.py @@ -1,6 +1,7 @@ import httpx import pytest +from opencode_a2a.protocol_versions import A2A_PROTOCOL_VERSION from opencode_a2a.server.application import create_app from tests.support.helpers import make_settings @@ -25,15 +26,14 @@ async def test_unsupported_method_returns_unified_error() -> None: assert "error" in body error = body["error"] assert error["code"] == -32601 - assert "Unsupported method" in error["message"] + assert error["message"] == "Method not found" data = error["data"] - assert data["type"] == "METHOD_NOT_SUPPORTED" assert data["method"] == "unsupported.method" - assert "supported_methods" in data - assert "message/send" in data["supported_methods"] - assert "opencode.sessions.list" in data["supported_methods"] - assert data["protocol_version"] == settings.a2a_protocol_version + assert "supportedMethods" in data + assert "SendMessage" in data["supportedMethods"] + assert "opencode.sessions.list" in data["supportedMethods"] + assert data["protocolVersion"] == A2A_PROTOCOL_VERSION @pytest.mark.asyncio @@ -61,11 +61,11 @@ async def test_unsupported_method_uses_requested_protocol_version() -> None: "supportedMethods": body["error"]["data"]["supportedMethods"], "protocolVersion": "1.0", } - assert "message/send" in body["error"]["data"]["supportedMethods"] + assert "SendMessage" in body["error"]["data"]["supportedMethods"] @pytest.mark.asyncio -async def test_pascalcase_jsonrpc_aliases_remain_unsupported_on_v03() -> None: +async def test_sendmessage_uses_canonical_v1_method_dispatch() -> None: settings = make_settings(test_bearer_token="test-token") app = create_app(settings) transport = httpx.ASGITransport(app=app) @@ -79,9 +79,10 @@ async def test_pascalcase_jsonrpc_aliases_remain_unsupported_on_v03() -> None: assert response.status_code == 200 body = response.json() - assert body["error"]["code"] == -32601 - assert body["error"]["data"]["method"] == "SendMessage" - assert "message/send" in body["error"]["data"]["supported_methods"] + assert body["jsonrpc"] == "2.0" + assert body["id"] == 123 + assert body.get("error") is None + assert body["result"]["task"]["status"]["state"] == "TASK_STATE_FAILED" @pytest.mark.asyncio @@ -94,7 +95,7 @@ async def test_unsupported_v1_minor_version_returns_v1_error_details() -> None: response = await client.post( "/?A2A-Version=1.1", headers={"Authorization": "Bearer test-token"}, - json={"jsonrpc": "2.0", "id": 124, "method": "message/send", "params": {}}, + json={"jsonrpc": "2.0", "id": 124, "method": "SendMessage", "params": {}}, ) assert response.status_code == 200 @@ -106,8 +107,8 @@ async def test_unsupported_v1_minor_version_returns_v1_error_details() -> None: "domain": "a2a-protocol.org", "metadata": { "requestedVersion": "1.1", - "supportedProtocolVersions": '["0.3","1.0"]', - "defaultProtocolVersion": "0.3", + "supportedProtocolVersions": '["1.0"]', + "defaultProtocolVersion": "1.0", }, } @@ -122,7 +123,7 @@ async def test_unsupported_version_returns_version_error() -> None: response = await client.post( "/?A2A-Version=2.0", headers={"Authorization": "Bearer test-token"}, - json={"jsonrpc": "2.0", "id": 123, "method": "message/send", "params": {}}, + json={"jsonrpc": "2.0", "id": 123, "method": "SendMessage", "params": {}}, ) assert response.status_code == 200 @@ -130,11 +131,15 @@ async def test_unsupported_version_returns_version_error() -> None: assert body["jsonrpc"] == "2.0" assert body["id"] == 123 assert body["error"]["code"] == -32001 - assert body["error"]["data"] == { - "type": "VERSION_NOT_SUPPORTED", - "requested_version": "2.0", - "supported_protocol_versions": ["0.3", "1.0"], - "default_protocol_version": "0.3", + assert body["error"]["data"][0] == { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "VERSION_NOT_SUPPORTED", + "domain": "a2a-protocol.org", + "metadata": { + "requestedVersion": "2.0", + "supportedProtocolVersions": '["1.0"]', + "defaultProtocolVersion": "1.0", + }, } @@ -183,9 +188,8 @@ async def test_disabled_shell_reports_current_supported_methods() -> None: body = response.json() error = body["error"] assert error["code"] == -32601 - assert error["data"]["type"] == "METHOD_NOT_SUPPORTED" assert error["data"]["method"] == "opencode.sessions.shell" - assert "opencode.sessions.shell" not in error["data"]["supported_methods"] + assert "opencode.sessions.shell" not in error["data"]["supportedMethods"] @pytest.mark.asyncio @@ -218,9 +222,8 @@ async def test_policy_disabled_shell_reports_current_supported_methods() -> None body = response.json() error = body["error"] assert error["code"] == -32601 - assert error["data"]["type"] == "METHOD_NOT_SUPPORTED" assert error["data"]["method"] == "opencode.sessions.shell" - assert "opencode.sessions.shell" not in error["data"]["supported_methods"] + assert "opencode.sessions.shell" not in error["data"]["supportedMethods"] @pytest.mark.asyncio @@ -245,6 +248,5 @@ async def test_disabled_workspace_mutation_reports_current_supported_methods() - body = response.json() error = body["error"] assert error["code"] == -32601 - assert error["data"]["type"] == "METHOD_NOT_SUPPORTED" assert error["data"]["method"] == "opencode.workspaces.create" - assert "opencode.workspaces.create" not in error["data"]["supported_methods"] + assert "opencode.workspaces.create" not in error["data"]["supportedMethods"] diff --git a/tests/jsonrpc/test_opencode_session_extension_commands.py b/tests/jsonrpc/test_opencode_session_extension_commands.py index a3653c1..9a4688d 100644 --- a/tests/jsonrpc/test_opencode_session_extension_commands.py +++ b/tests/jsonrpc/test_opencode_session_extension_commands.py @@ -5,7 +5,17 @@ DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) from tests.support.helpers import make_basic_auth_header, make_settings -from tests.support.session_extensions import _BASE_SETTINGS, _jsonrpc_app, _session_meta +from tests.support.jsonrpc_error_assertions import ( + assert_v1_error_context, + assert_v1_error_metadata_contains, + assert_v1_error_reason, +) +from tests.support.session_extensions import ( + _BASE_SETTINGS, + _extension_headers, + _jsonrpc_app, + _session_meta, +) @pytest.mark.asyncio @@ -32,7 +42,7 @@ async def test_session_command_extension_success(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -110,7 +120,7 @@ async def test_session_command_extension_uses_registry_bearer_principal(monkeypa async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: owned = await client.post( "/", - headers={"Authorization": "Bearer token-build"}, + headers=_extension_headers({"Authorization": "Bearer token-build"}), json={ "jsonrpc": "2.0", "id": 32001, @@ -123,7 +133,7 @@ async def test_session_command_extension_uses_registry_bearer_principal(monkeypa ) foreign = await client.post( "/", - headers={"Authorization": "Bearer token-other"}, + headers=_extension_headers({"Authorization": "Bearer token-other"}), json={ "jsonrpc": "2.0", "id": 32002, @@ -137,7 +147,11 @@ async def test_session_command_extension_uses_registry_bearer_principal(monkeypa assert owned.status_code == 200 assert owned.json().get("error") is None - assert foreign.json()["error"]["data"]["type"] == "SESSION_FORBIDDEN" + assert_v1_error_reason( + foreign.json()["error"], + reason="SESSION_FORBIDDEN", + metadata={"session_id": "s-1"}, + ) @pytest.mark.asyncio @@ -156,7 +170,7 @@ async def test_session_command_extension_prefers_workspace_metadata(monkeypatch) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 3202, @@ -188,7 +202,7 @@ async def test_session_command_extension_accepts_request_model(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -231,7 +245,7 @@ async def test_session_command_extension_rejects_invalid_params(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) missing_command = await client.post( "/", headers=headers, @@ -306,7 +320,7 @@ async def session_command(self, session_id: str, request: dict, *, directory=Non transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -322,7 +336,11 @@ async def session_command(self, session_id: str, request: dict, *, directory=Non ) payload = resp.json() assert payload["error"]["code"] == -32001 - assert payload["error"]["data"]["type"] == "SESSION_NOT_FOUND" + assert_v1_error_reason( + payload["error"], + reason="SESSION_NOT_FOUND", + metadata={"session_id": "s-404"}, + ) @pytest.mark.asyncio @@ -339,7 +357,7 @@ async def test_session_shell_extension_disabled_by_default(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -355,8 +373,8 @@ async def test_session_shell_extension_disabled_by_default(monkeypatch): ) payload = resp.json() assert payload["error"]["code"] == -32601 - assert payload["error"]["data"]["type"] == "METHOD_NOT_SUPPORTED" - assert "opencode.sessions.shell" not in payload["error"]["data"]["supported_methods"] + assert payload["error"]["data"]["method"] == "opencode.sessions.shell" + assert "opencode.sessions.shell" not in payload["error"]["data"]["supportedMethods"] assert dummy.shell_calls == [] @@ -388,7 +406,7 @@ async def test_session_shell_extension_success_when_enabled(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = make_basic_auth_header("operator", "op-pass") + headers = _extension_headers(make_basic_auth_header("operator", "op-pass")) resp = await client.post( "/", headers=headers, @@ -439,7 +457,7 @@ async def test_session_shell_extension_rejects_invalid_params(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = make_basic_auth_header("operator", "op-pass") + headers = _extension_headers(make_basic_auth_header("operator", "op-pass")) missing_agent = await client.post( "/", headers=headers, @@ -509,7 +527,7 @@ async def test_session_shell_extension_rejects_owner_mismatch(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = make_basic_auth_header("operator", "op-pass") + headers = _extension_headers(make_basic_auth_header("operator", "op-pass")) resp = await client.post( "/", headers=headers, @@ -525,7 +543,11 @@ async def test_session_shell_extension_rejects_owner_mismatch(monkeypatch): ) payload = resp.json() assert payload["error"]["code"] == -32006 - assert payload["error"]["data"]["type"] == "SESSION_FORBIDDEN" + assert_v1_error_reason( + payload["error"], + reason="SESSION_FORBIDDEN", + metadata={"session_id": "s-1"}, + ) assert dummy.shell_calls == [] @@ -581,7 +603,7 @@ async def test_session_shell_extension_requires_session_shell_capability(monkeyp async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 3341, @@ -595,12 +617,15 @@ async def test_session_shell_extension_requires_session_shell_capability(monkeyp payload = resp.json() assert payload["error"]["code"] == -32007 - assert payload["error"]["data"] == { - "type": "AUTHORIZATION_FORBIDDEN", - "method": "opencode.sessions.shell", - "capability": "session_shell", - "credential_id": "cred-bearer", - } + assert_v1_error_reason( + payload["error"], + reason="AUTHORIZATION_FORBIDDEN", + metadata={ + "method": "opencode.sessions.shell", + "capability": "session_shell", + "credential_id": "cred-bearer", + }, + ) assert dummy.shell_calls == [] @@ -652,7 +677,7 @@ async def test_session_shell_extension_accepts_registry_bearer_with_explicit_cap async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/", - headers={"Authorization": "Bearer token-ops"}, + headers=_extension_headers({"Authorization": "Bearer token-ops"}), json={ "jsonrpc": "2.0", "id": 3342, @@ -687,7 +712,7 @@ async def session_command(self, session_id: str, request: dict, *, directory=Non transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -703,8 +728,11 @@ async def session_command(self, session_id: str, request: dict, *, directory=Non ) payload = resp.json() assert payload["error"]["code"] == -32003 - assert payload["error"]["data"]["type"] == "UPSTREAM_HTTP_ERROR" - assert payload["error"]["data"]["upstream_status"] == 500 + assert_v1_error_metadata_contains( + payload["error"], + reason="UPSTREAM_HTTP_ERROR", + metadata={"upstream_status": 500}, + ) @pytest.mark.asyncio @@ -731,7 +759,7 @@ async def session_shell(self, session_id: str, request: dict, *, directory=None) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = make_basic_auth_header("operator", "op-pass") + headers = _extension_headers(make_basic_auth_header("operator", "op-pass")) resp = await client.post( "/", headers=headers, @@ -747,4 +775,12 @@ async def session_shell(self, session_id: str, request: dict, *, directory=None) ) payload = resp.json() assert payload["error"]["code"] == -32002 - assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" + assert_v1_error_reason( + payload["error"], + reason="UPSTREAM_UNREACHABLE", + metadata={"method": "opencode.sessions.shell", "session_id": "s-1"}, + ) + assert_v1_error_context( + payload["error"], + metadata={"method": "opencode.sessions.shell", "session_id": "s-1"}, + ) diff --git a/tests/jsonrpc/test_opencode_session_extension_interrupts.py b/tests/jsonrpc/test_opencode_session_extension_interrupts.py index 6ec8529..7f24896 100644 --- a/tests/jsonrpc/test_opencode_session_extension_interrupts.py +++ b/tests/jsonrpc/test_opencode_session_extension_interrupts.py @@ -7,7 +7,11 @@ DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) from tests.support.helpers import make_settings -from tests.support.session_extensions import _BASE_SETTINGS +from tests.support.jsonrpc_error_assertions import ( + assert_v1_error_reason, + error_context_detail, +) +from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers @pytest.mark.asyncio @@ -56,7 +60,7 @@ async def permission_reply( transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -103,7 +107,7 @@ async def test_interrupt_callback_extension_rejects_legacy_permission_fields(mon transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -132,7 +136,7 @@ async def test_interrupt_callback_extension_rejects_legacy_metadata_directory(mo transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -218,7 +222,7 @@ async def question_reject( transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) reply_resp = await client.post( "/", headers=headers, @@ -300,7 +304,7 @@ async def permission_reply( transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -313,7 +317,11 @@ async def permission_reply( ) payload = resp.json() assert payload["error"]["code"] == -32004 - assert payload["error"]["data"]["type"] == "INTERRUPT_REQUEST_NOT_FOUND" + assert_v1_error_reason( + payload["error"], + reason="INTERRUPT_REQUEST_NOT_FOUND", + metadata={"request_id": "perm-404"}, + ) @pytest.mark.asyncio @@ -332,7 +340,7 @@ async def resolve_interrupt_request(self, request_id: str): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -345,7 +353,11 @@ async def resolve_interrupt_request(self, request_id: str): ) payload = resp.json() assert payload["error"]["code"] == -32007 - assert payload["error"]["data"]["type"] == "INTERRUPT_REQUEST_EXPIRED" + assert_v1_error_reason( + payload["error"], + reason="INTERRUPT_REQUEST_EXPIRED", + metadata={"request_id": "perm-expired"}, + ) @pytest.mark.asyncio @@ -380,7 +392,7 @@ async def permission_reply( transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -393,7 +405,11 @@ async def permission_reply( ) payload = resp.json() assert payload["error"]["code"] == -32004 - assert payload["error"]["data"]["type"] == "INTERRUPT_REQUEST_NOT_FOUND" + assert_v1_error_reason( + payload["error"], + reason="INTERRUPT_REQUEST_NOT_FOUND", + metadata={"request_id": "perm-unknown"}, + ) assert dummy.permission_reply_calls == [] @@ -419,7 +435,7 @@ class InterruptClient(DummyOpencodeUpstreamClient): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -432,9 +448,15 @@ class InterruptClient(DummyOpencodeUpstreamClient): ) payload = resp.json() assert payload["error"]["code"] == -32008 - assert payload["error"]["data"]["type"] == "INTERRUPT_TYPE_MISMATCH" - assert payload["error"]["data"]["expected_interrupt_type"] == "permission" - assert payload["error"]["data"]["actual_interrupt_type"] == "question" + assert_v1_error_reason( + payload["error"], + reason="INTERRUPT_TYPE_MISMATCH", + metadata={ + "request_id": "q-only", + "expected_interrupt_type": "permission", + "actual_interrupt_type": "question", + }, + ) @pytest.mark.asyncio @@ -460,7 +482,7 @@ class InterruptClient(DummyOpencodeUpstreamClient): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -473,7 +495,11 @@ class InterruptClient(DummyOpencodeUpstreamClient): ) payload = resp.json() assert payload["error"]["code"] == -32004 - assert payload["error"]["data"]["type"] == "INTERRUPT_REQUEST_NOT_FOUND" + assert_v1_error_reason( + payload["error"], + reason="INTERRUPT_REQUEST_NOT_FOUND", + metadata={"request_id": "perm-owned"}, + ) @pytest.mark.asyncio @@ -524,7 +550,7 @@ class InterruptClient(DummyOpencodeUpstreamClient): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -537,7 +563,11 @@ class InterruptClient(DummyOpencodeUpstreamClient): ) payload = resp.json() assert payload["error"]["code"] == -32004 - assert payload["error"]["data"]["type"] == "INTERRUPT_REQUEST_NOT_FOUND" + assert_v1_error_reason( + payload["error"], + reason="INTERRUPT_REQUEST_NOT_FOUND", + metadata={"request_id": "perm-owned"}, + ) @pytest.mark.asyncio @@ -576,7 +606,7 @@ async def permission_reply( transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -589,5 +619,7 @@ async def permission_reply( ) payload = resp.json() assert payload["error"]["code"] == -32002 - assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" - assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] + assert_v1_error_reason(payload["error"], reason="UPSTREAM_UNREACHABLE") + context = error_context_detail(payload["error"]) + assert context is not None + assert "concurrency limit exceeded" in context["detail"] diff --git a/tests/jsonrpc/test_opencode_session_extension_lifecycle.py b/tests/jsonrpc/test_opencode_session_extension_lifecycle.py index 2b3a750..c1788f4 100644 --- a/tests/jsonrpc/test_opencode_session_extension_lifecycle.py +++ b/tests/jsonrpc/test_opencode_session_extension_lifecycle.py @@ -7,7 +7,13 @@ DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) from tests.support.helpers import make_settings -from tests.support.session_extensions import _BASE_SETTINGS, _jsonrpc_app, _session_meta +from tests.support.jsonrpc_error_assertions import assert_v1_error_reason +from tests.support.session_extensions import ( + _BASE_SETTINGS, + _extension_headers, + _jsonrpc_app, + _session_meta, +) def _identity_for_token(token: str) -> str: @@ -29,7 +35,7 @@ async def test_session_lifecycle_status_get_and_children_success(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) status_response = await client.post( "/", headers=headers, @@ -99,7 +105,7 @@ async def test_session_lifecycle_todo_diff_and_message_get_success(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) todo_response = await client.post( "/", headers=headers, @@ -152,6 +158,7 @@ async def test_session_lifecycle_todo_diff_and_message_get_success(monkeypatch): message_item = message_response.json()["result"]["item"] assert message_item["messageId"] == "m-1" + assert message_item["role"] == "ROLE_AGENT" assert message_item["parts"][0]["text"] == "One message payload" assert _session_meta(message_item)["id"] == "s-1" @@ -193,7 +200,7 @@ async def test_session_lifecycle_mutations_succeed_and_claim_owner( async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={"jsonrpc": "2.0", "id": 407, "method": method, "params": params}, ) @@ -234,7 +241,7 @@ async def test_session_lifecycle_summarize_succeeds_and_claims_owner(monkeypatch async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 4071, @@ -289,7 +296,7 @@ async def test_session_lifecycle_mutation_rejects_owner_mismatch(monkeypatch): async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 408, @@ -300,7 +307,11 @@ async def test_session_lifecycle_mutation_rejects_owner_mismatch(monkeypatch): payload = response.json() assert payload["error"]["code"] == -32006 - assert payload["error"]["data"]["type"] == "SESSION_FORBIDDEN" + assert_v1_error_reason( + payload["error"], + reason="SESSION_FORBIDDEN", + metadata={"session_id": "s-1"}, + ) assert dummy.lifecycle_calls == [] @@ -324,7 +335,7 @@ async def get_session(self, session_id: str, *, directory=None, workspace_id=Non async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: missing_message_id = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 409, @@ -334,7 +345,7 @@ async def get_session(self, session_id: str, *, directory=None, workspace_id=Non ) invalid_summarize = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 4091, @@ -347,7 +358,7 @@ async def get_session(self, session_id: str, *, directory=None, workspace_id=Non ) missing_revert_message_id = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 4092, @@ -357,7 +368,7 @@ async def get_session(self, session_id: str, *, directory=None, workspace_id=Non ) not_found = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 410, @@ -370,4 +381,8 @@ async def get_session(self, session_id: str, *, directory=None, workspace_id=Non assert invalid_summarize.json()["error"]["data"]["field"] == "request.auto" assert missing_revert_message_id.json()["error"]["data"]["field"] == "request.messageID" assert not_found.json()["error"]["code"] == -32001 - assert not_found.json()["error"]["data"]["session_id"] == "s-404" + assert_v1_error_reason( + not_found.json()["error"], + reason="SESSION_NOT_FOUND", + metadata={"session_id": "s-404"}, + ) diff --git a/tests/jsonrpc/test_opencode_session_extension_prompt_async.py b/tests/jsonrpc/test_opencode_session_extension_prompt_async.py index 65bdb79..4a5b5d8 100644 --- a/tests/jsonrpc/test_opencode_session_extension_prompt_async.py +++ b/tests/jsonrpc/test_opencode_session_extension_prompt_async.py @@ -11,7 +11,12 @@ DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) from tests.support.helpers import make_settings -from tests.support.session_extensions import _BASE_SETTINGS, _jsonrpc_app +from tests.support.jsonrpc_error_assertions import ( + assert_v1_error_metadata_contains, + assert_v1_error_reason, + error_context_detail, +) +from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers, _jsonrpc_app @pytest.mark.asyncio @@ -38,7 +43,7 @@ async def test_session_prompt_async_extension_success(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -92,7 +97,7 @@ async def test_session_prompt_async_extension_accepts_subtask_parts(monkeypatch) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 3012, @@ -153,7 +158,7 @@ async def test_session_prompt_async_extension_prefers_workspace_metadata(monkeyp async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 3011, @@ -190,7 +195,7 @@ async def test_session_prompt_async_extension_rejects_invalid_params(monkeypatch transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) missing_session_id = await client.post( "/", @@ -293,7 +298,7 @@ async def test_session_prompt_async_extension_rejects_owner_mismatch(monkeypatch transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -309,7 +314,11 @@ async def test_session_prompt_async_extension_rejects_owner_mismatch(monkeypatch ) payload = resp.json() assert payload["error"]["code"] == -32006 - assert payload["error"]["data"]["type"] == "SESSION_FORBIDDEN" + assert_v1_error_reason( + payload["error"], + reason="SESSION_FORBIDDEN", + metadata={"session_id": "s-1"}, + ) assert dummy.prompt_async_calls == [] @@ -337,7 +346,7 @@ async def test_session_prompt_async_extension_reuses_directory_boundary_validati transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -384,7 +393,7 @@ async def test_session_prompt_async_extension_honors_directory_override_switch(m transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -423,7 +432,7 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -439,7 +448,11 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director ) payload = resp.json() assert payload["error"]["code"] == -32001 - assert payload["error"]["data"]["type"] == "SESSION_NOT_FOUND" + assert_v1_error_reason( + payload["error"], + reason="SESSION_NOT_FOUND", + metadata={"session_id": "s-404"}, + ) @pytest.mark.asyncio @@ -460,7 +473,7 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -476,7 +489,7 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director ) payload = resp.json() assert payload["error"]["code"] == -32005 - assert payload["error"]["data"]["type"] == "UPSTREAM_PAYLOAD_ERROR" + assert_v1_error_reason(payload["error"], reason="UPSTREAM_PAYLOAD_ERROR") @pytest.mark.asyncio @@ -497,7 +510,7 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -513,8 +526,11 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director ) payload = resp.json() assert payload["error"]["code"] == -32003 - assert payload["error"]["data"]["type"] == "UPSTREAM_HTTP_ERROR" - assert payload["error"]["data"]["upstream_status"] == 500 + assert_v1_error_metadata_contains( + payload["error"], + reason="UPSTREAM_HTTP_ERROR", + metadata={"upstream_status": 500}, + ) @pytest.mark.asyncio @@ -534,7 +550,7 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -550,7 +566,7 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director ) payload = resp.json() assert payload["error"]["code"] == -32002 - assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" + assert_v1_error_reason(payload["error"], reason="UPSTREAM_UNREACHABLE") @pytest.mark.asyncio @@ -577,7 +593,7 @@ async def _release_raises(self: SessionManager, *, identity: str, session_id: st transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -593,7 +609,7 @@ async def _release_raises(self: SessionManager, *, identity: str, session_id: st ) payload = resp.json() assert payload["error"]["code"] == -32002 - assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" + assert_v1_error_reason(payload["error"], reason="UPSTREAM_UNREACHABLE") assert any( "Failed to release pending session claim" in record.message for record in caplog.records @@ -620,7 +636,7 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -636,8 +652,10 @@ async def session_prompt_async(self, session_id: str, request: dict, *, director ) payload = resp.json() assert payload["error"]["code"] == -32002 - assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" - assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] + assert_v1_error_reason(payload["error"], reason="UPSTREAM_UNREACHABLE") + context = error_context_detail(payload["error"]) + assert context is not None + assert "concurrency limit exceeded" in context["detail"] @pytest.mark.asyncio @@ -654,7 +672,7 @@ async def test_session_prompt_async_extension_notification_returns_204(monkeypat transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, diff --git a/tests/jsonrpc/test_opencode_session_extension_queries.py b/tests/jsonrpc/test_opencode_session_extension_queries.py index 8c1a5d9..64199a3 100644 --- a/tests/jsonrpc/test_opencode_session_extension_queries.py +++ b/tests/jsonrpc/test_opencode_session_extension_queries.py @@ -13,7 +13,11 @@ DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) from tests.support.helpers import make_settings -from tests.support.session_extensions import _BASE_SETTINGS, _session_meta +from tests.support.jsonrpc_error_assertions import ( + assert_v1_error_reason, + error_context_detail, +) +from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers, _session_meta def _identity_for_token(token: str) -> str: @@ -74,7 +78,7 @@ async def test_session_query_extension_returns_jsonrpc_result(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -149,7 +153,7 @@ async def test_session_query_extension_supports_session_filters_and_message_curs transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) sessions_resp = await client.post( "/", headers=headers, @@ -219,7 +223,7 @@ async def test_session_query_extension_prefers_workspace_metadata_for_routing(mo transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) await client.post( "/", headers=headers, @@ -278,7 +282,7 @@ async def test_session_query_extension_rejects_directory_outside_workspace(monke transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -322,7 +326,7 @@ async def test_session_query_extension_applies_default_limit(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -378,7 +382,7 @@ async def test_session_query_extension_enforces_session_limit_locally(monkeypatc transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -419,7 +423,7 @@ async def test_provider_discovery_extension_returns_normalized_catalog(monkeypat transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) providers_resp = await client.post( "/", headers=headers, @@ -470,7 +474,7 @@ async def test_provider_discovery_extension_rejects_invalid_provider_id(monkeypa transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -501,7 +505,7 @@ async def test_provider_discovery_extension_maps_payload_mismatch(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -514,7 +518,7 @@ async def test_provider_discovery_extension_maps_payload_mismatch(monkeypatch): ) payload = resp.json() assert payload["error"]["code"] == -32005 - assert payload["error"]["data"]["type"] == "UPSTREAM_PAYLOAD_ERROR" + assert_v1_error_reason(payload["error"], reason="UPSTREAM_PAYLOAD_ERROR") @pytest.mark.asyncio @@ -537,7 +541,7 @@ async def list_provider_catalog(self, *, directory: str | None = None): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -550,8 +554,10 @@ async def list_provider_catalog(self, *, directory: str | None = None): ) payload = resp.json() assert payload["error"]["code"] == -32002 - assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" - assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] + assert_v1_error_reason(payload["error"], reason="UPSTREAM_UNREACHABLE") + context = error_context_detail(payload["error"]) + assert context is not None + assert "concurrency limit exceeded" in context["detail"] @pytest.mark.asyncio @@ -570,7 +576,7 @@ def __init__(self, _settings: Settings) -> None: transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -584,7 +590,7 @@ def __init__(self, _settings: Settings) -> None: assert resp.status_code == 200 payload = resp.json() assert payload["error"]["code"] == -32005 - assert payload["error"]["data"]["type"] == "UPSTREAM_PAYLOAD_ERROR" + assert_v1_error_reason(payload["error"], reason="UPSTREAM_PAYLOAD_ERROR") @pytest.mark.asyncio @@ -607,7 +613,7 @@ async def list_sessions(self, *, params=None, directory: str | None = None): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -615,8 +621,10 @@ async def list_sessions(self, *, params=None, directory: str | None = None): ) payload = resp.json() assert payload["error"]["code"] == -32002 - assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" - assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] + assert_v1_error_reason(payload["error"], reason="UPSTREAM_UNREACHABLE") + context = error_context_detail(payload["error"]) + assert context is not None + assert "concurrency limit exceeded" in context["detail"] @pytest.mark.asyncio @@ -672,7 +680,7 @@ async def test_interrupt_recovery_extension_returns_identity_scoped_items(monkey transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": f"Bearer {token}"} + headers = _extension_headers({"Authorization": f"Bearer {token}"}) permission_resp = await client.post( "/", headers=headers, @@ -719,7 +727,7 @@ async def test_interrupt_recovery_extension_rejects_unsupported_fields(monkeypat transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -752,7 +760,7 @@ async def test_interrupt_recovery_extension_notification_returns_204(monkeypatch async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={"jsonrpc": "2.0", "method": "opencode.questions.list", "params": {}}, ) @@ -775,7 +783,7 @@ def __init__(self, _settings: Settings) -> None: transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -803,7 +811,7 @@ def __init__(self, _settings: Settings) -> None: transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -836,7 +844,7 @@ def __init__(self, _settings: Settings) -> None: transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -850,7 +858,7 @@ def __init__(self, _settings: Settings) -> None: payload = resp.json() message = payload["result"]["items"][0] assert message["messageId"] == "msg-1" - assert message["role"] == "user" + assert message["role"] == "ROLE_USER" assert message["parts"][0]["text"] == "hello" @@ -876,7 +884,7 @@ def __init__(self, _settings: Settings) -> None: transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -900,6 +908,7 @@ def __init__(self, _settings: Settings) -> None: payload = resp.json() assert payload["result"]["items"][0]["contextId"] == "ctx:opencode-session:s-1" assert _session_meta(payload["result"]["items"][0])["id"] == "s-1" + assert payload["result"]["items"][0]["role"] == "ROLE_AGENT" assert payload["result"]["items"][0]["parts"][0]["text"] == "SECRET_HISTORY" @@ -920,7 +929,7 @@ def __init__(self, _settings: Settings) -> None: transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -928,7 +937,7 @@ def __init__(self, _settings: Settings) -> None: ) payload = resp.json() assert payload["error"]["code"] == -32005 - assert payload["error"]["data"]["type"] == "UPSTREAM_PAYLOAD_ERROR" + assert_v1_error_reason(payload["error"], reason="UPSTREAM_PAYLOAD_ERROR") resp = await client.post( "/", @@ -942,7 +951,7 @@ def __init__(self, _settings: Settings) -> None: ) payload = resp.json() assert payload["error"]["code"] == -32005 - assert payload["error"]["data"]["type"] == "UPSTREAM_PAYLOAD_ERROR" + assert_v1_error_reason(payload["error"], reason="UPSTREAM_PAYLOAD_ERROR") @pytest.mark.asyncio @@ -969,7 +978,7 @@ async def test_session_query_extension_rejects_cursor_limit(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -1011,7 +1020,7 @@ async def test_session_query_extension_rejects_page_size_pagination(monkeypatch) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -1052,7 +1061,7 @@ async def test_session_query_extension_rejects_limit_above_max(monkeypatch): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -1098,7 +1107,7 @@ async def test_session_query_extension_accepts_equivalent_string_and_integer_lim transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -1134,7 +1143,7 @@ async def list_messages(self, session_id: str, *, params=None): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, @@ -1149,7 +1158,11 @@ async def list_messages(self, session_id: str, *, params=None): assert payload["jsonrpc"] == "2.0" assert payload["id"] == 2 assert payload["error"]["code"] == -32001 - assert payload["error"]["data"]["type"] == "SESSION_NOT_FOUND" + assert_v1_error_reason( + payload["error"], + reason="SESSION_NOT_FOUND", + metadata={"session_id": "s-404"}, + ) @pytest.mark.asyncio @@ -1165,7 +1178,7 @@ async def test_session_query_extension_does_not_log_response_bodies(monkeypatch, transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) resp = await client.post( "/", headers=headers, diff --git a/tests/jsonrpc/test_opencode_workspace_control_extension.py b/tests/jsonrpc/test_opencode_workspace_control_extension.py index b849c19..9561425 100644 --- a/tests/jsonrpc/test_opencode_workspace_control_extension.py +++ b/tests/jsonrpc/test_opencode_workspace_control_extension.py @@ -5,7 +5,11 @@ DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) from tests.support.helpers import make_basic_auth_header, make_settings -from tests.support.session_extensions import _BASE_SETTINGS +from tests.support.jsonrpc_error_assertions import ( + assert_v1_error_metadata_contains, + assert_v1_error_reason, +) +from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers @pytest.mark.asyncio @@ -22,7 +26,7 @@ async def test_workspace_control_extension_supports_read_only_methods(monkeypatc transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = {"Authorization": "Bearer t-1"} + headers = _extension_headers({"Authorization": "Bearer t-1"}) projects = await client.post( "/", headers=headers, @@ -80,7 +84,7 @@ async def test_workspace_control_extension_supports_mutating_methods(monkeypatch transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - headers = make_basic_auth_header("operator", "op-pass") + headers = _extension_headers(make_basic_auth_header("operator", "op-pass")) create_workspace = await client.post( "/", headers=headers, @@ -175,7 +179,7 @@ async def test_workspace_control_extension_validates_request_shape(monkeypatch) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers=make_basic_auth_header("operator", "op-pass"), + headers=_extension_headers(make_basic_auth_header("operator", "op-pass")), json={ "jsonrpc": "2.0", "id": 20, @@ -223,7 +227,7 @@ async def test_workspace_control_mutations_require_workspace_mutation_capability async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 201, @@ -235,12 +239,15 @@ async def test_workspace_control_mutations_require_workspace_mutation_capability assert response.status_code == 200 payload = response.json() assert payload["error"]["code"] == -32007 - assert payload["error"]["data"] == { - "type": "AUTHORIZATION_FORBIDDEN", - "method": "opencode.workspaces.create", - "capability": "workspace_mutation", - "credential_id": "cred-bearer", - } + assert_v1_error_reason( + payload["error"], + reason="AUTHORIZATION_FORBIDDEN", + metadata={ + "method": "opencode.workspaces.create", + "capability": "workspace_mutation", + "credential_id": "cred-bearer", + }, + ) @pytest.mark.asyncio @@ -256,7 +263,7 @@ async def test_workspace_control_mutations_are_disabled_by_default(monkeypatch) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 22, @@ -268,7 +275,6 @@ async def test_workspace_control_mutations_are_disabled_by_default(monkeypatch) assert response.status_code == 200 payload = response.json() assert payload["error"]["code"] == -32601 - assert payload["error"]["data"]["type"] == "METHOD_NOT_SUPPORTED" assert payload["error"]["data"]["method"] == "opencode.worktrees.create" @@ -291,7 +297,7 @@ async def list_workspaces(self): async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/", - headers={"Authorization": "Bearer t-1"}, + headers=_extension_headers({"Authorization": "Bearer t-1"}), json={ "jsonrpc": "2.0", "id": 21, @@ -302,5 +308,8 @@ async def list_workspaces(self): assert response.status_code == 200 payload = response.json() - assert payload["error"]["data"]["type"] == "UPSTREAM_HTTP_ERROR" - assert payload["error"]["data"]["upstream_status"] == 503 + assert_v1_error_metadata_contains( + payload["error"], + reason="UPSTREAM_HTTP_ERROR", + metadata={"upstream_status": 503}, + ) diff --git a/tests/profile/test_profile_runtime.py b/tests/profile/test_profile_runtime.py index efa1d3b..40fbaa5 100644 --- a/tests/profile/test_profile_runtime.py +++ b/tests/profile/test_profile_runtime.py @@ -1,4 +1,5 @@ from opencode_a2a.profile.runtime import build_runtime_profile +from opencode_a2a.protocol_versions import A2A_PROTOCOL_VERSION from tests.support.helpers import make_settings @@ -24,9 +25,9 @@ def test_profile_runtime_splits_deployment_runtime_features_and_health_payload() profile = build_runtime_profile(settings) - assert profile.summary_dict(protocol_version=settings.a2a_protocol_version) == { + assert profile.summary_dict(protocol_version=A2A_PROTOCOL_VERSION) == { "profile_id": "opencode-a2a-single-tenant-coding-v1", - "protocol_version": "0.3", + "protocol_version": "1.0", "deployment": { "id": "single_tenant_shared_workspace", "single_tenant": True, @@ -95,12 +96,12 @@ def test_profile_runtime_splits_deployment_runtime_features_and_health_payload() assert profile.health_payload( service="opencode-a2a", version=settings.a2a_version, - protocol_version=settings.a2a_protocol_version, + protocol_version=A2A_PROTOCOL_VERSION, ) == { "status": "ok", "service": "opencode-a2a", "version": settings.a2a_version, - "profile": profile.summary_dict(protocol_version=settings.a2a_protocol_version), + "profile": profile.summary_dict(protocol_version=A2A_PROTOCOL_VERSION), } diff --git a/tests/server/test_a2a_client_manager.py b/tests/server/test_a2a_client_manager.py index 89b953c..7ba8cb7 100644 --- a/tests/server/test_a2a_client_manager.py +++ b/tests/server/test_a2a_client_manager.py @@ -14,8 +14,6 @@ def _make_settings(**overrides: object) -> SimpleNamespace: "a2a_client_use_client_preference": False, "a2a_client_bearer_token": None, "a2a_client_basic_auth": None, - "a2a_client_protocol_version": None, - "a2a_protocol_version": "0.3", "a2a_client_supported_transports": ("JSONRPC", "HTTP+JSON"), "a2a_client_cache_ttl_seconds": 60.0, "a2a_client_cache_maxsize": 2, @@ -212,9 +210,3 @@ def test_client_manager_loads_basic_auth_into_client_settings() -> None: ) assert manager.client_settings.basic_auth == "user:pass" - - -def test_client_manager_defaults_protocol_version_from_runtime_setting() -> None: - manager = client_manager_module.A2AClientManager(_make_settings(a2a_protocol_version="1.0")) - - assert manager.client_settings.protocol_version == "1.0" diff --git a/tests/server/test_agent_card.py b/tests/server/test_agent_card.py index 75ca64f..c72dc7e 100644 --- a/tests/server/test_agent_card.py +++ b/tests/server/test_agent_card.py @@ -1,15 +1,8 @@ import json -from opencode_a2a.a2a_utils import proto_to_dict +from google.protobuf.json_format import MessageToDict + from opencode_a2a.contracts.extensions import ( - SESSION_QUERY_DEFAULT_LIMIT, - SESSION_QUERY_MAX_LIMIT, - build_protocol_compatibility_params, - build_service_behavior_contract_params, -) -from opencode_a2a.jsonrpc.application import SESSION_CONTEXT_PREFIX -from opencode_a2a.server.agent_card import build_authenticated_extended_agent_card -from opencode_a2a.server.application import ( COMPATIBILITY_PROFILE_EXTENSION_URI, INTERRUPT_CALLBACK_EXTENSION_URI, INTERRUPT_RECOVERY_EXTENSION_URI, @@ -17,16 +10,24 @@ PROVIDER_DISCOVERY_EXTENSION_URI, SESSION_BINDING_EXTENSION_URI, SESSION_MANAGEMENT_EXTENSION_URI, + SESSION_QUERY_DEFAULT_LIMIT, + SESSION_QUERY_MAX_LIMIT, STREAMING_EXTENSION_URI, WIRE_CONTRACT_EXTENSION_URI, WORKSPACE_CONTROL_EXTENSION_URI, + build_protocol_compatibility_params, + build_service_behavior_contract_params, +) +from opencode_a2a.jsonrpc.methods import SESSION_CONTEXT_PREFIX +from opencode_a2a.server.agent_card import ( build_agent_card, + build_authenticated_extended_agent_card, ) from tests.support.helpers import make_settings def _security_requirements(card) -> list[dict[str, dict[str, list[str]]]]: - return [proto_to_dict(requirement)["schemes"] for requirement in card.security_requirements] + return [MessageToDict(requirement)["schemes"] for requirement in card.security_requirements] def test_agent_card_description_reflects_actual_transport_capabilities() -> None: @@ -44,8 +45,8 @@ def test_agent_card_description_reflects_actual_transport_capabilities() -> None assert [ (iface.protocol_binding, iface.protocol_version) for iface in card.supported_interfaces ] == [ - ("HTTP+JSON", "0.3"), - ("JSONRPC", "0.3"), + ("HTTP+JSON", "1.0"), + ("JSONRPC", "1.0"), ] assert card.default_input_modes == ["text/plain", "application/octet-stream"] assert card.default_output_modes == ["text/plain", "application/json"] @@ -116,7 +117,7 @@ def test_public_agent_card_is_slimmed_but_keeps_core_shared_contract_hints() -> assert ext_by_uri[MODEL_SELECTION_EXTENSION_URI].params == { "metadata_field": "metadata.shared.model", "behavior": "prefer_metadata_model_else_upstream_default", - "applies_to_methods": ["message/send", "message/stream"], + "applies_to_methods": ["SendMessage", "SendStreamingMessage"], "supported_metadata": [ "shared.model.providerID", "shared.model.modelID", @@ -179,18 +180,18 @@ def test_public_agent_card_is_slimmed_but_keeps_core_shared_contract_hints() -> COMPATIBILITY_PROFILE_EXTENSION_URI, WIRE_CONTRACT_EXTENSION_URI, ): - assert proto_to_dict(ext_by_uri[uri]).get("params") in (None, {}) + assert MessageToDict(ext_by_uri[uri]).get("params") in (None, {}) public_size = len( json.dumps( - proto_to_dict(public_card), + MessageToDict(public_card), ensure_ascii=False, separators=(",", ":"), ).encode("utf-8") ) extended_size = len( json.dumps( - proto_to_dict(extended_card), + MessageToDict(extended_card), ensure_ascii=False, separators=(",", ":"), ).encode("utf-8") @@ -281,7 +282,7 @@ def test_agent_card_injects_profile_into_extensions() -> None: assert model_selection.params["metadata_field"] == "metadata.shared.model" assert model_selection.params["fields"]["providerID"] == "metadata.shared.model.providerID" assert model_selection.params["fields"]["modelID"] == "metadata.shared.model.modelID" - assert model_selection.params["applies_to_methods"] == ["message/send", "message/stream"] + assert model_selection.params["applies_to_methods"] == ["SendMessage", "SendStreamingMessage"] assert model_selection.params["behavior"] == "prefer_metadata_model_else_upstream_default" streaming = ext_by_uri[STREAMING_EXTENSION_URI] @@ -679,8 +680,8 @@ def test_agent_card_injects_profile_into_extensions() -> None: compatibility = ext_by_uri[COMPATIBILITY_PROFILE_EXTENSION_URI] expected_service_behaviors = build_service_behavior_contract_params() expected_protocol_compatibility = build_protocol_compatibility_params( - supported_protocol_versions=["0.3", "1.0"], - default_protocol_version="0.3", + supported_protocol_versions=["1.0"], + default_protocol_version="1.0", ) assert compatibility.params["extension_retention"][MODEL_SELECTION_EXTENSION_URI] == { "surface": "core-runtime-metadata", @@ -739,12 +740,12 @@ def test_agent_card_injects_profile_into_extensions() -> None: "implementation_scope": "adapter-local", "identity_scope": "current_authenticated_caller", } - assert compatibility.params["method_retention"]["agent/getAuthenticatedExtendedCard"] == { + assert compatibility.params["method_retention"]["GetExtendedAgentCard"] == { "surface": "core", "availability": "always", "retention": "required", } - assert compatibility.params["method_retention"]["tasks/pushNotificationConfig/get"] == { + assert compatibility.params["method_retention"]["GetTaskPushNotificationConfig"] == { "surface": "core", "availability": "always", "retention": "required", @@ -753,14 +754,14 @@ def test_agent_card_injects_profile_into_extensions() -> None: assert compatibility.params["service_behaviors"]["classification"] == ( "service-level-semantic-enhancement" ) - assert compatibility.params["service_behaviors"]["methods"]["tasks/cancel"]["idempotency"] == { + assert compatibility.params["service_behaviors"]["methods"]["CancelTask"]["idempotency"] == { "already_canceled": { "behavior": "return_current_terminal_task", "returns_current_state": "canceled", "error": None, } } - assert compatibility.params["service_behaviors"]["methods"]["tasks/resubscribe"][ + assert compatibility.params["service_behaviors"]["methods"]["SubscribeToTask"][ "terminal_state_behavior" ] == { "behavior": "replay_terminal_task_once_then_close", @@ -772,15 +773,15 @@ def test_agent_card_injects_profile_into_extensions() -> None: wire_contract = ext_by_uri[WIRE_CONTRACT_EXTENSION_URI] assert wire_contract.params["profile"]["profile_id"] == "opencode-a2a-single-tenant-coding-v1" - assert wire_contract.params["default_protocol_version"] == "0.3" - assert wire_contract.params["supported_protocol_versions"] == ["0.3", "1.0"] + assert wire_contract.params["default_protocol_version"] == "1.0" + assert wire_contract.params["supported_protocol_versions"] == ["1.0"] assert wire_contract.params["protocol_compatibility"] == expected_protocol_compatibility assert MODEL_SELECTION_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert PROVIDER_DISCOVERY_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert WORKSPACE_CONTROL_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert INTERRUPT_RECOVERY_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] - assert "agent/getAuthenticatedExtendedCard" in wire_contract.params["all_jsonrpc_methods"] - assert "tasks/pushNotificationConfig/get" in wire_contract.params["all_jsonrpc_methods"] + assert "GetExtendedAgentCard" in wire_contract.params["all_jsonrpc_methods"] + assert "GetTaskPushNotificationConfig" in wire_contract.params["all_jsonrpc_methods"] assert "GET /v1/tasks" in wire_contract.params["core"]["http_endpoints"] assert ( "GET /v1/tasks/{id}/pushNotificationConfigs" diff --git a/tests/server/test_app_behaviors.py b/tests/server/test_app_behaviors.py index 91b6eda..5f446dd 100644 --- a/tests/server/test_app_behaviors.py +++ b/tests/server/test_app_behaviors.py @@ -8,16 +8,20 @@ import httpx import pytest -from a2a.server.apps.rest.rest_adapter import RESTAdapter from a2a.server.events import EventConsumer +from a2a.server.routes.rest_dispatcher import RestDispatcher from a2a.types import ( + AgentCapabilities, + AgentCard, + CancelTaskRequest, + GetTaskRequest, InternalError, Message, + SendMessageRequest, + SubscribeToTaskRequest, Task, - TaskIdParams, TaskNotCancelableError, TaskNotFoundError, - TaskQueryParams, TaskState, TaskStatus, UnsupportedOperationError, @@ -27,36 +31,42 @@ from google.protobuf.json_format import MessageToDict, ParseError import opencode_a2a.server.application as app_module -from opencode_a2a.contracts.extensions import build_capability_snapshot +from opencode_a2a.contracts.extensions import ( + MODEL_SELECTION_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + SESSION_METHODS, + build_capability_snapshot, +) from opencode_a2a.profile.runtime import build_runtime_profile -from opencode_a2a.server.application import ( - OpencodeRequestHandler, +from opencode_a2a.protocol_versions import A2A_PROTOCOL_VERSION +from opencode_a2a.server.agent_card import ( _build_agent_card_description, _build_chat_examples, + _build_session_management_skill_examples, +) +from opencode_a2a.server.application import ( + OpencodeRequestHandler, + _configure_logging, + _normalize_log_level, + _parse_rest_send_message_request, + _rest_error_response, + create_app, +) +from opencode_a2a.server.openapi import ( _build_jsonrpc_extension_openapi_description, _build_jsonrpc_extension_openapi_examples, - _build_rest_legacy_error_payload, _build_rest_message_openapi_examples, - _build_session_management_skill_examples, - _configure_logging, +) +from opencode_a2a.server.request_parsing import ( _decode_payload_preview, _detect_sensitive_extension_method, _is_json_content_type, _looks_like_jsonrpc_envelope, - _looks_like_jsonrpc_message_payload, _normalize_content_type, - _normalize_log_level, - _normalize_rest_content_part, - _normalize_rest_send_message_payload, - _normalize_v1_jsonrpc_method_alias, _parse_content_length, _parse_json_body, - _parse_rest_send_message_request, _request_body_too_large_response, _RequestBodyTooLargeError, - _rest_error_response, - build_agent_card, - create_app, ) from opencode_a2a.server.task_store import TaskStoreOperationError from tests.support.helpers import ( @@ -66,6 +76,10 @@ ) +def _agent_card() -> AgentCard: + return AgentCard(name="opencode-a2a", capabilities=AgentCapabilities(streaming=True)) + + def _request( path: str, body: bytes = b"{}", @@ -105,13 +119,13 @@ async def receive() -> dict: def test_request_payload_helpers_cover_edge_cases() -> None: assert _parse_json_body(b"{") is None assert _parse_json_body(b"[]") is None - assert _parse_json_body(b'{"method":"message/send"}') == {"method": "message/send"} + assert _parse_json_body(b'{"method":"SendMessage"}') == {"method": "SendMessage"} assert _detect_sensitive_extension_method(None) is None - assert _detect_sensitive_extension_method({"method": "message/send"}) is None + assert _detect_sensitive_extension_method({"method": "SendMessage"}) is None assert ( - _detect_sensitive_extension_method({"method": app_module.SESSION_METHODS["list_sessions"]}) - == app_module.SESSION_METHODS["list_sessions"] + _detect_sensitive_extension_method({"method": SESSION_METHODS["list_sessions"]}) + == SESSION_METHODS["list_sessions"] ) assert _parse_content_length(None) is None @@ -126,35 +140,9 @@ def test_request_payload_helpers_cover_edge_cases() -> None: assert _is_json_content_type("application/problem+json") is True assert _decode_payload_preview(b"abcdef", limit=3) == "abc...[truncated]" - assert _looks_like_jsonrpc_message_payload(None) is False - assert _looks_like_jsonrpc_message_payload({"message": {"parts": []}}) is True - assert _looks_like_jsonrpc_message_payload({"message": {"role": "user"}}) is True - assert _looks_like_jsonrpc_message_payload({"message": {"role": "ROLE_USER"}}) is False assert _looks_like_jsonrpc_envelope(None) is False - assert _looks_like_jsonrpc_envelope({"jsonrpc": "2.0", "method": "message/send"}) is True - assert _looks_like_jsonrpc_envelope({"jsonrpc": 2, "method": "message/send"}) is False - assert _normalize_v1_jsonrpc_method_alias(None, protocol_version="1.0") is None - assert _normalize_v1_jsonrpc_method_alias( - {"jsonrpc": "2.0", "method": "SendMessage"}, - protocol_version="1.0", - ) == { - "jsonrpc": "2.0", - "method": "message/send", - } - assert _normalize_v1_jsonrpc_method_alias( - {"jsonrpc": "2.0", "method": "SendMessage"}, - protocol_version="0.3", - ) == { - "jsonrpc": "2.0", - "method": "SendMessage", - } - assert _normalize_v1_jsonrpc_method_alias( - {"jsonrpc": "2.0", "method": "message/send"}, - protocol_version="1.0", - ) == { - "jsonrpc": "2.0", - "method": "message/send", - } + assert _looks_like_jsonrpc_envelope({"jsonrpc": "2.0", "method": "SendMessage"}) is True + assert _looks_like_jsonrpc_envelope({"jsonrpc": 2, "method": "SendMessage"}) is False response = _request_body_too_large_response( path="/", @@ -162,163 +150,105 @@ def test_request_payload_helpers_cover_edge_cases() -> None: error=_RequestBodyTooLargeError(limit=64, actual_size=65), ) assert response.status_code == 413 - assert response.body == b'{"error":"Request body too large","max_bytes":64}' + payload = json.loads(response.body) + assert payload["error"]["code"] == 413 + assert payload["error"]["status"] == "RESOURCE_EXHAUSTED" + assert payload["error"]["message"] == "Request body too large" + assert payload["error"]["details"][0] == { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "REQUEST_BODY_TOO_LARGE", + "domain": "a2a-protocol.org", + "metadata": {"maxBytes": "64", "actualSize": "65"}, + } + assert payload["error"]["details"][1] == { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "maxBytes": 64, + "actualSize": 65, + } def test_rest_message_parsing_helpers_cover_upgrade_paths() -> None: - assert _build_rest_legacy_error_payload(message="boom") == {"error": "boom"} - assert _build_rest_legacy_error_payload( - message="boom", - reason="INVALID_REQUEST", - metadata={"path": "/v1/message:send"}, - ) == { - "error": "boom", - "type": "INVALID_REQUEST", - "path": "/v1/message:send", - } - request_v1 = _request("/v1/message:send") request_v1.state.a2a_protocol_version = "1.0" v1_error = _rest_error_response( request=request_v1, - default_protocol_version="0.3", error=InvalidRequestError(message="bad payload", data={"path": "/v1/message:send"}), ) assert v1_error.status_code == 400 - assert v1_error.body == ( - b'{"error":{"code":400,"status":"INVALID_ARGUMENT","message":"bad payload",' - b'"details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"INVALID_REQUEST",' - b'"domain":"a2a-protocol.org","metadata":{"path":"/v1/message:send"}},' - b'{"@type":"type.googleapis.com/opencode_a2a.HttpErrorContext","path":"/v1/message:send"}]}}' - ) + assert json.loads(v1_error.body) == { + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "bad payload", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_REQUEST", + "domain": "a2a-protocol.org", + "metadata": {"path": "/v1/message:send"}, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "path": "/v1/message:send", + }, + ], + } + } - request_v03 = _request("/v1/message:send") + request_default = _request("/v1/message:send") parse_error = _rest_error_response( - request=request_v03, - default_protocol_version="0.3", + request=request_default, error=ParseError("bad parse"), ) assert parse_error.status_code == 400 - assert parse_error.body == b'{"error":"bad parse","type":"INVALID_REQUEST"}' + assert json.loads(parse_error.body) == { + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "bad parse", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_REQUEST", + "domain": "a2a-protocol.org", + } + ], + } + } generic_error = _rest_error_response( - request=request_v03, - default_protocol_version="0.3", + request=request_default, error=RuntimeError("boom"), ) assert generic_error.status_code == 500 - assert generic_error.body == b'{"error":"unknown exception","type":"INTERNAL_ERROR"}' - - with pytest.raises(InvalidRequestError, match="message.content\\[0\\] must be an object"): - _normalize_rest_content_part("bad", field="message.content[0]") - - assert _normalize_rest_content_part( - {"text": "hello", "metadata": {"tag": "plain"}}, - field="message.content[0]", - ) == {"metadata": {"tag": "plain"}, "text": "hello"} - assert _normalize_rest_content_part( - {"data": {"step": 1}}, - field="message.content[1]", - ) == {"data": {"step": 1}} - assert _normalize_rest_content_part( - {"file": {"bytes": "aGVsbG8=", "name": "report.txt", "mimeType": "text/plain"}}, - field="message.content[2]", - ) == { - "raw": "aGVsbG8=", - "filename": "report.txt", - "mediaType": "text/plain", - } - assert _normalize_rest_content_part( - { - "file": { - "uri": "file:///tmp/report.txt", - "name": "report.txt", - "mediaType": "text/plain", - } - }, - field="message.content[3]", - ) == { - "url": "file:///tmp/report.txt", - "filename": "report.txt", - "mediaType": "text/plain", - } - assert _normalize_rest_content_part( - {"raw": "aGVsbG8=", "filename": "inline.bin", "mediaType": "application/octet-stream"}, - field="message.content[4]", - ) == { - "raw": "aGVsbG8=", - "filename": "inline.bin", - "mediaType": "application/octet-stream", - } - assert _normalize_rest_content_part( - { - "url": "https://example.com/report.txt", - "filename": "report.txt", - "mediaType": "text/plain", - }, - field="message.content[5]", - ) == { - "url": "https://example.com/report.txt", - "filename": "report.txt", - "mediaType": "text/plain", - } - with pytest.raises( - InvalidRequestError, match="message.content\\[6\\]\\.file must contain uri or bytes" - ): - _normalize_rest_content_part({"file": {"name": "report.txt"}}, field="message.content[6]") - - normalized_payload = _normalize_rest_send_message_payload( - { - "message": { - "messageId": "msg-1", - "contextId": "ctx-1", - "role": "ROLE_USER", - "content": [{"text": "hello"}], - }, - "metadata": {"shared": {"session": {"id": "s-1"}}}, - "acceptedOutputModes": ["text/plain"], - "configuration": {"historyLength": 2}, + assert json.loads(generic_error.body) == { + "error": { + "code": 500, + "status": "INTERNAL", + "message": "unknown exception", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INTERNAL_ERROR", + "domain": "a2a-protocol.org", + } + ], } - ) - assert normalized_payload == { - "message": { - "messageId": "msg-1", - "contextId": "ctx-1", - "role": "ROLE_USER", - "parts": [{"text": "hello"}], - }, - "metadata": {"shared": {"session": {"id": "s-1"}}}, - "configuration": { - "historyLength": 2, - "acceptedOutputModes": ["text/plain"], - }, } - with pytest.raises(InvalidRequestError, match="message must be an object"): - _normalize_rest_send_message_payload({}) - with pytest.raises(InvalidRequestError, match="message.content must be an array"): - _normalize_rest_send_message_payload({"message": {"content": "bad"}}) - with pytest.raises(InvalidRequestError, match="configuration must be an object"): - _normalize_rest_send_message_payload( - { - "message": {"role": "ROLE_USER", "content": [{"text": "hello"}]}, - "configuration": "bad", - "acceptedOutputModes": ["text/plain"], - } - ) - parsed = _parse_rest_send_message_request( json.dumps( { "message": { "messageId": "msg-2", "role": "ROLE_USER", - "content": [{"text": "hello from rest"}], + "parts": [{"text": "hello from rest"}], }, - "returnImmediately": True, + "configuration": {"returnImmediately": True}, } ).encode("utf-8") ) + assert isinstance(parsed, SendMessageRequest) assert MessageToDict(parsed) == { "message": { "messageId": "msg-2", @@ -329,6 +259,51 @@ def test_rest_message_parsing_helpers_cover_upgrade_paths() -> None: } with pytest.raises(InvalidRequestError, match="REST message payload must be a JSON object"): _parse_rest_send_message_request(b"[]") + with pytest.raises( + InvalidRequestError, + match="REST message payload must use message.parts, not message.content.", + ): + _parse_rest_send_message_request( + json.dumps( + { + "message": { + "messageId": "msg-legacy", + "role": "ROLE_USER", + "content": [{"text": "hello from rest"}], + } + } + ).encode("utf-8") + ) + with pytest.raises( + InvalidRequestError, + match="REST message payload must use ROLE_\\* values for message.role.", + ): + _parse_rest_send_message_request( + json.dumps( + { + "message": { + "messageId": "msg-legacy", + "role": "user", + "parts": [{"text": "hello from rest"}], + } + } + ).encode("utf-8") + ) + with pytest.raises( + InvalidRequestError, + match="message.parts\\[0\\] must use direct Part fields such as text, raw, url, or data.", + ): + _parse_rest_send_message_request( + json.dumps( + { + "message": { + "messageId": "msg-legacy", + "role": "ROLE_USER", + "parts": [{"file": {"uri": "file:///tmp/report.txt"}}], + } + } + ).encode("utf-8") + ) def test_agent_card_helper_builders_cover_optional_branches() -> None: @@ -560,7 +535,7 @@ async def close(self) -> None: "version": settings.a2a_version, "profile": { "profile_id": "opencode-a2a-single-tenant-coding-v1", - "protocol_version": settings.a2a_protocol_version, + "protocol_version": A2A_PROTOCOL_VERSION, "deployment": { "id": "single_tenant_shared_workspace", "single_tenant": True, @@ -645,16 +620,13 @@ async def close(self) -> None: @pytest.mark.asyncio async def test_rest_adapter_routes_and_preconsume_error() -> None: handler = MagicMock() - adapter = RESTAdapter( - agent_card=build_agent_card(make_settings(test_bearer_token="test-token")), - http_handler=handler, - ) + adapter = RestDispatcher(request_handler=handler) async def _stream(_request: Request, _context): # noqa: ANN001 yield {"id": "evt-1"} - handler.on_resubscribe_to_task = _stream - response = await adapter.routes()[("/v1/tasks/{id}:subscribe", "GET")]( + handler.on_subscribe_to_task = _stream + response = await adapter.on_subscribe_to_task( _request("/v1/tasks/x:subscribe", method="GET", path_params={"id": "x"}) ) assert response is not None @@ -666,7 +638,7 @@ async def body(self) -> bytes: with pytest.raises( InvalidRequestError, match="Failed to pre-consume request body: broken body" ): - await adapter._dispatcher._handle_streaming( # pyright: ignore[reportAttributeAccessIssue] + await adapter._handle_streaming( # pyright: ignore[reportAttributeAccessIssue] _BrokenRequest(), lambda _context: _stream(None, _context), ) @@ -690,7 +662,20 @@ async def test_push_notification_routes_are_explicitly_unsupported(monkeypatch) ) assert response.status_code == 501 - assert response.json() == {"message": "Push notifications are not supported by the agent"} + assert response.json() == { + "error": { + "code": 501, + "status": "UNIMPLEMENTED", + "message": "Push notifications are not supported by the agent", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "PUSH_NOTIFICATIONS_UNSUPPORTED", + "domain": "a2a-protocol.org", + } + ], + } + } @pytest.mark.asyncio @@ -710,7 +695,7 @@ async def test_push_notification_jsonrpc_methods_remain_unsupported(monkeypatch) json={ "jsonrpc": "2.0", "id": 1, - "method": "tasks/pushNotificationConfig/get", + "method": "GetTaskPushNotificationConfig", "params": {"id": "task-1"}, }, ) @@ -720,6 +705,13 @@ async def test_push_notification_jsonrpc_methods_remain_unsupported(monkeypatch) "error": { "code": -32004, "message": "This operation is not supported", + "data": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "UNSUPPORTED_OPERATION", + "domain": "a2a-protocol.org", + } + ], }, "id": 1, "jsonrpc": "2.0", @@ -729,10 +721,15 @@ async def test_push_notification_jsonrpc_methods_remain_unsupported(monkeypatch) @pytest.mark.asyncio async def test_on_cancel_task_and_resubscribe_cover_race_paths(monkeypatch) -> None: task_store = MagicMock() - handler = OpencodeRequestHandler(agent_executor=MagicMock(), task_store=task_store) + handler = OpencodeRequestHandler( + agent_executor=MagicMock(), + task_store=task_store, + agent_card=_agent_card(), + ) handler.agent_executor.cancel = AsyncMock() handler._queue_manager.tap = AsyncMock(return_value=MagicMock()) # noqa: SLF001 - params = TaskIdParams(id="task-1") + cancel_params = CancelTaskRequest(id="task-1") + subscribe_params = SubscribeToTaskRequest(id="task-1") canceled_task = Task( id="task-1", context_id="ctx-1", @@ -751,14 +748,14 @@ async def test_on_cancel_task_and_resubscribe_cover_race_paths(monkeypatch) -> N task_store.get = AsyncMock(return_value=None) with pytest.raises(TaskNotFoundError): - await handler.on_cancel_task(params) + await handler.on_cancel_task(cancel_params) task_store.get = AsyncMock(return_value=canceled_task) - assert await handler.on_cancel_task(params) is canceled_task + assert await handler.on_cancel_task(cancel_params) is canceled_task task_store.get = AsyncMock(return_value=completed_task) with pytest.raises(TaskNotCancelableError): - await handler.on_cancel_task(params) + await handler.on_cancel_task(cancel_params) task_store.get = AsyncMock(side_effect=[working_task, canceled_task]) @@ -766,7 +763,7 @@ async def _consume_non_canceled(_self, _consumer): # noqa: ANN001 return working_task monkeypatch.setattr(app_module.ResultAggregator, "consume_all", _consume_non_canceled) - assert await handler.on_cancel_task(params) is canceled_task + assert await handler.on_cancel_task(cancel_params) is canceled_task task_store.get = AsyncMock(return_value=working_task) @@ -774,15 +771,15 @@ async def _consume_canceled(_self, _consumer): # noqa: ANN001 return canceled_task monkeypatch.setattr(app_module.ResultAggregator, "consume_all", _consume_canceled) - assert await handler.on_cancel_task(params) is canceled_task + assert await handler.on_cancel_task(cancel_params) is canceled_task task_store.get = AsyncMock(return_value=None) with pytest.raises(TaskNotFoundError): - events = [item async for item in handler.on_resubscribe_to_task(params)] + events = [item async for item in handler.on_subscribe_to_task(subscribe_params)] assert events == [] task_store.get = AsyncMock(return_value=canceled_task) - events = [item async for item in handler.on_resubscribe_to_task(params)] + events = [item async for item in handler.on_subscribe_to_task(subscribe_params)] assert events == [canceled_task] task_store.get = AsyncMock(return_value=working_task) @@ -795,7 +792,7 @@ async def _consume_and_emit(_self, _consumer): # noqa: ANN001 "consume_and_emit", _consume_and_emit, ) - events = [item async for item in handler.on_resubscribe_to_task(params)] + events = [item async for item in handler.on_subscribe_to_task(subscribe_params)] assert events == [working_task, "evt-1"] @@ -814,7 +811,7 @@ async def _message_response(params, context=None): # noqa: ANN001 return Message( message_id="m-server", role=app_module.Role.ROLE_AGENT, - parts=[app_module.make_text_part("server reply")], + parts=[app_module.Part(text="server reply")], ) async def _stream_failure(params, context=None): # noqa: ANN001 @@ -832,7 +829,7 @@ async def _stream_failure(params, context=None): # noqa: ANN001 "message": { "messageId": "m-rest", "role": "ROLE_USER", - "content": [{"text": "hello"}], + "parts": [{"text": "hello"}], } } @@ -850,22 +847,38 @@ async def _stream_failure(params, context=None): # noqa: ANN001 } assert stream_response.status_code == 400 assert stream_response.json() == { - "error": "stream bad", - "type": "INVALID_REQUEST", + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "stream bad", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_REQUEST", + "domain": "a2a-protocol.org", + } + ], + } } @pytest.mark.asyncio async def test_task_store_failures_map_to_stable_handler_errors() -> None: task_store = MagicMock() - handler = OpencodeRequestHandler(agent_executor=MagicMock(), task_store=task_store) + handler = OpencodeRequestHandler( + agent_executor=MagicMock(), + task_store=task_store, + agent_card=_agent_card(), + ) task_store.get = AsyncMock(side_effect=TaskStoreOperationError("get", "task-1")) with pytest.raises(InternalError, match="Task store unavailable while loading task state."): - await handler.on_get_task(TaskQueryParams(id="task-1")) + await handler.on_get_task(GetTaskRequest(id="task-1")) with pytest.raises(InternalError, match="Task store unavailable while loading task state."): - events = [item async for item in handler.on_resubscribe_to_task(TaskIdParams(id="task-1"))] + events = [ + item async for item in handler.on_subscribe_to_task(SubscribeToTaskRequest(id="task-1")) + ] assert events == [] @@ -878,7 +891,11 @@ async def consume_and_break_on_interrupt(self, _consumer, *, blocking, event_cal class _Handler(OpencodeRequestHandler): def __init__(self) -> None: - super().__init__(agent_executor=MagicMock(), task_store=MagicMock()) + super().__init__( + agent_executor=MagicMock(), + task_store=MagicMock(), + agent_card=_agent_card(), + ) self.queue = AsyncMock() self.producer = MagicMock() @@ -926,7 +943,11 @@ async def consume_and_emit(self, _consumer): class _Handler(OpencodeRequestHandler): def __init__(self) -> None: - super().__init__(agent_executor=MagicMock(), task_store=MagicMock()) + super().__init__( + agent_executor=MagicMock(), + task_store=MagicMock(), + agent_card=_agent_card(), + ) self.queue = AsyncMock() self.producer = MagicMock() self.background_tasks: list[asyncio.Task] = [] @@ -982,7 +1003,11 @@ async def consume_and_break_on_interrupt(self, _consumer, *, blocking, event_cal class _Handler(OpencodeRequestHandler): def __init__(self, aggregator: _Aggregator) -> None: - super().__init__(agent_executor=MagicMock(), task_store=MagicMock()) + super().__init__( + agent_executor=MagicMock(), + task_store=MagicMock(), + agent_card=_agent_card(), + ) self.aggregator = aggregator self.queue = AsyncMock() self.producer = MagicMock() @@ -1071,7 +1096,11 @@ async def consume_and_break_on_interrupt(self, _consumer, *, blocking, event_cal class _Handler(OpencodeRequestHandler): def __init__(self, aggregator: _Aggregator) -> None: - super().__init__(agent_executor=MagicMock(), task_store=MagicMock()) + super().__init__( + agent_executor=MagicMock(), + task_store=MagicMock(), + agent_card=_agent_card(), + ) self.aggregator = aggregator self.queue = AsyncMock() self.producer = MagicMock() @@ -1132,7 +1161,11 @@ def _apply_history_length(task: Task, configuration) -> Task: # noqa: ANN001 async def test_on_message_send_rejects_output_modes_without_text_plain() -> None: class _Handler(OpencodeRequestHandler): def __init__(self) -> None: - super().__init__(agent_executor=MagicMock(), task_store=MagicMock()) + super().__init__( + agent_executor=MagicMock(), + task_store=MagicMock(), + agent_card=_agent_card(), + ) self.setup_called = False async def _setup_message_execution(self, params, context=None): # noqa: ANN001 @@ -1161,7 +1194,11 @@ async def _setup_message_execution(self, params, context=None): # noqa: ANN001 async def test_on_message_send_stream_rejects_incompatible_output_modes_before_execution() -> None: class _Handler(OpencodeRequestHandler): def __init__(self) -> None: - super().__init__(agent_executor=MagicMock(), task_store=MagicMock()) + super().__init__( + agent_executor=MagicMock(), + task_store=MagicMock(), + agent_card=_agent_card(), + ) self.setup_called = False async def _setup_message_execution(self, params, context=None): # noqa: ANN001 @@ -1185,6 +1222,102 @@ async def _setup_message_execution(self, params, context=None): # noqa: ANN001 assert handler.setup_called is False +@pytest.mark.asyncio +async def test_on_message_send_rejects_shared_extension_metadata_without_negotiation() -> None: + class _Handler(OpencodeRequestHandler): + def __init__(self) -> None: + super().__init__( + agent_executor=MagicMock(), + task_store=MagicMock(), + agent_card=_agent_card(), + ) + self.setup_called = False + + async def _setup_message_execution(self, params, context=None): # noqa: ANN001 + del params, context + self.setup_called = True + raise AssertionError("_setup_message_execution should not be called") + + handler = _Handler() + params = types.SimpleNamespace( + message=types.SimpleNamespace( + metadata={ + "shared": { + "session": {"id": "ses-1"}, + "model": {"providerID": "openai", "modelID": "gpt-5"}, + } + } + ), + metadata={ + "opencode": { + "directory": "/workspace", + "workspace": {"id": "wrk-1"}, + } + }, + configuration=None, + ) + + with pytest.raises(UnsupportedOperationError) as exc_info: + await handler.on_message_send(params) + + assert exc_info.value.data == { + "type": "EXTENSION_NEGOTIATION_REQUIRED", + "fields": [ + "metadata.shared.session.id", + "metadata.opencode.directory", + "metadata.opencode.workspace.id", + "metadata.shared.model", + ], + "required_extensions": sorted( + [SESSION_BINDING_EXTENSION_URI, MODEL_SELECTION_EXTENSION_URI] + ), + "requested_extensions": [], + "header": "A2A-Extensions", + } + assert handler.setup_called is False + + +@pytest.mark.asyncio +async def test_on_message_send_stream_rejects_shared_extension_metadata_without_negotiation() -> ( + None +): + class _Handler(OpencodeRequestHandler): + def __init__(self) -> None: + super().__init__( + agent_executor=MagicMock(), + task_store=MagicMock(), + agent_card=_agent_card(), + ) + self.setup_called = False + + async def _setup_message_execution(self, params, context=None): # noqa: ANN001 + del params, context + self.setup_called = True + raise AssertionError("_setup_message_execution should not be called") + + handler = _Handler() + params = types.SimpleNamespace( + message=types.SimpleNamespace(metadata={"shared": {"session": {"id": "ses-1"}}}), + metadata={"opencode": {"directory": "/workspace"}}, + configuration=None, + ) + + with pytest.raises(UnsupportedOperationError) as exc_info: + await handler.on_message_send_stream(params).__anext__() + + assert exc_info.value.data == { + "type": "EXTENSION_NEGOTIATION_REQUIRED", + "fields": [ + "metadata.shared.session.id", + "metadata.opencode.directory", + ], + "required_extensions": [SESSION_BINDING_EXTENSION_URI], + "requested_extensions": [], + "header": "A2A-Extensions", + } + assert handler.setup_called is False + + def test_normalize_log_level_configure_logging_and_main(monkeypatch) -> None: assert _normalize_log_level("debug") == "DEBUG" diff --git a/tests/server/test_cancel_contract.py b/tests/server/test_cancel_contract.py index 64290a3..7c0a27d 100644 --- a/tests/server/test_cancel_contract.py +++ b/tests/server/test_cancel_contract.py @@ -7,15 +7,16 @@ from a2a.types import ( AgentCapabilities, AgentCard, + CancelTaskRequest, + Part, + SubscribeToTaskRequest, Task, - TaskIdParams, TaskNotCancelableError, TaskNotFoundError, TaskState, TaskStatus, ) -from opencode_a2a.a2a_utils import make_text_part from opencode_a2a.server.application import OpencodeRequestHandler @@ -31,10 +32,14 @@ def _store() -> InMemoryTaskStore: return InMemoryTaskStore(owner_resolver=lambda _context: "test-owner") +def _agent_card() -> AgentCard: + return AgentCard(name="opencode-a2a", capabilities=AgentCapabilities(streaming=True)) + + def _message_send_params(*, text: str = "hello") -> types.SimpleNamespace: return types.SimpleNamespace( configuration=None, - message=types.SimpleNamespace(parts=[make_text_part(text)]), + message=types.SimpleNamespace(parts=[Part(text=text)]), ) @@ -42,11 +47,15 @@ def _message_send_params(*, text: str = "hello") -> types.SimpleNamespace: async def test_cancel_is_idempotent_for_already_canceled_task() -> None: executor = AsyncMock() store = _store() - handler = OpencodeRequestHandler(agent_executor=executor, task_store=store) + handler = OpencodeRequestHandler( + agent_executor=executor, + task_store=store, + agent_card=_agent_card(), + ) task = _task(task_id="task-1", context_id="ctx-1", state=TaskState.TASK_STATE_CANCELED) await store.save(task, None) - result = await handler.on_cancel_task(TaskIdParams(id="task-1")) + result = await handler.on_cancel_task(CancelTaskRequest(id="task-1")) assert result is not None assert result.status.state == TaskState.TASK_STATE_CANCELED @@ -57,12 +66,16 @@ async def test_cancel_is_idempotent_for_already_canceled_task() -> None: async def test_cancel_rejects_completed_task() -> None: executor = AsyncMock() store = _store() - handler = OpencodeRequestHandler(agent_executor=executor, task_store=store) + handler = OpencodeRequestHandler( + agent_executor=executor, + task_store=store, + agent_card=_agent_card(), + ) task = _task(task_id="task-2", context_id="ctx-2", state=TaskState.TASK_STATE_COMPLETED) await store.save(task, None) with pytest.raises(TaskNotCancelableError): - await handler.on_cancel_task(TaskIdParams(id="task-2")) + await handler.on_cancel_task(CancelTaskRequest(id="task-2")) executor.cancel.assert_not_awaited() @@ -73,7 +86,11 @@ async def test_cancel_is_race_safe_when_task_becomes_canceled_during_super_call( ) -> None: executor = AsyncMock() store = _store() - handler = OpencodeRequestHandler(agent_executor=executor, task_store=store) + handler = OpencodeRequestHandler( + agent_executor=executor, + task_store=store, + agent_card=_agent_card(), + ) task = _task(task_id="task-race", context_id="ctx-race", state=TaskState.TASK_STATE_WORKING) await store.save(task, None) @@ -93,7 +110,7 @@ async def _consume_non_canceled(_self, _consumer): # noqa: ANN001 _consume_non_canceled, ) - result = await handler.on_cancel_task(TaskIdParams(id="task-race")) + result = await handler.on_cancel_task(CancelTaskRequest(id="task-race")) assert result is not None assert result.status.state == TaskState.TASK_STATE_CANCELED @@ -106,13 +123,13 @@ async def test_resubscribe_terminal_task_replays_final_snapshot_once() -> None: handler = OpencodeRequestHandler( agent_executor=executor, task_store=store, - agent_card=AgentCard(name="opencode-a2a", capabilities=AgentCapabilities(streaming=True)), + agent_card=_agent_card(), ) task = _task(task_id="task-3", context_id="ctx-3", state=TaskState.TASK_STATE_CANCELED) await store.save(task, None) events = [] - async for event in handler.on_resubscribe_to_task(TaskIdParams(id="task-3")): + async for event in handler.on_subscribe_to_task(SubscribeToTaskRequest(id="task-3")): events.append(event) assert len(events) == 1 @@ -127,13 +144,13 @@ async def test_resubscribe_non_terminal_without_queue_keeps_not_found_behavior() handler = OpencodeRequestHandler( agent_executor=executor, task_store=store, - agent_card=AgentCard(name="opencode-a2a", capabilities=AgentCapabilities(streaming=True)), + agent_card=_agent_card(), ) task = _task(task_id="task-4", context_id="ctx-4", state=TaskState.TASK_STATE_WORKING) await store.save(task, None) with pytest.raises(TaskNotFoundError): - async for _event in handler.on_resubscribe_to_task(TaskIdParams(id="task-4")): + async for _event in handler.on_subscribe_to_task(SubscribeToTaskRequest(id="task-4")): pass @@ -143,7 +160,11 @@ async def test_message_send_tracks_background_consumer_from_sdk_interrupt_path( ) -> None: executor = AsyncMock() store = _store() - handler = OpencodeRequestHandler(agent_executor=executor, task_store=store) + handler = OpencodeRequestHandler( + agent_executor=executor, + task_store=store, + agent_card=_agent_card(), + ) result_task = _task( task_id="task-5", context_id="ctx-5", state=TaskState.TASK_STATE_INPUT_REQUIRED diff --git a/tests/server/test_database_app_persistence.py b/tests/server/test_database_app_persistence.py index 7c9951b..8e48fcc 100644 --- a/tests/server/test_database_app_persistence.py +++ b/tests/server/test_database_app_persistence.py @@ -8,6 +8,7 @@ from opencode_a2a.opencode_upstream_client import OpencodeMessage from tests.support.helpers import make_settings +from tests.support.session_extensions import _extension_headers def _task(task_id: str, *, context_id: str = "ctx-1") -> Task: @@ -204,7 +205,7 @@ async def permission_reply( async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: query_response = await client.post( "/", - headers={"Authorization": "Bearer test-token"}, + headers=_extension_headers({"Authorization": "Bearer test-token"}), json={ "jsonrpc": "2.0", "id": 0, @@ -214,7 +215,7 @@ async def permission_reply( ) response = await client.post( "/", - headers={"Authorization": "Bearer test-token"}, + headers=_extension_headers({"Authorization": "Bearer test-token"}), json={ "jsonrpc": "2.0", "id": 1, diff --git a/tests/server/test_output_negotiation.py b/tests/server/test_output_negotiation.py index 6729b3e..94832ab 100644 --- a/tests/server/test_output_negotiation.py +++ b/tests/server/test_output_negotiation.py @@ -8,26 +8,33 @@ from a2a.server.tasks import TaskManager from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import ( + AgentCapabilities, + AgentCard, Artifact, + GetTaskRequest, Message, + Part, Role, + SubscribeToTaskRequest, Task, TaskArtifactUpdateEvent, - TaskIdParams, - TaskQueryParams, TaskState, TaskStatus, TaskStatusUpdateEvent, ) -from opencode_a2a.a2a_utils import make_data_part, make_text_part, part_text +from opencode_a2a.a2a_utils import make_data_part +from opencode_a2a.contracts.extensions import ( + MODEL_SELECTION_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + STREAMING_EXTENSION_URI, +) from opencode_a2a.output_modes import ( NegotiatingResultAggregator, apply_accepted_output_modes, build_output_negotiation_metadata, extract_accepted_output_modes_from_metadata, normalize_accepted_output_modes, - part_text_fallback, ) from opencode_a2a.server.application import OpencodeRequestHandler @@ -36,11 +43,15 @@ def _store() -> InMemoryTaskStore: return InMemoryTaskStore(owner_resolver=lambda _context: "test-owner") +def _agent_card() -> AgentCard: + return AgentCard(name="opencode-a2a", capabilities=AgentCapabilities(streaming=True)) + + def _message(*, message_id: str, text: str, task_id: str, context_id: str) -> Message: return Message( message_id=message_id, role=Role.ROLE_AGENT, - parts=[make_text_part(text)], + parts=[Part(text=text)], task_id=task_id, context_id=context_id, ) @@ -72,7 +83,7 @@ def _task_with_negotiated_outputs(*, task_id: str, context_id: str) -> Task: artifacts=[ Artifact( artifact_id=f"{task_id}:text", - parts=[make_text_part("plain result")], + parts=[Part(text="plain result")], ), Artifact( artifact_id=f"{task_id}:json", @@ -83,6 +94,34 @@ def _task_with_negotiated_outputs(*, task_id: str, context_id: str) -> Task: ) +def _task_with_extension_metadata(*, task_id: str, context_id: str) -> Task: + metadata = build_output_negotiation_metadata(["text/plain"]) + assert metadata is not None + metadata["shared"] = { + "session": {"id": "ses-1", "title": "Alpha"}, + "model": {"providerID": "openai", "modelID": "gpt-5"}, + "usage": {"input_tokens": 12}, + } + metadata["opencode"] = { + **metadata["opencode"], + "directory": "/workspace", + } + return Task( + id=task_id, + context_id=context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_COMPLETED, + message=_message( + message_id=f"{task_id}:status", + text="done", + task_id=task_id, + context_id=context_id, + ), + ), + metadata=metadata, + ) + + def test_normalize_accepted_output_modes_treats_wildcards_as_unrestricted() -> None: assert normalize_accepted_output_modes(["text/plain", "APPLICATION/JSON"]) == ( "text/plain", @@ -92,12 +131,6 @@ def test_normalize_accepted_output_modes_treats_wildcards_as_unrestricted() -> N assert normalize_accepted_output_modes(["*"]) is None -def test_part_text_fallback_serializes_data_parts_as_stable_json() -> None: - assert part_text_fallback(make_data_part({"tool": "bash", "status": "running"})) == ( - '{"status":"running","tool":"bash"}' - ) - - def test_apply_accepted_output_modes_downgrades_task_data_parts_to_text() -> None: task = Task( id="task-send", @@ -124,9 +157,11 @@ def test_apply_accepted_output_modes_downgrades_task_data_parts_to_text() -> Non assert isinstance(downgraded, Task) assert downgraded.status.message is not None - assert part_text(downgraded.status.message.parts[0]) == '{"status":"running","tool":"bash"}' + assert downgraded.status.message.parts[0].HasField("text") + assert downgraded.status.message.parts[0].text == '{"status":"running","tool":"bash"}' assert downgraded.artifacts is not None - assert part_text(downgraded.artifacts[0].parts[0]) == '{"status":"running","tool":"bash"}' + assert downgraded.artifacts[0].parts[0].HasField("text") + assert downgraded.artifacts[0].parts[0].text == '{"status":"running","tool":"bash"}' @pytest.mark.asyncio @@ -148,7 +183,7 @@ async def test_negotiating_result_aggregator_persists_metadata_for_artifact_firs context_id="ctx-artifact-first", artifact=Artifact( artifact_id="artifact-1", - parts=[make_text_part("hello")], + parts=[Part(text="hello")], ), append=False, last_chunk=False, @@ -185,9 +220,13 @@ async def test_on_get_task_applies_persisted_output_negotiation() -> None: store = _store() task = _task_with_negotiated_outputs(task_id="task-get", context_id="ctx-get") await store.save(task, ServerCallContext()) - handler = OpencodeRequestHandler(agent_executor=AsyncMock(), task_store=store) + handler = OpencodeRequestHandler( + agent_executor=AsyncMock(), + task_store=store, + agent_card=_agent_card(), + ) - result = await handler.on_get_task(TaskQueryParams(id="task-get")) + result = await handler.on_get_task(GetTaskRequest(id="task-get")) assert result is not None assert extract_accepted_output_modes_from_metadata(result.metadata) == ("text/plain",) @@ -196,7 +235,8 @@ async def test_on_get_task_applies_persisted_output_negotiation() -> None: "task-get:text", "task-get:json", ] - assert part_text(result.artifacts[1].parts[0]) == '{"status":"completed","tool":"bash"}' + assert result.artifacts[1].parts[0].HasField("text") + assert result.artifacts[1].parts[0].text == '{"status":"completed","tool":"bash"}' @pytest.mark.asyncio @@ -204,10 +244,14 @@ async def test_resubscribe_terminal_task_applies_persisted_output_negotiation() store = _store() task = _task_with_negotiated_outputs(task_id="task-resub", context_id="ctx-resub") await store.save(task, ServerCallContext()) - handler = OpencodeRequestHandler(agent_executor=AsyncMock(), task_store=store) + handler = OpencodeRequestHandler( + agent_executor=AsyncMock(), + task_store=store, + agent_card=_agent_card(), + ) events = [] - async for event in handler.on_resubscribe_to_task(TaskIdParams(id="task-resub")): + async for event in handler.on_subscribe_to_task(SubscribeToTaskRequest(id="task-resub")): events.append(event) assert len(events) == 1 @@ -217,4 +261,39 @@ async def test_resubscribe_terminal_task_applies_persisted_output_negotiation() "task-resub:text", "task-resub:json", ] - assert part_text(events[0].artifacts[1].parts[0]) == '{"status":"completed","tool":"bash"}' + assert events[0].artifacts[1].parts[0].HasField("text") + assert events[0].artifacts[1].parts[0].text == '{"status":"completed","tool":"bash"}' + + +@pytest.mark.asyncio +async def test_on_get_task_filters_unnegotiated_shared_extension_metadata() -> None: + store = _store() + task = _task_with_extension_metadata(task_id="task-filter", context_id="ctx-filter") + await store.save(task, ServerCallContext()) + handler = OpencodeRequestHandler( + agent_executor=AsyncMock(), + task_store=store, + agent_card=_agent_card(), + ) + + unnegotiated = await handler.on_get_task( + GetTaskRequest(id="task-filter"), context=ServerCallContext() + ) + assert unnegotiated is not None + assert "shared" not in unnegotiated.metadata + assert unnegotiated.metadata["opencode"]["directory"] == "/workspace" + + negotiated = await handler.on_get_task( + GetTaskRequest(id="task-filter"), + context=ServerCallContext( + requested_extensions={ + SESSION_BINDING_EXTENSION_URI, + MODEL_SELECTION_EXTENSION_URI, + STREAMING_EXTENSION_URI, + } + ), + ) + assert negotiated is not None + assert negotiated.metadata["shared"]["session"]["id"] == "ses-1" + assert negotiated.metadata["shared"]["model"]["providerID"] == "openai" + assert negotiated.metadata["shared"]["usage"]["input_tokens"] == 12 diff --git a/tests/server/test_transport_contract.py b/tests/server/test_transport_contract.py index dad8098..29348b9 100644 --- a/tests/server/test_transport_contract.py +++ b/tests/server/test_transport_contract.py @@ -5,26 +5,34 @@ import httpx import pytest -from a2a.server.apps.rest.rest_adapter import RESTAdapter +from a2a.server.routes.rest_dispatcher import RestDispatcher from a2a.types import ( Artifact, Message, + Part, Role, Task, TaskState, TaskStatus, ) -from opencode_a2a.a2a_utils import make_data_part, make_text_part +from opencode_a2a.a2a_utils import make_data_part +from opencode_a2a.contracts.extensions import ( + MODEL_SELECTION_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + SESSION_MANAGEMENT_EXTENSION_URI, + STREAMING_EXTENSION_URI, +) from opencode_a2a.output_modes import build_output_negotiation_metadata +from opencode_a2a.server.agent_card import build_agent_card from opencode_a2a.server.application import ( - AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL, - PUBLIC_AGENT_CARD_CACHE_CONTROL, - SESSION_MANAGEMENT_EXTENSION_URI, _normalize_log_level, - build_agent_card, create_app, ) +from opencode_a2a.server.middleware import ( + AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL, + PUBLIC_AGENT_CARD_CACHE_CONTROL, +) from opencode_a2a.trace_context import parse_traceparent from tests.support.helpers import ( DummyChatOpencodeUpstreamClient, @@ -55,7 +63,7 @@ def _task_for_listing( artifacts = [ Artifact( artifact_id=f"{task_id}-artifact", - parts=[make_text_part(f"artifact:{task_id}")], + parts=[Part(text=f"artifact:{task_id}")], ) ] history: list[Message] = [] @@ -64,7 +72,7 @@ def _task_for_listing( Message( message_id=f"{task_id}-history-{index}", role=Role.ROLE_AGENT, - parts=[make_text_part(f"history:{task_id}:{index}")], + parts=[Part(text=f"history:{task_id}:{index}")], context_id=context_id, task_id=task_id, ) @@ -85,8 +93,8 @@ def test_agent_card_declares_dual_stack_with_http_json_preferred() -> None: interfaces = { (iface.protocol_binding, iface.protocol_version) for iface in card.supported_interfaces } - assert ("HTTP+JSON", "0.3") in interfaces - assert ("JSONRPC", "0.3") in interfaces + assert ("HTTP+JSON", "1.0") in interfaces + assert ("JSONRPC", "1.0") in interfaces def test_normalize_log_level_falls_back_to_warning_for_invalid_value() -> None: @@ -101,7 +109,7 @@ async def test_public_agent_card_response_echoes_supplied_traceparent() -> None: async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: response = await client.get( - "/.well-known/agent.json", + "/.well-known/agent-card.json", headers={"traceparent": traceparent, "tracestate": "vendor=value"}, ) @@ -115,7 +123,7 @@ async def test_public_agent_card_response_generates_traceparent_when_missing() - transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - response = await client.get("/.well-known/agent.json") + response = await client.get("/.well-known/agent-card.json") assert response.status_code == 200 generated = response.headers.get("traceparent") @@ -123,6 +131,20 @@ async def test_public_agent_card_response_generates_traceparent_when_missing() - assert parse_traceparent(generated) is not None +@pytest.mark.asyncio +async def test_legacy_public_agent_card_path_is_not_exposed() -> None: + app = create_app(make_settings(test_bearer_token="test-token")) + transport = httpx.ASGITransport(app=app) + + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get( + "/.well-known/agent.json", + headers={"Authorization": "Bearer test-token"}, + ) + + assert response.status_code == 404 + + def test_rest_subscription_route_matches_current_sdk_contract() -> None: app = create_app(make_settings(test_bearer_token="test-token")) route_paths = {route.path for route in app.router.routes if hasattr(route, "path")} @@ -146,18 +168,14 @@ def test_rest_subscription_route_registers_distinct_get_and_post_operations() -> } -def test_rest_adapter_exposes_sdk_rest_routes() -> None: - rest_adapter = RESTAdapter( - agent_card=build_agent_card(make_settings(test_bearer_token="test-token")), - http_handler=MagicMock(), - ) - route_paths = {route[0] for route in rest_adapter.routes()} +def test_rest_dispatcher_exposes_sdk_rest_handlers() -> None: + rest_dispatcher = RestDispatcher(request_handler=MagicMock()) - assert "/v1/message:send" in route_paths - assert "/v1/message:stream" in route_paths - assert "/v1/tasks/{id}" in route_paths - assert "/v1/tasks/{id}:cancel" in route_paths - assert "/v1/tasks/{id}:subscribe" in route_paths + assert callable(rest_dispatcher.on_message_send) + assert callable(rest_dispatcher.on_message_send_stream) + assert callable(rest_dispatcher.on_get_task) + assert callable(rest_dispatcher.on_cancel_task) + assert callable(rest_dispatcher.on_subscribe_to_task) @pytest.mark.asyncio @@ -351,7 +369,7 @@ async def test_list_tasks_route_supports_history_artifacts_and_filters(monkeypat headers=headers, params={ "contextId": "ctx-filtered", - "status": "completed", + "status": "TASK_STATE_COMPLETED", "historyLength": "2", "includeArtifacts": "true", "statusTimestampAfter": now.isoformat(), @@ -391,7 +409,7 @@ async def test_list_tasks_route_applies_persisted_output_negotiation(monkeypatch artifacts=[ Artifact( artifact_id="task-negotiated-list-text", - parts=[make_text_part("plain")], + parts=[Part(text="plain")], ), Artifact( artifact_id="task-negotiated-list-json", @@ -427,6 +445,69 @@ async def test_list_tasks_route_applies_persisted_output_negotiation(monkeypatch ) +@pytest.mark.asyncio +async def test_list_tasks_route_filters_unnegotiated_shared_extension_metadata(monkeypatch) -> None: + import opencode_a2a.server.application as app_module + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", DummyChatOpencodeUpstreamClient) + app = app_module.create_app( + make_settings( + test_bearer_token="test-token", + a2a_task_store_backend="memory", + ) + ) + task_store = app.state.task_store + metadata = { + "shared": { + "session": {"id": "ses-1"}, + "model": {"providerID": "openai", "modelID": "gpt-5"}, + "usage": {"input_tokens": 7}, + }, + "opencode": {"directory": "/workspace"}, + } + await task_store.save( + Task( + id="task-shared-list", + context_id="ctx-shared-list", + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED, timestamp=datetime.now(UTC)), + metadata=metadata, + ) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + without_extensions = await client.get( + "/v1/tasks", + headers={"Authorization": "Bearer test-token"}, + params={"contextId": "ctx-shared-list"}, + ) + with_extensions = await client.get( + "/v1/tasks", + headers={ + "Authorization": "Bearer test-token", + "A2A-Extensions": ",".join( + [ + MODEL_SELECTION_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + STREAMING_EXTENSION_URI, + ] + ), + }, + params={"contextId": "ctx-shared-list"}, + ) + + assert without_extensions.status_code == 200 + assert without_extensions.json()["tasks"][0]["metadata"] == { + "opencode": {"directory": "/workspace"} + } + assert with_extensions.status_code == 200 + assert with_extensions.json()["tasks"][0]["metadata"]["shared"] == { + "session": {"id": "ses-1"}, + "model": {"providerID": "openai", "modelID": "gpt-5"}, + "usage": {"input_tokens": 7.0}, + } + + @pytest.mark.asyncio async def test_list_tasks_route_tolerates_invalid_stored_status_timestamp(monkeypatch) -> None: import opencode_a2a.server.application as app_module @@ -495,16 +576,71 @@ async def test_list_tasks_route_validates_query_parameters(monkeypatch) -> None: headers=headers, params={"pageToken": "invalid-token"}, ) + status_error = await client.get( + "/v1/tasks", + headers=headers, + params={"status": "completed"}, + ) assert page_size_error.status_code == 400 assert page_size_error.json() == { - "error": "pageSize must be between 1 and 100.", - "field": "pageSize", + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "pageSize must be between 1 and 100.", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_LIST_TASKS_REQUEST", + "domain": "a2a-protocol.org", + "metadata": {"field": "pageSize"}, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "field": "pageSize", + }, + ], + } } assert page_token_error.status_code == 400 assert page_token_error.json() == { - "error": "pageToken is invalid.", - "field": "pageToken", + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "pageToken is invalid.", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_LIST_TASKS_REQUEST", + "domain": "a2a-protocol.org", + "metadata": {"field": "pageToken"}, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "field": "pageToken", + }, + ], + } + } + assert status_error.status_code == 400 + assert status_error.json() == { + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "Unsupported task status 'completed'.", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_LIST_TASKS_REQUEST", + "domain": "a2a-protocol.org", + "metadata": {"field": "status"}, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "field": "status", + }, + ], + } } @@ -541,8 +677,8 @@ def test_openapi_jsonrpc_examples_include_core_message_methods() -> None: ) example_values = examples.values() methods = {value.get("value", {}).get("method") for value in example_values} - assert "message/send" in methods - assert "message/stream" in methods + assert "SendMessage" in methods + assert "SendStreamingMessage" in methods assert "message_send_file_input" in examples @@ -563,7 +699,7 @@ async def test_agent_card_routes_split_public_and_authenticated_extended_contrac assert public_card.headers["cache-control"] == PUBLIC_AGENT_CARD_CACHE_CONTROL assert public_card.headers["etag"] assert public_card.headers["vary"] == "Accept-Encoding" - assert public_card.json()["supportsAuthenticatedExtendedCard"] is True + assert public_card.json()["capabilities"]["extendedAgentCard"] is True public_cached = await client.get( "/.well-known/agent-card.json", @@ -609,7 +745,7 @@ async def test_agent_card_routes_split_public_and_authenticated_extended_contrac json={ "jsonrpc": "2.0", "id": "card-1", - "method": "agent/getAuthenticatedExtendedCard", + "method": "GetExtendedAgentCard", "params": {}, }, ) @@ -661,18 +797,36 @@ async def test_rest_endpoints_reject_unsupported_protocol_version() -> None: "message": { "messageId": "req-1", "role": "ROLE_USER", - "content": [{"text": "hello"}], + "parts": [{"text": "hello"}], } }, ) assert response.status_code == 400 assert response.json() == { - "error": "Unsupported A2A version", - "type": "VERSION_NOT_SUPPORTED", - "requested_version": "2.0", - "supported_protocol_versions": ["0.3", "1.0"], - "default_protocol_version": "0.3", + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "Unsupported A2A version", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "VERSION_NOT_SUPPORTED", + "domain": "a2a-protocol.org", + "metadata": { + "requestedVersion": "2.0", + "supportedProtocolVersions": '["1.0"]', + "defaultProtocolVersion": "1.0", + }, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "requestedVersion": "2.0", + "supportedProtocolVersions": ["1.0"], + "defaultProtocolVersion": "1.0", + }, + ], + } } @@ -689,7 +843,7 @@ async def test_rest_endpoints_return_v1_status_body_for_v1_protocol_errors() -> "message": { "messageId": "req-2", "role": "ROLE_USER", - "content": [{"text": "hello"}], + "parts": [{"text": "hello"}], } }, ) @@ -707,15 +861,15 @@ async def test_rest_endpoints_return_v1_status_body_for_v1_protocol_errors() -> "domain": "a2a-protocol.org", "metadata": { "requestedVersion": "1.1", - "supportedProtocolVersions": '["0.3","1.0"]', - "defaultProtocolVersion": "0.3", + "supportedProtocolVersions": '["1.0"]', + "defaultProtocolVersion": "1.0", }, }, { "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", "requestedVersion": "1.1", - "supportedProtocolVersions": ["0.3", "1.0"], - "defaultProtocolVersion": "0.3", + "supportedProtocolVersions": ["1.0"], + "defaultProtocolVersion": "1.0", }, ], } @@ -813,7 +967,7 @@ async def test_streaming_responses_remain_outside_gzip_middleware(monkeypatch) - "message": { "messageId": "gzip-stream-test", "role": "ROLE_USER", - "content": [{"text": "hello"}], + "parts": [{"text": "hello"}], } }, ) as response: @@ -835,18 +989,18 @@ async def test_dual_stack_send_accepts_transport_native_payloads(monkeypatch) -> "message": { "messageId": "m-rest", "role": "ROLE_USER", - "content": [{"text": "hello from rest"}], + "parts": [{"text": "hello from rest"}], } } rpc_payload = { "jsonrpc": "2.0", "id": 1, - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "m-rpc", - "role": "user", - "parts": [{"kind": "text", "text": "hello from jsonrpc"}], + "role": "ROLE_USER", + "parts": [{"text": "hello from jsonrpc"}], } }, } @@ -878,8 +1032,8 @@ async def test_v1_pascalcase_sendmessage_alias_is_accepted(monkeypatch) -> None: "params": { "message": { "messageId": "m-rpc-v1", - "role": "user", - "parts": [{"kind": "text", "text": "hello from v1 alias"}], + "role": "ROLE_USER", + "parts": [{"text": "hello from v1 dispatch"}], } }, } @@ -911,7 +1065,7 @@ async def test_dual_stack_send_rejects_cross_transport_payload_shapes(monkeypatc full_jsonrpc_envelope = { "jsonrpc": "2.0", "id": 3, - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "m-rest-cross-envelope", @@ -923,7 +1077,7 @@ async def test_dual_stack_send_rejects_cross_transport_payload_shapes(monkeypatc rpc_with_rest_shape = { "jsonrpc": "2.0", "id": 2, - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "m-rpc-cross", @@ -940,7 +1094,20 @@ async def test_dual_stack_send_rejects_cross_transport_payload_shapes(monkeypatc json=rest_with_jsonrpc_shape, ) assert rest_resp.status_code == 400 - assert "Invalid HTTP+JSON payload" in rest_resp.text + assert rest_resp.json() == { + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "REST message payload must use ROLE_* values for message.role.", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_REQUEST", + "domain": "a2a-protocol.org", + } + ], + } + } rest_envelope_resp = await client.post( "/v1/message:send", @@ -948,22 +1115,15 @@ async def test_dual_stack_send_rejects_cross_transport_payload_shapes(monkeypatc json=full_jsonrpc_envelope, ) assert rest_envelope_resp.status_code == 400 - assert "Invalid HTTP+JSON payload" in rest_envelope_resp.text - - v1_rest_resp = await client.post( - "/v1/message:send", - headers={**headers, "A2A-Version": "1.0"}, - json=rest_with_jsonrpc_shape, - ) - assert v1_rest_resp.status_code == 400 - assert v1_rest_resp.json() == { + assert rest_envelope_resp.json() == { "error": { "code": 400, "status": "INVALID_ARGUMENT", "message": ( - "Invalid HTTP+JSON payload for REST endpoint. " - "Use message.content with ROLE_* role values, or call " - "POST / with method=message/send or method=message/stream." + "Invalid JSON-RPC payload for REST endpoint. " + "Call POST / for JSON-RPC methods such as SendMessage " + "or SendStreamingMessage, or send ProtoJSON " + "SendMessageRequest payloads to the REST endpoint." ), "details": [ { @@ -980,6 +1140,27 @@ async def test_dual_stack_send_rejects_cross_transport_payload_shapes(monkeypatc } } + v1_rest_resp = await client.post( + "/v1/message:send", + headers={**headers, "A2A-Version": "1.0"}, + json=rest_with_jsonrpc_shape, + ) + assert v1_rest_resp.status_code == 400 + assert v1_rest_resp.json() == { + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "REST message payload must use ROLE_* values for message.role.", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_REQUEST", + "domain": "a2a-protocol.org", + }, + ], + } + } + rpc_resp = await client.post("/", headers=headers, json=rpc_with_rest_shape) assert rpc_resp.status_code == 200 payload = rpc_resp.json() @@ -991,7 +1172,7 @@ def _rest_message_payload() -> dict: "message": { "messageId": "m-rest", "role": "ROLE_USER", - "content": [{"text": "hello from rest"}], + "parts": [{"text": "hello from rest"}], } } @@ -1000,12 +1181,12 @@ def _jsonrpc_message_send_payload(text: str) -> dict: return { "jsonrpc": "2.0", "id": 99, - "method": "message/send", + "method": "SendMessage", "params": { "message": { "messageId": "m-rpc", - "role": "user", - "parts": [{"kind": "text", "text": text}], + "role": "ROLE_USER", + "parts": [{"text": text}], } }, } @@ -1106,8 +1287,8 @@ async def test_log_payloads_omits_text_plain_request_body(monkeypatch, caplog) - "Content-Type": "text/plain", } body = ( - '{"jsonrpc":"2.0","id":1,"method":"message/send","params":{"message":' - '{"messageId":"m","role":"user","parts":[{"kind":"text","text":"secret"}]}}}' + '{"jsonrpc":"2.0","id":1,"method":"SendMessage","params":{"message":' + '{"messageId":"m","role":"ROLE_USER","parts":[{"text":"secret"}]}}}' ) with caplog.at_level(logging.DEBUG, logger="opencode_a2a.server.application"): @@ -1140,8 +1321,8 @@ async def test_log_payloads_omits_when_content_length_missing(monkeypatch, caplo "Content-Type": "application/json", } body = ( - b'{"jsonrpc":"2.0","id":1,"method":"message/send","params":{"message":' - b'{"messageId":"m","role":"user","parts":[{"kind":"text","text":"missing-cl"}]}}}' + b'{"jsonrpc":"2.0","id":1,"method":"SendMessage","params":{"message":' + b'{"messageId":"m","role":"ROLE_USER","parts":[{"text":"missing-cl"}]}}}' ) async def _body_stream(): @@ -1218,7 +1399,26 @@ async def test_request_body_limit_rejects_oversized_content_length(monkeypatch) ) assert resp.status_code == 413 - assert resp.json() == {"error": "Request body too large", "max_bytes": 64} + assert resp.json() == { + "error": { + "code": 413, + "status": "RESOURCE_EXHAUSTED", + "message": "Request body too large", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "REQUEST_BODY_TOO_LARGE", + "domain": "a2a-protocol.org", + "metadata": {"maxBytes": "64", "actualSize": str(len(resp.request.content))}, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "maxBytes": 64, + "actualSize": len(resp.request.content), + }, + ], + } + } @pytest.mark.asyncio @@ -1237,10 +1437,8 @@ async def test_request_body_limit_rejects_oversized_stream_without_content_lengt "Content-Type": "application/json", } body = ( - b'{"jsonrpc":"2.0","id":1,"method":"message/send","params":{"message":' - b'{"messageId":"m","role":"user","parts":[{"kind":"text","text":"' - + (b"x" * 128) - + b'"}]}}}' + b'{"jsonrpc":"2.0","id":1,"method":"SendMessage","params":{"message":' + b'{"messageId":"m","role":"ROLE_USER","parts":[{"text":"' + (b"x" * 128) + b'"}]}}}' ) async def _body_stream(): @@ -1250,7 +1448,26 @@ async def _body_stream(): resp = await client.post("/", headers=headers, content=_body_stream()) assert resp.status_code == 413 - assert resp.json() == {"error": "Request body too large", "max_bytes": 64} + assert resp.json() == { + "error": { + "code": 413, + "status": "RESOURCE_EXHAUSTED", + "message": "Request body too large", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "REQUEST_BODY_TOO_LARGE", + "domain": "a2a-protocol.org", + "metadata": {"maxBytes": "64", "actualSize": str(len(body))}, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "maxBytes": 64, + "actualSize": len(body), + }, + ], + } + } @pytest.mark.asyncio diff --git a/tests/support/fake_client_errors.py b/tests/support/fake_client_errors.py new file mode 100644 index 0000000..d1bb119 --- /dev/null +++ b/tests/support/fake_client_errors.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from a2a.client.errors import A2AClientError + + +class FakeA2AClientHTTPError(A2AClientError): + def __init__(self, status_code: int, message: str) -> None: + self.status_code = status_code + super().__init__(f"HTTP Error {status_code}: {message}") + + +class FakeA2AClientJSONError(A2AClientError): + pass + + +class FakeA2AClientJSONRPCError(A2AClientError): + def __init__(self, response: object) -> None: + self.response = response + error = getattr(response, "error", None) + message = getattr(error, "message", "JSON-RPC error") + super().__init__(message) diff --git a/tests/support/helpers.py b/tests/support/helpers.py index 37db41d..0fcb813 100644 --- a/tests/support/helpers.py +++ b/tests/support/helpers.py @@ -10,8 +10,12 @@ from a2a.server.context import ServerCallContext from a2a.types import Message, Part, Role, SendMessageConfiguration, SendMessageRequest -from opencode_a2a.a2a_utils import make_data_part, make_text_part, make_url_part from opencode_a2a.config import Settings +from opencode_a2a.contracts.extensions import ( + MODEL_SELECTION_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + STREAMING_EXTENSION_URI, +) from opencode_a2a.opencode_upstream_client import OpencodeMessage, OpencodeMessagePage @@ -88,6 +92,22 @@ async def close(self) -> None: return None +def _default_requested_extensions() -> set[str]: + return { + MODEL_SELECTION_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + STREAMING_EXTENSION_URI, + } + + +def _ensure_test_call_context(call_context: Any | None) -> Any: + if call_context is None: + return ServerCallContext(requested_extensions=_default_requested_extensions()) + if not hasattr(call_context, "requested_extensions"): + call_context.requested_extensions = _default_requested_extensions() + return call_context + + def make_request_context_mock( *, task_id: str | None, @@ -109,6 +129,7 @@ def make_request_context_mock( if call_context_enabled: call_context = MagicMock(spec=ServerCallContext) call_context.state = {"identity": identity} if identity else {} + call_context.requested_extensions = _default_requested_extensions() context.call_context = call_context else: context.call_context = None @@ -141,10 +162,11 @@ def make_request_context( accepted_output_modes: list[str] | None = None, call_context: Any = None, ) -> RequestContext: + call_context = _ensure_test_call_context(call_context) message = Message( message_id=message_id, role=Role.ROLE_USER, - parts=[make_text_part(text)], + parts=[Part(text=text)], ) configuration = ( SendMessageConfiguration(accepted_output_modes=accepted_output_modes) @@ -160,36 +182,6 @@ def make_request_context( ) -def _normalize_test_part(part: Any) -> Part: - if isinstance(part, Part): - return part - text = getattr(part, "text", None) - if isinstance(text, str): - return make_text_part(text) - if hasattr(part, "data"): - return make_data_part(part.data) - if hasattr(part, "file"): - file_payload = part.file - mime_type = getattr(file_payload, "mimeType", None) - filename = getattr(file_payload, "name", None) - raw_bytes = getattr(file_payload, "bytes", None) - if isinstance(raw_bytes, str): - media_type = mime_type or "application/octet-stream" - return make_url_part( - f"data:{media_type};base64,{raw_bytes}", - filename=filename, - media_type=media_type, - ) - uri = getattr(file_payload, "uri", None) - if isinstance(uri, str): - return make_url_part( - uri, - filename=filename, - media_type=mime_type, - ) - raise TypeError(f"Unsupported test part payload: {type(part)!r}") - - def make_request_context_with_parts( *, task_id: str, @@ -200,10 +192,11 @@ def make_request_context_with_parts( call_context: Any = None, accepted_output_modes: list[str] | None = None, ) -> RequestContext: + call_context = _ensure_test_call_context(call_context) message = Message( message_id=message_id, role=Role.ROLE_USER, - parts=[_normalize_test_part(part) for part in parts], + parts=parts, ) configuration = ( SendMessageConfiguration(accepted_output_modes=accepted_output_modes) diff --git a/tests/support/jsonrpc_error_assertions.py b/tests/support/jsonrpc_error_assertions.py new file mode 100644 index 0000000..c8d2620 --- /dev/null +++ b/tests/support/jsonrpc_error_assertions.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from opencode_a2a.jsonrpc.error_responses import GOOGLE_RPC_ERROR_INFO_TYPE + + +def _camelize_key(name: str) -> str: + if "_" not in name: + return name + head, *tail = [part for part in name.split("_") if part] + return head + "".join(part[:1].upper() + part[1:] for part in tail) + + +def _stringify_metadata_value(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, bool | int | float): + return str(value) + raise TypeError(f"Unsupported metadata value: {value!r}") + + +def error_info_detail(error_payload: Mapping[str, Any]) -> dict[str, Any]: + data = error_payload.get("data") + if not isinstance(data, list): + raise TypeError(f"Expected list-backed error data, got {type(data)!r}") + for item in data: + if isinstance(item, dict) and item.get("@type") == GOOGLE_RPC_ERROR_INFO_TYPE: + return item + raise AssertionError("google.rpc.ErrorInfo detail not found") + + +def error_context_detail(error_payload: Mapping[str, Any]) -> dict[str, Any] | None: + data = error_payload.get("data") + if not isinstance(data, list): + return None + for item in data: + if isinstance(item, dict) and item.get("@type", "").startswith( + "type.googleapis.com/opencode_a2a." + ): + return item + return None + + +def assert_v1_error_reason( + error_payload: Mapping[str, Any], + *, + reason: str, + metadata: Mapping[str, Any] | None = None, +) -> None: + detail = error_info_detail(error_payload) + assert detail["reason"] == reason + assert detail["domain"] == "a2a-protocol.org" + if metadata is not None: + assert detail["metadata"] == { + _camelize_key(str(key)): _stringify_metadata_value(value) + for key, value in metadata.items() + } + + +def assert_v1_error_metadata_contains( + error_payload: Mapping[str, Any], + *, + reason: str, + metadata: Mapping[str, Any], +) -> None: + detail = error_info_detail(error_payload) + assert detail["reason"] == reason + actual = detail.get("metadata", {}) + for key, value in metadata.items(): + assert actual[_camelize_key(str(key))] == _stringify_metadata_value(value) + + +def assert_v1_error_context( + error_payload: Mapping[str, Any], + *, + metadata: Mapping[str, Any], +) -> None: + detail = error_context_detail(error_payload) + assert detail is not None + detail = dict(detail) + detail.pop("@type", None) + assert detail == {_camelize_key(str(key)): value for key, value in metadata.items()} diff --git a/tests/support/session_extensions.py b/tests/support/session_extensions.py index 2defc3d..d6302f3 100644 --- a/tests/support/session_extensions.py +++ b/tests/support/session_extensions.py @@ -2,10 +2,31 @@ from fastapi import FastAPI +from opencode_a2a.contracts.extensions import ( + INTERRUPT_CALLBACK_EXTENSION_URI, + INTERRUPT_RECOVERY_EXTENSION_URI, + MODEL_SELECTION_EXTENSION_URI, + PROVIDER_DISCOVERY_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + SESSION_MANAGEMENT_EXTENSION_URI, + STREAMING_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, +) + _BASE_SETTINGS = { "opencode_timeout": 1.0, "a2a_log_level": "DEBUG", } +_ALL_EXTENSION_URIS = ( + INTERRUPT_CALLBACK_EXTENSION_URI, + INTERRUPT_RECOVERY_EXTENSION_URI, + MODEL_SELECTION_EXTENSION_URI, + PROVIDER_DISCOVERY_EXTENSION_URI, + SESSION_BINDING_EXTENSION_URI, + SESSION_MANAGEMENT_EXTENSION_URI, + STREAMING_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, +) def _session_meta(payload: dict) -> dict: @@ -17,3 +38,9 @@ def _jsonrpc_app(app: FastAPI): if target is not None: return target raise AssertionError("JSON-RPC app handle not found") + + +def _extension_headers(headers: dict[str, str] | None = None) -> dict[str, str]: + merged = dict(headers or {}) + merged["A2A-Extensions"] = ",".join(sorted(_ALL_EXTENSION_URIS)) + return merged diff --git a/tests/support/streaming_output.py b/tests/support/streaming_output.py index ac8af22..c6df8ad 100644 --- a/tests/support/streaming_output.py +++ b/tests/support/streaming_output.py @@ -4,8 +4,8 @@ TaskArtifactUpdateEvent, TaskStatusUpdateEvent, ) +from google.protobuf.json_format import MessageToDict -from opencode_a2a.a2a_utils import part_data_to_python, proto_to_dict from opencode_a2a.opencode_upstream_client import OpencodeMessage from tests.support.helpers import ( DummyEventQueue, @@ -262,16 +262,16 @@ def _part_text(event: TaskArtifactUpdateEvent) -> str: def _part_data(event: TaskArtifactUpdateEvent) -> dict: part = event.artifact.parts[0] if hasattr(part, "HasField") and part.HasField("data"): - return part_data_to_python(part) or {} + return MessageToDict(part.data) or {} return getattr(part, "data", None) or getattr(getattr(part, "root", None), "data", {}) def _artifact_stream_meta(event: TaskArtifactUpdateEvent) -> dict: - return proto_to_dict(event.artifact.metadata).get("shared", {}).get("stream", {}) + return MessageToDict(event.artifact.metadata).get("shared", {}).get("stream", {}) def _status_shared_meta(event: TaskStatusUpdateEvent) -> dict: - return proto_to_dict(event.metadata).get("shared", {}) + return MessageToDict(event.metadata).get("shared", {}) def _interrupt_meta(event: TaskStatusUpdateEvent) -> dict: