Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 137 additions & 4 deletions fastmcp_slim/fastmcp/server/providers/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import mcp.types
from mcp import ServerSession
from mcp.client.session import ClientSession
from mcp.client.session import ClientSession, MessageHandlerFnT
from mcp.server.lowlevel.server import request_ctx
from mcp.shared.context import LifespanContextT, RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from mcp.types import (
METHOD_NOT_FOUND,
BlobResourceContents,
Expand Down Expand Up @@ -59,6 +60,11 @@

# Type alias for client factory functions
ClientFactoryT = Callable[[], Client] | Callable[[], Awaitable[Client]]
ProxyMessageT = (
RequestResponder[mcp.types.ServerRequest, mcp.types.ClientResult]
| mcp.types.ServerNotification
| Exception
)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -132,8 +138,9 @@ async def run(
# its receive-loop task has stale ContextVars from the first
# request. Stash the current RequestContext in the shared
# ref so handlers can restore it before forwarding.
if isinstance(client, StatefulProxyClient):
client._proxy_rc_ref[0] = (
proxy_rc_ref = getattr(client, "_proxy_rc_ref", None)
if proxy_rc_ref is not None:
proxy_rc_ref[0] = (
ctx.request_context,
ctx._fastmcp, # weakref to FastMCP, not the Context
)
Expand Down Expand Up @@ -793,12 +800,17 @@ def fresh_client_factory() -> Client:
)

def reuse_client_factory() -> Client:
if not isinstance(client, ProxyClient):
_install_plain_client_proxy_message_handler(client)
return client
Comment on lines +803 to 805
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Apply proxy message handler before reusing connected client

When create_proxy() receives a connected plain Client, this branch calls _install_plain_client_proxy_message_handler(client) and then reuses the same connection, but _install_* only mutates client._session_kwargs. For an already-open session, connect_session(**_session_kwargs) has already run, so the active receive loop keeps the old handler and upstream notifications are still not forwarded in this mode. This means connected-client proxying does not get the behavior this change intends unless the client disconnects/reconnects first.

Useful? React with 👍 / 👎.


return reuse_client_factory

def fresh_client_factory() -> Client:
return client.new()
fresh_client = client.new()
if not isinstance(fresh_client, ProxyClient):
_install_plain_client_proxy_message_handler(fresh_client)
return fresh_client

return fresh_client_factory
else:
Expand Down Expand Up @@ -980,14 +992,124 @@ def _make_restoring_handler(handler: Callable, rc_ref: list[Any]) -> Callable:
``inspect.isfunction()`` checks in handler registration paths
(e.g., ``create_roots_callback``).
"""
if getattr(handler, "_fastmcp_proxy_restores_context", False):
return handler

async def wrapper(*args: Any, **kwargs: Any) -> Any:
_restore_request_context(rc_ref)
return await handler(*args, **kwargs)

wrapper_with_flags = cast(Any, wrapper)
wrapper_with_flags._fastmcp_proxy_restores_context = True
if getattr(handler, "_fastmcp_proxy_message_handler", False):
wrapper_with_flags._fastmcp_proxy_message_handler = True
wrapper_with_flags._fastmcp_proxy_forwards_logging = getattr(
handler,
"_fastmcp_proxy_forwards_logging",
False,
)

return wrapper


async def default_proxy_message_handler(
message: object,
*,
forward_logging_and_progress: bool = False,
) -> None:
if not isinstance(message, mcp.types.ServerNotification):
return

forwarded_types = (
mcp.types.ToolListChangedNotification,
mcp.types.ResourceListChangedNotification,
mcp.types.PromptListChangedNotification,
mcp.types.ResourceUpdatedNotification,
)
if forward_logging_and_progress:
forwarded_types = (
*forwarded_types,
mcp.types.LoggingMessageNotification,
mcp.types.ProgressNotification,
)

if not isinstance(message.root, forwarded_types):
return

try:
ctx = get_context()
except RuntimeError:
logger.debug(
"Dropping upstream server notification outside a proxy request: %s",
message.root.method,
)
return

root_data = message.root.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
)
root_data.pop("jsonrpc", None)
forwarded_root = type(message.root).model_validate(root_data)

await ctx.session.send_notification(
mcp.types.ServerNotification(forwarded_root),
related_request_id=ctx.request_id,
)


def _compose_proxy_message_handler(
user_handler: MessageHandlerFnT | None,
*,
forward_logging_and_progress: bool,
) -> MessageHandlerFnT:
if (
user_handler is not None
and getattr(user_handler, "_fastmcp_proxy_message_handler", False)
and (
not forward_logging_and_progress
or getattr(user_handler, "_fastmcp_proxy_forwards_logging", False)
)
):
return user_handler

async def message_handler(message: ProxyMessageT) -> None:
if user_handler is not None:
await user_handler(message)
await default_proxy_message_handler(
message,
forward_logging_and_progress=forward_logging_and_progress,
)

message_handler_with_flags = cast(Any, message_handler)
message_handler_with_flags._fastmcp_proxy_message_handler = True
message_handler_with_flags._fastmcp_proxy_forwards_logging = (
forward_logging_and_progress
)
return cast(MessageHandlerFnT, message_handler)


def _install_plain_client_proxy_message_handler(client: Client) -> None:
proxy_rc_ref = getattr(client, "_proxy_rc_ref", None)
if proxy_rc_ref is None:
proxy_rc_ref = [None]
cast(Any, client)._proxy_rc_ref = proxy_rc_ref

current_handler = client._session_kwargs.get("message_handler")
message_handler = _compose_proxy_message_handler(
current_handler,
forward_logging_and_progress=True,
)
message_handler = cast(
MessageHandlerFnT,
_make_restoring_handler(message_handler, proxy_rc_ref),
)
client._session_kwargs["message_handler"] = message_handler
if client.is_connected():
cast(Any, client.session)._message_handler = message_handler


class ProxyClient(Client[ClientTransportT]):
"""A proxy client that forwards advanced interactions between a remote MCP server and the proxy's connected clients.

Expand Down Expand Up @@ -1018,6 +1140,10 @@ def __init__(
kwargs["log_handler"] = default_proxy_log_handler
if "progress_handler" not in kwargs:
kwargs["progress_handler"] = default_proxy_progress_handler
kwargs["message_handler"] = _compose_proxy_message_handler(
kwargs.get("message_handler"),
forward_logging_and_progress=False,
)
super().__init__(**kwargs | {"transport": transport})

# Enable forwarding of inbound HTTP headers (e.g. authorization) to
Expand Down Expand Up @@ -1071,6 +1197,13 @@ def __init__(self, *args: Any, **kwargs: Any):
):
if key not in kwargs:
kwargs[key] = _make_restoring_handler(default_fn, self._proxy_rc_ref)
kwargs["message_handler"] = _make_restoring_handler(
_compose_proxy_message_handler(
kwargs.get("message_handler"),
forward_logging_and_progress=False,
),
self._proxy_rc_ref,
)

super().__init__(*args, **kwargs)
self._caches: dict[ServerSession, Client[ClientTransportT]] = {}
Expand Down
100 changes: 100 additions & 0 deletions tests/server/providers/proxy/test_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from fastmcp.exceptions import ToolError
from fastmcp.resources import ResourceContent, ResourceResult
from fastmcp.server import create_proxy
from fastmcp.server.context import Context
from fastmcp.server.providers.proxy import (
FastMCPProxy,
ProxyClient,
Expand Down Expand Up @@ -171,6 +172,105 @@ async def test_create_proxy_with_client(fastmcp_server):
assert server.name.startswith("FastMCPProxy-")


async def test_create_proxy_forwards_upstream_server_notifications_from_client():
backend = FastMCP("NotificationBackend")

@backend.tool
async def emit_notifications(ctx: Context) -> dict[str, bool]:
await ctx.send_notification(mcp_types.ToolListChangedNotification())
await ctx.send_notification(mcp_types.ResourceListChangedNotification())
await ctx.send_notification(mcp_types.PromptListChangedNotification())
await ctx.send_notification(
mcp_types.ResourceUpdatedNotification(
params=mcp_types.ResourceUpdatedNotificationParams(
uri=AnyUrl("resource://backend/example"),
),
),
)
await ctx.info("backend emitted notifications")
return {"emitted": True}

upstream_seen: list[str] = []
downstream_seen: list[str] = []

async def upstream_handler(message: object) -> None:
if isinstance(message, mcp_types.ServerNotification):
upstream_seen.append(message.root.method)

async def downstream_handler(message: object) -> None:
if isinstance(message, mcp_types.ServerNotification):
downstream_seen.append(message.root.method)

upstream_client = Client(
FastMCPTransport(backend),
message_handler=upstream_handler,
)
proxy = create_proxy(upstream_client)

async with Client(proxy, message_handler=downstream_handler) as client:
await client.call_tool("emit_notifications", {})

expected = [
"notifications/tools/list_changed",
"notifications/resources/list_changed",
"notifications/prompts/list_changed",
"notifications/resources/updated",
"notifications/message",
]
assert upstream_seen == expected
assert downstream_seen == expected


async def test_create_proxy_forwards_notifications_from_connected_plain_client():
backend = FastMCP("NotificationBackend")

@backend.tool
async def emit_notifications(ctx: Context) -> dict[str, bool]:
await ctx.send_notification(mcp_types.ToolListChangedNotification())
await ctx.send_notification(mcp_types.ResourceListChangedNotification())
await ctx.send_notification(mcp_types.PromptListChangedNotification())
await ctx.send_notification(
mcp_types.ResourceUpdatedNotification(
params=mcp_types.ResourceUpdatedNotificationParams(
uri=AnyUrl("resource://backend/example"),
),
),
)
await ctx.info("backend emitted notifications")
return {"emitted": True}

upstream_seen: list[str] = []
downstream_seen: list[str] = []

async def upstream_handler(message: object) -> None:
if isinstance(message, mcp_types.ServerNotification):
upstream_seen.append(message.root.method)

async def downstream_handler(message: object) -> None:
if isinstance(message, mcp_types.ServerNotification):
downstream_seen.append(message.root.method)

upstream_client = Client(
FastMCPTransport(backend),
message_handler=upstream_handler,
)

async with upstream_client:
proxy = create_proxy(upstream_client)
async with Client(proxy, message_handler=downstream_handler) as client:
await client.call_tool("emit_notifications", {})

expected = [
"notifications/tools/list_changed",
"notifications/resources/list_changed",
"notifications/prompts/list_changed",
"notifications/resources/updated",
"notifications/message",
]
assert upstream_seen == expected
assert downstream_seen == expected


async def test_create_proxy_with_server(fastmcp_server):
"""create_proxy should accept a FastMCP instance."""
proxy = create_proxy(fastmcp_server)
Expand Down