Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/arcade-mcp-server/arcade_mcp_server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ async def list(self) -> list[Any]:
return cast(list[Any], prompts)

async def get(self, name: str, arguments: dict[str, str] | None = None) -> Any:
return await self._ctx.server._prompt_manager.get_prompt(name, arguments)
return await self._ctx.server._prompt_manager.get_prompt(name, arguments, self._ctx)


class Sampling(_ContextComponent):
Expand Down
91 changes: 80 additions & 11 deletions libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,85 @@

from __future__ import annotations

import inspect
import logging
from typing import Callable
from typing import TYPE_CHECKING, Any, Callable, Union, cast

from arcade_mcp_server.exceptions import NotFoundError, PromptError
from arcade_mcp_server.managers.base import ComponentManager
from arcade_mcp_server.types import GetPromptResult, Prompt, PromptMessage

if TYPE_CHECKING:
from arcade_mcp_server.context import Context

logger = logging.getLogger("arcade.mcp.managers.prompt")

# Type aliases for prompt handler signatures
PromptHandlerLegacy = Callable[[dict[str, str]], list[PromptMessage]]
PromptHandlerWithContext = Callable[["Context", dict[str, str]], list[PromptMessage]]
PromptHandlerType = Union[PromptHandlerLegacy, PromptHandlerWithContext]


class PromptHandler:
"""Handler for generating prompt messages."""
"""Handler for generating prompt messages.

Supports two handler signatures:
1. Legacy: handler(args: dict[str, str]) -> list[PromptMessage]
2. New (with context): handler(context: Context, args: dict[str, str]) -> list[PromptMessage]

The handler signature is detected automatically using introspection.
"""

def __init__(
self,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> None:
self.prompt = prompt
self.handler = handler or self._default_handler
self.handler: Any = handler or self._default_handler
self._accepts_context = self._check_context_signature(self.handler)

def __eq__(self, other: object) -> bool: # pragma: no cover - simple comparison
if not isinstance(other, PromptHandler):
return False
return self.prompt == other.prompt and self.handler == other.handler

def _check_context_signature(self, handler: Any) -> bool:
"""Check if handler accepts context parameter.

Returns True if the first parameter is type-annotated as Context or named "context"
without a conflicting type annotation. Returns False for legacy signatures.

Examples:
- handler(context: Context, args) -> True (typed context)
- handler(context, args) -> True (untyped context)
- handler(context: dict[str, str]) -> False (legacy with misleading name)
- handler(args) -> False (legacy)
"""
try:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
# Filter out 'self' parameter for bound methods
params = [p for p in params if p.name != "self"]

if not params:
return False

first_param = params[0]

# Check if first parameter is type-annotated
if first_param.annotation != inspect.Parameter.empty:
annotation_str = str(first_param.annotation)
# Only return True if the type annotation contains "Context"
# This handles Context, arcade_mcp_server.context.Context, etc.
return "Context" in annotation_str
else:
# No type annotation - check if named "context" (untyped context handler)
return first_param.name == "context"
except (ValueError, TypeError):
# If we can't inspect, assume legacy signature
return False

def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]:
return [
PromptMessage(
Expand All @@ -43,7 +96,11 @@ def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]:
)
]

async def get_messages(self, arguments: dict[str, str] | None = None) -> list[PromptMessage]:
async def get_messages(
self,
arguments: dict[str, str] | None = None,
context: Context | None = None,
) -> list[PromptMessage]:
args = arguments or {}

# Validate required arguments
Expand All @@ -52,11 +109,20 @@ async def get_messages(self, arguments: dict[str, str] | None = None) -> list[Pr
if arg.required and arg.name not in args:
raise PromptError(f"Required argument '{arg.name}' not provided")

result = self.handler(args)
# Call handler with appropriate signature
result: Any
if self._accepts_context:
if context is None:
raise PromptError("Handler requires context but none was provided")
result = self.handler(context, args)
else:
result = self.handler(args)

if hasattr(result, "__await__"):
result = await result

return result
# Cast result to expected type after dynamic invocation
return cast(list[PromptMessage], result)


class PromptManager(ComponentManager[str, PromptHandler]):
Expand All @@ -72,15 +138,18 @@ async def list_prompts(self) -> list[Prompt]:
return [h.prompt for h in handlers]

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
self,
name: str,
arguments: dict[str, str] | None = None,
context: Context | None = None,
) -> GetPromptResult:
try:
handler = await self.registry.get(name)
except KeyError:
raise NotFoundError(f"Prompt '{name}' not found")

try:
messages = await handler.get_messages(arguments)
messages = await handler.get_messages(arguments, context)
return GetPromptResult(
description=handler.prompt.description,
messages=messages,
Expand All @@ -93,7 +162,7 @@ async def get_prompt(
async def add_prompt(
self,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> None:
prompt_handler = PromptHandler(prompt, handler)
await self.registry.upsert(prompt.name, prompt_handler)
Expand All @@ -109,7 +178,7 @@ async def update_prompt(
self,
name: str,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> Prompt:
# Ensure exists
try:
Expand Down
2 changes: 2 additions & 0 deletions libs/arcade-mcp-server/arcade_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,9 +1337,11 @@ async def _handle_get_prompt(
) -> JSONRPCResponse[GetPromptResult] | JSONRPCError:
"""Handle get prompt request."""
try:
context = get_current_model_context()
result = await self._prompt_manager.get_prompt(
message.params.name,
message.params.arguments if hasattr(message.params, "arguments") else None,
context,
)
return JSONRPCResponse(id=message.id, result=result)
except NotFoundError:
Expand Down
2 changes: 1 addition & 1 deletion libs/arcade-mcp-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "arcade-mcp-server"
version = "1.13.0"
version = "1.14.0"
description = "Model Context Protocol (MCP) server framework for Arcade.dev"
readme = "README.md"
authors = [{ name = "Arcade.dev" }]
Expand Down
156 changes: 155 additions & 1 deletion libs/tests/arcade_mcp_server/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests for Prompt Manager implementation."""

import asyncio
from unittest.mock import Mock

import pytest
from arcade_mcp_server.context import Context
from arcade_mcp_server.exceptions import NotFoundError, PromptError
from arcade_mcp_server.managers.prompt import PromptManager
from arcade_mcp_server.types import (
Expand Down Expand Up @@ -37,7 +39,7 @@ def sample_prompt(self):

@pytest.fixture
def prompt_function(self):
"""Create a prompt function."""
"""Create a prompt function (legacy signature without context)."""

async def greeting_prompt(args: dict[str, str]) -> list[PromptMessage]:
name = args.get("name", "")
Expand All @@ -53,6 +55,45 @@ async def greeting_prompt(args: dict[str, str]) -> list[PromptMessage]:

return greeting_prompt

@pytest.fixture
def prompt_function_with_context(self):
"""Create a prompt function with context parameter (new signature)."""

async def greeting_prompt_with_context(
context: Context, args: dict[str, str]
) -> list[PromptMessage]:
name = args.get("name", "")
formal_arg = args.get("formal", "false")
formal = str(formal_arg).lower() == "true"

# Access context (e.g., for logging)
if hasattr(context, "log"):
await context.log.info(f"Generating greeting for {name}")

if formal:
text = f"Good day, {name}. How may I assist you?"
else:
text = f"Hey {name}! What's up?"

return [PromptMessage(role="assistant", content={"type": "text", "text": text})]

return greeting_prompt_with_context

@pytest.fixture
def mock_context(self):
"""Create a mock context."""
mock_server = Mock()
context = Context(mock_server)
# Mock the log interface with async methods
mock_log = Mock()

async def async_info(*args, **kwargs):
pass

mock_log.info = async_info
context._log = mock_log
return context

def test_manager_initialization(self):
"""Test prompt manager initialization."""
manager = PromptManager()
Expand Down Expand Up @@ -239,3 +280,116 @@ async def error_prompt(args: dict[str, str]):

with pytest.raises(PromptError):
await manager.get_prompt("error_prompt", {})

@pytest.mark.asyncio
async def test_prompt_with_context_parameter(
self, prompt_manager, sample_prompt, prompt_function_with_context, mock_context
):
"""Test prompt with new context parameter signature."""
await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context)

result = await prompt_manager.get_prompt(
"greeting", {"name": "Alice", "formal": "true"}, mock_context
)

assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert "Good day, Alice" in result.messages[0].content["text"]

@pytest.mark.asyncio
async def test_prompt_with_context_logging(
self, prompt_manager, sample_prompt, prompt_function_with_context, mock_context
):
"""Test that prompt with context can use logging."""
await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context)

await prompt_manager.get_prompt("greeting", {"name": "Bob"}, mock_context)

# Verify logging was called (if mock was set up properly)
# This would require a more sophisticated mock setup

@pytest.mark.asyncio
async def test_prompt_context_required_but_not_provided(
self, prompt_manager, sample_prompt, prompt_function_with_context
):
"""Test that error is raised when context-requiring prompt doesn't get context."""
await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context)

with pytest.raises(PromptError, match="Handler requires context"):
await prompt_manager.get_prompt("greeting", {"name": "Alice"}, None)

@pytest.mark.asyncio
async def test_backward_compatibility_legacy_signature(
self, prompt_manager, sample_prompt, prompt_function
):
"""Test backward compatibility with legacy signature (no context)."""
await prompt_manager.add_prompt(sample_prompt, prompt_function)

# Should work without context
result = await prompt_manager.get_prompt("greeting", {"name": "Charlie"}, None)

assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert "Hey Charlie!" in result.messages[0].content["text"]

@pytest.mark.asyncio
async def test_mixed_signatures(
self, prompt_manager, prompt_function, prompt_function_with_context, mock_context
):
"""Test that both signatures can coexist."""
prompt1 = Prompt(name="legacy", description="Legacy prompt")
prompt2 = Prompt(name="new", description="New prompt with context")

await prompt_manager.add_prompt(prompt1, prompt_function)
await prompt_manager.add_prompt(prompt2, prompt_function_with_context)

# Legacy prompt works without context
result1 = await prompt_manager.get_prompt(
"legacy", {"name": "Dave", "formal": "false"}, None
)
assert "Hey Dave!" in result1.messages[0].content["text"]

# New prompt works with context
result2 = await prompt_manager.get_prompt(
"new", {"name": "Eve", "formal": "true"}, mock_context
)
assert "Good day, Eve" in result2.messages[0].content["text"]

@pytest.mark.asyncio
async def test_sync_prompt_function_with_context(self, prompt_manager, mock_context):
"""Test synchronous prompt function with context parameter."""
prompt = Prompt(name="sync_prompt", description="Sync prompt with context")

def sync_prompt(context: Context, args: dict[str, str]) -> list[PromptMessage]:
name = args.get("name", "User")
return [
PromptMessage(role="user", content={"type": "text", "text": f"Hello {name}!"})
]

await prompt_manager.add_prompt(prompt, sync_prompt)

result = await prompt_manager.get_prompt("sync_prompt", {"name": "Frank"}, mock_context)

assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert "Hello Frank!" in result.messages[0].content["text"]

@pytest.mark.asyncio
async def test_sync_prompt_function_without_context(self, prompt_manager):
"""Test synchronous prompt function without context (legacy)."""
prompt = Prompt(name="sync_legacy", description="Sync legacy prompt")

def sync_legacy_prompt(args: dict[str, str]) -> list[PromptMessage]:
name = args.get("name", "User")
return [
PromptMessage(role="user", content={"type": "text", "text": f"Hi {name}!"})
]

await prompt_manager.add_prompt(prompt, sync_legacy_prompt)

result = await prompt_manager.get_prompt("sync_legacy", {"name": "Grace"}, None)

assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert "Hi Grace!" in result.messages[0].content["text"]