diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 5d21dd1f..a0c003c5 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -274,6 +274,7 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: }, } await self.transport.write(json.dumps(success_response) + "\n") + await self.transport.flush_stdin() except Exception as e: # Send error response @@ -286,6 +287,7 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: }, } await self.transport.write(json.dumps(error_response) + "\n") + await self.transport.flush_stdin() async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]: """Send control request to CLI and wait for response.""" @@ -309,6 +311,11 @@ async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any] await self.transport.write(json.dumps(control_request) + "\n") + # Flush stdin to ensure the request is sent immediately + # This is critical on Windows where buffering can prevent the subprocess + # from receiving the data + await self.transport.flush_stdin() + # Wait for response try: with anyio.fail_after(60.0): diff --git a/src/claude_agent_sdk/_internal/transport/__init__.py b/src/claude_agent_sdk/_internal/transport/__init__.py index 6dedef61..ec514ee5 100644 --- a/src/claude_agent_sdk/_internal/transport/__init__.py +++ b/src/claude_agent_sdk/_internal/transport/__init__.py @@ -64,5 +64,17 @@ async def end_input(self) -> None: """End the input stream (close stdin for process transports).""" pass + async def flush_stdin(self) -> None: + """Flush the stdin stream to ensure data is sent immediately. + + This is primarily needed on Windows where subprocess stdin buffering + can prevent data from being sent to the child process immediately. + + Default implementation does nothing. Transports that support stdin + flushing should override this method. + """ + # Default implementation - subclasses can override for platform-specific flushing + return None + __all__ = ["Transport"] diff --git a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py index bd9fc59d..2c0237c5 100644 --- a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py @@ -3,6 +3,7 @@ import json import logging import os +import platform import re import shutil import sys @@ -496,3 +497,46 @@ async def _check_claude_version(self) -> None: def is_ready(self) -> bool: """Check if transport is ready for communication.""" return self._ready + + async def flush_stdin(self) -> None: + """Flush stdin to ensure data is sent immediately to the subprocess. + + This is particularly important on Windows where subprocess stdin buffering + can prevent data from reaching the child process immediately. + + This method attempts to drain the stdin stream if using asyncio backend, + which is the primary fix for Windows subprocess communication issues. + """ + # Only flush if we have a process and stdin stream + if not self._process or not self._process.stdin: + return + + # On Windows, we need to explicitly flush/drain the stdin stream + # to ensure data reaches the subprocess immediately + if platform.system() == "Windows": + try: + # anyio wraps subprocess stdin in a ByteSendStream + # When using asyncio backend, the underlying stream is a StreamWriter + # which has a drain() method that we need to call + stdin_stream = self._process.stdin + + # Check if this is an asyncio StreamWriter (has drain method) + if hasattr(stdin_stream, "drain") and callable(stdin_stream.drain): + await stdin_stream.drain() + logger.debug("Flushed stdin stream on Windows") + else: + # If not a StreamWriter, try to access wrapped/inner stream + # anyio may wrap the stream in various ways depending on backend + for attr in ["_stream", "_transport_stream", "transport_stream"]: + if hasattr(stdin_stream, attr): + inner = getattr(stdin_stream, attr) + if hasattr(inner, "drain") and callable(inner.drain): + await inner.drain() + logger.debug( + f"Flushed stdin inner stream via {attr} on Windows" + ) + break + except Exception as e: + # Log but don't fail - flushing is a best-effort optimization + logger.debug(f"Could not flush stdin on Windows: {e}") + pass diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 29294419..5e282c29 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -15,11 +15,14 @@ ClaudeAgentOptions, ClaudeSDKClient, CLIConnectionError, + PermissionResultAllow, ResultMessage, TextBlock, UserMessage, query, ) +from claude_agent_sdk._internal.query import Query +from claude_agent_sdk._internal.transport import Transport from claude_agent_sdk._internal.transport.subprocess_cli import SubprocessCLITransport @@ -34,6 +37,7 @@ def create_mock_transport(with_init_response=True): mock_transport.close = AsyncMock() mock_transport.end_input = AsyncMock() mock_transport.write = AsyncMock() + mock_transport.flush_stdin = AsyncMock() mock_transport.is_ready = Mock(return_value=True) # Track written messages to simulate control protocol responses @@ -571,6 +575,82 @@ async def get_next_message(): anyio.run(_test) + def test_flush_stdin_called_after_control_responses(self): + """Test that flush_stdin is called after responding to control requests (issue #208).""" + + async def _test(): + # Create a mock transport + mock_transport = AsyncMock(spec=Transport) + mock_transport.is_ready = Mock(return_value=True) + + # Track write and flush calls + write_calls = [] + flush_calls = [] + + async def mock_write(data): + write_calls.append(data) + + async def mock_flush(): + flush_calls.append(True) + + mock_transport.write = AsyncMock(side_effect=mock_write) + mock_transport.flush_stdin = AsyncMock(side_effect=mock_flush) + + # Create mock read_messages that doesn't yield anything + async def mock_read_messages(): + # Just wait forever (test will complete before this matters) + await asyncio.sleep(1000) + yield {} + + mock_transport.read_messages = mock_read_messages + + # Create Query with streaming mode + query = Query(transport=mock_transport, is_streaming_mode=True) + await query.start() + + # Simulate an incoming tool permission request + permission_request = { + "type": "control_request", + "request_id": "test_req_123", + "request": { + "subtype": "can_use_tool", + "tool_name": "Read", + "input": {"file_path": "/test.txt"}, + "permission_suggestions": [], + }, + } + + # Set up a permission callback that allows the tool + async def mock_can_use_tool(tool_name, input_data, context): + return PermissionResultAllow() + + query.can_use_tool = mock_can_use_tool + + # Clear previous calls + write_calls.clear() + flush_calls.clear() + + # Handle the control request + await query._handle_control_request(permission_request) + + # Give it a moment to complete + await asyncio.sleep(0.01) + + # Verify that flush_stdin was called after writing the response + assert len(write_calls) == 1, "Should have written one control response" + assert len(flush_calls) == 1, ( + "flush_stdin should be called after writing response" + ) + + # Verify the response was a success + response_data = json.loads(write_calls[0]) + assert response_data["type"] == "control_response" + assert response_data["response"]["subtype"] == "success" + + await query.close() + + anyio.run(_test) + class TestQueryWithAsyncIterable: """Test query() function with async iterable inputs.""" @@ -833,3 +913,44 @@ async def mock_receive(): assert isinstance(messages[-1], ResultMessage) anyio.run(_test) + + def test_flush_stdin_called_after_control_requests(self): + """Test that flush_stdin is called after sending control requests (issue #208).""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = create_mock_transport() + + # Add flush_stdin mock and tracking + flush_calls = [] + + async def mock_flush(): + flush_calls.append(True) + + mock_transport.flush_stdin = AsyncMock(side_effect=mock_flush) + + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + # Initialization should call flush_stdin + # Wait a bit for initialization to complete + await asyncio.sleep(0.05) + + # Verify flush_stdin was called at least once (for initialization) + assert len(flush_calls) >= 1, ( + "flush_stdin should be called during initialization" + ) + initial_flush_count = len(flush_calls) + + # Send interrupt control request + await client.interrupt() + await asyncio.sleep(0.05) + + # Verify flush_stdin was called again (for interrupt request) + assert len(flush_calls) > initial_flush_count, ( + "flush_stdin should be called after interrupt" + ) + + anyio.run(_test) diff --git a/tests/test_transport.py b/tests/test_transport.py index 93538f4a..49e9d729 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -486,3 +486,110 @@ async def _test(): assert user_passed == "claude" anyio.run(_test) + + def test_flush_stdin_on_windows(self): + """Test that flush_stdin calls drain() on Windows (issue #208).""" + + async def _test(): + # Mock platform.system to return Windows + with patch("platform.system", return_value="Windows"): + transport = SubprocessCLITransport( + prompt="test", + options=ClaudeAgentOptions(), + cli_path="/usr/bin/claude", + ) + + # Create a mock process with stdin that has drain method + mock_process = MagicMock() + mock_stdin = AsyncMock() + mock_stdin.drain = AsyncMock() + mock_process.stdin = mock_stdin + transport._process = mock_process + + # Call flush_stdin + await transport.flush_stdin() + + # Verify drain was called on Windows + mock_stdin.drain.assert_called_once() + + anyio.run(_test) + + def test_flush_stdin_on_non_windows(self): + """Test that flush_stdin does nothing on non-Windows platforms.""" + + async def _test(): + # Mock platform.system to return Linux + with patch("platform.system", return_value="Linux"): + transport = SubprocessCLITransport( + prompt="test", + options=ClaudeAgentOptions(), + cli_path="/usr/bin/claude", + ) + + # Create a mock process with stdin + mock_process = MagicMock() + mock_stdin = AsyncMock() + mock_stdin.drain = AsyncMock() + mock_process.stdin = mock_stdin + transport._process = mock_process + + # Call flush_stdin + await transport.flush_stdin() + + # Verify drain was NOT called on non-Windows + mock_stdin.drain.assert_not_called() + + anyio.run(_test) + + def test_flush_stdin_without_process(self): + """Test that flush_stdin handles missing process gracefully.""" + + async def _test(): + transport = SubprocessCLITransport( + prompt="test", + options=ClaudeAgentOptions(), + cli_path="/usr/bin/claude", + ) + + # Don't set up a process + transport._process = None + + # Should not raise an error + await transport.flush_stdin() + + anyio.run(_test) + + def test_flush_stdin_fallback_to_inner_stream(self): + """Test that flush_stdin tries to find drain() in wrapped streams.""" + + async def _test(): + # Mock platform.system to return Windows + with patch("platform.system", return_value="Windows"): + transport = SubprocessCLITransport( + prompt="test", + options=ClaudeAgentOptions(), + cli_path="/usr/bin/claude", + ) + + # Create a mock process with stdin that doesn't have drain, + # but has an inner _stream that does + mock_process = MagicMock() + mock_stdin = MagicMock() + # Remove drain from stdin itself + del mock_stdin.drain + + # Add inner stream with drain + mock_inner_stream = AsyncMock() + mock_inner_stream.drain = AsyncMock() + mock_stdin._stream = mock_inner_stream + + mock_process.stdin = mock_stdin + transport._process = mock_process + + # Call flush_stdin + await transport.flush_stdin() + + # Verify drain was called on the inner stream + mock_inner_stream.drain.assert_called_once() + + anyio.run(_test)