Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion fastmcp_remote/fastmcp_remote/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ async def run(config: RemoteConfig) -> None:
client,
name="fastmcp-remote",
provider_error_strategy="raise",
validate_on_initialize=True,
)
if config.ignore_tools:
server.add_transform(IgnoreTools(config.ignore_tools))
Expand Down
57 changes: 48 additions & 9 deletions fastmcp_slim/fastmcp/server/providers/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
from pydantic.networks import AnyUrl

from fastmcp.client.client import Client, FastMCP1Server
from fastmcp.client.elicitation import ElicitResult
from fastmcp.client.logging import LogMessage
from fastmcp.client.roots import RootsList
from fastmcp.client.elicitation import ElicitResult, create_elicitation_callback
from fastmcp.client.logging import LogMessage, create_log_callback
from fastmcp.client.roots import RootsList, create_roots_callback
from fastmcp.client.sampling import create_sampling_callback
from fastmcp.client.telemetry import client_span
from fastmcp.client.transports import ClientTransportT
from fastmcp.exceptions import ResourceError
Expand Down Expand Up @@ -88,8 +89,15 @@ async def on_initialize(
) -> mcp.types.InitializeResult | None:
client = await self.proxy._get_client()
try:
if isinstance(client, StatefulProxyClient):
ctx = context.fastmcp_context
if ctx is not None:
client._proxy_rc_ref[0] = (
ctx.request_context,
ctx._fastmcp,
)
async with client:
pass
await client.initialize()
except McpError:
raise
except (
Expand Down Expand Up @@ -881,7 +889,6 @@ def __init__(
*,
client_factory: ClientFactoryT,
provider_error_strategy: ProviderErrorStrategy = "warn",
validate_on_initialize: bool = False,
**kwargs,
):
"""Initialize the proxy server.
Expand All @@ -896,17 +903,14 @@ def __init__(
provider_error_strategy: How provider errors should affect aggregate
operations. Defaults to ``"warn"`` for compatibility; use
``"raise"`` when the proxy should surface upstream failures.
validate_on_initialize: If true, connect to the upstream server during
the incoming MCP initialize request.
**kwargs: Additional settings for the FastMCP server.
"""
super().__init__(**kwargs)
self.provider_error_strategy = provider_error_strategy
self.client_factory = client_factory
provider: Provider = ProxyProvider(client_factory)
self.add_provider(provider)
if validate_on_initialize:
self.middleware.append(ProxyInitializeMiddleware(self))
self.middleware.append(ProxyInitializeMiddleware(self))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore opt-out for initialize forwarding in FastMCPProxy

This change makes FastMCPProxy always install ProxyInitializeMiddleware, so every downstream initialize now forces an upstream connect/initialize handshake with no way to disable it. That regresses the prior lazy behavior for direct FastMCPProxy(...) usage (and any caller that previously used validate_on_initialize=False), causing startup to fail immediately when upstream is temporarily unavailable instead of deferring failure to the first proxied operation; this is especially disruptive for stateful/custom client_factory mounting flows that rely on lazy or degraded startup.

Useful? React with 👍 / 👎.

self._setup_proxy_ping_handler()

async def _get_client(self) -> Client:
Expand Down Expand Up @@ -1140,11 +1144,13 @@ class StatefulProxyClient(ProxyClient[ClientTransportT]):
# would resolve stale values in the receive loop. The restore helper
# constructs a fresh Context from the weakref after setting request_ctx.
_proxy_rc_ref: list[Any]
_proxy_restoring_handler_keys: set[str]

def __init__(self, *args: Any, **kwargs: Any):
# Install context-restoring handler wrappers BEFORE super().__init__
# registers them with the Client's session kwargs.
self._proxy_rc_ref = [None]
self._proxy_restoring_handler_keys = set()
for key, default_fn in (
("roots", default_proxy_roots_handler),
("sampling_handler", default_proxy_sampling_handler),
Expand All @@ -1154,10 +1160,43 @@ def __init__(self, *args: Any, **kwargs: Any):
):
if key not in kwargs:
kwargs[key] = _make_restoring_handler(default_fn, self._proxy_rc_ref)
self._proxy_restoring_handler_keys.add(key)

super().__init__(*args, **kwargs)
self._caches: dict[ServerSession, Client[ClientTransportT]] = {}

def _bind_restoring_handlers(self) -> None:
if "roots" in self._proxy_restoring_handler_keys:
self._session_kwargs["list_roots_callback"] = create_roots_callback(
_make_restoring_handler(default_proxy_roots_handler, self._proxy_rc_ref)
)
Comment on lines +1169 to +1172
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve overridden callbacks when cloning stateful proxy clients

StatefulProxyClient.new() now calls _bind_restoring_handlers(), which unconditionally rewrites callbacks for any handler key that was defaulted at construction time. That means a caller who later customizes behavior (for example via set_roots, set_sampling_callback, or set_elicitation_callback) will have those overrides silently discarded on new_stateful() clones, and the proxy falls back to the built-in forwarding handlers instead. This is a regression from prior behavior where Client.new() preserved post-init callback updates, and it can break custom routing/auth/context logic in stateful proxy deployments.

Useful? React with 👍 / 👎.

if "sampling_handler" in self._proxy_restoring_handler_keys:
self._session_kwargs["sampling_callback"] = create_sampling_callback(
_make_restoring_handler(
default_proxy_sampling_handler, self._proxy_rc_ref
)
)
if "elicitation_handler" in self._proxy_restoring_handler_keys:
self._session_kwargs["elicitation_callback"] = create_elicitation_callback(
_make_restoring_handler(
default_proxy_elicitation_handler, self._proxy_rc_ref
)
)
if "log_handler" in self._proxy_restoring_handler_keys:
self._session_kwargs["logging_callback"] = create_log_callback(
_make_restoring_handler(default_proxy_log_handler, self._proxy_rc_ref)
)
if "progress_handler" in self._proxy_restoring_handler_keys:
self._progress_handler = _make_restoring_handler(
default_proxy_progress_handler, self._proxy_rc_ref
)

def new(self) -> StatefulProxyClient[ClientTransportT]:
new_client = cast(StatefulProxyClient[ClientTransportT], super().new())
new_client._proxy_rc_ref = [None]
new_client._bind_restoring_handlers()
return new_client

async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore[override] # ty:ignore[invalid-method-override]
"""The stateful proxy client will be forced disconnected when the session is exited.

Expand Down
10 changes: 5 additions & 5 deletions tests/server/providers/proxy/test_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,16 @@ async def test_proxy_ping_surfaces_wrong_remote_path():
async with run_server_async(remote, transport="http") as url:
proxy = create_proxy(StreamableHttpTransport(url.removesuffix("/mcp")))

async with Client(proxy) as client:
with pytest.raises(McpError, match="Session terminated"):
await client.ping()
with pytest.raises(McpError, match="Session terminated"):
async with Client(proxy):
pass


async def test_proxy_initialize_surfaces_remote_connection_error():
async def test_proxy_initialize_forwards_remote_connection_error():
port = find_available_port()
proxy = create_proxy(
StreamableHttpTransport(f"http://127.0.0.1:{port}/mcp"),
validate_on_initialize=True,
provider_error_strategy="raise",
)

with pytest.raises(McpError, match="Client failed to connect"):
Expand Down
Loading