diff --git a/README.md b/README.md index 6b20977..05ae0e2 100644 --- a/README.md +++ b/README.md @@ -75,12 +75,37 @@ Default local address: `http://127.0.0.1:8000` - A2A HTTP+JSON endpoints such as `/v1/message:send` and `/v1/message:stream` - A2A JSON-RPC support on `POST /` +- Peering capabilities: can act as a client via `opencode-a2a call` +- Autonomous tool execution: supports `a2a_call` tool for outbound agent-to-agent communication - SSE streaming with normalized `text`, `reasoning`, and `tool_call` blocks - Explicit REST SSE keepalive configurable through `A2A_STREAM_SSE_PING_SECONDS` - Session continuity through `metadata.shared.session.id` - Request-scoped model selection through `metadata.shared.model` - OpenCode-oriented JSON-RPC extensions for session and model/provider queries +## Peering Node + +`opencode-a2a` supports a "Peering Node" architecture where a single process handles both inbound (Server) and outbound (Client) A2A traffic. + +### CLI Client +Interact with other A2A agents directly from the command line: + +```bash +opencode-a2a call http://other-agent:8000 "How are you?" --token your-outbound-token +``` + +### Outbound Agent Calls (Tools) +The server can autonomously execute `a2a_call(url, message)` tool calls emitted by the OpenCode runtime. Results are fetched via A2A and returned to the model as tool results, enabling multi-agent orchestration. + +When the target peer requires bearer auth, configure `A2A_CLIENT_BEARER_TOKEN` +for server-side outbound calls. CLI calls can continue using `--token` or +`A2A_CLIENT_BEARER_TOKEN`. + +Server-side outbound client settings are fully wired through runtime config: +`A2A_CLIENT_TIMEOUT_SECONDS`, `A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS`, +`A2A_CLIENT_USE_CLIENT_PREFERENCE`, `A2A_CLIENT_BEARER_TOKEN`, and +`A2A_CLIENT_SUPPORTED_TRANSPORTS`. + Detailed protocol contracts, examples, and extension docs live in [`docs/guide.md`](docs/guide.md). @@ -107,6 +132,8 @@ This repository improves the service boundary around OpenCode, but it does not turn OpenCode into a hardened multi-tenant platform. - `A2A_BEARER_TOKEN` protects the A2A surface. +- `A2A_CLIENT_BEARER_TOKEN` is used for outbound peer calls initiated by the + server-side `a2a_call` tool. - Provider auth and default model configuration remain on the OpenCode side. - Deployment supervision is intentionally BYO. Use `systemd`, Docker, Kubernetes, or another supervisor if you need long-running operation. diff --git a/docs/guide.md b/docs/guide.md index 34f5406..8e832c0 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -69,7 +69,47 @@ Key variables to understand protocol behavior: `session.abort` in cancel flow. - `OPENCODE_TIMEOUT` / `OPENCODE_TIMEOUT_STREAM`: upstream request timeout and optional stream timeout override. +- `A2A_CLIENT_TIMEOUT_SECONDS`: outbound client timeout. Default: `30` seconds. +- `A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS`: outbound Agent Card fetch timeout. + Default: `5` seconds. +- `A2A_CLIENT_USE_CLIENT_PREFERENCE`: whether the outbound client prefers its own transport choices. +- `A2A_CLIENT_BEARER_TOKEN`: optional bearer token attached to outbound peer + calls made by the embedded A2A client and `a2a_call` tool path. +- `A2A_CLIENT_SUPPORTED_TRANSPORTS`: ordered outbound transport preference list. - Runtime authentication is bearer-token only via `A2A_BEARER_TOKEN`. +- The same outbound client flags are also honored by the server-side embedded + A2A client used for peer calls and `a2a_call` tool execution: + - `A2A_CLIENT_TIMEOUT_SECONDS` + - `A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS` + - `A2A_CLIENT_USE_CLIENT_PREFERENCE` + - `A2A_CLIENT_BEARER_TOKEN` + - `A2A_CLIENT_SUPPORTED_TRANSPORTS` + +## Client Initialization Facade (Preview) + +`opencode-a2a` now includes a minimal client bootstrap module in +`src/opencode_a2a/client/` to support downstream consumer usage while keeping +server and client concerns separate. + +Boundary separation: + +- Server code owns runtime request handling, transport orchestration, stream + behavior, and public compatibility profile exposure. +- Client code owns peer card discovery, SDK client construction, operation call + helpers, and protocol error normalization. + +Current client facade API: + +- `A2AClient.get_agent_card()` +- `A2AClient.send()` / `A2AClient.send_message()` +- `A2AClient.get_task()` +- `A2AClient.cancel_task()` +- `A2AClient.resubscribe_task()` + +Server-side outbound peer calls use bearer auth only for now. Configure +`A2A_CLIENT_BEARER_TOKEN` when the remote agent protects its runtime surface. +CLI outbound calls may pass `--token` explicitly or use +`A2A_CLIENT_BEARER_TOKEN`. Execution-boundary metadata is intentionally declarative deployment metadata: it is published through `RuntimeProfile`, Agent Card, OpenAPI, and `/health`, diff --git a/src/opencode_a2a/cli.py b/src/opencode_a2a/cli.py index 7a23724..68147d3 100644 --- a/src/opencode_a2a/cli.py +++ b/src/opencode_a2a/cli.py @@ -1,6 +1,8 @@ from __future__ import annotations import argparse +import asyncio +import os import sys from collections.abc import Sequence @@ -8,6 +10,45 @@ from .server.application import main as serve_main +async def run_call(agent_url: str, text: str, token: str | None = None) -> int: + from a2a.types import Message, TaskArtifactUpdateEvent, TaskStatusUpdateEvent + + from .client import A2AClient + + client = A2AClient(agent_url) + metadata = {} + if token: + # Use Authorization header for bearer auth. + metadata["authorization"] = f"Bearer {token}" + + try: + async for event in client.send_message(text, metadata=metadata): + if isinstance(event, tuple): + _, update = event + if isinstance(update, TaskArtifactUpdateEvent): + artifact = update.artifact + if artifact and artifact.parts: + for part in artifact.parts: + text_val = getattr(part.root, "text", None) + if isinstance(text_val, str): + print(text_val, end="", flush=True) + elif isinstance(update, TaskStatusUpdateEvent): + if update.status and update.status.state == "failed": + print(f"\n[Failed] {update.status.message or ''}") + elif isinstance(event, Message): + for part in event.parts: + text_val = getattr(part.root, "text", None) + if isinstance(text_val, str): + print(text_val, end="", flush=True) + print() # New line after completion + except Exception as exc: + print(f"\n[Error] {exc}", file=sys.stderr) + return 1 + finally: + await client.close() + return 0 + + def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="opencode-a2a", @@ -28,6 +69,20 @@ def build_parser() -> argparse.ArgumentParser: help="Start the OpenCode A2A runtime using environment-based settings.", description="Start the OpenCode A2A runtime using environment-based settings.", ) + + call_parser = subparsers.add_parser( + "call", + help="Call an A2A agent.", + description="Call an A2A agent using the A2A protocol.", + ) + call_parser.add_argument("agent_url", help="URL of the agent to call.") + call_parser.add_argument("text", help="Text message to send.") + call_parser.add_argument( + "--token", + help="Bearer token for authentication (can also use A2A_CLIENT_BEARER_TOKEN env).", + default=os.environ.get("A2A_CLIENT_BEARER_TOKEN"), + ) + return parser @@ -40,7 +95,14 @@ def main(argv: Sequence[str] | None = None) -> int: return 0 namespace = parser.parse_args(args) - if namespace.command in {None, "serve"}: + if namespace.command == "serve": + serve_main() + return 0 + + if namespace.command == "call": + return asyncio.run(run_call(namespace.agent_url, namespace.text, namespace.token)) + + if namespace.command is None: serve_main() return 0 diff --git a/src/opencode_a2a/client/__init__.py b/src/opencode_a2a/client/__init__.py new file mode 100644 index 0000000..28f83bb --- /dev/null +++ b/src/opencode_a2a/client/__init__.py @@ -0,0 +1,28 @@ +"""Reusable A2A client utilities and facade types.""" + +from .client import A2AClient +from .config import A2AClientSettings, load_settings +from .errors import ( + A2AAgentUnavailableError, + A2AClientError, + A2AClientResetRequiredError, + A2APeerProtocolError, + A2AUnsupportedBindingError, + A2AUnsupportedOperationError, +) +from .types import A2AClientEvent, A2AClientEventStream, A2AClientMetadata + +__all__ = [ + "A2AClient", + "A2AClientError", + "A2AAgentUnavailableError", + "A2AClientResetRequiredError", + "A2APeerProtocolError", + "A2AUnsupportedBindingError", + "A2AUnsupportedOperationError", + "A2AClientSettings", + "A2AClientEvent", + "A2AClientEventStream", + "A2AClientMetadata", + "load_settings", +] diff --git a/src/opencode_a2a/client/client.py b/src/opencode_a2a/client/client.py new file mode 100644 index 0000000..ea28b6d --- /dev/null +++ b/src/opencode_a2a/client/client.py @@ -0,0 +1,608 @@ +"""A2A client initialization and facade utilities for opencode-a2a consumers.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator, Mapping +from typing import Any, cast +from urllib.parse import urlsplit, urlunsplit +from uuid import uuid4 + +import httpx +from a2a.client import Client, ClientConfig, ClientFactory +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientJSONRPCError, +) +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.types import ( + Message, + Part, + Role, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskQueryParams, + TaskStatusUpdateEvent, + TextPart, +) +from a2a.utils.constants import ( + AGENT_CARD_WELL_KNOWN_PATH, + EXTENDED_AGENT_CARD_PATH, + PREV_AGENT_CARD_WELL_KNOWN_PATH, +) + +from .config import A2AClientSettings, load_settings +from .errors import ( + A2AAgentUnavailableError, + A2AClientResetRequiredError, + A2APeerProtocolError, + A2AUnsupportedBindingError, + A2AUnsupportedOperationError, +) +from .types import A2AClientEvent + + +class _HeaderInterceptor(ClientCallInterceptor): + def __init__(self, default_headers: Mapping[str, str] | None = None) -> None: + self._default_headers = { + key: value for key, value in dict(default_headers or {}).items() if value is not None + } + + async def intercept( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any], + agent_card: object | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + del method_name, agent_card + headers = dict(http_kwargs.get("headers") or {}) + headers.update(self._default_headers) + if context is not None: + dynamic_headers = context.state.get("headers") + if isinstance(dynamic_headers, Mapping): + for key, value in dynamic_headers.items(): + if isinstance(key, str) and value is not None: + headers[key] = str(value) + if headers: + http_kwargs["headers"] = headers + return request_payload, http_kwargs + + +class A2AClient: + """Factory-style facade for lightweight A2A client bootstrap and calls.""" + + def __init__( + self, + agent_url: str, + *, + settings: A2AClientSettings | None = None, + httpx_client: httpx.AsyncClient | None = None, + ) -> None: + if not agent_url or not agent_url.strip(): + raise ValueError("agent_url must be non-empty") + self.agent_url = agent_url.rstrip("/") + self._settings = settings or load_settings({}) + self._owns_httpx_client = httpx_client is None + self._httpx_client = httpx_client + self._client: Client | None = None + self._agent_card: object | None = None + self._lock = asyncio.Lock() + + async def close(self) -> None: + """Close cached client resources and owned HTTP transport.""" + self._client = None + if self._httpx_client is not None and self._owns_httpx_client: + await self._httpx_client.aclose() + + async def get_agent_card(self) -> Any: + """Fetch and cache peer Agent Card.""" + if self._agent_card is not None: + return self._agent_card + + resolver = await self._build_card_resolver() + try: + card = await resolver.get_agent_card(http_kwargs=self._build_resolver_http_kwargs()) + except A2AClientHTTPError as exc: + raise A2AAgentUnavailableError(str(exc)) from exc + except A2AClientJSONError as exc: + raise A2APeerProtocolError( + str(exc), + error_code="invalid_agent_card", + ) from exc + self._agent_card = card + return card + + async def send_message( + self, + text: str, + *, + context_id: str | None = None, + task_id: str | None = None, + message_id: str | None = None, + metadata: Mapping[str, Any] | None = None, + extensions: list[str] | None = None, + ) -> AsyncIterator[A2AClientEvent]: + """Send one user message and stream protocol events.""" + client = await self._ensure_client() + request_metadata, extra_headers = self._split_request_metadata(metadata) + request = self._build_user_message( + text=text, + context_id=context_id, + task_id=task_id, + message_id=message_id, + ) + try: + async for event in client.send_message( + request, + context=self._build_call_context(extra_headers), + request_metadata=request_metadata, + extensions=extensions, + ): + yield event + except A2AClientHTTPError as exc: + raise self._map_http_error("message/send", exc) from exc + except A2AClientJSONRPCError as exc: + raise self._map_jsonrpc_error(exc) from exc + + async def send( + self, + text: str, + *, + context_id: str | None = None, + task_id: str | None = None, + message_id: str | None = None, + metadata: Mapping[str, Any] | None = None, + extensions: list[str] | None = None, + ) -> A2AClientEvent: + """Send a message and return the terminal response/event.""" + last_event: A2AClientEvent = None + async for event in self.send_message( + text, + context_id=context_id, + task_id=task_id, + message_id=message_id, + metadata=metadata, + extensions=extensions, + ): + last_event = event + return last_event + + async def get_task( + self, + task_id: str, + *, + history_length: int | None = None, + metadata: Mapping[str, Any] | None = None, + ) -> Task: + """Fetch one task by id.""" + client = await self._ensure_client() + request_metadata, extra_headers = self._split_request_metadata(metadata) + try: + return await client.get_task( + TaskQueryParams( + id=task_id, + history_length=history_length, + metadata=request_metadata or {}, + ), + context=self._build_call_context(extra_headers), + ) + except A2AClientHTTPError as exc: + raise self._map_http_error("tasks/get", exc) from exc + except A2AClientJSONRPCError as exc: + raise self._map_jsonrpc_error(exc) from exc + + async def cancel_task( + self, + task_id: str, + *, + metadata: Mapping[str, Any] | None = None, + ) -> Task: + """Cancel one task by id.""" + client = await self._ensure_client() + request_metadata, extra_headers = self._split_request_metadata(metadata) + try: + return await client.cancel_task( + TaskIdParams(id=task_id, metadata=request_metadata or {}), + context=self._build_call_context(extra_headers), + ) + except A2AClientHTTPError as exc: + raise self._map_http_error("tasks/cancel", exc) from exc + except A2AClientJSONRPCError as exc: + raise self._map_jsonrpc_error(exc) from exc + + async def resubscribe_task( + self, + task_id: str, + *, + metadata: Mapping[str, Any] | None = None, + ) -> AsyncIterator[tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None]]: + """Resubscribe to task updates.""" + client = await self._ensure_client() + request_metadata, extra_headers = self._split_request_metadata(metadata) + try: + async for event in client.resubscribe( + TaskIdParams(id=task_id, metadata=request_metadata or {}), + context=self._build_call_context(extra_headers), + ): + yield event + except A2AClientHTTPError as exc: + raise self._map_http_error("tasks/resubscribe", exc) from exc + except A2AClientJSONRPCError as exc: + raise self._map_jsonrpc_error(exc) from exc + + async def _ensure_client(self) -> Client: + async with self._lock: + if self._client is not None: + return self._client + return await self._build_client() + + async def _build_client(self) -> Client: + card = await self.get_agent_card() + config = ClientConfig( + streaming=True, + polling=False, + httpx_client=await self._get_httpx_client(), + supported_transports=list(self._settings.supported_transports), + use_client_preference=self._settings.use_client_preference, + ) + try: + factory = ClientFactory(config, consumers=None) + client = factory.create(card, interceptors=self._build_interceptors()) + except ValueError as exc: + raise A2AUnsupportedBindingError( + f"No supported transport found for {self.agent_url}" + ) from exc + self._client = client + return client + + async def _get_httpx_client(self) -> httpx.AsyncClient: + if self._httpx_client is not None: + return self._httpx_client + self._httpx_client = httpx.AsyncClient(timeout=self._settings.default_timeout) + return self._httpx_client + + async def _build_card_resolver(self) -> A2ACardResolver: + parsed_url = urlsplit(self.agent_url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"agent_url must be absolute URL: {self.agent_url}") + + path = parsed_url.path or "" + normalized_no_leading = path.rstrip("/").lstrip("/") + candidate_paths = ( + AGENT_CARD_WELL_KNOWN_PATH, + PREV_AGENT_CARD_WELL_KNOWN_PATH, + EXTENDED_AGENT_CARD_PATH, + ) + + base_path = normalized_no_leading + agent_card_path = AGENT_CARD_WELL_KNOWN_PATH + for candidate_path in candidate_paths: + card_suffix = candidate_path.lstrip("/") + if normalized_no_leading.endswith(card_suffix): + base_path = normalized_no_leading[: -len(card_suffix)].rstrip("/") + agent_card_path = candidate_path + break + + base_url = urlunsplit( + ( + parsed_url.scheme, + parsed_url.netloc, + f"/{base_path}" if base_path else "", + "", + "", + ) + ).rstrip("/") + + return A2ACardResolver( + httpx_client=await self._get_httpx_client(), + base_url=base_url, + agent_card_path=agent_card_path, + ) + + def _build_user_message( + self, + *, + text: str, + context_id: str | None, + task_id: str | None, + message_id: str | None, + ) -> Message: + return Message( + role=Role.user, + message_id=message_id or str(uuid4()), + context_id=context_id, + task_id=task_id, + parts=self._normalize_parts(text), + metadata=None, + ) + + def _split_request_metadata( + self, + metadata: Mapping[str, Any] | None, + ) -> tuple[dict[str, Any] | None, dict[str, str] | None]: + request_metadata: dict[str, Any] = {} + extra_headers: dict[str, str] = {} + for key, value in dict(metadata or {}).items(): + if isinstance(key, str) and key.lower() == "authorization": + if value is not None: + extra_headers["Authorization"] = str(value) + continue + request_metadata[key] = value + return request_metadata or None, extra_headers or None + + def _build_call_context( + self, + extra_headers: Mapping[str, str] | None, + ) -> ClientCallContext | None: + default_headers = self._build_default_headers() + merged_headers = dict(default_headers) + if extra_headers: + merged_headers.update(extra_headers) + if not merged_headers: + return None + return ClientCallContext( + state={ + "headers": dict(merged_headers), + "http_kwargs": {"headers": dict(merged_headers)}, + } + ) + + def _build_default_headers(self) -> dict[str, str]: + if not self._settings.bearer_token: + return {} + return {"Authorization": f"Bearer {self._settings.bearer_token}"} + + def _build_interceptors(self) -> list[ClientCallInterceptor] | None: + default_headers = self._build_default_headers() + return [_HeaderInterceptor(default_headers)] + + def _build_resolver_http_kwargs(self) -> dict[str, Any]: + http_kwargs: dict[str, Any] = {"timeout": self._settings.card_fetch_timeout} + default_headers = self._build_default_headers() + if default_headers: + http_kwargs["headers"] = default_headers + return http_kwargs + + @classmethod + def extract_text(cls, payload: Any) -> str | None: + return cls._extract_text_from_payload(payload) + + @classmethod + def _extract_text_from_payload(cls, payload: Any) -> str | None: + def extract_from_iterable(items: Any) -> str | None: + if not isinstance(items, (list, tuple)): + return None + for item in items: + extracted = cls._extract_text_from_payload(item) + if extracted: + return extracted + return None + + def extract_from_parts(parts: Any) -> str | None: + if not isinstance(parts, (list, tuple)): + return None + collected: list[str] = [] + for part in parts: + text_part = None + if isinstance(part, TextPart): + text_part = part + else: + root = getattr(part, "root", None) + if isinstance(root, TextPart): + text_part = root + elif isinstance(part, Mapping): + text_value = part.get("text") + if isinstance(text_value, str) and text_value.strip(): + collected.append(text_value) + continue + mapped_root = part.get("root") + if isinstance(mapped_root, TextPart): + text_part = mapped_root + elif isinstance(part.get("role"), str): + nested = cls._extract_text_from_payload(part) + if nested: + collected.append(nested) + continue + if text_part and getattr(text_part, "text", None): + collected.append(text_part.text) + if collected: + return "\n".join(collected) + return None + + def extract_from_mapping(payload_map: Mapping[str, Any]) -> str | None: + for key in ( + "content", + "message", + "messages", + "result", + "status", + "text", + "parts", + "artifact", + "artifacts", + "history", + "events", + "root", + ): + if key not in payload_map: + continue + value = payload_map[key] + if value in (None, ""): + continue + if key == "text" and isinstance(value, (str, int, float, bool)): + text_value = str(value).strip() + if text_value: + return text_value + if key == "parts": + parts_text = extract_from_parts(value) + if parts_text: + return parts_text + if key == "artifact": + artifact_text = cls._extract_text_from_payload(value) + if artifact_text: + return artifact_text + if isinstance(value, (list, tuple)) and key in ( + "messages", + "artifacts", + "history", + "events", + ): + iterable_text = extract_from_iterable(value) + if iterable_text: + return iterable_text + nested_text = cls._extract_text_from_payload(value) + if nested_text: + return nested_text + return None + + if isinstance(payload, (list, tuple)): + return extract_from_iterable(payload) + + if isinstance(payload, Message): + return extract_from_parts(payload.parts) + + if isinstance(payload, str): + return payload.strip() or None + + status_payload = getattr(payload, "status", None) + if status_payload is not None: + text = cls._extract_text_from_payload(status_payload) + if text: + return text + + message_payload = getattr(payload, "message", None) + if message_payload is not None: + text = cls._extract_text_from_payload(message_payload) + if text: + return text + + artifact_payload = getattr(payload, "artifact", None) + if artifact_payload is not None: + text = cls._extract_text_from_payload(artifact_payload) + if text: + return text + + result_payload = getattr(payload, "result", None) + if result_payload is not None: + text = cls._extract_text_from_payload(result_payload) + if text: + return text + + history = getattr(payload, "history", None) + if isinstance(history, (list, tuple)) and history: + for item in reversed(history): + text = cls._extract_text_from_payload(item) + if text: + return text + + artifacts = getattr(payload, "artifacts", None) + if isinstance(artifacts, (list, tuple)): + for artifact in artifacts: + artifact_parts = getattr(artifact, "parts", None) + if isinstance(artifact_parts, (list, tuple)): + text = extract_from_parts(artifact_parts) + if text: + return text + + text = extract_from_parts(getattr(payload, "parts", None)) + if text: + return text + + event_text = extract_from_iterable(getattr(payload, "events", None)) + if event_text: + return event_text + + if isinstance(payload, Mapping): + mapped_text = extract_from_mapping(payload) + if mapped_text: + return mapped_text + + mapping_payload = None + if hasattr(payload, "model_dump") and callable(payload.model_dump): + payload_dict = payload.model_dump() + if isinstance(payload_dict, Mapping): + mapping_payload = payload_dict + elif hasattr(payload, "dict") and callable(payload.dict): + payload_dict = payload.dict() + if isinstance(payload_dict, Mapping): + mapping_payload = payload_dict + elif isinstance(getattr(payload, "__dict__", None), Mapping): + mapping_payload = dict(payload.__dict__) + + if mapping_payload is not None: + mapped_text = extract_from_mapping(mapping_payload) + if mapped_text: + return mapped_text + + return None + + @staticmethod + def _extract_jsonrpc_error_payload( + exc: A2AClientJSONRPCError, + ) -> tuple[str, int | None, object]: + error = getattr(exc, "error", None) + if error is None: + return str(exc), None, None + return ( + str(getattr(error, "message", str(exc))), + getattr(error, "code", None), + getattr(error, "data", None), + ) + + def _map_jsonrpc_error( + self, + exc: A2AClientJSONRPCError, + ) -> A2AUnsupportedOperationError | A2APeerProtocolError | A2AClientResetRequiredError: + message, code, data = self._extract_jsonrpc_error_payload(exc) + if code == -32601: + parsed_error = A2AUnsupportedOperationError(message) + parsed_error.error_code = "method_not_supported" + parsed_error.code = code + parsed_error.data = data + return parsed_error + if code == -32602: + return A2APeerProtocolError( + message, + error_code="invalid_params", + rpc_code=code, + data=data, + ) + if code == -32603: + return A2AClientResetRequiredError( + message, + ) + return A2APeerProtocolError( + message, + error_code="peer_protocol_error", + rpc_code=code, + data=data, + ) + + def _map_http_error( + self, + operation: str, + exc: A2AClientHTTPError, + ) -> A2AClientResetRequiredError | A2AUnsupportedOperationError | A2AAgentUnavailableError: + if exc.status_code in {404, 405, 409, 501}: + parsed_error = A2AUnsupportedOperationError(f"{operation} is not supported by peer") + parsed_error.http_status = exc.status_code + return parsed_error + if exc.status_code in {502, 503, 504}: + reset_error = A2AClientResetRequiredError( + f"{operation} failed with upstream instability" + ) + reset_error.http_status = exc.status_code + return reset_error + return A2AAgentUnavailableError(str(exc)) + + # keep parts construction explicitly typed for mypy compatibility in older stubs + def _normalize_parts(self, text: str) -> list[Part]: + return [cast(Part, TextPart(text=text))] + + +__all__ = ["A2AClient"] diff --git a/src/opencode_a2a/client/config.py b/src/opencode_a2a/client/config.py new file mode 100644 index 0000000..ec63c89 --- /dev/null +++ b/src/opencode_a2a/client/config.py @@ -0,0 +1,177 @@ +"""Configuration helpers for the opencode-a2a A2A client initialization layer.""" + +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from typing import Any + + +def _read_setting( + source: Any, + keys: Iterable[str], + *, + default: Any = None, +) -> Any: + if source is None: + return default + if isinstance(source, Mapping): + for key in keys: + if key in source: + return source[key] + return default + for key in keys: + if hasattr(source, key): + return getattr(source, key) + return default + + +def _coerce_float(name: str, value: Any, *, default: float) -> float: + if value is None: + return default + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + normalized = value.strip() + if not normalized: + return default + try: + return float(normalized) + except ValueError as exc: + raise ValueError(f"{name} must be a number, got {value!r}") from exc + raise ValueError(f"{name} must be a number, got {value!r}") + + +def _coerce_bool(name: str, value: Any, *, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + if lowered in {"t", "f"}: + return lowered == "t" + raise ValueError(f"{name} must be a boolean-like value, got {value!r}") + + +def _coerce_optional_str(name: str, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + normalized = value.strip() + return normalized or None + raise ValueError(f"{name} must be a string, got {value!r}") + + +def _normalize_transport(value: str) -> str: + normalized = value.strip().lower() + if normalized in {"jsonrpc", "json-rpc", "json_rpc"}: + return "JSONRPC" + if normalized in {"http+json", "http_json", "http-json", "httpjson", "http+json+"}: + return "HTTP+JSON" + if normalized in {"grpc"}: + return "GRPC" + if not normalized: + return "JSONRPC" + return value.strip() + + +def _parse_transports( + raw_value: Any, + *, + default: tuple[str, ...], +) -> tuple[str, ...]: + if raw_value is None: + return default + if isinstance(raw_value, str): + items = [part for part in raw_value.split(",") if part.strip()] + elif isinstance(raw_value, (list, tuple, set)): + items = [str(part) for part in raw_value] + else: + raise ValueError("supported_transports must be a comma-separated string or list") + + normalized = tuple(_normalize_transport(item) for item in items if str(item).strip()) + return normalized or default + + +@dataclass(frozen=True) +class A2AClientSettings: + """Runtime settings used by opencode-a2a client wrappers.""" + + default_timeout: float = 30.0 + use_client_preference: bool = False + card_fetch_timeout: float = 5.0 + bearer_token: str | None = None + supported_transports: tuple[str, ...] = ( + "JSONRPC", + "HTTP+JSON", + ) + + +def load_settings(raw_settings: Any) -> A2AClientSettings: + """Load client settings from an object or mapping.""" + + default_timeout = _coerce_float( + "A2A_CLIENT_TIMEOUT_SECONDS", + _read_setting( + raw_settings, + keys=("A2A_CLIENT_TIMEOUT_SECONDS", "a2a_client_timeout_seconds"), + default=30.0, + ), + default=30.0, + ) + card_fetch_timeout = _coerce_float( + "A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS", + _read_setting( + raw_settings, + keys=( + "A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS", + "a2a_client_card_fetch_timeout_seconds", + ), + default=5.0, + ), + default=5.0, + ) + use_client_preference = _coerce_bool( + "A2A_CLIENT_USE_CLIENT_PREFERENCE", + _read_setting( + raw_settings, + keys=("A2A_CLIENT_USE_CLIENT_PREFERENCE", "a2a_client_use_client_preference"), + default=False, + ), + default=False, + ) + bearer_token = _coerce_optional_str( + "A2A_CLIENT_BEARER_TOKEN", + _read_setting( + raw_settings, + keys=("A2A_CLIENT_BEARER_TOKEN", "a2a_client_bearer_token"), + default=None, + ), + ) + supported_transports = _parse_transports( + _read_setting( + raw_settings, + keys=( + "A2A_CLIENT_SUPPORTED_TRANSPORTS", + "a2a_client_supported_transports", + ), + default=("JSONRPC", "HTTP+JSON"), + ), + default=("JSONRPC", "HTTP+JSON"), + ) + + return A2AClientSettings( + default_timeout=default_timeout, + use_client_preference=use_client_preference, + card_fetch_timeout=card_fetch_timeout, + bearer_token=bearer_token, + supported_transports=supported_transports, + ) + + +__all__ = ["A2AClientSettings", "load_settings"] diff --git a/src/opencode_a2a/client/errors.py b/src/opencode_a2a/client/errors.py new file mode 100644 index 0000000..47a8b4b --- /dev/null +++ b/src/opencode_a2a/client/errors.py @@ -0,0 +1,65 @@ +"""Error definitions for client initialization and runtime delegation.""" + +from __future__ import annotations + + +class A2AClientError(RuntimeError): + """Base error for opencode-a2a A2A client wrapper.""" + + error_code = "client_error" + code: int | None = None + data: object | None = None + http_status: int | None = None + + +class A2AAgentUnavailableError(A2AClientError): + """Raised when a remote A2A peer cannot be reached.""" + + error_code = "agent_unavailable" + + +class A2AClientResetRequiredError(A2AAgentUnavailableError): + """Raised when the cached transport should be rebuilt.""" + + error_code = "reset_required" + + +class A2AUnsupportedBindingError(A2AClientError): + """Raised when local and remote transport configuration has no overlap.""" + + error_code = "unsupported_binding" + + +class A2AUnsupportedOperationError(A2AClientError): + """Raised when peer does not support an attempted operation.""" + + error_code = "unsupported_operation" + + +class A2APeerProtocolError(A2AClientError): + """Raised when peer response violates JSON-RPC / task contract.""" + + def __init__( + self, + message: str, + *, + error_code: str = "peer_protocol_error", + rpc_code: int | None = None, + http_status: int | None = None, + data: object | None = None, + ) -> None: + super().__init__(message) + self.error_code = error_code + self.code = rpc_code + self.http_status = http_status + self.data = data + + +__all__ = [ + "A2AClientError", + "A2AAgentUnavailableError", + "A2AClientResetRequiredError", + "A2AUnsupportedBindingError", + "A2AUnsupportedOperationError", + "A2APeerProtocolError", +] diff --git a/src/opencode_a2a/client/types.py b/src/opencode_a2a/client/types.py new file mode 100644 index 0000000..3028433 --- /dev/null +++ b/src/opencode_a2a/client/types.py @@ -0,0 +1,21 @@ +"""Public type hints for the lightweight opencode-a2a client facade.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping +from typing import Any + +from a2a.types import ( + Message, + Task, + TaskArtifactUpdateEvent, + TaskStatusUpdateEvent, +) + +A2AClientEvent = ( + Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None +) +A2AClientEventStream = AsyncIterator[A2AClientEvent] +A2AClientMetadata = Mapping[str, Any] + +__all__ = ["A2AClientEvent", "A2AClientEventStream", "A2AClientMetadata"] diff --git a/src/opencode_a2a/config.py b/src/opencode_a2a/config.py index b57d426..4699aad 100644 --- a/src/opencode_a2a/config.py +++ b/src/opencode_a2a/config.py @@ -148,7 +148,27 @@ class Settings(BaseSettings): alias="A2A_CANCEL_ABORT_TIMEOUT_SECONDS", ) - @field_validator("a2a_sandbox_writable_roots", "a2a_network_allowed_domains", mode="before") + # Outbound A2A Client settings + a2a_client_timeout_seconds: float = Field(default=30.0, alias="A2A_CLIENT_TIMEOUT_SECONDS") + a2a_client_card_fetch_timeout_seconds: float = Field( + default=5.0, + alias="A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS", + ) + a2a_client_use_client_preference: bool = Field( + default=False, alias="A2A_CLIENT_USE_CLIENT_PREFERENCE" + ) + a2a_client_bearer_token: str | None = Field(default=None, alias="A2A_CLIENT_BEARER_TOKEN") + a2a_client_supported_transports: DeclaredStringList = Field( + default=("JSONRPC", "HTTP+JSON"), + alias="A2A_CLIENT_SUPPORTED_TRANSPORTS", + ) + + @field_validator( + "a2a_sandbox_writable_roots", + "a2a_network_allowed_domains", + "a2a_client_supported_transports", + mode="before", + ) @classmethod def _normalize_declared_lists(cls, value: Any) -> tuple[str, ...]: return _parse_declared_list(value) diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index b773787..0b91e7e 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -11,7 +11,10 @@ from contextlib import suppress from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from ..server.application import A2AClientManager import httpx from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -184,15 +187,61 @@ async def run(self) -> None: try: await self._bind_session() await self._enqueue_working_status() - response = await self._send_message() - if self._pending_preferred_claim: - await self._executor._finalize_preferred_session_binding( - identity=self._prepared.identity, - context_id=self._context_id, - session_id=self._session_id, - ) - self._pending_preferred_claim = False - await self._handle_response(response) + + turn_request_parts = list(self._prepared.request_parts) + user_text = self._prepared.user_text + + while True: + send_kwargs: dict[str, Any] = { + "directory": self._prepared.directory, + "model_override": self._prepared.model_override, + } + if self._prepared.streaming_request: + send_kwargs["timeout_override"] = self._executor._client.stream_timeout + + if not self._prepared.use_structured_parts and not turn_request_parts: + response = await self._executor._client.send_message( + self._session_id, + user_text, + **send_kwargs, + ) + else: + response = await self._executor._client.send_message( + self._session_id, + user_text or None, + parts=turn_request_parts, + **send_kwargs, + ) + + if self._pending_preferred_claim: + await self._executor._finalize_preferred_session_binding( + identity=self._prepared.identity, + context_id=self._context_id, + session_id=self._session_id, + ) + self._pending_preferred_claim = False + + # Check for tool calls that we should handle + tool_results = await self._executor._maybe_handle_tools(response.raw) + if tool_results: + # Clear user_text/parts for the next turn with tool results. + user_text = "" + turn_request_parts = [ + { + "type": "tool", + "tool": res["tool"], + "call_id": res["call_id"], + "output": res.get("output"), + "error": res.get("error"), + } + for res in tool_results + ] + # Loop back to send tool results + continue + + await self._handle_response(response) + break + except httpx.HTTPStatusError as exc: logger.exception("OpenCode request failed with HTTP error") error_type, state, message = _format_upstream_error( @@ -508,10 +557,12 @@ def __init__( cancel_abort_timeout_seconds: float = 2.0, session_cache_ttl_seconds: int = 3600, session_cache_maxsize: int = 10_000, + a2a_client_manager: A2AClientManager | None = None, ) -> None: self._client = client self._streaming_enabled = streaming_enabled self._cancel_abort_timeout_seconds = max(0.0, float(cancel_abort_timeout_seconds)) + self._a2a_client_manager = a2a_client_manager self._sessions = _TTLCache( ttl_seconds=session_cache_ttl_seconds, maxsize=session_cache_maxsize, @@ -604,6 +655,122 @@ async def release_session_for_control(self, *, identity: str, session_id: str) - """Release pending control-session ownership on failure.""" await self._release_preferred_session_claim(identity=identity, session_id=session_id) + async def _maybe_handle_tools( + self, raw_response: dict[str, Any] + ) -> list[dict[str, Any]] | None: + """Heuristically detect and execute A2A tool calls from upstream OpenCode.""" + parts = raw_response.get("parts", []) + if not isinstance(parts, list): + return None + + results: list[dict[str, Any]] = [] + for part in parts: + if not isinstance(part, dict) or part.get("type") != "tool": + continue + + state = part.get("state") + if not isinstance(state, dict) or state.get("status") != "calling": + continue + + tool_name = part.get("tool") + if tool_name == "a2a_call": + result = await self._handle_a2a_call_tool(part) + if result: + results.append(result) + + return results if results else None + + async def _handle_a2a_call_tool(self, part: dict[str, Any]) -> dict[str, Any]: + call_id = part.get("callID") or str(uuid.uuid4()) + tool_name = part.get("tool") or "a2a_call" + state = part.get("state", {}) + inputs = state.get("input", {}) + + if not isinstance(inputs, dict): + return {"call_id": call_id, "tool": tool_name, "error": "Invalid input format"} + + agent_url = inputs.get("url") + message = inputs.get("message") + if not agent_url or not message: + return {"call_id": call_id, "tool": tool_name, "error": "Missing url or message"} + + mgr = self._a2a_client_manager + if mgr is None: + return { + "call_id": call_id, + "tool": tool_name, + "error": "A2A client manager not available", + } + + try: + client = await mgr.get_client(agent_url) + event = None + result_text = "" + async for current_event in client.send_message(message): + event = current_event + extracted = client.extract_text(current_event) + if extracted: + result_text = self._merge_streamed_tool_output(result_text, extracted) + + from a2a.types import Task + + if result_text: + return { + "call_id": call_id, + "tool": tool_name, + "output": result_text, + } + + if isinstance(event, Task): + result_text = "" + # Extract text from Task's assistant message if available + if event.status and event.status.message: + for part_obj in event.status.message.parts: + # Use dict-style access if available or getattr to satisfy type checkers + root = getattr(part_obj, "root", part_obj) + text_val = getattr(root, "text", "") + if text_val: + result_text += str(text_val) + return { + "call_id": call_id, + "tool": tool_name, + "output": result_text or "Task completed.", + } + + # Handle case where event is a tuple (Task, Update) + if isinstance(event, tuple) and len(event) > 0 and isinstance(event[0], Task): + return { + "call_id": call_id, + "tool": tool_name, + "output": "Task completed (streaming).", + } + + return { + "call_id": call_id, + "tool": tool_name, + "error": f"Unexpected agent response type: {type(event).__name__}", + } + except Exception as exc: + logger.exception("A2A tool call failed") + return {"call_id": call_id, "tool": tool_name, "error": str(exc)} + + @staticmethod + def _merge_streamed_tool_output(current: str, incoming: str) -> str: + if not current: + return incoming + if incoming == current or incoming in current: + return current + if incoming.startswith(current): + return incoming + if current.startswith(incoming): + return current + separator = ( + "" + if current.endswith(("\n", " ", "\t")) or incoming.startswith(("\n", " ", "\t")) + else "\n" + ) + return f"{current}{separator}{incoming}" + async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: task_id = context.task_id context_id = context.context_id diff --git a/src/opencode_a2a/parts/mapping.py b/src/opencode_a2a/parts/mapping.py index 6f15d57..6a6f0a0 100644 --- a/src/opencode_a2a/parts/mapping.py +++ b/src/opencode_a2a/parts/mapping.py @@ -19,7 +19,15 @@ class OpencodeFileInputPart(TypedDict, total=False): filename: str -OpencodeInputPart = OpencodeTextInputPart | OpencodeFileInputPart +class OpencodeToolResultPart(TypedDict, total=False): + type: Literal["tool"] + tool: str + call_id: str + output: str + error: str + + +OpencodeInputPart = OpencodeTextInputPart | OpencodeFileInputPart | OpencodeToolResultPart def extract_text_from_a2a_parts(parts: Any) -> str: diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index 1803df2..ce35ef9 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -30,6 +30,7 @@ from fastapi.responses import JSONResponse from starlette.responses import StreamingResponse +from ..client import A2AClient from ..config import Settings from ..contracts.extensions import ( COMPATIBILITY_PROFILE_EXTENSION_URI, @@ -337,14 +338,48 @@ async def bearer_auth(request: Request, call_next): return await call_next(request) +class A2AClientManager: + def __init__(self, settings: Settings) -> None: + from ..client.config import load_settings as load_client_settings + + self.client_settings = load_client_settings( + { + "A2A_CLIENT_TIMEOUT_SECONDS": settings.a2a_client_timeout_seconds, + "A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS": ( + settings.a2a_client_card_fetch_timeout_seconds + ), + "A2A_CLIENT_USE_CLIENT_PREFERENCE": settings.a2a_client_use_client_preference, + "A2A_CLIENT_BEARER_TOKEN": settings.a2a_client_bearer_token, + "A2A_CLIENT_SUPPORTED_TRANSPORTS": settings.a2a_client_supported_transports, + } + ) + self.clients: dict[str, A2AClient] = {} + self._lock = asyncio.Lock() + + async def get_client(self, agent_url: str) -> A2AClient: + async with self._lock: + url = agent_url.rstrip("/") + if url not in self.clients: + self.clients[url] = A2AClient(url, settings=self.client_settings) + return self.clients[url] + + async def close_all(self) -> None: + async with self._lock: + for client in self.clients.values(): + await client.close() + self.clients.clear() + + def create_app(settings: Settings) -> FastAPI: upstream_client = OpencodeUpstreamClient(settings) + client_manager = A2AClientManager(settings) executor = OpencodeAgentExecutor( upstream_client, streaming_enabled=True, cancel_abort_timeout_seconds=settings.a2a_cancel_abort_timeout_seconds, session_cache_ttl_seconds=settings.a2a_session_cache_ttl_seconds, session_cache_maxsize=settings.a2a_session_cache_maxsize, + a2a_client_manager=client_manager, ) task_store = InMemoryTaskStore() handler = OpencodeRequestHandler( @@ -352,11 +387,6 @@ def create_app(settings: Settings) -> FastAPI: task_store=task_store, ) - @asynccontextmanager - async def lifespan(_app: FastAPI): - yield - await upstream_client.close() - agent_card = build_agent_card(settings) context_builder = IdentityAwareCallContextBuilder() runtime_profile = build_runtime_profile(settings) @@ -388,6 +418,12 @@ async def lifespan(_app: FastAPI): context_builder=context_builder, ) + @asynccontextmanager + async def lifespan(_app: FastAPI): + yield + await client_manager.close_all() + await upstream_client.close() + app = A2AFastAPI( title=settings.a2a_title, version=settings.a2a_version, @@ -397,6 +433,7 @@ async def lifespan(_app: FastAPI): for route, callback in rest_adapter.routes().items(): app.add_api_route(route[0], callback, methods=[route[1]]) app.state.opencode_agent_executor = executor + app.state.a2a_client_manager = client_manager _patch_jsonrpc_openapi_contract(app, settings, runtime_profile=runtime_profile) @app.get("/health") diff --git a/tests/client/__init__.py b/tests/client/__init__.py new file mode 100644 index 0000000..475f2fe --- /dev/null +++ b/tests/client/__init__.py @@ -0,0 +1 @@ +"""Tests for client-side initialization scaffold.""" diff --git a/tests/client/test_client_config.py b/tests/client/test_client_config.py new file mode 100644 index 0000000..067e20c --- /dev/null +++ b/tests/client/test_client_config.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest + +from opencode_a2a.client.config import A2AClientSettings, load_settings + + +def test_load_settings_default() -> None: + settings = load_settings({}) + + assert settings == A2AClientSettings() + + +def test_load_settings_from_mapping() -> None: + raw = { + "A2A_CLIENT_TIMEOUT_SECONDS": "42", + "A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS": 6, + "A2A_CLIENT_USE_CLIENT_PREFERENCE": "true", + "A2A_CLIENT_BEARER_TOKEN": "peer-token", + "A2A_CLIENT_SUPPORTED_TRANSPORTS": "json-rpc,http-json", + } + + settings = load_settings(raw) + + assert settings.default_timeout == 42.0 + assert settings.card_fetch_timeout == 6.0 + assert settings.use_client_preference is True + assert settings.bearer_token == "peer-token" + assert settings.supported_transports == ("JSONRPC", "HTTP+JSON") + + +def test_load_settings_invalid_transport_raises() -> None: + with pytest.raises(ValueError, match="supported_transports"): + load_settings({"A2A_CLIENT_SUPPORTED_TRANSPORTS": 1}) + + +def test_load_settings_invalid_bool_raises() -> None: + with pytest.raises(ValueError, match="boolean-like"): + load_settings({"A2A_CLIENT_USE_CLIENT_PREFERENCE": "maybe"}) + + +def test_load_settings_invalid_bearer_token_type_raises() -> None: + with pytest.raises(ValueError, match="must be a string"): + load_settings({"A2A_CLIENT_BEARER_TOKEN": 123}) diff --git a/tests/client/test_client_facade.py b/tests/client/test_client_facade.py new file mode 100644 index 0000000..1f44ca2 --- /dev/null +++ b/tests/client/test_client_facade.py @@ -0,0 +1,656 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from unittest.mock import AsyncMock + +import httpx +import pytest +from a2a.client import ClientConfig +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError, A2AClientJSONRPCError +from a2a.types import ( + Artifact, + JSONRPCError, + JSONRPCErrorResponse, + Message, + Part, + Role, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TextPart, +) + +from opencode_a2a.client import A2AClient +from opencode_a2a.client import client as client_module +from opencode_a2a.client.config import A2AClientSettings +from opencode_a2a.client.errors import ( + A2AAgentUnavailableError, + A2AClientResetRequiredError, + A2APeerProtocolError, + A2AUnsupportedOperationError, +) + + +class _FakeCardResolver: + def __init__(self, card: object) -> None: + self._card = card + + self.get_calls = 0 + + async def get_agent_card(self, **_kwargs: object) -> object: + self.get_calls += 1 + return self._card + + +class _FakeClient: + def __init__( + self, + events: list[object] | None = None, + *, + fail: BaseException | None = None, + ): + self._events = list(events or []) + self._fail = fail + self.send_message_inputs: list[tuple[object, object, object]] = [] + self.task_inputs: list[tuple[object, object]] = [] + self.cancel_inputs: list[tuple[object, object]] = [] + self.resubscribe_inputs: list[tuple[object, object]] = [] + + async def send_message(self, message, *args: object, **kwargs: object) -> AsyncIterator[object]: + self.send_message_inputs.append((message, args, kwargs)) + if self._fail: + raise self._fail + for event in self._events: + yield event + + async def get_task(self, params, *args: object, **kwargs: object) -> object: + self.task_inputs.append((params, kwargs)) + if self._fail: + raise self._fail + return {"task_id": params.id} + + async def cancel_task(self, params, *args: object, **kwargs: object) -> object: + self.cancel_inputs.append((params, kwargs)) + if self._fail: + raise self._fail + return {"task_id": params.id, "status": "canceled"} + + async def resubscribe(self, params, *args: object, **kwargs: object) -> AsyncIterator[object]: + self.resubscribe_inputs.append((params, kwargs)) + if self._fail: + raise self._fail + for event in self._events: + yield event + + +@pytest.mark.asyncio +async def test_get_agent_card_cached_and_reused(monkeypatch: pytest.MonkeyPatch) -> None: + resolver = _FakeCardResolver("agent-card") + + async def _build_card_resolver(self: A2AClient) -> _FakeCardResolver: + return resolver + + client = A2AClient("http://agent.example.com") + monkeypatch.setattr(A2AClient, "_build_card_resolver", _build_card_resolver) + first = await client.get_agent_card() + second = await client.get_agent_card() + assert first == second == "agent-card" + assert resolver.get_calls == 1 + + +@pytest.mark.asyncio +async def test_build_card_resolver_strips_explicit_well_known_path( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: dict[str, str] = {} + + class _FakeResolver: + def __init__( + self, + *, + base_url: str, + agent_card_path: str, + httpx_client: object, + ) -> None: + captured["base_url"] = base_url + captured["agent_card_path"] = agent_card_path + + async def get_agent_card(self, **kwargs: object) -> str: + return "agent-card" + + monkeypatch.setattr(client_module, "A2ACardResolver", _FakeResolver) + + client = A2AClient("https://ops.example.com/tenant/.well-known/agent-card.json") + await client.get_agent_card() + + assert captured["base_url"] == "https://ops.example.com/tenant" + assert captured["agent_card_path"] == "/.well-known/agent-card.json" + + +@pytest.mark.asyncio +async def test_build_client_uses_settings_and_transport_config( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake_http_client = AsyncMock(spec=httpx.AsyncClient) + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings( + default_timeout=10, + use_client_preference=True, + card_fetch_timeout=3, + bearer_token="peer-token", + supported_transports=("HTTP+JSON",), + ), + httpx_client=fake_http_client, + ) + + fake_sdk_client = _FakeClient() + factory_calls: dict[str, object] = {} + + class _FakeFactory: + def __init__(self, config: ClientConfig, consumers: list[object] | None = None): + factory_calls["config"] = config + factory_calls["consumers"] = consumers + + def create( + self, + _card: object, + consumers: list[object] | None = None, + interceptors: list[object] | None = None, + extensions: list[str] | None = None, + ) -> _FakeClient: + factory_calls["create_consumers"] = consumers + factory_calls["interceptors"] = interceptors + factory_calls["extensions"] = extensions + return fake_sdk_client + + async def _build_card_resolver(self: A2AClient) -> _FakeCardResolver: + return _FakeCardResolver("agent-card") + + monkeypatch.setattr(client_module, "ClientFactory", _FakeFactory) + monkeypatch.setattr(A2AClient, "_build_card_resolver", _build_card_resolver) + actual = await client._build_client() + + config = factory_calls["config"] + assert isinstance(config, ClientConfig) + assert config.streaming is True + assert config.polling is False + assert config.use_client_preference is True + assert config.supported_transports == ["HTTP+JSON"] + assert factory_calls["interceptors"] is not None + assert len(factory_calls["interceptors"]) == 1 + assert actual is fake_sdk_client + + +@pytest.mark.asyncio +async def test_send_returns_last_event(monkeypatch: pytest.MonkeyPatch) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient(events=["a", "b", "last"]) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + response = await client.send("hello") + assert response == "last" + + +@pytest.mark.asyncio +async def test_send_message_adds_bearer_token_from_settings( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings(bearer_token="peer-token"), + ) + fake_client = _FakeClient(events=["ok"]) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + + result = [event async for event in client.send_message("hello")] + + assert result == ["ok"] + _, _, kwargs = fake_client.send_message_inputs[0] + assert kwargs["request_metadata"] is None + assert kwargs["context"] is not None + assert ( + kwargs["context"].state["headers"]["Authorization"] + == "Bearer peer-token" + ) + + +@pytest.mark.asyncio +async def test_send_message_preserves_explicit_authorization_metadata( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings(bearer_token="peer-token"), + ) + fake_client = _FakeClient(events=["ok"]) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + + result = [ + event + async for event in client.send_message( + "hello", + metadata={"authorization": "Bearer explicit-token", "trace_id": "trace-1"}, + ) + ] + + assert result == ["ok"] + _, _, kwargs = fake_client.send_message_inputs[0] + assert kwargs["request_metadata"] == {"trace_id": "trace-1"} + assert kwargs["context"].state["headers"]["Authorization"] == "Bearer explicit-token" + + +@pytest.mark.asyncio +async def test_send_message_prefers_explicit_authorization_without_default_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient(events=["ok"]) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + + result = [ + event + async for event in client.send_message( + "hello", metadata={"authorization": "Bearer explicit-token"} + ) + ] + + assert result == ["ok"] + _, _, kwargs = fake_client.send_message_inputs[0] + assert kwargs["request_metadata"] is None + assert kwargs["context"].state["headers"]["Authorization"] == "Bearer explicit-token" + + +@pytest.mark.asyncio +async def test_send_message_maps_jsonrpc_not_supported( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rpc_error = JSONRPCErrorResponse( + error=JSONRPCError(code=-32601, message="Unsupported method: message/send"), + id="req-1", + ) + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient(fail=A2AClientJSONRPCError(rpc_error)) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + with pytest.raises( + A2AUnsupportedOperationError, + match="Unsupported method", + ): + async for _event in client.send_message("hello"): + raise AssertionError + + +def test_extract_text_prefers_stream_artifact_payload() -> None: + task = Task( + id="remote-task", + context_id="remote-context", + status=TaskStatus(state=TaskState.working), + ) + update = TaskArtifactUpdateEvent( + task_id="remote-task", + context_id="remote-context", + artifact=Artifact( + artifact_id="artifact-1", + name="response", + parts=[Part(root=TextPart(text="streamed remote text"))], + ), + ) + + assert A2AClient.extract_text((task, update)) == "streamed remote text" + + +def test_extract_text_reads_task_status_message() -> None: + task = Task( + id="remote-task", + context_id="remote-context", + status=TaskStatus( + state=TaskState.completed, + message=Message( + role=Role.agent, + message_id="m1", + parts=[Part(root=TextPart(text="status message text"))], + ), + ), + ) + + assert A2AClient.extract_text(task) == "status message text" + + +def test_extract_text_reads_nested_mapping_payload() -> None: + payload = { + "result": { + "history": [ + {"parts": [{"text": "mapped nested text"}]}, + ] + } + } + + assert A2AClient.extract_text(payload) == "mapped nested text" + + +def test_extract_text_reads_model_dump_payload() -> None: + class _Payload: + def model_dump(self) -> dict[str, object]: + return {"artifacts": [{"parts": [{"text": "model dump text"}]}]} + + assert A2AClient.extract_text(_Payload()) == "model dump text" + + +def test_extract_text_reads_direct_string_payload() -> None: + assert A2AClient.extract_text(" string payload ") == "string payload" + + +def test_extract_text_reads_message_and_artifact_attributes() -> None: + class _ArtifactHolder: + artifact = {"parts": [{"text": "artifact attribute text"}]} + + class _MessageHolder: + message = {"parts": [{"text": "message attribute text"}]} + + assert A2AClient.extract_text(_ArtifactHolder()) == "artifact attribute text" + assert A2AClient.extract_text(_MessageHolder()) == "message attribute text" + + +def test_extract_text_reads_result_history_and_artifacts_attributes() -> None: + class _ResultHolder: + result = {"parts": [{"text": "result attribute text"}]} + + class _HistoryHolder: + history = [{"parts": [{"text": "history attribute text"}]}] + + class _Artifact: + parts = [{"text": "artifacts attribute text"}] + + class _ArtifactsHolder: + artifacts = [_Artifact()] + + assert A2AClient.extract_text(_ResultHolder()) == "result attribute text" + assert A2AClient.extract_text(_HistoryHolder()) == "history attribute text" + assert A2AClient.extract_text(_ArtifactsHolder()) == "artifacts attribute text" + + +@pytest.mark.asyncio +async def test_get_agent_card_maps_json_error(monkeypatch: pytest.MonkeyPatch) -> None: + class _BrokenResolver: + async def get_agent_card(self, **_kwargs: object) -> object: + raise A2AClientJSONError("invalid json") + + async def _build_card_resolver(self: A2AClient) -> _BrokenResolver: + return _BrokenResolver() + + client = A2AClient("http://agent.example.com") + monkeypatch.setattr(A2AClient, "_build_card_resolver", _build_card_resolver) + + with pytest.raises(A2APeerProtocolError, match="invalid json"): + await client.get_agent_card() + + +@pytest.mark.asyncio +async def test_cancel_task_adds_bearer_token_from_settings( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings(bearer_token="peer-token"), + ) + fake_client = _FakeClient() + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + + await client.cancel_task("task-id") + + params, _ = fake_client.cancel_inputs[0] + assert params.metadata == {} + + +@pytest.mark.asyncio +async def test_get_task_uses_authorization_header_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient() + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + + await client.get_task( + "task-id", + metadata={"authorization": "Bearer explicit-token", "trace_id": "trace-1"}, + ) + + params, kwargs = fake_client.task_inputs[0] + assert params.metadata == {"trace_id": "trace-1"} + assert kwargs["context"].state["headers"]["Authorization"] == "Bearer explicit-token" + + +@pytest.mark.asyncio +async def test_cancel_task_uses_authorization_header_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient() + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + + await client.cancel_task( + "task-id", + metadata={"authorization": "Bearer explicit-token", "trace_id": "trace-1"}, + ) + + params, kwargs = fake_client.cancel_inputs[0] + assert params.metadata == {"trace_id": "trace-1"} + assert kwargs["context"].state["headers"]["Authorization"] == "Bearer explicit-token" + + +def test_map_jsonrpc_error_variants() -> None: + client = A2AClient("http://agent.example.com") + + invalid_params_error = A2AClientJSONRPCError( + JSONRPCErrorResponse( + error=JSONRPCError(code=-32602, message="bad params"), + id="req-1", + ) + ) + internal_error = A2AClientJSONRPCError( + JSONRPCErrorResponse( + error=JSONRPCError(code=-32603, message="internal"), + id="req-2", + ) + ) + generic_error = A2AClientJSONRPCError( + JSONRPCErrorResponse( + error=JSONRPCError(code=-32000, message="generic"), + id="req-3", + ) + ) + + mapped_invalid = client._map_jsonrpc_error(invalid_params_error) + mapped_internal = client._map_jsonrpc_error(internal_error) + mapped_generic = client._map_jsonrpc_error(generic_error) + + assert isinstance(mapped_invalid, A2APeerProtocolError) + assert mapped_invalid.error_code == "invalid_params" + assert isinstance(mapped_internal, A2AClientResetRequiredError) + assert isinstance(mapped_generic, A2APeerProtocolError) + assert mapped_generic.error_code == "peer_protocol_error" + + +def test_map_http_error_variants() -> None: + client = A2AClient("http://agent.example.com") + + unsupported = client._map_http_error("message/send", A2AClientHTTPError(405, "nope")) + reset = client._map_http_error("message/send", A2AClientHTTPError(503, "busy")) + unavailable = client._map_http_error("message/send", A2AClientHTTPError(500, "boom")) + + assert isinstance(unsupported, A2AUnsupportedOperationError) + assert unsupported.http_status == 405 + assert isinstance(reset, A2AClientResetRequiredError) + assert reset.http_status == 503 + assert isinstance(unavailable, A2AAgentUnavailableError) + + +@pytest.mark.asyncio +async def test_build_card_resolver_requires_absolute_url() -> None: + client = A2AClient("/relative/path") + + with pytest.raises(ValueError, match="absolute URL"): + await client._build_card_resolver() + + +def test_split_request_metadata_and_resolver_headers() -> None: + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings(bearer_token="peer-token", card_fetch_timeout=7), + ) + + request_metadata, extra_headers = client._split_request_metadata( + {"authorization": "Bearer explicit-token", "trace_id": "trace-1"} + ) + + assert request_metadata == {"trace_id": "trace-1"} + assert extra_headers == {"Authorization": "Bearer explicit-token"} + assert client._build_default_headers() == {"Authorization": "Bearer peer-token"} + assert client._build_resolver_http_kwargs() == { + "timeout": 7, + "headers": {"Authorization": "Bearer peer-token"}, + } + + +@pytest.mark.asyncio +async def test_header_interceptor_merges_static_and_dynamic_headers() -> None: + interceptor = client_module._HeaderInterceptor({"Authorization": "Bearer peer-token"}) + context = client_module.ClientCallContext(state={"headers": {"X-Trace-Id": "trace-1"}}) + + request_payload, http_kwargs = await interceptor.intercept( + "message/send", + {"jsonrpc": "2.0"}, + {"headers": {"Accept": "application/json"}}, + agent_card=None, + context=context, + ) + + assert request_payload == {"jsonrpc": "2.0"} + assert http_kwargs["headers"] == { + "Accept": "application/json", + "Authorization": "Bearer peer-token", + "X-Trace-Id": "trace-1", + } + + +@pytest.mark.asyncio +async def test_get_task_maps_transport_http_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient(fail=A2AClientHTTPError(404, "gone")) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + + with pytest.raises(A2AUnsupportedOperationError, match="not supported"): + await client.get_task("task-id") + + +@pytest.mark.asyncio +async def test_resubscribe_forward_events(monkeypatch: pytest.MonkeyPatch) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient(events=[1, 2]) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + result = [event async for event in client.resubscribe_task("task-id")] + assert result == [1, 2] + + +@pytest.mark.asyncio +async def test_resubscribe_uses_authorization_header_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient("http://agent.example.com") + fake_client = _FakeClient(events=[1]) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr( + A2AClient, + "_build_card_resolver", + AsyncMock(return_value=_FakeCardResolver("card")), + ) + + result = [ + event + async for event in client.resubscribe_task( + "task-id", + metadata={"authorization": "Bearer explicit-token", "trace_id": "trace-1"}, + ) + ] + + assert result == [1] + params, kwargs = fake_client.resubscribe_inputs[0] + assert params.metadata == {"trace_id": "trace-1"} + assert kwargs["context"].state["headers"]["Authorization"] == "Bearer explicit-token" + + +@pytest.mark.asyncio +async def test_close_releases_owned_http_client() -> None: + owned_http_client = AsyncMock(spec=httpx.AsyncClient) + client = A2AClient("http://agent.example.com") + client._httpx_client = owned_http_client + client._owns_httpx_client = True + client._client = object() + await client.close() + + owned_http_client.aclose.assert_awaited_once() + assert client._client is None + + +@pytest.mark.asyncio +async def test_close_preserves_borrowed_http_client() -> None: + borrowed_http_client = AsyncMock(spec=httpx.AsyncClient) + client = A2AClient("http://agent.example.com", httpx_client=borrowed_http_client) + client._client = object() + + await client.close() + + borrowed_http_client.aclose.assert_not_awaited() + assert client._client is None diff --git a/tests/execution/test_opencode_agent_session_binding.py b/tests/execution/test_opencode_agent_session_binding.py index abe320f..2fd828b 100644 --- a/tests/execution/test_opencode_agent_session_binding.py +++ b/tests/execution/test_opencode_agent_session_binding.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any import pytest from a2a.types import Task @@ -131,13 +132,15 @@ class MissingMessageIdClient(DummyChatOpencodeUpstreamClient): async def send_message( self, session_id: str, - text: str, + text: str | None = None, *, + parts: list[dict[str, Any]] | None = None, directory: str | None = None, model_override: dict[str, str] | None = None, - timeout_override=None, # noqa: ANN001 + timeout_override: float | None = None, + **kwargs: Any, ) -> OpencodeMessage: - del text, directory, model_override, timeout_override + del text, parts, directory, model_override, timeout_override, kwargs self.sent_session_ids.append(session_id) return OpencodeMessage( text="echo:hello", @@ -166,13 +169,15 @@ class UsageClient(DummyChatOpencodeUpstreamClient): async def send_message( self, session_id: str, - text: str, + text: str | None = None, *, + parts: list[dict[str, Any]] | None = None, directory: str | None = None, model_override: dict[str, str] | None = None, - timeout_override=None, # noqa: ANN001 + timeout_override: float | None = None, + **kwargs: Any, ) -> OpencodeMessage: - del text, directory, model_override, timeout_override + del text, parts, directory, model_override, timeout_override, kwargs self.sent_session_ids.append(session_id) return OpencodeMessage( text="echo:hello", @@ -206,3 +211,199 @@ async def send_message( assert usage["output_tokens"] == 3 assert usage["total_tokens"] == 10 assert "raw" not in usage + + +@pytest.mark.asyncio +async def test_agent_handles_a2a_call_tool(monkeypatch) -> None: + from a2a.types import ( + Artifact, + Part, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TextPart, + ) + + from opencode_a2a.client import A2AClient + + class MockA2AClient: + extract_text = staticmethod(A2AClient.extract_text) + + async def send_message(self, text: str): + task = Task( + id="remote-task", + context_id="remote-ctx", + status=TaskStatus(state=TaskState.working), + ) + yield ( + task, + TaskArtifactUpdateEvent( + task_id="remote-task", + context_id="remote-ctx", + artifact=Artifact( + artifact_id="artifact-1", + name="response", + parts=[Part(root=TextPart(text=f"remote response to {text}"))], + ), + ), + ) + + async def close(self): + pass + + class MockManager: + async def get_client(self, url: str): + return MockA2AClient() + + client = DummyChatOpencodeUpstreamClient() + manager = MockManager() + executor = OpencodeAgentExecutor(client, streaming_enabled=False, a2a_client_manager=manager) + + raw_response = { + "parts": [ + { + "type": "tool", + "tool": "a2a_call", + "callID": "call-1", + "state": { + "status": "calling", + "input": {"url": "http://remote", "message": "hello remote"}, + }, + } + ] + } + + results = await executor._maybe_handle_tools(raw_response) + assert results is not None + assert len(results) == 1 + assert results[0]["call_id"] == "call-1" + assert "remote response to hello remote" in results[0]["output"] + + +@pytest.mark.asyncio +async def test_execution_coordinator_handles_tool_loop() -> None: + class ToolLoopClient(DummyChatOpencodeUpstreamClient): + def __init__(self): + super().__init__() + self.call_count = 0 + + async def send_message(self, *args, **kwargs) -> OpencodeMessage: + self.call_count += 1 + if self.call_count == 1: + return OpencodeMessage( + text="call tool", + session_id="s1", + message_id="m1", + raw={ + "parts": [ + { + "type": "tool", + "tool": "a2a_call", + "callID": "c1", + "state": { + "status": "calling", + "input": {"url": "http://x", "message": "y"}, + }, + } + ] + }, + ) + return OpencodeMessage(text="done", session_id="s1", message_id="m2", raw={}) + + class MockManager: + async def get_client(self, url: str): + mock_client = MagicMock() + + async def _send_message(_text: str): + task = Task(id="t", context_id="c", status=TaskStatus(state=TaskState.working)) + yield ( + task, + TaskArtifactUpdateEvent( + task_id="t", + context_id="c", + artifact=Artifact( + artifact_id="artifact-1", + name="response", + parts=[Part(root=TextPart(text="streamed tool output"))], + ), + ), + ) + + mock_client.send_message = _send_message + mock_client.extract_text = A2AClient.extract_text + return mock_client + + from unittest.mock import MagicMock + + from a2a.types import ( + Artifact, + Part, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TextPart, + ) + + from opencode_a2a.client import A2AClient + + client = ToolLoopClient() + manager = MockManager() + executor = OpencodeAgentExecutor(client, streaming_enabled=False, a2a_client_manager=manager) + q = DummyEventQueue() + + await executor.execute(make_request_context(task_id="t1", context_id="c1", text="start"), q) + + assert client.call_count == 2 + task = next(event for event in q.events if isinstance(event, Task)) + assert task.status.message.parts[0].root.text == "done" + + +@pytest.mark.asyncio +async def test_agent_merges_streamed_a2a_tool_output() -> None: + merged = OpencodeAgentExecutor._merge_streamed_tool_output("hello", "hello world") + distinct = OpencodeAgentExecutor._merge_streamed_tool_output("hello world", "from peer") + duplicate = OpencodeAgentExecutor._merge_streamed_tool_output("hello world", "world") + + assert merged == "hello world" + assert distinct == "hello world\nfrom peer" + assert duplicate == "hello world" + + +@pytest.mark.asyncio +async def test_agent_handles_a2a_call_tool_errors() -> None: + from unittest.mock import MagicMock + + client = DummyChatOpencodeUpstreamClient() + # No manager + executor = OpencodeAgentExecutor(client, streaming_enabled=False, a2a_client_manager=None) + + raw_response = { + "parts": [ + { + "type": "tool", + "tool": "a2a_call", + "callID": "c1", + "state": {"status": "calling", "input": {"url": "h", "message": "m"}}, + } + ] + } + results = await executor._maybe_handle_tools(raw_response) + assert results is not None + assert "not available" in results[0]["error"] + + # Invalid input + executor = OpencodeAgentExecutor( + client, streaming_enabled=False, a2a_client_manager=MagicMock() + ) + raw_response["parts"][0]["state"]["input"] = "invalid" + results = await executor._maybe_handle_tools(raw_response) + assert results is not None + assert "Invalid input" in results[0]["error"] + + # Missing message + raw_response["parts"][0]["state"]["input"] = {"url": "http://x"} + results = await executor._maybe_handle_tools(raw_response) + assert results is not None + assert "Missing url or message" in results[0]["error"] diff --git a/tests/server/test_cli.py b/tests/server/test_cli.py index 9e74460..614bd81 100644 --- a/tests/server/test_cli.py +++ b/tests/server/test_cli.py @@ -56,3 +56,29 @@ def test_cli_serve_subcommand_invokes_runtime() -> None: assert cli.main(["serve"]) == 0 serve_mock.assert_called_once_with() + + +def test_cli_call_uses_outbound_bearer_env_default() -> None: + with mock.patch.dict( + "os.environ", + {"A2A_CLIENT_BEARER_TOKEN": "peer-token"}, + clear=False, + ): + parser = cli.build_parser() + + namespace = parser.parse_args(["call", "http://agent.example.com", "hello"]) + + assert namespace.token == "peer-token" + + +def test_cli_call_does_not_fall_back_to_inbound_bearer_env() -> None: + with mock.patch.dict( + "os.environ", + {"A2A_BEARER_TOKEN": "inbound-token"}, + clear=True, + ): + parser = cli.build_parser() + + namespace = parser.parse_args(["call", "http://agent.example.com", "hello"]) + + assert namespace.token is None diff --git a/tests/server/test_transport_contract.py b/tests/server/test_transport_contract.py index a73d5db..5a74ce8 100644 --- a/tests/server/test_transport_contract.py +++ b/tests/server/test_transport_contract.py @@ -484,11 +484,13 @@ def __init__( cancel_abort_timeout_seconds: float, session_cache_ttl_seconds: int, session_cache_maxsize: int, + a2a_client_manager: object = None, ) -> None: captured["streaming_enabled"] = streaming_enabled captured["cancel_abort_timeout_seconds"] = cancel_abort_timeout_seconds captured["session_cache_ttl_seconds"] = session_cache_ttl_seconds captured["session_cache_maxsize"] = session_cache_maxsize + captured["a2a_client_manager"] = a2a_client_manager async def execute(self, _context, _event_queue) -> None: # noqa: ANN001 raise NotImplementedError @@ -526,6 +528,30 @@ async def release_session_for_control(self, *, identity: str, session_id: str) - assert captured["session_cache_maxsize"] == 22 +def test_create_app_propagates_outbound_client_settings(monkeypatch) -> None: + import opencode_a2a.server.application as app_module + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", DummyChatOpencodeUpstreamClient) + app = app_module.create_app( + make_settings( + a2a_bearer_token="test-token", + a2a_client_timeout_seconds=41.0, + a2a_client_card_fetch_timeout_seconds=7.0, + a2a_client_use_client_preference=True, + a2a_client_bearer_token="peer-token", + a2a_client_supported_transports=("http-json", "json-rpc"), + ) + ) + + client_manager = app.state.a2a_client_manager + settings = client_manager.client_settings + assert settings.default_timeout == 41.0 + assert settings.card_fetch_timeout == 7.0 + assert settings.use_client_preference is True + assert settings.bearer_token == "peer-token" + assert settings.supported_transports == ("HTTP+JSON", "JSONRPC") + + def test_create_app_requires_control_guard_hooks(monkeypatch) -> None: import opencode_a2a.server.application as app_module @@ -538,12 +564,14 @@ def __init__( cancel_abort_timeout_seconds: float, session_cache_ttl_seconds: int, session_cache_maxsize: int, + a2a_client_manager: object = None, ) -> None: del ( streaming_enabled, cancel_abort_timeout_seconds, session_cache_ttl_seconds, session_cache_maxsize, + a2a_client_manager, ) self.claim_session_for_control = None