Skip to content

Commit da54ea0

Browse files
jkawamotoKludex
andauthored
Allow generic parameters to be passed onto Context on FastMCP tools
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent d6e611f commit da54ea0

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

src/mcp/server/fastmcp/tools/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import inspect
44
from collections.abc import Callable
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, get_origin
66

77
from pydantic import BaseModel, Field
88

@@ -53,7 +53,9 @@ def from_function(
5353
if context_kwarg is None:
5454
sig = inspect.signature(fn)
5555
for param_name, param in sig.parameters.items():
56-
if param.annotation is Context:
56+
if get_origin(param.annotation) is not None:
57+
continue
58+
if issubclass(param.annotation, Context):
5759
context_kwarg = param_name
5860
break
5961

tests/server/fastmcp/test_tool_manager.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import pytest
55
from pydantic import BaseModel
66

7+
from mcp.server.fastmcp import Context, FastMCP
78
from mcp.server.fastmcp.exceptions import ToolError
89
from mcp.server.fastmcp.tools import ToolManager
10+
from mcp.server.session import ServerSessionT
11+
from mcp.shared.context import LifespanContextT
912

1013

1114
class TestAddTools:
@@ -194,8 +197,6 @@ def concat_strs(vals: list[str] | str) -> str:
194197

195198
@pytest.mark.anyio
196199
async def test_call_tool_with_complex_model(self):
197-
from mcp.server.fastmcp import Context
198-
199200
class MyShrimpTank(BaseModel):
200201
class Shrimp(BaseModel):
201202
name: str
@@ -223,8 +224,6 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]:
223224
class TestToolSchema:
224225
@pytest.mark.anyio
225226
async def test_context_arg_excluded_from_schema(self):
226-
from mcp.server.fastmcp import Context
227-
228227
def something(a: int, ctx: Context) -> int:
229228
return a
230229

@@ -241,7 +240,6 @@ class TestContextHandling:
241240
def test_context_parameter_detection(self):
242241
"""Test that context parameters are properly detected in
243242
Tool.from_function()."""
244-
from mcp.server.fastmcp import Context
245243

246244
def tool_with_context(x: int, ctx: Context) -> str:
247245
return str(x)
@@ -256,10 +254,17 @@ def tool_without_context(x: int) -> str:
256254
tool = manager.add_tool(tool_without_context)
257255
assert tool.context_kwarg is None
258256

257+
def tool_with_parametrized_context(
258+
x: int, ctx: Context[ServerSessionT, LifespanContextT]
259+
) -> str:
260+
return str(x)
261+
262+
tool = manager.add_tool(tool_with_parametrized_context)
263+
assert tool.context_kwarg == "ctx"
264+
259265
@pytest.mark.anyio
260266
async def test_context_injection(self):
261267
"""Test that context is properly injected during tool execution."""
262-
from mcp.server.fastmcp import Context, FastMCP
263268

264269
def tool_with_context(x: int, ctx: Context) -> str:
265270
assert isinstance(ctx, Context)
@@ -276,7 +281,6 @@ def tool_with_context(x: int, ctx: Context) -> str:
276281
@pytest.mark.anyio
277282
async def test_context_injection_async(self):
278283
"""Test that context is properly injected in async tools."""
279-
from mcp.server.fastmcp import Context, FastMCP
280284

281285
async def async_tool(x: int, ctx: Context) -> str:
282286
assert isinstance(ctx, Context)
@@ -293,7 +297,6 @@ async def async_tool(x: int, ctx: Context) -> str:
293297
@pytest.mark.anyio
294298
async def test_context_optional(self):
295299
"""Test that context is optional when calling tools."""
296-
from mcp.server.fastmcp import Context
297300

298301
def tool_with_context(x: int, ctx: Context | None = None) -> str:
299302
return str(x)
@@ -307,7 +310,6 @@ def tool_with_context(x: int, ctx: Context | None = None) -> str:
307310
@pytest.mark.anyio
308311
async def test_context_error_handling(self):
309312
"""Test error handling when context injection fails."""
310-
from mcp.server.fastmcp import Context, FastMCP
311313

312314
def tool_with_context(x: int, ctx: Context) -> str:
313315
raise ValueError("Test error")

0 commit comments

Comments
 (0)