diff --git a/docs/deployment.md b/docs/deployment.md index dac984f..f5bd718 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -396,8 +396,14 @@ Application-level safeguards: `GET /v1/tasks/{task_id}:subscribe` - service subscribes to OpenCode `/event` stream and forwards filtered per-session updates -- stream emits incremental `TaskArtifactUpdateEvent` with channel metadata - (`reasoning` / `tool_call` / `final_answer`) +- stream emits incremental `TaskArtifactUpdateEvent` on a single artifact + with `opencode.block_type` metadata + (`text` / `reasoning` / `tool_call`) and monotonic `opencode.sequence` +- routing is schema-first via OpenCode `part.type` + `part_id` state, not + inline marker parsing +- `message.part.delta` may arrive before `message.part.updated`; the service + buffers those deltas and replays them when the part state is available +- structured `tool` parts are emitted as `tool_call` block updates - events without `message_id` are discarded to avoid ambiguous correlation - final snapshot is emitted only when stream chunks did not already produce - the same final answer; stream then closes with `TaskStatusUpdateEvent(final=true)` + the same final text; stream then closes with `TaskStatusUpdateEvent(final=true)` diff --git a/docs/guide.md b/docs/guide.md index 391af27..0384598 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -57,10 +57,18 @@ This guide covers configuration, authentication, API behavior, streaming re-subs - Streaming (`/v1/message:stream`) emits incremental `TaskArtifactUpdateEvent` and then `TaskStatusUpdateEvent(final=true)`. Stream artifacts carry - `artifact.metadata.opencode.channel` with values - `reasoning` / `tool_call` / `final_answer`. Events without + `artifact.metadata.opencode.block_type` with values + `text` / `reasoning` / `tool_call`. All chunks share one stream + artifact ID and preserve original timeline via + `artifact.metadata.opencode.sequence`. Events without `message_id` are dropped. A final snapshot is only emitted when stream - chunks did not already produce the same final answer text. + chunks did not already produce the same final text. + Stream routing is schema-first: the service classifies chunks primarily by + OpenCode `part.type` (plus `part_id` state) rather than inline text markers. + `message.part.delta` and `message.part.updated` are merged per `part_id`; + out-of-order deltas are buffered and replayed when the corresponding + `part.updated` arrives. Structured `tool` parts are emitted as `tool_call` + blocks with normalized state payload. Non-streaming requests return a `Task` directly. - Requests require `Authorization: Bearer `; otherwise `401` is returned. Agent Card endpoints are public. diff --git a/src/opencode_a2a_serve/agent.py b/src/opencode_a2a_serve/agent.py index 8fd50f9..f1b2522 100644 --- a/src/opencode_a2a_serve/agent.py +++ b/src/opencode_a2a_serve/agent.py @@ -1,13 +1,16 @@ from __future__ import annotations import asyncio +import json import logging import os import time import uuid +from collections import defaultdict from collections.abc import Mapping from contextlib import suppress from dataclasses import dataclass, field +from enum import Enum from pathlib import Path from typing import Any @@ -29,77 +32,99 @@ logger = logging.getLogger(__name__) -_STREAM_CHANNEL_REASONING = "reasoning" -_STREAM_CHANNEL_TOOL_CALL = "tool_call" -_STREAM_CHANNEL_FINAL_ANSWER = "final_answer" + +class BlockType(str, Enum): + TEXT = "text" + REASONING = "reasoning" + TOOL_CALL = "tool_call" @dataclass(frozen=True) class _NormalizedStreamChunk: text: str append: bool - channel: str + block_type: BlockType source: str event_type: str message_id: str | None role: str | None +@dataclass(frozen=True) +class _PendingDelta: + field: str + delta: str + message_id: str | None + + +@dataclass +class _StreamPartState: + block_type: BlockType + message_id: str | None + role: str | None + buffer: str = "" + saw_delta: bool = False + + @dataclass class _StreamOutputState: user_text: str - response_message_id: str | None = None - channel_buffers: dict[str, str] = field(default_factory=dict) - saw_final_answer_chunk: bool = False + observed_message_ids: set[str] = field(default_factory=set) + content_buffers: dict[BlockType, str] = field(default_factory=dict) saw_any_chunk: bool = False - - def set_response_message_id(self, message_id: str | None) -> None: - if not isinstance(message_id, str): - self.response_message_id = None - return - value = message_id.strip() - self.response_message_id = value or None + emitted_stream_chunk: bool = False + sequence: int = 0 def matches_expected_message(self, message_id: str | None) -> bool: if not message_id: return False - if not self.response_message_id: - return True - return message_id == self.response_message_id + self.observed_message_ids.add(message_id) + return True - def should_drop_initial_user_echo(self, text: str, *, channel: str, role: str | None) -> bool: + def should_drop_initial_user_echo( + self, + text: str, + *, + block_type: BlockType, + role: str | None, + ) -> bool: if role is not None: return False - if channel != _STREAM_CHANNEL_FINAL_ANSWER: + if block_type != BlockType.TEXT: return False if self.saw_any_chunk: return False user_text = self.user_text.strip() return bool(user_text) and text.strip() == user_text - def register_chunk(self, *, channel: str, text: str, append: bool) -> tuple[bool, bool]: - previous = self.channel_buffers.get(channel, "") - effective_append = append if previous else False - next_value = f"{previous}{text}" if effective_append else text + def register_chunk( + self, *, block_type: BlockType, text: str, append: bool + ) -> tuple[bool, bool]: + previous = self.content_buffers.get(block_type, "") + next_value = f"{previous}{text}" if append else text if next_value == previous: - return False, effective_append - self.channel_buffers[channel] = next_value + return False, False + self.content_buffers[block_type] = next_value self.saw_any_chunk = True - if channel == _STREAM_CHANNEL_FINAL_ANSWER and next_value.strip(): - self.saw_final_answer_chunk = True + # Single-artifact stream must stay append-only after the first emitted chunk. + effective_append = self.emitted_stream_chunk + self.emitted_stream_chunk = True return True, effective_append def should_emit_final_snapshot(self, text: str) -> bool: if not text.strip(): return False - existing = self.channel_buffers.get(_STREAM_CHANNEL_FINAL_ANSWER, "") + existing = self.content_buffers.get(BlockType.TEXT, "") if existing.strip() == text.strip(): return False - self.channel_buffers[_STREAM_CHANNEL_FINAL_ANSWER] = text + self.content_buffers[BlockType.TEXT] = text self.saw_any_chunk = True - self.saw_final_answer_chunk = True return True + def next_sequence(self) -> int: + self.sequence += 1 + return self.sequence + class _TTLCache: """Bounded TTL cache for hashable key -> string value. @@ -190,6 +215,9 @@ def __init__( self._lock = asyncio.Lock() self._inflight_session_creates: dict[tuple[str, str], asyncio.Task[str]] = {} self._session_locks: dict[str, asyncio.Lock] = {} + self._running_requests: dict[tuple[str, str], asyncio.Task[Any]] = {} + self._running_stop_events: dict[tuple[str, str], asyncio.Event] = {} + self._running_identities: dict[tuple[str, str], str] = {} def _resolve_and_validate_directory(self, requested: str | None) -> str | None: """Normalizes and validates the directory parameter against workspace boundaries. @@ -310,6 +338,13 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non pending_preferred_claim = False session_lock: asyncio.Lock | None = None session_id = "" + execution_key = (task_id, context_id) + current_task = asyncio.current_task() + if current_task is not None: + async with self._lock: + self._running_requests[execution_key] = current_task + self._running_stop_events[execution_key] = stop_event + self._running_identities[execution_key] = identity try: session_id, pending_preferred_claim = await self._get_or_create_session( @@ -370,23 +405,21 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non response_text, ) if streaming_request: - stream_state.set_response_message_id(response.message_id) if stream_state.should_emit_final_snapshot(response_text): await _enqueue_artifact_update( event_queue=event_queue, task_id=task_id, context_id=context_id, - artifact_id=_artifact_id_for_stream_channel( - stream_artifact_id, _STREAM_CHANNEL_FINAL_ANSWER - ), + artifact_id=stream_artifact_id, text=response_text, - append=False, + append=stream_state.emitted_stream_chunk, last_chunk=True, artifact_metadata=_build_stream_artifact_metadata( - channel=_STREAM_CHANNEL_FINAL_ANSWER, + block_type=BlockType.TEXT, source="final_snapshot", event_type="message.finalized", message_id=response.message_id, + sequence=stream_state.next_sequence(), ), ) await event_queue.enqueue_event( @@ -458,6 +491,10 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non await stream_task if session_lock and session_lock.locked(): session_lock.release() + async with self._lock: + self._running_requests.pop(execution_key, None) + self._running_stop_events.pop(execution_key, None) + self._running_identities.pop(execution_key, None) async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: task_id = context.task_id @@ -484,9 +521,21 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None ) await event_queue.enqueue_event(event) + execution_key = (task_id, context_id) async with self._lock: - self._sessions.pop((identity, context_id)) - inflight = self._inflight_session_creates.pop((identity, context_id), None) + running_identity = self._running_identities.get(execution_key, identity) + running_task = self._running_requests.get(execution_key) + stop_event = self._running_stop_events.get(execution_key) + self._sessions.pop((running_identity, context_id)) + inflight = self._inflight_session_creates.pop((running_identity, context_id), None) + if stop_event: + stop_event.set() + if ( + running_task + and running_task is not asyncio.current_task() + and not running_task.done() + ): + running_task.cancel() if inflight: inflight.cancel() with suppress(asyncio.CancelledError): @@ -502,8 +551,6 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None message=f"Cancel failed: {exc}", streaming_request=False, ) - finally: - await event_queue.close() async def _get_or_create_session( self, @@ -681,9 +728,197 @@ async def _consume_opencode_stream( stop_event: asyncio.Event, directory: str | None = None, ) -> None: - buffered_text: dict[str, str] = {} + part_states: dict[str, _StreamPartState] = {} + pending_deltas: defaultdict[str, list[_PendingDelta]] = defaultdict(list) backoff = 0.5 max_backoff = 5.0 + + async def _emit_chunks(chunks: list[_NormalizedStreamChunk]) -> None: + for chunk in chunks: + if not stream_state.matches_expected_message(chunk.message_id): + continue + if stream_state.should_drop_initial_user_echo( + chunk.text, + block_type=chunk.block_type, + role=chunk.role, + ): + continue + should_emit, effective_append = stream_state.register_chunk( + block_type=chunk.block_type, + text=chunk.text, + append=chunk.append, + ) + if not should_emit: + continue + await _enqueue_artifact_update( + event_queue=event_queue, + task_id=task_id, + context_id=context_id, + artifact_id=artifact_id, + text=chunk.text, + append=effective_append, + last_chunk=False, + artifact_metadata=_build_stream_artifact_metadata( + block_type=chunk.block_type, + source=chunk.source, + event_type=chunk.event_type, + message_id=chunk.message_id, + role=chunk.role, + sequence=stream_state.next_sequence(), + ), + ) + logger.debug( + "Stream chunk task_id=%s session_id=%s block_type=%s append=%s text=%s", + task_id, + session_id, + chunk.block_type, + effective_append, + chunk.text, + ) + + def _new_chunk( + *, + text: str, + append: bool, + block_type: BlockType, + source: str, + event_type: str, + message_id: str | None, + role: str | None, + ) -> _NormalizedStreamChunk: + return _NormalizedStreamChunk( + text=text, + append=append, + block_type=block_type, + source=source, + event_type=event_type, + message_id=message_id, + role=role, + ) + + def _upsert_part_state( + *, + part_id: str, + part: Mapping[str, Any], + props: Mapping[str, Any], + role: str | None, + message_id: str | None, + ) -> _StreamPartState | None: + block_type = _resolve_stream_block_type(part, props) + if block_type is None: + return None + state = part_states.get(part_id) + if state is None: + state = _StreamPartState( + block_type=block_type, + message_id=message_id, + role=role, + ) + part_states[part_id] = state + return state + state.block_type = block_type + if role is not None: + state.role = role + if message_id: + state.message_id = message_id + return state + + def _delta_chunks( + *, + state: _StreamPartState, + delta_text: str, + message_id: str | None, + event_type: str, + source: str, + ) -> list[_NormalizedStreamChunk]: + if not delta_text: + return [] + if message_id: + state.message_id = message_id + state.buffer = f"{state.buffer}{delta_text}" + state.saw_delta = True + return [ + _new_chunk( + text=delta_text, + append=True, + block_type=state.block_type, + source=source, + event_type=event_type, + message_id=state.message_id, + role=state.role, + ) + ] + + def _snapshot_chunks( + *, + state: _StreamPartState, + snapshot: str, + message_id: str | None, + event_type: str, + part_id: str, + ) -> list[_NormalizedStreamChunk]: + if message_id: + state.message_id = message_id + previous = state.buffer + if snapshot == previous: + return [] + if snapshot.startswith(previous): + delta_text = snapshot[len(previous) :] + state.buffer = snapshot + if not delta_text: + return [] + return [ + _new_chunk( + text=delta_text, + append=True, + block_type=state.block_type, + source="part_text_diff", + event_type=event_type, + message_id=state.message_id, + role=state.role, + ) + ] + state.buffer = snapshot + logger.warning( + "Suppressing non-prefix snapshot rewrite " + "task_id=%s session_id=%s part_id=%s block_type=%s had_delta=%s", + task_id, + session_id, + part_id, + state.block_type.value, + state.saw_delta, + ) + return [] + + def _tool_chunks( + *, + state: _StreamPartState, + part: Mapping[str, Any], + message_id: str | None, + event_type: str, + ) -> list[_NormalizedStreamChunk]: + tool_chunk = _serialize_tool_part(part) + if not tool_chunk: + return [] + if message_id: + state.message_id = message_id + previous = state.buffer + if tool_chunk == previous: + return [] + state.buffer = tool_chunk + text = tool_chunk if not previous else f"\n{tool_chunk}" + return [ + _new_chunk( + text=text, + append=bool(previous), + block_type=state.block_type, + source="tool_part_update", + event_type=event_type, + message_id=state.message_id, + role=state.role, + ) + ] + try: while not stop_event.is_set(): try: @@ -693,7 +928,7 @@ async def _consume_opencode_stream( if stop_event.is_set(): break event_type = event.get("type") - if event_type != "message.part.updated": + if event_type not in {"message.part.updated", "message.part.delta"}: continue props = event.get("properties") if not isinstance(props, Mapping): @@ -703,80 +938,105 @@ async def _consume_opencode_stream( part = {} if _extract_stream_session_id(part, props) != session_id: continue - role = _extract_stream_role(part, props) - if role in {"user", "system"}: - continue - channel = _classify_stream_channel(part, props) - delta = props.get("delta") - chunk_text: str | None = None - append = True - source = "delta" message_id = _extract_stream_message_id(part, props) - buffer_key = f"{channel}:{message_id or 'unknown'}" - previous = buffered_text.get(buffer_key, "") - if isinstance(delta, str) and delta: - chunk_text = delta - buffered_text[buffer_key] = f"{previous}{delta}" - elif part.get("type") == "text" and isinstance(part.get("text"), str): - next_text = part["text"] - if next_text != previous: - if next_text.startswith(previous): - chunk_text = next_text[len(previous) :] - append = True - source = "part_text_diff" - else: - chunk_text = next_text - append = False - source = "part_text_reset" - buffered_text[buffer_key] = next_text - if not chunk_text: + part_id = _extract_stream_part_id(part, props) + if not part_id and event_type == "message.part.updated": + part_id = _build_fallback_part_id(part, props, message_id=message_id) + if not part_id: continue - chunk = _NormalizedStreamChunk( - text=chunk_text, - append=append, - channel=channel, - source=source, - event_type=event_type, - message_id=message_id, + + if event_type == "message.part.delta": + field = props.get("field") + delta = props.get("delta") + if field != "text" or not isinstance(delta, str) or not delta: + continue + state = part_states.get(part_id) + if state is None: + pending_deltas[part_id].append( + _PendingDelta( + field=field, + delta=delta, + message_id=message_id, + ) + ) + continue + if state.role in {"user", "system"}: + continue + chunks = _delta_chunks( + state=state, + delta_text=delta, + message_id=message_id, + event_type=event_type, + source="delta_event", + ) + if chunks: + await _emit_chunks(chunks) + continue + + role = _extract_stream_role(part, props) + state = _upsert_part_state( + part_id=part_id, + part=part, + props=props, role=role, + message_id=message_id, ) - if not stream_state.matches_expected_message(chunk.message_id): - continue - if stream_state.should_drop_initial_user_echo( - chunk.text, channel=chunk.channel, role=chunk.role - ): + if state is None: + pending_deltas.pop(part_id, None) continue - should_emit, effective_append = stream_state.register_chunk( - channel=chunk.channel, - text=chunk.text, - append=chunk.append, - ) - if not should_emit: + if state.role in {"user", "system"}: + pending_deltas.pop(part_id, None) continue - await _enqueue_artifact_update( - event_queue=event_queue, - task_id=task_id, - context_id=context_id, - artifact_id=_artifact_id_for_stream_channel(artifact_id, chunk.channel), - text=chunk.text, - append=effective_append, - last_chunk=False, - artifact_metadata=_build_stream_artifact_metadata( - channel=chunk.channel, - source=chunk.source, - event_type=chunk.event_type, - message_id=chunk.message_id, - role=chunk.role, - ), - ) - logger.debug( - "Stream chunk task_id=%s session_id=%s channel=%s append=%s text=%s", - task_id, - session_id, - chunk.channel, - effective_append, - chunk.text, - ) + + chunks: list[_NormalizedStreamChunk] = [] + pending = pending_deltas.pop(part_id, []) + for buffered in pending: + if buffered.field != "text": + continue + chunks.extend( + _delta_chunks( + state=state, + delta_text=buffered.delta, + message_id=buffered.message_id, + event_type="message.part.delta", + source="delta_event_buffered", + ) + ) + + delta = props.get("delta") + if isinstance(delta, str) and delta: + chunks.extend( + _delta_chunks( + state=state, + delta_text=delta, + message_id=message_id, + event_type=event_type, + source="delta", + ) + ) + elif state.block_type == BlockType.TOOL_CALL: + chunks.extend( + _tool_chunks( + state=state, + part=part, + message_id=message_id, + event_type=event_type, + ) + ) + elif isinstance(part.get("text"), str): + chunks.extend( + _snapshot_chunks( + state=state, + snapshot=part["text"], + message_id=message_id, + event_type=event_type, + part_id=part_id, + ) + ) + + if chunks: + await _emit_chunks(chunks) + break except Exception: if stop_event.is_set(): @@ -833,22 +1093,17 @@ async def _enqueue_artifact_update( ) -def _artifact_id_for_stream_channel(base_artifact_id: str, channel: str) -> str: - if channel == _STREAM_CHANNEL_FINAL_ANSWER: - return base_artifact_id - return f"{base_artifact_id}:{channel}" - - def _build_stream_artifact_metadata( *, - channel: str, + block_type: BlockType, source: str, event_type: str, message_id: str | None = None, role: str | None = None, + sequence: int | None = None, ) -> dict[str, Any]: opencode_meta: dict[str, Any] = { - "channel": channel, + "block_type": block_type, "source": source, "event_type": event_type, } @@ -856,6 +1111,8 @@ def _build_stream_artifact_metadata( opencode_meta["message_id"] = message_id if role: opencode_meta["role"] = role + if sequence is not None: + opencode_meta["sequence"] = sequence return {"opencode": opencode_meta} @@ -891,6 +1148,10 @@ def _extract_stream_session_id(part: Mapping[str, Any], props: Mapping[str, Any] value = part.get(key) if isinstance(value, str) and value: return value + for key in session_keys: + value = props.get(key) + if isinstance(value, str) and value: + return value message = props.get("message") if isinstance(message, Mapping): for key in session_keys: @@ -933,31 +1194,101 @@ def _extract_stream_message_id(part: Mapping[str, Any], props: Mapping[str, Any] return None -def _classify_stream_channel(part: Mapping[str, Any], props: Mapping[str, Any]) -> str: - def _iter_candidates() -> list[str]: - candidates: list[str] = [] - for value in ( - part.get("channel"), - props.get("channel"), - part.get("kind"), - props.get("kind"), - part.get("type"), - props.get("type"), - props.get("deltaType"), - props.get("contentType"), - props.get("phase"), - props.get("name"), - ): - if isinstance(value, str) and value.strip(): - candidates.append(value.strip().lower()) - return candidates - - candidates = _iter_candidates() +def _extract_stream_part_id(part: Mapping[str, Any], props: Mapping[str, Any]) -> str | None: + part_keys = ("partID", "partId", "part_id", "id") + for key in part_keys: + value = part.get(key) + if isinstance(value, str): + normalized = value.strip() + if normalized: + return normalized + for key in part_keys: + value = props.get(key) + if isinstance(value, str): + normalized = value.strip() + if normalized: + return normalized + return None + + +def _build_fallback_part_id( + part: Mapping[str, Any], + props: Mapping[str, Any], + *, + message_id: str | None, +) -> str | None: + if not message_id: + return None + part_type = _extract_stream_part_type(part, props) or "unknown" + return f"fallback:{message_id}:{part_type}" + + +def _extract_stream_part_type(part: Mapping[str, Any], props: Mapping[str, Any]) -> str | None: + for value in ( + part.get("type"), + part.get("kind"), + props.get("partType"), + props.get("part_type"), + ): + if isinstance(value, str): + normalized = value.strip().lower() + if normalized: + return normalized + return None + + +def _map_part_type_to_block_type(part_type: str | None) -> BlockType | None: + if not part_type: + return None + if part_type == "text": + return BlockType.TEXT + if part_type in {"reasoning", "thinking", "thought"}: + return BlockType.REASONING + if part_type in { + "tool", + "tool_call", + "toolcall", + "function_call", + "functioncall", + "action", + }: + return BlockType.TOOL_CALL + return None + + +def _resolve_stream_block_type( + part: Mapping[str, Any], props: Mapping[str, Any] +) -> BlockType | None: + explicit_part_type = _extract_stream_part_type(part, props) + if explicit_part_type is not None: + return _map_part_type_to_block_type(explicit_part_type) + return _classify_stream_block_type(part, props) + + +def _classify_stream_block_type( + part: Mapping[str, Any], props: Mapping[str, Any] +) -> BlockType | None: + candidates: list[str] = [] + for value in ( + part.get("block_type"), + props.get("block_type"), + part.get("channel"), + props.get("channel"), + part.get("kind"), + props.get("kind"), + props.get("type"), + props.get("deltaType"), + props.get("phase"), + props.get("name"), + ): + if isinstance(value, str) and value.strip(): + candidates.append(value.strip().lower()) + if any( any(keyword in candidate for keyword in ("reason", "thinking", "thought")) for candidate in candidates ): - return _STREAM_CHANNEL_REASONING + return BlockType.REASONING if any( any( keyword in candidate @@ -972,8 +1303,45 @@ def _iter_candidates() -> list[str]: ) for candidate in candidates ): - return _STREAM_CHANNEL_TOOL_CALL - return _STREAM_CHANNEL_FINAL_ANSWER + return BlockType.TOOL_CALL + if any( + any(keyword in candidate for keyword in ("text", "answer", "final")) + for candidate in candidates + ): + return BlockType.TEXT + return None + + +def _serialize_tool_part(part: Mapping[str, Any]) -> str | None: + payload: dict[str, Any] = {} + for source_key in ("callID", "callId", "call_id"): + value = part.get(source_key) + if isinstance(value, str): + normalized = value.strip() + if normalized: + payload["call_id"] = normalized + break + for source_key in ("tool", "name"): + value = part.get(source_key) + if isinstance(value, str): + normalized = value.strip() + if normalized: + payload["tool"] = normalized + break + state = part.get("state") + if isinstance(state, Mapping): + status = state.get("status") + if isinstance(status, str): + normalized = status.strip() + if normalized: + payload["status"] = normalized + for key in ("title", "subtitle", "input", "output", "error"): + value = state.get(key) + if value is not None: + payload[key] = value + if not payload: + return None + return json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) def _build_history(context: RequestContext) -> list[Message]: diff --git a/tests/test_agent_errors.py b/tests/test_agent_errors.py index 5a67b75..28c3fee 100644 --- a/tests/test_agent_errors.py +++ b/tests/test_agent_errors.py @@ -49,9 +49,9 @@ async def test_cancel_missing_ids(): # This should no longer raise RuntimeError await executor.cancel(context, event_queue) - # Verify that an event was enqueued and queue was closed + # Verify that an event was enqueued and queue is not force-closed by executor.cancel event_queue.enqueue_event.assert_called() - event_queue.close.assert_called() + event_queue.close.assert_not_called() @pytest.mark.asyncio diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py new file mode 100644 index 0000000..bbf4738 --- /dev/null +++ b/tests/test_cancellation.py @@ -0,0 +1,101 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, PropertyMock + +import pytest +from a2a.server.agent_execution import RequestContext +from a2a.server.context import ServerCallContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import TaskState, TaskStatusUpdateEvent + +from opencode_a2a_serve.agent import OpencodeAgentExecutor +from opencode_a2a_serve.config import Settings +from opencode_a2a_serve.opencode_client import OpencodeClient + + +@pytest.mark.asyncio +async def test_cancel_interrupts_running_execute_and_keeps_queue_open(): + client = AsyncMock(spec=OpencodeClient) + send_started = asyncio.Event() + send_cancelled = asyncio.Event() + + async def send_message( + session_id, + _text, + *, + directory=None, # noqa: ARG001 + timeout_override=None, # noqa: ARG001 + ): + send_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + send_cancelled.set() + raise + response = MagicMock() + response.text = "OpenCode response" + response.session_id = session_id + response.message_id = "msg-1" + return response + + client.create_session.return_value = "session-1" + client.send_message.side_effect = send_message + type(client).directory = PropertyMock(return_value="/tmp/workspace") + type(client).settings = PropertyMock( + return_value=Settings( + A2A_BEARER_TOKEN="test", + OPENCODE_BASE_URL="http://localhost", + A2A_ALLOW_DIRECTORY_OVERRIDE=True, + ) + ) + + executor = OpencodeAgentExecutor(client, streaming_enabled=False) + + execute_context = MagicMock(spec=RequestContext) + execute_context.task_id = "task-1" + execute_context.context_id = "context-A" + execute_context.call_context = MagicMock(spec=ServerCallContext) + execute_context.call_context.state = {"identity": "user-1"} + execute_context.get_user_input.return_value = "hello" + execute_context.current_task = None + execute_context.message = None + execute_context.metadata = None + execute_queue = AsyncMock(spec=EventQueue) + + execute_task = asyncio.create_task(executor.execute(execute_context, execute_queue)) + await asyncio.wait_for(send_started.wait(), timeout=1.0) + + cancel_context = MagicMock(spec=RequestContext) + cancel_context.task_id = "task-1" + cancel_context.context_id = "context-A" + cancel_context.call_context = None + cancel_queue = AsyncMock(spec=EventQueue) + + await asyncio.wait_for(executor.cancel(cancel_context, cancel_queue), timeout=1.0) + + cancel_events = [call.args[0] for call in cancel_queue.enqueue_event.call_args_list] + assert any( + isinstance(event, TaskStatusUpdateEvent) and event.status.state == TaskState.canceled + for event in cancel_events + ) + cancel_queue.close.assert_not_called() + + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(execute_task, timeout=1.0) + + assert send_cancelled.is_set() + assert executor._sessions.get(("user-1", "context-A")) is None + assert ("task-1", "context-A") not in executor._running_requests + assert ("task-1", "context-A") not in executor._running_stop_events + assert ("task-1", "context-A") not in executor._running_identities + + +@pytest.mark.asyncio +async def test_cancel_does_not_block_with_real_event_queue() -> None: + executor = OpencodeAgentExecutor(MagicMock(), streaming_enabled=False) + context = MagicMock(spec=RequestContext) + context.task_id = None + context.context_id = None + context.call_context = None + queue = EventQueue() + + await asyncio.wait_for(executor.cancel(context, queue), timeout=0.5) diff --git a/tests/test_streaming_output_contract.py b/tests/test_streaming_output_contract.py index 55a9822..3ebdf56 100644 --- a/tests/test_streaming_output_contract.py +++ b/tests/test_streaming_output_contract.py @@ -99,6 +99,9 @@ def _event( part_type: str, delta: str, message_id: str | None = "msg-1", + part_id: str | None = None, + text: str | None = None, + part_overrides: dict | None = None, ) -> dict: properties: dict = { "part": { @@ -111,12 +114,39 @@ def _event( properties["part"]["role"] = role if message_id is not None: properties["part"]["messageID"] = message_id + if part_id is not None: + properties["part"]["id"] = part_id + if text is not None: + properties["part"]["text"] = text + if part_overrides: + properties["part"].update(part_overrides) return { "type": "message.part.updated", "properties": properties, } +def _delta_event( + *, + session_id: str, + part_id: str, + delta: str, + message_id: str | None = "msg-1", +) -> dict: + properties: dict = { + "sessionID": session_id, + "partID": part_id, + "field": "text", + "delta": delta, + } + if message_id is not None: + properties["messageID"] = message_id + return { + "type": "message.part.delta", + "properties": properties, + } + + def _artifact_updates(queue: DummyEventQueue) -> list[TaskArtifactUpdateEvent]: return [event for event in queue.events if isinstance(event, TaskArtifactUpdateEvent)] @@ -127,7 +157,7 @@ def _part_text(event: TaskArtifactUpdateEvent) -> str: @pytest.mark.asyncio -async def test_streaming_filters_user_echo_and_emits_structured_channels() -> None: +async def test_streaming_filters_user_echo_and_emits_single_artifact_block_types() -> None: user_text = "who are you" client = DummyStreamingClient( stream_events_payload=[ @@ -153,8 +183,12 @@ async def test_streaming_filters_user_echo_and_emits_structured_channels() -> No assert updates texts = [_part_text(event) for event in updates] assert user_text not in texts - channels = [event.artifact.metadata["opencode"]["channel"] for event in updates] - assert _unique(channels) == ["reasoning", "tool_call", "final_answer"] + block_types = [event.artifact.metadata["opencode"]["block_type"] for event in updates] + assert _unique(block_types) == ["reasoning", "tool_call", "text"] + artifact_ids = [event.artifact.artifact_id for event in updates] + assert len(set(artifact_ids)) == 1 + sequences = [event.artifact.metadata["opencode"]["sequence"] for event in updates] + assert sequences == list(range(1, len(updates) + 1)) @pytest.mark.asyncio @@ -179,7 +213,7 @@ async def test_streaming_does_not_send_duplicate_final_snapshot_when_chunks_exis final_updates = [ event for event in _artifact_updates(queue) - if event.artifact.metadata["opencode"]["channel"] == "final_answer" + if event.artifact.metadata["opencode"]["block_type"] == "text" ] assert len(final_updates) == 1 assert _part_text(final_updates[0]) == "stable final answer" @@ -203,13 +237,13 @@ async def test_streaming_emits_final_snapshot_only_when_stream_has_no_final_answ final_updates = [ event for event in _artifact_updates(queue) - if event.artifact.metadata["opencode"]["channel"] == "final_answer" + if event.artifact.metadata["opencode"]["block_type"] == "text" ] assert len(final_updates) == 1 final_event = final_updates[0] assert _part_text(final_event) == "final answer from send_message" assert final_event.artifact.metadata["opencode"]["source"] == "final_snapshot" - assert final_event.append is False + assert final_event.append is True assert final_event.last_chunk is True @@ -262,7 +296,7 @@ async def test_streaming_drops_events_without_message_id_and_falls_back_to_snaps update = updates[0] assert _part_text(update) == "final answer from send_message" assert update.artifact.metadata["opencode"]["source"] == "final_snapshot" - assert update.artifact.metadata["opencode"]["channel"] == "final_answer" + assert update.artifact.metadata["opencode"]["block_type"] == "text" def _unique(items: list[str]) -> list[str]: @@ -274,3 +308,338 @@ def _unique(items: list[str]) -> list[str]: seen.add(item) ordered.append(item) return ordered + + +@pytest.mark.asyncio +async def test_streaming_treats_embedded_markers_as_plain_text_without_typed_parts() -> None: + client = DummyStreamingClient( + stream_events_payload=[ + _event(session_id="ses-1", role="assistant", part_type="text", delta="start "), + _event(session_id="ses-1", role="assistant", part_type="text", delta=" str: + parts = [] + for ev in updates: + if ev.artifact.metadata["opencode"]["block_type"] == block_type: + if not ev.append: + parts = [_part_text(ev)] + else: + parts.append(_part_text(ev)) + return "".join(parts) + + assert _final_state("text") == 'start thinking middle [tool_call: {"foo":1}] end' + assert _final_state("reasoning") == "" + assert _final_state("tool_call") == "" + + +@pytest.mark.asyncio +async def test_streaming_emits_structured_tool_part_updates() -> None: + client = DummyStreamingClient( + stream_events_payload=[ + _event( + session_id="ses-1", + role="assistant", + part_type="tool", + delta="", + part_id="prt-tool-1", + part_overrides={ + "callID": "call-1", + "tool": "bash", + "state": {"status": "pending"}, + }, + ), + _event( + session_id="ses-1", + role="assistant", + part_type="tool", + delta="", + part_id="prt-tool-1", + part_overrides={ + "callID": "call-1", + "tool": "bash", + "state": {"status": "running"}, + }, + ), + _event( + session_id="ses-1", + role="assistant", + part_type="tool", + delta="", + part_id="prt-tool-1", + part_overrides={ + "callID": "call-1", + "tool": "bash", + "state": {"status": "completed"}, + }, + ), + ], + response_text="done", + ) + executor = OpencodeAgentExecutor(client, streaming_enabled=True) + executor._should_stream = lambda context: True # type: ignore[method-assign] + queue = DummyEventQueue() + + await executor.execute( + _context(task_id="task-tool-bracket", context_id="ctx-tool-bracket", text="go"), + queue, + ) + + updates = _artifact_updates(queue) + tool_updates = [ + ev for ev in updates if ev.artifact.metadata["opencode"]["block_type"] == "tool_call" + ] + assert len(tool_updates) == 3 + merged = "".join(_part_text(ev) for ev in tool_updates) + assert '"status":"pending"' in merged + assert '"status":"running"' in merged + assert '"status":"completed"' in merged + + +@pytest.mark.asyncio +async def test_streaming_flushes_partial_marker_on_eof_as_current_block_type() -> None: + client = DummyStreamingClient( + stream_events_payload=[ + _event(session_id="ses-1", role="assistant", part_type="text", delta="hello None: + client = DummyStreamingClient( + stream_events_payload=[ + { + "type": "message.part.updated", + "properties": { + "part": { + "sessionID": "ses-1", + "type": "text", + "role": "assistant", + "messageID": "msg-1", + "text": "hello", + }, + "delta": "", + }, + }, + { + "type": "message.part.updated", + "properties": { + "part": { + "sessionID": "ses-1", + "type": "text", + "role": "assistant", + "messageID": "msg-1", + "text": "HELLO", + }, + "delta": "", + }, + }, + ], + response_text="HELLO", + ) + executor = OpencodeAgentExecutor(client, streaming_enabled=True) + executor._should_stream = lambda context: True # type: ignore[method-assign] + queue = DummyEventQueue() + + await executor.execute( + _context(task_id="task-no-reset", context_id="ctx-no-reset", text="go"), + queue, + ) + + updates = _artifact_updates(queue) + assert len(updates) >= 2 + assert updates[0].append is False + assert all(ev.append is True for ev in updates[1:]) + + +@pytest.mark.asyncio +async def test_streaming_suppresses_reasoning_snapshot_reset_after_delta() -> None: + client = DummyStreamingClient( + stream_events_payload=[ + _event( + session_id="ses-1", + role="assistant", + part_type="reasoning", + delta="", + part_id="prt-r1", + text="", + ), + _event( + session_id="ses-1", + role="assistant", + part_type="reasoning", + delta="reasoning line\n\n", + part_id="prt-r1", + ), + _event( + session_id="ses-1", + role="assistant", + part_type="reasoning", + delta="", + part_id="prt-r1", + text="reasoning line", + ), + ], + response_text="answer", + ) + executor = OpencodeAgentExecutor(client, streaming_enabled=True) + executor._should_stream = lambda context: True # type: ignore[method-assign] + queue = DummyEventQueue() + + await executor.execute( + _context(task_id="task-reason-reset", context_id="ctx-reason-reset", text="go"), + queue, + ) + + reasoning_updates = [ + event + for event in _artifact_updates(queue) + if event.artifact.metadata["opencode"]["block_type"] == "reasoning" + ] + assert len(reasoning_updates) == 1 + assert _part_text(reasoning_updates[0]) == "reasoning line\n\n" + + +@pytest.mark.asyncio +async def test_streaming_supports_message_part_delta_events() -> None: + client = DummyStreamingClient( + stream_events_payload=[ + _event( + session_id="ses-1", + role="assistant", + part_type="reasoning", + delta="", + part_id="prt-r2", + text="", + ), + _delta_event(session_id="ses-1", part_id="prt-r2", delta="first "), + _delta_event(session_id="ses-1", part_id="prt-r2", delta="second"), + ], + response_text="answer", + ) + executor = OpencodeAgentExecutor(client, streaming_enabled=True) + executor._should_stream = lambda context: True # type: ignore[method-assign] + queue = DummyEventQueue() + + await executor.execute( + _context(task_id="task-delta", context_id="ctx-delta", text="go"), + queue, + ) + + reasoning_updates = [ + event + for event in _artifact_updates(queue) + if event.artifact.metadata["opencode"]["block_type"] == "reasoning" + ] + assert reasoning_updates + merged = "".join(_part_text(ev) for ev in reasoning_updates) + assert merged == "first second" + + +@pytest.mark.asyncio +async def test_streaming_buffers_delta_until_part_updated_arrives() -> None: + client = DummyStreamingClient( + stream_events_payload=[ + _delta_event(session_id="ses-1", part_id="prt-late", delta="first "), + _delta_event(session_id="ses-1", part_id="prt-late", delta="second"), + _event( + session_id="ses-1", + role="assistant", + part_type="reasoning", + delta="", + part_id="prt-late", + text="first second", + ), + ], + response_text="answer", + ) + executor = OpencodeAgentExecutor(client, streaming_enabled=True) + executor._should_stream = lambda context: True # type: ignore[method-assign] + queue = DummyEventQueue() + + await executor.execute( + _context(task_id="task-buffered-delta", context_id="ctx-buffered-delta", text="go"), + queue, + ) + + reasoning_updates = [ + event + for event in _artifact_updates(queue) + if event.artifact.metadata["opencode"]["block_type"] == "reasoning" + ] + assert reasoning_updates + merged = "".join(_part_text(ev) for ev in reasoning_updates) + assert merged == "first second" + + +@pytest.mark.asyncio +async def test_streaming_keeps_multiple_message_ids_in_same_request_window() -> None: + client = DummyStreamingClient( + stream_events_payload=[ + _event( + session_id="ses-1", + role="assistant", + part_type="reasoning", + part_id="prt-m1", + message_id="msg-a", + delta="step one ", + ), + _event( + session_id="ses-1", + role="assistant", + part_type="text", + part_id="prt-m2", + message_id="msg-b", + delta="final answer", + ), + ], + response_text="final answer", + response_message_id="msg-b", + ) + executor = OpencodeAgentExecutor(client, streaming_enabled=True) + executor._should_stream = lambda context: True # type: ignore[method-assign] + queue = DummyEventQueue() + + await executor.execute( + _context(task_id="task-multi-mid", context_id="ctx-multi-mid", text="go"), + queue, + ) + + updates = _artifact_updates(queue) + message_ids = [ev.artifact.metadata["opencode"].get("message_id") for ev in updates] + assert "msg-a" in message_ids + assert "msg-b" in message_ids