Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/arcade-mcp-server/arcade_mcp_server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ async def list(self) -> list[Any]:
return cast(list[Any], prompts)

async def get(self, name: str, arguments: dict[str, str] | None = None) -> Any:
return await self._ctx.server._prompt_manager.get_prompt(name, arguments)
return await self._ctx.server._prompt_manager.get_prompt(name, arguments, self._ctx)


class Sampling(_ContextComponent):
Expand Down
176 changes: 165 additions & 11 deletions libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,170 @@

from __future__ import annotations

import inspect
import logging
from typing import Callable
import types
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Union,
cast,
get_args,
get_origin,
get_type_hints,
)

from arcade_mcp_server.exceptions import NotFoundError, PromptError
from arcade_mcp_server.managers.base import ComponentManager
from arcade_mcp_server.types import GetPromptResult, Prompt, PromptMessage

if TYPE_CHECKING:
from arcade_mcp_server.context import Context

logger = logging.getLogger("arcade.mcp.managers.prompt")

# Type aliases for prompt handler signatures
PromptHandlerLegacy = Callable[[dict[str, str]], list[PromptMessage]]
PromptHandlerWithContext = Callable[["Context", dict[str, str]], list[PromptMessage]]
PromptHandlerType = Union[PromptHandlerLegacy, PromptHandlerWithContext]


class PromptHandler:
"""Handler for generating prompt messages."""
"""Handler for generating prompt messages.

Supports two handler signatures:
1. Legacy: handler(args: dict[str, str]) -> list[PromptMessage]
2. New (with context): handler(context: Context, args: dict[str, str]) -> list[PromptMessage]

The handler signature is detected automatically using introspection.
"""

def __init__(
self,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> None:
self.prompt = prompt
self.handler = handler or self._default_handler
self.handler: Any = handler or self._default_handler
self._accepts_context = self._check_context_signature(self.handler)

def __eq__(self, other: object) -> bool: # pragma: no cover - simple comparison
if not isinstance(other, PromptHandler):
return False
return self.prompt == other.prompt and self.handler == other.handler

def _check_context_signature(self, handler: Any) -> bool:
"""Check if handler accepts context parameter.

Returns True if the first parameter is type-annotated as Context or named "context"
without a conflicting type annotation. Returns False for legacy signatures.

Examples:
- handler(context: Context, args) -> True (typed context)
- handler(context, args) -> True (untyped context)
- handler(context: dict[str, str]) -> False (legacy with misleading name)
- handler(args) -> False (legacy)
"""
from arcade_mcp_server.context import Context as ArcadeContext

def _is_context_annotation(ann: Any) -> bool:
"""Return True only for the actual Context type (or Optional/Union/Annotated wrappers).

Important: do NOT do substring matching on string annotations. That can produce false
positives for unrelated types like ContextManager/ExecutionContext/etc.
"""
if ann is ArcadeContext:
return True

# Real class annotations (including subclasses).
if isinstance(ann, type) and issubclass(ann, ArcadeContext):
return True

# Unwrap common typing wrappers.
origin = get_origin(ann)
if origin is Annotated:
args = get_args(ann)
return _is_context_annotation(args[0]) if args else False

if origin is Union or origin is types.UnionType:
return any(_is_context_annotation(a) for a in get_args(ann))

# Conservative fallback for unresolved forward refs (strings).
if isinstance(ann, str):
s = ann.strip().strip("'\"")

# Handle PEP604 unions in string form: "Context | None"
if "|" in s:
return any(_is_context_annotation(part.strip()) for part in s.split("|"))

# Handle Optional/Union/Annotated in string form. We only unwrap these;
# we intentionally do NOT look inside arbitrary generics like ContextManager[...].
for wrapper in ("Optional[", "Union[", "Annotated["):
if s.startswith(wrapper) and s.endswith("]"):
inner = s[len(wrapper) : -1].strip()
if wrapper == "Union[":
return any(_is_context_annotation(p.strip()) for p in inner.split(","))
if wrapper == "Annotated[":
first = inner.split(",", 1)[0].strip()
return _is_context_annotation(first)
# Optional[
return _is_context_annotation(inner)

# Accept only the actual arcade_mcp_server Context name(s).
return s in {"Context", "arcade_mcp_server.context.Context", "arcade_mcp_server.Context"}

return False

try:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
# Filter out 'self' parameter for bound methods
params = [p for p in params if p.name != "self"]

if not params:
return False

first_param = params[0]

# Check if first parameter is type-annotated
if first_param.annotation != inspect.Parameter.empty:
ann: Any = first_param.annotation

# Prefer resolving type hints (handles forward refs, Optional/Union, etc.)
try:
import arcade_mcp_server

globalns = getattr(handler, "__globals__", {}) or {}
if "arcade_mcp_server" not in globalns:
# Avoid mutating handler globals in-place.
globalns = dict(globalns)
globalns["arcade_mcp_server"] = arcade_mcp_server

hints = get_type_hints(
handler,
globalns=globalns,
localns={"Context": ArcadeContext},
include_extras=True,
)
ann = hints.get(first_param.name, ann)
except Exception:
# Fall back to raw signature annotation.
logger.debug(
"Failed to resolve prompt handler type hints; falling back to raw signature annotations",
exc_info=True,
)
ann = first_param.annotation

return _is_context_annotation(ann)
else:
# No type annotation - check if named "context" (untyped context handler)
return first_param.name == "context"
except (ValueError, TypeError):
# If we can't inspect, assume legacy signature
return False

def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]:
return [
PromptMessage(
Expand All @@ -43,7 +181,11 @@ def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]:
)
]

async def get_messages(self, arguments: dict[str, str] | None = None) -> list[PromptMessage]:
async def get_messages(
self,
arguments: dict[str, str] | None = None,
context: Context | None = None,
) -> list[PromptMessage]:
args = arguments or {}

# Validate required arguments
Expand All @@ -52,11 +194,20 @@ async def get_messages(self, arguments: dict[str, str] | None = None) -> list[Pr
if arg.required and arg.name not in args:
raise PromptError(f"Required argument '{arg.name}' not provided")

result = self.handler(args)
# Call handler with appropriate signature
result: Any
if self._accepts_context:
if context is None:
raise PromptError("Handler requires context but none was provided")
result = self.handler(context, args)
else:
result = self.handler(args)

if hasattr(result, "__await__"):
result = await result

return result
# Cast result to expected type after dynamic invocation
return cast(list[PromptMessage], result)


class PromptManager(ComponentManager[str, PromptHandler]):
Expand All @@ -72,15 +223,18 @@ async def list_prompts(self) -> list[Prompt]:
return [h.prompt for h in handlers]

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
self,
name: str,
arguments: dict[str, str] | None = None,
context: Context | None = None,
) -> GetPromptResult:
try:
handler = await self.registry.get(name)
except KeyError:
raise NotFoundError(f"Prompt '{name}' not found")

try:
messages = await handler.get_messages(arguments)
messages = await handler.get_messages(arguments, context)
return GetPromptResult(
description=handler.prompt.description,
messages=messages,
Expand All @@ -93,7 +247,7 @@ async def get_prompt(
async def add_prompt(
self,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> None:
prompt_handler = PromptHandler(prompt, handler)
await self.registry.upsert(prompt.name, prompt_handler)
Expand All @@ -109,7 +263,7 @@ async def update_prompt(
self,
name: str,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> Prompt:
# Ensure exists
try:
Expand Down
2 changes: 2 additions & 0 deletions libs/arcade-mcp-server/arcade_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,9 +1337,11 @@ async def _handle_get_prompt(
) -> JSONRPCResponse[GetPromptResult] | JSONRPCError:
"""Handle get prompt request."""
try:
context = get_current_model_context()
result = await self._prompt_manager.get_prompt(
message.params.name,
message.params.arguments if hasattr(message.params, "arguments") else None,
context,
)
return JSONRPCResponse(id=message.id, result=result)
except NotFoundError:
Expand Down
2 changes: 1 addition & 1 deletion libs/arcade-mcp-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "arcade-mcp-server"
version = "1.13.0"
version = "1.14.1"
description = "Model Context Protocol (MCP) server framework for Arcade.dev"
readme = "README.md"
authors = [{ name = "Arcade.dev" }]
Expand Down
Loading
Loading