Skip to content

Commit 20dc0fb

Browse files
authored
Allow to pass timeout as float (#941)
1 parent d69b290 commit 20dc0fb

File tree

1 file changed

+16
-28
lines changed

1 file changed

+16
-28
lines changed

src/mcp/client/streamable_http.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from contextlib import asynccontextmanager
1212
from dataclasses import dataclass
1313
from datetime import timedelta
14-
from typing import Any
1514

1615
import anyio
1716
import httpx
@@ -52,14 +51,10 @@
5251
class StreamableHTTPError(Exception):
5352
"""Base exception for StreamableHTTP transport errors."""
5453

55-
pass
56-
5754

5855
class ResumptionError(StreamableHTTPError):
5956
"""Raised when resumption request is invalid."""
6057

61-
pass
62-
6358

6459
@dataclass
6560
class RequestContext:
@@ -71,7 +66,7 @@ class RequestContext:
7166
session_message: SessionMessage
7267
metadata: ClientMessageMetadata | None
7368
read_stream_writer: StreamWriter
74-
sse_read_timeout: timedelta
69+
sse_read_timeout: float
7570

7671

7772
class StreamableHTTPTransport:
@@ -80,9 +75,9 @@ class StreamableHTTPTransport:
8075
def __init__(
8176
self,
8277
url: str,
83-
headers: dict[str, Any] | None = None,
84-
timeout: timedelta = timedelta(seconds=30),
85-
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
78+
headers: dict[str, str] | None = None,
79+
timeout: float | timedelta = 30,
80+
sse_read_timeout: float | timedelta = 60 * 5,
8681
auth: httpx.Auth | None = None,
8782
) -> None:
8883
"""Initialize the StreamableHTTP transport.
@@ -96,10 +91,12 @@ def __init__(
9691
"""
9792
self.url = url
9893
self.headers = headers or {}
99-
self.timeout = timeout
100-
self.sse_read_timeout = sse_read_timeout
94+
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
95+
self.sse_read_timeout = (
96+
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
97+
)
10198
self.auth = auth
102-
self.session_id: str | None = None
99+
self.session_id = None
103100
self.request_headers = {
104101
ACCEPT: f"{JSON}, {SSE}",
105102
CONTENT_TYPE: JSON,
@@ -160,7 +157,7 @@ async def _handle_sse_event(
160157
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
161158

162159
except Exception as exc:
163-
logger.error(f"Error parsing SSE message: {exc}")
160+
logger.exception("Error parsing SSE message")
164161
await read_stream_writer.send(exc)
165162
return False
166163
else:
@@ -184,10 +181,7 @@ async def handle_get_stream(
184181
"GET",
185182
self.url,
186183
headers=headers,
187-
timeout=httpx.Timeout(
188-
self.timeout.total_seconds(),
189-
read=self.sse_read_timeout.total_seconds(),
190-
),
184+
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
191185
) as event_source:
192186
event_source.response.raise_for_status()
193187
logger.debug("GET SSE connection established")
@@ -216,10 +210,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
216210
"GET",
217211
self.url,
218212
headers=headers,
219-
timeout=httpx.Timeout(
220-
self.timeout.total_seconds(),
221-
read=ctx.sse_read_timeout.total_seconds(),
222-
),
213+
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
223214
) as event_source:
224215
event_source.response.raise_for_status()
225216
logger.debug("Resumption GET SSE connection established")
@@ -412,9 +403,9 @@ def get_session_id(self) -> str | None:
412403
@asynccontextmanager
413404
async def streamablehttp_client(
414405
url: str,
415-
headers: dict[str, Any] | None = None,
416-
timeout: timedelta = timedelta(seconds=30),
417-
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
406+
headers: dict[str, str] | None = None,
407+
timeout: float | timedelta = 30,
408+
sse_read_timeout: float | timedelta = 60 * 5,
418409
terminate_on_close: bool = True,
419410
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
420411
auth: httpx.Auth | None = None,
@@ -449,10 +440,7 @@ async def streamablehttp_client(
449440

450441
async with httpx_client_factory(
451442
headers=transport.request_headers,
452-
timeout=httpx.Timeout(
453-
transport.timeout.total_seconds(),
454-
read=transport.sse_read_timeout.total_seconds(),
455-
),
443+
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
456444
auth=transport.auth,
457445
) as client:
458446
# Define callbacks that need access to tg

0 commit comments

Comments
 (0)