diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 83de57a2..d51d5c12 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -6,6 +6,7 @@ import anyio import anyio.lowlevel +from anyio.abc import Process from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field @@ -38,6 +39,10 @@ ) +class ProcessTerminatedEarlyError(Exception): + """Raised when a process terminates unexpectedly.""" + + def get_default_environment() -> dict[str, str]: """ Returns a default environment object including only environment variables deemed @@ -110,7 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder command = _get_executable_command(server.command) # Open process with stderr piped for capture - process = await _create_platform_compatible_process( + process: Process = await _create_platform_compatible_process( command=command, args=server.args, env=( @@ -122,7 +127,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder cwd=server.cwd, ) - async def stdout_reader(): + async def stdout_reader(done_event: anyio.Event): assert process.stdout, "Opened process is missing stdout" try: @@ -146,6 +151,7 @@ async def stdout_reader(): await read_stream_writer.send(message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() + done_event.set() async def stdin_writer(): assert process.stdin, "Opened process is missing stdin" @@ -163,20 +169,45 @@ async def stdin_writer(): except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() + process_error: str | None = None + async with ( anyio.create_task_group() as tg, process, ): - tg.start_soon(stdout_reader) + stdout_done_event = anyio.Event() + tg.start_soon(stdout_reader, stdout_done_event) tg.start_soon(stdin_writer) try: yield read_stream, write_stream + if stdout_done_event.is_set(): + # The stdout reader exited before the calling code stopped listening + # (e.g. because of process error) + # Give the process a chance to exit if it was the reason for crashing + # so we can get exit code + with anyio.move_on_after(0.1) as scope: + await process.wait() + process_error = f"Process exited with code {process.returncode}." + if scope.cancelled_caught: + process_error = ( + "Stdout reader exited (process did not exit immediately)." + ) finally: + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() # Clean up process to prevent any dangling orphaned processes - if sys.platform == "win32": - await terminate_windows_process(process) - else: - process.terminate() + if process.returncode is None: + if sys.platform == "win32": + await terminate_windows_process(process) + else: + process.terminate() + + if process_error: + # Raise outside the task group so that the error is not wrapped in an + # ExceptionGroup + raise ProcessTerminatedEarlyError(process_error) def _get_executable_command(command: str) -> str: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd..ae968974 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,11 +1,17 @@ import shutil import pytest +from anyio import fail_after -from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.stdio import ( + ProcessTerminatedEarlyError, + StdioServerParameters, + stdio_client, +) from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore +uv: str = shutil.which("uv") # type: ignore @pytest.mark.anyio @@ -41,3 +47,28 @@ async def test_stdio_client(): assert read_messages[1] == JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) ) + + +@pytest.mark.anyio +@pytest.mark.skipif(uv is None, reason="could not find uv command") +async def test_stdio_client_bad_path(): + """Check that the connection doesn't hang if process errors.""" + server_parameters = StdioServerParameters( + command="uv", args=["run", "non-existent-file.py"] + ) + + with pytest.raises(ProcessTerminatedEarlyError): + try: + with fail_after(1): + async with stdio_client(server_parameters) as ( + read_stream, + _, + ): + # Try waiting for read_stream so that we don't exit before the + # process fails. + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + except TimeoutError: + pytest.fail("The connection hung.")