11
11
from contextlib import asynccontextmanager
12
12
from dataclasses import dataclass
13
13
from datetime import timedelta
14
- from typing import Any
15
14
16
15
import anyio
17
16
import httpx
52
51
class StreamableHTTPError (Exception ):
53
52
"""Base exception for StreamableHTTP transport errors."""
54
53
55
- pass
56
-
57
54
58
55
class ResumptionError (StreamableHTTPError ):
59
56
"""Raised when resumption request is invalid."""
60
57
61
- pass
62
-
63
58
64
59
@dataclass
65
60
class RequestContext :
@@ -71,7 +66,7 @@ class RequestContext:
71
66
session_message : SessionMessage
72
67
metadata : ClientMessageMetadata | None
73
68
read_stream_writer : StreamWriter
74
- sse_read_timeout : timedelta
69
+ sse_read_timeout : float
75
70
76
71
77
72
class StreamableHTTPTransport :
@@ -80,9 +75,9 @@ class StreamableHTTPTransport:
80
75
def __init__ (
81
76
self ,
82
77
url : str ,
83
- headers : dict [str , Any ] | None = None ,
84
- timeout : timedelta = timedelta ( seconds = 30 ) ,
85
- sse_read_timeout : timedelta = timedelta ( seconds = 60 * 5 ) ,
78
+ headers : dict [str , str ] | None = None ,
79
+ timeout : float | timedelta = 30 ,
80
+ sse_read_timeout : float | timedelta = 60 * 5 ,
86
81
auth : httpx .Auth | None = None ,
87
82
) -> None :
88
83
"""Initialize the StreamableHTTP transport.
@@ -96,10 +91,12 @@ def __init__(
96
91
"""
97
92
self .url = url
98
93
self .headers = headers or {}
99
- self .timeout = timeout
100
- self .sse_read_timeout = sse_read_timeout
94
+ self .timeout = timeout .total_seconds () if isinstance (timeout , timedelta ) else timeout
95
+ self .sse_read_timeout = (
96
+ sse_read_timeout .total_seconds () if isinstance (sse_read_timeout , timedelta ) else sse_read_timeout
97
+ )
101
98
self .auth = auth
102
- self .session_id : str | None = None
99
+ self .session_id = None
103
100
self .request_headers = {
104
101
ACCEPT : f"{ JSON } , { SSE } " ,
105
102
CONTENT_TYPE : JSON ,
@@ -160,7 +157,7 @@ async def _handle_sse_event(
160
157
return isinstance (message .root , JSONRPCResponse | JSONRPCError )
161
158
162
159
except Exception as exc :
163
- logger .error ( f "Error parsing SSE message: { exc } " )
160
+ logger .exception ( "Error parsing SSE message" )
164
161
await read_stream_writer .send (exc )
165
162
return False
166
163
else :
@@ -184,10 +181,7 @@ async def handle_get_stream(
184
181
"GET" ,
185
182
self .url ,
186
183
headers = headers ,
187
- timeout = httpx .Timeout (
188
- self .timeout .total_seconds (),
189
- read = self .sse_read_timeout .total_seconds (),
190
- ),
184
+ timeout = httpx .Timeout (self .timeout , read = self .sse_read_timeout ),
191
185
) as event_source :
192
186
event_source .response .raise_for_status ()
193
187
logger .debug ("GET SSE connection established" )
@@ -216,10 +210,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
216
210
"GET" ,
217
211
self .url ,
218
212
headers = headers ,
219
- timeout = httpx .Timeout (
220
- self .timeout .total_seconds (),
221
- read = ctx .sse_read_timeout .total_seconds (),
222
- ),
213
+ timeout = httpx .Timeout (self .timeout , read = self .sse_read_timeout ),
223
214
) as event_source :
224
215
event_source .response .raise_for_status ()
225
216
logger .debug ("Resumption GET SSE connection established" )
@@ -412,9 +403,9 @@ def get_session_id(self) -> str | None:
412
403
@asynccontextmanager
413
404
async def streamablehttp_client (
414
405
url : str ,
415
- headers : dict [str , Any ] | None = None ,
416
- timeout : timedelta = timedelta ( seconds = 30 ) ,
417
- sse_read_timeout : timedelta = timedelta ( seconds = 60 * 5 ) ,
406
+ headers : dict [str , str ] | None = None ,
407
+ timeout : float | timedelta = 30 ,
408
+ sse_read_timeout : float | timedelta = 60 * 5 ,
418
409
terminate_on_close : bool = True ,
419
410
httpx_client_factory : McpHttpClientFactory = create_mcp_http_client ,
420
411
auth : httpx .Auth | None = None ,
@@ -449,10 +440,7 @@ async def streamablehttp_client(
449
440
450
441
async with httpx_client_factory (
451
442
headers = transport .request_headers ,
452
- timeout = httpx .Timeout (
453
- transport .timeout .total_seconds (),
454
- read = transport .sse_read_timeout .total_seconds (),
455
- ),
443
+ timeout = httpx .Timeout (transport .timeout , read = transport .sse_read_timeout ),
456
444
auth = transport .auth ,
457
445
) as client :
458
446
# Define callbacks that need access to tg
0 commit comments