Skip to content

pass sse_read_timeout to MCP ClientSession read_timeout_seconds #2240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
90 changes: 84 additions & 6 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import base64
import functools
import warnings
from abc import ABC, abstractmethod
from asyncio import Lock
from collections.abc import AsyncIterator, Awaitable, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field, replace
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable

Expand Down Expand Up @@ -59,6 +61,7 @@ class MCPServer(AbstractToolset[Any], ABC):
log_level: mcp_types.LoggingLevel | None = None
log_handler: LoggingFnT | None = None
timeout: float = 5
read_timeout: float = 5 * 60
process_tool_call: ProcessToolCallback | None = None
allow_sampling: bool = True
max_retries: int = 1
Expand Down Expand Up @@ -208,6 +211,7 @@ async def __aenter__(self) -> Self:
write_stream=self._write_stream,
sampling_callback=self._sampling_callback if self.allow_sampling else None,
logging_callback=self.log_handler,
read_timeout_seconds=timedelta(seconds=self.read_timeout),
)
self._client = await self._exit_stack.enter_async_context(client)

Expand Down Expand Up @@ -401,7 +405,7 @@ def __repr__(self) -> str:
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'


@dataclass
@dataclass(init=False)
class _MCPServerHTTP(MCPServer):
url: str
"""The URL of the endpoint on the MCP server."""
Expand Down Expand Up @@ -438,10 +442,10 @@ class _MCPServerHTTP(MCPServer):
```
"""

sse_read_timeout: float = 5 * 60
"""Maximum time in seconds to wait for new SSE messages before timing out.
read_timeout: float = 5 * 60
"""Maximum time in seconds to wait for new messages before timing out.

This timeout applies to the long-lived SSE connection after it's established.
This timeout applies to the long-lived connection after it's established.
If no new messages are received within this time, the connection will be considered stale
and may be closed. Defaults to 5 minutes (300 seconds).
"""
Expand Down Expand Up @@ -485,6 +489,35 @@ class _MCPServerHTTP(MCPServer):
sampling_model: models.Model | None = None
"""The model to use for sampling."""

def __init__(
self,
*,
url: str,
headers: dict[str, str] | None = None,
http_client: httpx.AsyncClient | None = None,
read_timeout: float = 5 * 60,
tool_prefix: str | None = None,
log_level: mcp_types.LoggingLevel | None = None,
log_handler: LoggingFnT | None = None,
timeout: float = 5,
process_tool_call: ProcessToolCallback | None = None,
allow_sampling: bool = True,
max_retries: int = 1,
sampling_model: models.Model | None = None,
):
self.url = url
self.headers = headers
self.http_client = http_client
self.tool_prefix = tool_prefix
self.log_level = log_level
self.log_handler = log_handler
self.timeout = timeout
self.process_tool_call = process_tool_call
self.allow_sampling = allow_sampling
self.max_retries = max_retries
self.sampling_model = sampling_model
self.read_timeout = read_timeout

@property
@abstractmethod
def _transport_client(
Expand Down Expand Up @@ -522,7 +555,7 @@ async def client_streams(
self._transport_client,
url=self.url,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
sse_read_timeout=self.read_timeout,
)

if self.http_client is not None:
Expand All @@ -549,7 +582,7 @@ def __repr__(self) -> str: # pragma: no cover
return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'


@dataclass
@dataclass(init=False)
class MCPServerSSE(_MCPServerHTTP):
"""An MCP server that connects over streamable HTTP connections.

Expand Down Expand Up @@ -577,6 +610,51 @@ async def main():
2. This will connect to a server running on `localhost:3001`.
"""

def __init__(
self,
*,
url: str,
headers: dict[str, str] | None = None,
http_client: httpx.AsyncClient | None = None,
read_timeout: float | None = None,
tool_prefix: str | None = None,
log_level: mcp_types.LoggingLevel | None = None,
log_handler: LoggingFnT | None = None,
timeout: float = 5,
process_tool_call: ProcessToolCallback | None = None,
allow_sampling: bool = True,
max_retries: int = 1,
sampling_model: models.Model | None = None,
**kwargs: Any,
) -> None:
# Handle deprecated sse_read_timeout parameter
if 'sse_read_timeout' in kwargs:
if read_timeout is not None:
raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")

warnings.warn(
"'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
)
read_timeout = kwargs.pop('sse_read_timeout')

if read_timeout is None:
read_timeout = 5 * 60

super().__init__(
url=url,
headers=headers,
http_client=http_client,
read_timeout=read_timeout,
tool_prefix=tool_prefix,
log_level=log_level,
log_handler=log_handler,
timeout=timeout,
process_tool_call=process_tool_call,
allow_sampling=allow_sampling,
max_retries=max_retries,
sampling_model=sampling_model,
)

@property
def _transport_client(self):
return sse_client # pragma: no cover
Expand Down
26 changes: 18 additions & 8 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,30 @@ def test_sse_server():


def test_sse_server_with_header_and_timeout():
sse_server = MCPServerSSE(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=10,
sse_read_timeout=100,
log_level='info',
)
with pytest.warns(DeprecationWarning, match="'sse_read_timeout' is deprecated, use 'read_timeout' instead."):
sse_server = MCPServerSSE(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=10,
sse_read_timeout=100,
log_level='info',
)
assert sse_server.url == 'http://localhost:8000/sse'
assert sse_server.headers is not None and sse_server.headers['my-custom-header'] == 'my-header-value'
assert sse_server.timeout == 10
assert sse_server.sse_read_timeout == 100
assert sse_server.read_timeout == 100
assert sse_server.log_level == 'info'


def test_sse_server_conflicting_timeout_params():
with pytest.raises(TypeError, match="'read_timeout' and 'sse_read_timeout' cannot be set at the same time."):
MCPServerSSE(
url='http://localhost:8000/sse',
read_timeout=50,
sse_read_timeout=100,
)


@pytest.mark.vcr()
async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent):
async with agent:
Expand Down