diff --git a/src/fastmcp/client/auth/oauth.py b/src/fastmcp/client/auth/oauth.py index fd5056fdd..349db61b3 100644 --- a/src/fastmcp/client/auth/oauth.py +++ b/src/fastmcp/client/auth/oauth.py @@ -12,6 +12,7 @@ from key_value.aio.protocols import AsyncKeyValue from key_value.aio.stores.memory import MemoryStore from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, @@ -147,6 +148,7 @@ def __init__( token_storage: AsyncKeyValue | None = None, additional_client_metadata: dict[str, Any] | None = None, callback_port: int | None = None, + httpx_client_factory: McpHttpClientFactory | None = None, ): """ Initialize OAuth client provider for an MCP server. @@ -164,6 +166,7 @@ def __init__( server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" # Setup OAuth client + self.httpx_client_factory = httpx_client_factory or httpx.AsyncClient self.redirect_port = callback_port or find_available_port() redirect_uri = f"http://localhost:{self.redirect_port}/callback" @@ -226,7 +229,7 @@ async def _initialize(self) -> None: async def redirect_handler(self, authorization_url: str) -> None: """Open browser for authorization, with pre-flight check for invalid client.""" # Pre-flight check to detect invalid client_id before opening browser - async with httpx.AsyncClient() as client: + async with self.httpx_client_factory() as client: response = await client.get(authorization_url, follow_redirects=False) # Check for client not found error (400 typically means bad client_id) diff --git a/src/fastmcp/client/transports.py b/src/fastmcp/client/transports.py index 25f81afc4..81afc9c88 100644 --- a/src/fastmcp/client/transports.py +++ b/src/fastmcp/client/transports.py @@ -177,8 +177,8 @@ def __init__( self.url = url self.headers = headers or {} - self._set_auth(auth) self.httpx_client_factory = httpx_client_factory + self._set_auth(auth) if isinstance(sse_read_timeout, int | float): sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout)) @@ -186,7 +186,7 @@ def __init__( def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): if auth == "oauth": - auth = OAuth(self.url) + auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) elif isinstance(auth, str): auth = BearerAuth(auth) self.auth = auth @@ -247,8 +247,8 @@ def __init__( self.url = url self.headers = headers or {} - self._set_auth(auth) self.httpx_client_factory = httpx_client_factory + self._set_auth(auth) if isinstance(sse_read_timeout, int | float): sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout)) @@ -256,7 +256,7 @@ def __init__( def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): if auth == "oauth": - auth = OAuth(self.url) + auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) elif isinstance(auth, str): auth = BearerAuth(auth) self.auth = auth diff --git a/tests/client/transports/test_transports.py b/tests/client/transports/test_transports.py new file mode 100644 index 000000000..3f84fd2a7 --- /dev/null +++ b/tests/client/transports/test_transports.py @@ -0,0 +1,37 @@ +from ssl import VerifyMode + +import httpx + +from fastmcp.client.transports import SSETransport, StreamableHttpTransport + + +async def test_oauth_uses_same_client_as_transport_streamable_http(): + transport = StreamableHttpTransport( + "https://some.fake.url/", + httpx_client_factory=lambda *args, **kwargs: httpx.AsyncClient( + verify=False, *args, **kwargs + ), + auth="oauth", + ) + + async with transport.auth.httpx_client_factory() as httpx_client: # type: ignore[attr-defined] + assert ( + httpx_client._transport._pool._ssl_context.verify_mode + == VerifyMode.CERT_NONE + ) + + +async def test_oauth_uses_same_client_as_transport_sse(): + transport = SSETransport( + "https://some.fake.url/", + httpx_client_factory=lambda *args, **kwargs: httpx.AsyncClient( + verify=False, *args, **kwargs + ), + auth="oauth", + ) + + async with transport.auth.httpx_client_factory() as httpx_client: # type: ignore[attr-defined] + assert ( + httpx_client._transport._pool._ssl_context.verify_mode + == VerifyMode.CERT_NONE + )