From 8b5a32e4e9dc0abe07928e29f968a0324e46c58a Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Wed, 20 May 2026 14:14:22 +0800 Subject: [PATCH] fix: forward proxy server notifications --- .../fastmcp/server/providers/proxy.py | 141 +++++++++++++++++- .../providers/proxy/test_proxy_server.py | 100 +++++++++++++ 2 files changed, 237 insertions(+), 4 deletions(-) diff --git a/fastmcp_slim/fastmcp/server/providers/proxy.py b/fastmcp_slim/fastmcp/server/providers/proxy.py index 04e5fb810..1df5932a1 100644 --- a/fastmcp_slim/fastmcp/server/providers/proxy.py +++ b/fastmcp_slim/fastmcp/server/providers/proxy.py @@ -16,10 +16,11 @@ import mcp.types from mcp import ServerSession -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, MessageHandlerFnT from mcp.server.lowlevel.server import request_ctx from mcp.shared.context import LifespanContextT, RequestContext from mcp.shared.exceptions import McpError +from mcp.shared.session import RequestResponder from mcp.types import ( METHOD_NOT_FOUND, BlobResourceContents, @@ -59,6 +60,11 @@ # Type alias for client factory functions ClientFactoryT = Callable[[], Client] | Callable[[], Awaitable[Client]] +ProxyMessageT = ( + RequestResponder[mcp.types.ServerRequest, mcp.types.ClientResult] + | mcp.types.ServerNotification + | Exception +) # ----------------------------------------------------------------------------- @@ -132,8 +138,9 @@ async def run( # its receive-loop task has stale ContextVars from the first # request. Stash the current RequestContext in the shared # ref so handlers can restore it before forwarding. - if isinstance(client, StatefulProxyClient): - client._proxy_rc_ref[0] = ( + proxy_rc_ref = getattr(client, "_proxy_rc_ref", None) + if proxy_rc_ref is not None: + proxy_rc_ref[0] = ( ctx.request_context, ctx._fastmcp, # weakref to FastMCP, not the Context ) @@ -793,12 +800,17 @@ def fresh_client_factory() -> Client: ) def reuse_client_factory() -> Client: + if not isinstance(client, ProxyClient): + _install_plain_client_proxy_message_handler(client) return client return reuse_client_factory def fresh_client_factory() -> Client: - return client.new() + fresh_client = client.new() + if not isinstance(fresh_client, ProxyClient): + _install_plain_client_proxy_message_handler(fresh_client) + return fresh_client return fresh_client_factory else: @@ -980,14 +992,124 @@ def _make_restoring_handler(handler: Callable, rc_ref: list[Any]) -> Callable: ``inspect.isfunction()`` checks in handler registration paths (e.g., ``create_roots_callback``). """ + if getattr(handler, "_fastmcp_proxy_restores_context", False): + return handler async def wrapper(*args: Any, **kwargs: Any) -> Any: _restore_request_context(rc_ref) return await handler(*args, **kwargs) + wrapper_with_flags = cast(Any, wrapper) + wrapper_with_flags._fastmcp_proxy_restores_context = True + if getattr(handler, "_fastmcp_proxy_message_handler", False): + wrapper_with_flags._fastmcp_proxy_message_handler = True + wrapper_with_flags._fastmcp_proxy_forwards_logging = getattr( + handler, + "_fastmcp_proxy_forwards_logging", + False, + ) + return wrapper +async def default_proxy_message_handler( + message: object, + *, + forward_logging_and_progress: bool = False, +) -> None: + if not isinstance(message, mcp.types.ServerNotification): + return + + forwarded_types = ( + mcp.types.ToolListChangedNotification, + mcp.types.ResourceListChangedNotification, + mcp.types.PromptListChangedNotification, + mcp.types.ResourceUpdatedNotification, + ) + if forward_logging_and_progress: + forwarded_types = ( + *forwarded_types, + mcp.types.LoggingMessageNotification, + mcp.types.ProgressNotification, + ) + + if not isinstance(message.root, forwarded_types): + return + + try: + ctx = get_context() + except RuntimeError: + logger.debug( + "Dropping upstream server notification outside a proxy request: %s", + message.root.method, + ) + return + + root_data = message.root.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ) + root_data.pop("jsonrpc", None) + forwarded_root = type(message.root).model_validate(root_data) + + await ctx.session.send_notification( + mcp.types.ServerNotification(forwarded_root), + related_request_id=ctx.request_id, + ) + + +def _compose_proxy_message_handler( + user_handler: MessageHandlerFnT | None, + *, + forward_logging_and_progress: bool, +) -> MessageHandlerFnT: + if ( + user_handler is not None + and getattr(user_handler, "_fastmcp_proxy_message_handler", False) + and ( + not forward_logging_and_progress + or getattr(user_handler, "_fastmcp_proxy_forwards_logging", False) + ) + ): + return user_handler + + async def message_handler(message: ProxyMessageT) -> None: + if user_handler is not None: + await user_handler(message) + await default_proxy_message_handler( + message, + forward_logging_and_progress=forward_logging_and_progress, + ) + + message_handler_with_flags = cast(Any, message_handler) + message_handler_with_flags._fastmcp_proxy_message_handler = True + message_handler_with_flags._fastmcp_proxy_forwards_logging = ( + forward_logging_and_progress + ) + return cast(MessageHandlerFnT, message_handler) + + +def _install_plain_client_proxy_message_handler(client: Client) -> None: + proxy_rc_ref = getattr(client, "_proxy_rc_ref", None) + if proxy_rc_ref is None: + proxy_rc_ref = [None] + cast(Any, client)._proxy_rc_ref = proxy_rc_ref + + current_handler = client._session_kwargs.get("message_handler") + message_handler = _compose_proxy_message_handler( + current_handler, + forward_logging_and_progress=True, + ) + message_handler = cast( + MessageHandlerFnT, + _make_restoring_handler(message_handler, proxy_rc_ref), + ) + client._session_kwargs["message_handler"] = message_handler + if client.is_connected(): + cast(Any, client.session)._message_handler = message_handler + + class ProxyClient(Client[ClientTransportT]): """A proxy client that forwards advanced interactions between a remote MCP server and the proxy's connected clients. @@ -1018,6 +1140,10 @@ def __init__( kwargs["log_handler"] = default_proxy_log_handler if "progress_handler" not in kwargs: kwargs["progress_handler"] = default_proxy_progress_handler + kwargs["message_handler"] = _compose_proxy_message_handler( + kwargs.get("message_handler"), + forward_logging_and_progress=False, + ) super().__init__(**kwargs | {"transport": transport}) # Enable forwarding of inbound HTTP headers (e.g. authorization) to @@ -1071,6 +1197,13 @@ def __init__(self, *args: Any, **kwargs: Any): ): if key not in kwargs: kwargs[key] = _make_restoring_handler(default_fn, self._proxy_rc_ref) + kwargs["message_handler"] = _make_restoring_handler( + _compose_proxy_message_handler( + kwargs.get("message_handler"), + forward_logging_and_progress=False, + ), + self._proxy_rc_ref, + ) super().__init__(*args, **kwargs) self._caches: dict[ServerSession, Client[ClientTransportT]] = {} diff --git a/tests/server/providers/proxy/test_proxy_server.py b/tests/server/providers/proxy/test_proxy_server.py index 3e9650ec4..fa4c41f0f 100644 --- a/tests/server/providers/proxy/test_proxy_server.py +++ b/tests/server/providers/proxy/test_proxy_server.py @@ -18,6 +18,7 @@ from fastmcp.exceptions import ToolError from fastmcp.resources import ResourceContent, ResourceResult from fastmcp.server import create_proxy +from fastmcp.server.context import Context from fastmcp.server.providers.proxy import ( FastMCPProxy, ProxyClient, @@ -171,6 +172,105 @@ async def test_create_proxy_with_client(fastmcp_server): assert server.name.startswith("FastMCPProxy-") +async def test_create_proxy_forwards_upstream_server_notifications_from_client(): + backend = FastMCP("NotificationBackend") + + @backend.tool + async def emit_notifications(ctx: Context) -> dict[str, bool]: + await ctx.send_notification(mcp_types.ToolListChangedNotification()) + await ctx.send_notification(mcp_types.ResourceListChangedNotification()) + await ctx.send_notification(mcp_types.PromptListChangedNotification()) + await ctx.send_notification( + mcp_types.ResourceUpdatedNotification( + params=mcp_types.ResourceUpdatedNotificationParams( + uri=AnyUrl("resource://backend/example"), + ), + ), + ) + await ctx.info("backend emitted notifications") + return {"emitted": True} + + upstream_seen: list[str] = [] + downstream_seen: list[str] = [] + + async def upstream_handler(message: object) -> None: + if isinstance(message, mcp_types.ServerNotification): + upstream_seen.append(message.root.method) + + async def downstream_handler(message: object) -> None: + if isinstance(message, mcp_types.ServerNotification): + downstream_seen.append(message.root.method) + + upstream_client = Client( + FastMCPTransport(backend), + message_handler=upstream_handler, + ) + proxy = create_proxy(upstream_client) + + async with Client(proxy, message_handler=downstream_handler) as client: + await client.call_tool("emit_notifications", {}) + + expected = [ + "notifications/tools/list_changed", + "notifications/resources/list_changed", + "notifications/prompts/list_changed", + "notifications/resources/updated", + "notifications/message", + ] + assert upstream_seen == expected + assert downstream_seen == expected + + +async def test_create_proxy_forwards_notifications_from_connected_plain_client(): + backend = FastMCP("NotificationBackend") + + @backend.tool + async def emit_notifications(ctx: Context) -> dict[str, bool]: + await ctx.send_notification(mcp_types.ToolListChangedNotification()) + await ctx.send_notification(mcp_types.ResourceListChangedNotification()) + await ctx.send_notification(mcp_types.PromptListChangedNotification()) + await ctx.send_notification( + mcp_types.ResourceUpdatedNotification( + params=mcp_types.ResourceUpdatedNotificationParams( + uri=AnyUrl("resource://backend/example"), + ), + ), + ) + await ctx.info("backend emitted notifications") + return {"emitted": True} + + upstream_seen: list[str] = [] + downstream_seen: list[str] = [] + + async def upstream_handler(message: object) -> None: + if isinstance(message, mcp_types.ServerNotification): + upstream_seen.append(message.root.method) + + async def downstream_handler(message: object) -> None: + if isinstance(message, mcp_types.ServerNotification): + downstream_seen.append(message.root.method) + + upstream_client = Client( + FastMCPTransport(backend), + message_handler=upstream_handler, + ) + + async with upstream_client: + proxy = create_proxy(upstream_client) + async with Client(proxy, message_handler=downstream_handler) as client: + await client.call_tool("emit_notifications", {}) + + expected = [ + "notifications/tools/list_changed", + "notifications/resources/list_changed", + "notifications/prompts/list_changed", + "notifications/resources/updated", + "notifications/message", + ] + assert upstream_seen == expected + assert downstream_seen == expected + + async def test_create_proxy_with_server(fastmcp_server): """create_proxy should accept a FastMCP instance.""" proxy = create_proxy(fastmcp_server)