Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent stdio connection hang for missing server path. #401

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import anyio
import anyio.lowlevel
from anyio.abc import Process
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -38,6 +39,10 @@
)


class ProcessTerminatedEarlyError(Exception):
"""Raised when a process terminates unexpectedly."""


def get_default_environment() -> dict[str, str]:
"""
Returns a default environment object including only environment variables deemed
Expand Down Expand Up @@ -110,7 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
command = _get_executable_command(server.command)

# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
process: Process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=(
Expand Down Expand Up @@ -163,20 +168,36 @@ async def stdin_writer():
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

process_error: str | None = None

async with (
anyio.create_task_group() as tg,
process,
):
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
# tg.start_soon(monitor_process, tg.cancel_scope)
try:
yield read_stream, write_stream
finally:
# Clean up process to prevent any dangling orphaned processes
if sys.platform == "win32":
await terminate_windows_process(process)
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()

if process.returncode is not None and process.returncode != 0:
process_error = f"Process exited with code {process.returncode}."
else:
process.terminate()
# Clean up process to prevent any dangling orphaned processes
if sys.platform == "win32":
await terminate_windows_process(process)
else:
process.terminate()

if process_error:
# Raise outside the task group so that the error is not wrapped in an
# ExceptionGroup
raise ProcessTerminatedEarlyError(process_error)


def _get_executable_command(command: str) -> str:
Expand Down
35 changes: 34 additions & 1 deletion tests/client/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import shutil

import pytest
from anyio import fail_after

from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.stdio import (
ProcessTerminatedEarlyError,
StdioServerParameters,
stdio_client,
)
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse

tee: str = shutil.which("tee") # type: ignore
uv: str = shutil.which("uv") # type: ignore


@pytest.mark.anyio
Expand Down Expand Up @@ -41,3 +47,30 @@ async def test_stdio_client():
assert read_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)


@pytest.mark.anyio
@pytest.mark.skipif(uv is None, reason="could not find uv command")
async def test_stdio_client_bad_path():
"""Check that the connection doesn't hang if process errors."""
server_parameters = StdioServerParameters(
command="uv", args=["run", "non-existent-file.py"]
)

with pytest.raises(ProcessTerminatedEarlyError):
try:
with fail_after(1):
async with stdio_client(server_parameters) as (
read_stream,
_,
):
# Try waiting for read_stream so that we don't exit before the
# process fails.
async with read_stream:
async for message in read_stream:
if isinstance(message, Exception):
raise message

pass
except TimeoutError:
pytest.fail("The connection hung.")
Loading