Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def __init__(
self._closed = False
self._initialization_result: dict[str, Any] | None = None

# Owner task pattern: events for coordinating lifecycle
self._owner_stop_event: anyio.Event | None = None
self._owner_started_event: anyio.Event | None = None
self._outer_tg: anyio.abc.TaskGroup | None = None

async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.

Expand Down Expand Up @@ -152,11 +157,43 @@ async def initialize(self) -> dict[str, Any] | None:
return response

async def start(self) -> None:
"""Start reading messages from transport."""
"""Start reading messages from transport.

Uses the owner task pattern to ensure the inner task group is properly
managed by a single task, which is required for trio compatibility.
"""
if self._tg is None:
self._tg = anyio.create_task_group()
await self._tg.__aenter__()
self._tg.start_soon(self._read_messages)
self._owner_stop_event = anyio.Event()
self._owner_started_event = anyio.Event()

# Outer task group spawns the owner task
self._outer_tg = anyio.create_task_group()
await self._outer_tg.__aenter__()
self._outer_tg.start_soon(self._task_group_owner)

# Wait for owner to signal it's ready
await self._owner_started_event.wait()

async def _task_group_owner(self) -> None:
"""Owner task that manages the inner task group.

This task owns the task group for its entire lifetime, ensuring that
the same task that enters the cancel scope also exits it. This is
required for trio compatibility.
"""
try:
async with anyio.create_task_group() as tg:
self._tg = tg
tg.start_soon(self._read_messages)
self._owner_started_event.set() # type: ignore[union-attr]

# Wait until close() signals us to stop
await self._owner_stop_event.wait() # type: ignore[union-attr]

# Cancel child tasks
tg.cancel_scope.cancel()
finally:
self._tg = None

async def _read_messages(self) -> None:
"""Read messages from transport and route them."""
Expand Down Expand Up @@ -550,11 +587,17 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
async def close(self) -> None:
"""Close the query and transport."""
self._closed = True
if self._tg:
self._tg.cancel_scope.cancel()
# Wait for task group to complete cancellation

# Signal owner task to stop
if self._owner_stop_event:
self._owner_stop_event.set()

# Wait for outer task group to finish (owner will exit after stop event)
if self._outer_tg:
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)
await self._outer_tg.__aexit__(None, None, None)
self._outer_tg = None

await self.transport.close()

# Make Query an async iterator
Expand Down
36 changes: 32 additions & 4 deletions src/claude_agent_sdk/_internal/transport/subprocess_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(
self._stdin_stream: TextSendStream | None = None
self._stderr_stream: TextReceiveStream | None = None
self._stderr_task_group: anyio.abc.TaskGroup | None = None
self._stderr_stop_event: anyio.Event | None = None
self._stderr_started_event: anyio.Event | None = None
self._ready = False
self._exit_error: Exception | None = None # Track process exit errors
self._max_buffer_size = (
Expand Down Expand Up @@ -340,10 +342,13 @@ async def connect(self) -> None:
# Setup stderr stream if piped
if should_pipe_stderr and self._process.stderr:
self._stderr_stream = TextReceiveStream(self._process.stderr)
# Start async task to read stderr
# Start async task to read stderr using owner task pattern
self._stderr_stop_event = anyio.Event()
self._stderr_started_event = anyio.Event()
self._stderr_task_group = anyio.create_task_group()
await self._stderr_task_group.__aenter__()
self._stderr_task_group.start_soon(self._handle_stderr)
self._stderr_task_group.start_soon(self._stderr_owner_task)
await self._stderr_started_event.wait()

# Setup stdin for streaming mode
if self._is_streaming and self._process.stdin:
Expand All @@ -370,6 +375,26 @@ async def connect(self) -> None:
self._exit_error = error
raise error from e

async def _stderr_owner_task(self) -> None:
"""Owner task that manages the stderr reader task group.

This task owns the task group for its entire lifetime, ensuring that
the same task that enters the cancel scope also exits it. This is
required for trio compatibility.
"""
try:
async with anyio.create_task_group() as tg:
tg.start_soon(self._handle_stderr)
self._stderr_started_event.set() # type: ignore[union-attr]

# Wait until close() signals us to stop
await self._stderr_stop_event.wait() # type: ignore[union-attr]

# Cancel child tasks
tg.cancel_scope.cancel()
except Exception:
pass # Ignore errors during stderr task cleanup

async def _handle_stderr(self) -> None:
"""Handle stderr stream - read and invoke callbacks."""
if not self._stderr_stream:
Expand Down Expand Up @@ -411,10 +436,13 @@ async def close(self) -> None:
if not self._process:
return

# Close stderr task group if active
# Signal stderr owner task to stop
if self._stderr_stop_event:
self._stderr_stop_event.set()

# Wait for stderr task group to finish (owner will exit after stop event)
if self._stderr_task_group:
with suppress(Exception):
self._stderr_task_group.cancel_scope.cancel()
await self._stderr_task_group.__aexit__(None, None, None)
self._stderr_task_group = None

Expand Down
Loading
Loading