diff --git a/libs/arcade-mcp-server/arcade_mcp_server/context.py b/libs/arcade-mcp-server/arcade_mcp_server/context.py index 0bb6fd87d..2a51d1551 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/context.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/context.py @@ -533,7 +533,7 @@ async def list(self) -> list[Any]: return cast(list[Any], prompts) async def get(self, name: str, arguments: dict[str, str] | None = None) -> Any: - return await self._ctx.server._prompt_manager.get_prompt(name, arguments) + return await self._ctx.server._prompt_manager.get_prompt(name, arguments, self._ctx) class Sampling(_ContextComponent): diff --git a/libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py b/libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py index b0c6d0241..1000c1b24 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py @@ -6,32 +6,170 @@ from __future__ import annotations +import inspect import logging -from typing import Callable +import types +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Union, + cast, + get_args, + get_origin, + get_type_hints, +) from arcade_mcp_server.exceptions import NotFoundError, PromptError from arcade_mcp_server.managers.base import ComponentManager from arcade_mcp_server.types import GetPromptResult, Prompt, PromptMessage +if TYPE_CHECKING: + from arcade_mcp_server.context import Context + logger = logging.getLogger("arcade.mcp.managers.prompt") +# Type aliases for prompt handler signatures +PromptHandlerLegacy = Callable[[dict[str, str]], list[PromptMessage]] +PromptHandlerWithContext = Callable[["Context", dict[str, str]], list[PromptMessage]] +PromptHandlerType = Union[PromptHandlerLegacy, PromptHandlerWithContext] + class PromptHandler: - """Handler for generating prompt messages.""" + """Handler for generating prompt messages. + + Supports two handler signatures: + 1. Legacy: handler(args: dict[str, str]) -> list[PromptMessage] + 2. New (with context): handler(context: Context, args: dict[str, str]) -> list[PromptMessage] + + The handler signature is detected automatically using introspection. + """ def __init__( self, prompt: Prompt, - handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None, + handler: PromptHandlerType | None = None, ) -> None: self.prompt = prompt - self.handler = handler or self._default_handler + self.handler: Any = handler or self._default_handler + self._accepts_context = self._check_context_signature(self.handler) def __eq__(self, other: object) -> bool: # pragma: no cover - simple comparison if not isinstance(other, PromptHandler): return False return self.prompt == other.prompt and self.handler == other.handler + def _check_context_signature(self, handler: Any) -> bool: + """Check if handler accepts context parameter. + + Returns True if the first parameter is type-annotated as Context or named "context" + without a conflicting type annotation. Returns False for legacy signatures. + + Examples: + - handler(context: Context, args) -> True (typed context) + - handler(context, args) -> True (untyped context) + - handler(context: dict[str, str]) -> False (legacy with misleading name) + - handler(args) -> False (legacy) + """ + from arcade_mcp_server.context import Context as ArcadeContext + + def _is_context_annotation(ann: Any) -> bool: + """Return True only for the actual Context type (or Optional/Union/Annotated wrappers). + + Important: do NOT do substring matching on string annotations. That can produce false + positives for unrelated types like ContextManager/ExecutionContext/etc. + """ + if ann is ArcadeContext: + return True + + # Real class annotations (including subclasses). + if isinstance(ann, type) and issubclass(ann, ArcadeContext): + return True + + # Unwrap common typing wrappers. + origin = get_origin(ann) + if origin is Annotated: + args = get_args(ann) + return _is_context_annotation(args[0]) if args else False + + if origin is Union or origin is types.UnionType: + return any(_is_context_annotation(a) for a in get_args(ann)) + + # Conservative fallback for unresolved forward refs (strings). + if isinstance(ann, str): + s = ann.strip().strip("'\"") + + # Handle PEP604 unions in string form: "Context | None" + if "|" in s: + return any(_is_context_annotation(part.strip()) for part in s.split("|")) + + # Handle Optional/Union/Annotated in string form. We only unwrap these; + # we intentionally do NOT look inside arbitrary generics like ContextManager[...]. + for wrapper in ("Optional[", "Union[", "Annotated["): + if s.startswith(wrapper) and s.endswith("]"): + inner = s[len(wrapper) : -1].strip() + if wrapper == "Union[": + return any(_is_context_annotation(p.strip()) for p in inner.split(",")) + if wrapper == "Annotated[": + first = inner.split(",", 1)[0].strip() + return _is_context_annotation(first) + # Optional[ + return _is_context_annotation(inner) + + # Accept only the actual arcade_mcp_server Context name(s). + return s in {"Context", "arcade_mcp_server.context.Context", "arcade_mcp_server.Context"} + + return False + + try: + sig = inspect.signature(handler) + params = list(sig.parameters.values()) + # Filter out 'self' parameter for bound methods + params = [p for p in params if p.name != "self"] + + if not params: + return False + + first_param = params[0] + + # Check if first parameter is type-annotated + if first_param.annotation != inspect.Parameter.empty: + ann: Any = first_param.annotation + + # Prefer resolving type hints (handles forward refs, Optional/Union, etc.) + try: + import arcade_mcp_server + + globalns = getattr(handler, "__globals__", {}) or {} + if "arcade_mcp_server" not in globalns: + # Avoid mutating handler globals in-place. + globalns = dict(globalns) + globalns["arcade_mcp_server"] = arcade_mcp_server + + hints = get_type_hints( + handler, + globalns=globalns, + localns={"Context": ArcadeContext}, + include_extras=True, + ) + ann = hints.get(first_param.name, ann) + except Exception: + # Fall back to raw signature annotation. + logger.debug( + "Failed to resolve prompt handler type hints; falling back to raw signature annotations", + exc_info=True, + ) + ann = first_param.annotation + + return _is_context_annotation(ann) + else: + # No type annotation - check if named "context" (untyped context handler) + return first_param.name == "context" + except (ValueError, TypeError): + # If we can't inspect, assume legacy signature + return False + def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]: return [ PromptMessage( @@ -43,7 +181,11 @@ def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]: ) ] - async def get_messages(self, arguments: dict[str, str] | None = None) -> list[PromptMessage]: + async def get_messages( + self, + arguments: dict[str, str] | None = None, + context: Context | None = None, + ) -> list[PromptMessage]: args = arguments or {} # Validate required arguments @@ -52,11 +194,20 @@ async def get_messages(self, arguments: dict[str, str] | None = None) -> list[Pr if arg.required and arg.name not in args: raise PromptError(f"Required argument '{arg.name}' not provided") - result = self.handler(args) + # Call handler with appropriate signature + result: Any + if self._accepts_context: + if context is None: + raise PromptError("Handler requires context but none was provided") + result = self.handler(context, args) + else: + result = self.handler(args) + if hasattr(result, "__await__"): result = await result - return result + # Cast result to expected type after dynamic invocation + return cast(list[PromptMessage], result) class PromptManager(ComponentManager[str, PromptHandler]): @@ -72,7 +223,10 @@ async def list_prompts(self) -> list[Prompt]: return [h.prompt for h in handlers] async def get_prompt( - self, name: str, arguments: dict[str, str] | None = None + self, + name: str, + arguments: dict[str, str] | None = None, + context: Context | None = None, ) -> GetPromptResult: try: handler = await self.registry.get(name) @@ -80,7 +234,7 @@ async def get_prompt( raise NotFoundError(f"Prompt '{name}' not found") try: - messages = await handler.get_messages(arguments) + messages = await handler.get_messages(arguments, context) return GetPromptResult( description=handler.prompt.description, messages=messages, @@ -93,7 +247,7 @@ async def get_prompt( async def add_prompt( self, prompt: Prompt, - handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None, + handler: PromptHandlerType | None = None, ) -> None: prompt_handler = PromptHandler(prompt, handler) await self.registry.upsert(prompt.name, prompt_handler) @@ -109,7 +263,7 @@ async def update_prompt( self, name: str, prompt: Prompt, - handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None, + handler: PromptHandlerType | None = None, ) -> Prompt: # Ensure exists try: diff --git a/libs/arcade-mcp-server/arcade_mcp_server/server.py b/libs/arcade-mcp-server/arcade_mcp_server/server.py index 6ce1ccddf..7f5cb504a 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/server.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/server.py @@ -1337,9 +1337,11 @@ async def _handle_get_prompt( ) -> JSONRPCResponse[GetPromptResult] | JSONRPCError: """Handle get prompt request.""" try: + context = get_current_model_context() result = await self._prompt_manager.get_prompt( message.params.name, message.params.arguments if hasattr(message.params, "arguments") else None, + context, ) return JSONRPCResponse(id=message.id, result=result) except NotFoundError: diff --git a/libs/arcade-mcp-server/pyproject.toml b/libs/arcade-mcp-server/pyproject.toml index 262b21c9b..600e217a8 100644 --- a/libs/arcade-mcp-server/pyproject.toml +++ b/libs/arcade-mcp-server/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "arcade-mcp-server" -version = "1.13.0" +version = "1.14.1" description = "Model Context Protocol (MCP) server framework for Arcade.dev" readme = "README.md" authors = [{ name = "Arcade.dev" }] diff --git a/libs/tests/arcade_mcp_server/test_prompt.py b/libs/tests/arcade_mcp_server/test_prompt.py index cf2f69814..b762ec775 100644 --- a/libs/tests/arcade_mcp_server/test_prompt.py +++ b/libs/tests/arcade_mcp_server/test_prompt.py @@ -1,8 +1,10 @@ """Tests for Prompt Manager implementation.""" import asyncio +from unittest.mock import Mock import pytest +from arcade_mcp_server.context import Context from arcade_mcp_server.exceptions import NotFoundError, PromptError from arcade_mcp_server.managers.prompt import PromptManager from arcade_mcp_server.types import ( @@ -37,7 +39,7 @@ def sample_prompt(self): @pytest.fixture def prompt_function(self): - """Create a prompt function.""" + """Create a prompt function (legacy signature without context).""" async def greeting_prompt(args: dict[str, str]) -> list[PromptMessage]: name = args.get("name", "") @@ -53,6 +55,45 @@ async def greeting_prompt(args: dict[str, str]) -> list[PromptMessage]: return greeting_prompt + @pytest.fixture + def prompt_function_with_context(self): + """Create a prompt function with context parameter (new signature).""" + + async def greeting_prompt_with_context( + context: Context, args: dict[str, str] + ) -> list[PromptMessage]: + name = args.get("name", "") + formal_arg = args.get("formal", "false") + formal = str(formal_arg).lower() == "true" + + # Access context (e.g., for logging) + if hasattr(context, "log"): + await context.log.info(f"Generating greeting for {name}") + + if formal: + text = f"Good day, {name}. How may I assist you?" + else: + text = f"Hey {name}! What's up?" + + return [PromptMessage(role="assistant", content={"type": "text", "text": text})] + + return greeting_prompt_with_context + + @pytest.fixture + def mock_context(self): + """Create a mock context.""" + mock_server = Mock() + context = Context(mock_server) + # Mock the log interface with async methods + mock_log = Mock() + + async def async_info(*args, **kwargs): + pass + + mock_log.info = async_info + context._log = mock_log + return context + def test_manager_initialization(self): """Test prompt manager initialization.""" manager = PromptManager() @@ -239,3 +280,116 @@ async def error_prompt(args: dict[str, str]): with pytest.raises(PromptError): await manager.get_prompt("error_prompt", {}) + + @pytest.mark.asyncio + async def test_prompt_with_context_parameter( + self, prompt_manager, sample_prompt, prompt_function_with_context, mock_context + ): + """Test prompt with new context parameter signature.""" + await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context) + + result = await prompt_manager.get_prompt( + "greeting", {"name": "Alice", "formal": "true"}, mock_context + ) + + assert isinstance(result, GetPromptResult) + assert len(result.messages) == 1 + assert result.messages[0].role == "assistant" + assert "Good day, Alice" in result.messages[0].content["text"] + + @pytest.mark.asyncio + async def test_prompt_with_context_logging( + self, prompt_manager, sample_prompt, prompt_function_with_context, mock_context + ): + """Test that prompt with context can use logging.""" + await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context) + + await prompt_manager.get_prompt("greeting", {"name": "Bob"}, mock_context) + + # Verify logging was called (if mock was set up properly) + # This would require a more sophisticated mock setup + + @pytest.mark.asyncio + async def test_prompt_context_required_but_not_provided( + self, prompt_manager, sample_prompt, prompt_function_with_context + ): + """Test that error is raised when context-requiring prompt doesn't get context.""" + await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context) + + with pytest.raises(PromptError, match="Handler requires context"): + await prompt_manager.get_prompt("greeting", {"name": "Alice"}, None) + + @pytest.mark.asyncio + async def test_backward_compatibility_legacy_signature( + self, prompt_manager, sample_prompt, prompt_function + ): + """Test backward compatibility with legacy signature (no context).""" + await prompt_manager.add_prompt(sample_prompt, prompt_function) + + # Should work without context + result = await prompt_manager.get_prompt("greeting", {"name": "Charlie"}, None) + + assert isinstance(result, GetPromptResult) + assert len(result.messages) == 1 + assert "Hey Charlie!" in result.messages[0].content["text"] + + @pytest.mark.asyncio + async def test_mixed_signatures( + self, prompt_manager, prompt_function, prompt_function_with_context, mock_context + ): + """Test that both signatures can coexist.""" + prompt1 = Prompt(name="legacy", description="Legacy prompt") + prompt2 = Prompt(name="new", description="New prompt with context") + + await prompt_manager.add_prompt(prompt1, prompt_function) + await prompt_manager.add_prompt(prompt2, prompt_function_with_context) + + # Legacy prompt works without context + result1 = await prompt_manager.get_prompt( + "legacy", {"name": "Dave", "formal": "false"}, None + ) + assert "Hey Dave!" in result1.messages[0].content["text"] + + # New prompt works with context + result2 = await prompt_manager.get_prompt( + "new", {"name": "Eve", "formal": "true"}, mock_context + ) + assert "Good day, Eve" in result2.messages[0].content["text"] + + @pytest.mark.asyncio + async def test_sync_prompt_function_with_context(self, prompt_manager, mock_context): + """Test synchronous prompt function with context parameter.""" + prompt = Prompt(name="sync_prompt", description="Sync prompt with context") + + def sync_prompt(context: Context, args: dict[str, str]) -> list[PromptMessage]: + name = args.get("name", "User") + return [ + PromptMessage(role="user", content={"type": "text", "text": f"Hello {name}!"}) + ] + + await prompt_manager.add_prompt(prompt, sync_prompt) + + result = await prompt_manager.get_prompt("sync_prompt", {"name": "Frank"}, mock_context) + + assert isinstance(result, GetPromptResult) + assert len(result.messages) == 1 + assert "Hello Frank!" in result.messages[0].content["text"] + + @pytest.mark.asyncio + async def test_sync_prompt_function_without_context(self, prompt_manager): + """Test synchronous prompt function without context (legacy).""" + prompt = Prompt(name="sync_legacy", description="Sync legacy prompt") + + def sync_legacy_prompt(args: dict[str, str]) -> list[PromptMessage]: + name = args.get("name", "User") + return [ + PromptMessage(role="user", content={"type": "text", "text": f"Hi {name}!"}) + ] + + await prompt_manager.add_prompt(prompt, sync_legacy_prompt) + + result = await prompt_manager.get_prompt("sync_legacy", {"name": "Grace"}, None) + + assert isinstance(result, GetPromptResult) + assert len(result.messages) == 1 + assert "Hi Grace!" in result.messages[0].content["text"]