Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions activate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
# Quick activation script for the virtual environment

if [ -f ".venv/bin/activate" ]; then
source .venv/bin/activate
echo "Virtual environment activated!"
echo "Python: $(which python) ($(python --version))"
echo "To deactivate, run: deactivate"
else
echo "Error: Virtual environment not found. Run ./uv_setup.sh first."
exit 1
fi
2 changes: 1 addition & 1 deletion libs/arcade-mcp-server/arcade_mcp_server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ async def list(self) -> list[Any]:
return cast(list[Any], prompts)

async def get(self, name: str, arguments: dict[str, str] | None = None) -> Any:
return await self._ctx.server._prompt_manager.get_prompt(name, arguments)
return await self._ctx.server._prompt_manager.get_prompt(name, arguments, self._ctx)


class Sampling(_ContextComponent):
Expand Down
71 changes: 60 additions & 11 deletions libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,65 @@

from __future__ import annotations

import inspect
import logging
from typing import Callable
from typing import TYPE_CHECKING, Any, Callable, Union, cast

from arcade_mcp_server.exceptions import NotFoundError, PromptError
from arcade_mcp_server.managers.base import ComponentManager
from arcade_mcp_server.types import GetPromptResult, Prompt, PromptMessage

if TYPE_CHECKING:
from arcade_mcp_server.context import Context

logger = logging.getLogger("arcade.mcp.managers.prompt")

# Type aliases for prompt handler signatures
PromptHandlerLegacy = Callable[[dict[str, str]], list[PromptMessage]]
PromptHandlerWithContext = Callable[["Context", dict[str, str]], list[PromptMessage]]
PromptHandlerType = Union[PromptHandlerLegacy, PromptHandlerWithContext]


class PromptHandler:
"""Handler for generating prompt messages."""
"""Handler for generating prompt messages.

Supports two handler signatures:
1. Legacy: handler(args: dict[str, str]) -> list[PromptMessage]
2. New (with context): handler(context: Context, args: dict[str, str]) -> list[PromptMessage]

The handler signature is detected automatically using introspection.
"""

def __init__(
self,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> None:
self.prompt = prompt
self.handler = handler or self._default_handler
self.handler: Any = handler or self._default_handler
self._accepts_context = self._check_context_signature(self.handler)

def __eq__(self, other: object) -> bool: # pragma: no cover - simple comparison
if not isinstance(other, PromptHandler):
return False
return self.prompt == other.prompt and self.handler == other.handler

def _check_context_signature(self, handler: Any) -> bool:
"""Check if handler accepts context parameter.

Returns True if the handler signature has 2 parameters (context, args),
False if it has 1 parameter (args only).
"""
try:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
# Filter out 'self' parameter for bound methods
params = [p for p in params if p.name != "self"]
return len(params) >= 2
except (ValueError, TypeError):
# If we can't inspect, assume legacy signature
return False

def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]:
return [
PromptMessage(
Expand All @@ -43,7 +76,11 @@ def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]:
)
]

async def get_messages(self, arguments: dict[str, str] | None = None) -> list[PromptMessage]:
async def get_messages(
self,
arguments: dict[str, str] | None = None,
context: Context | None = None,
) -> list[PromptMessage]:
args = arguments or {}

# Validate required arguments
Expand All @@ -52,11 +89,20 @@ async def get_messages(self, arguments: dict[str, str] | None = None) -> list[Pr
if arg.required and arg.name not in args:
raise PromptError(f"Required argument '{arg.name}' not provided")

result = self.handler(args)
# Call handler with appropriate signature
result: Any
if self._accepts_context:
if context is None:
raise PromptError("Handler requires context but none was provided")
result = self.handler(context, args)
else:
result = self.handler(args)

if hasattr(result, "__await__"):
result = await result

return result
# Cast result to expected type after dynamic invocation
return cast(list[PromptMessage], result)


class PromptManager(ComponentManager[str, PromptHandler]):
Expand All @@ -72,15 +118,18 @@ async def list_prompts(self) -> list[Prompt]:
return [h.prompt for h in handlers]

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
self,
name: str,
arguments: dict[str, str] | None = None,
context: Context | None = None,
) -> GetPromptResult:
try:
handler = await self.registry.get(name)
except KeyError:
raise NotFoundError(f"Prompt '{name}' not found")

try:
messages = await handler.get_messages(arguments)
messages = await handler.get_messages(arguments, context)
return GetPromptResult(
description=handler.prompt.description,
messages=messages,
Expand All @@ -93,7 +142,7 @@ async def get_prompt(
async def add_prompt(
self,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> None:
prompt_handler = PromptHandler(prompt, handler)
await self.registry.upsert(prompt.name, prompt_handler)
Expand All @@ -109,7 +158,7 @@ async def update_prompt(
self,
name: str,
prompt: Prompt,
handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None,
handler: PromptHandlerType | None = None,
) -> Prompt:
# Ensure exists
try:
Expand Down
5 changes: 5 additions & 0 deletions libs/arcade-mcp-server/arcade_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,9 +1337,14 @@ async def _handle_get_prompt(
) -> JSONRPCResponse[GetPromptResult] | JSONRPCError:
"""Handle get prompt request."""
try:
# Get current context for prompt handlers that need it
from arcade_mcp_server.context import get_current_model_context

context = get_current_model_context()
result = await self._prompt_manager.get_prompt(
message.params.name,
message.params.arguments if hasattr(message.params, "arguments") else None,
context,
)
return JSONRPCResponse(id=message.id, result=result)
except NotFoundError:
Expand Down
2 changes: 1 addition & 1 deletion libs/arcade-mcp-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "arcade-mcp-server"
version = "1.13.0"
version = "1.14.0"
description = "Model Context Protocol (MCP) server framework for Arcade.dev"
readme = "README.md"
authors = [{ name = "Arcade.dev" }]
Expand Down
156 changes: 155 additions & 1 deletion libs/tests/arcade_mcp_server/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests for Prompt Manager implementation."""

import asyncio
from unittest.mock import Mock

import pytest
from arcade_mcp_server.context import Context
from arcade_mcp_server.exceptions import NotFoundError, PromptError
from arcade_mcp_server.managers.prompt import PromptManager
from arcade_mcp_server.types import (
Expand Down Expand Up @@ -37,7 +39,7 @@ def sample_prompt(self):

@pytest.fixture
def prompt_function(self):
"""Create a prompt function."""
"""Create a prompt function (legacy signature without context)."""

async def greeting_prompt(args: dict[str, str]) -> list[PromptMessage]:
name = args.get("name", "")
Expand All @@ -53,6 +55,45 @@ async def greeting_prompt(args: dict[str, str]) -> list[PromptMessage]:

return greeting_prompt

@pytest.fixture
def prompt_function_with_context(self):
"""Create a prompt function with context parameter (new signature)."""

async def greeting_prompt_with_context(
context: Context, args: dict[str, str]
) -> list[PromptMessage]:
name = args.get("name", "")
formal_arg = args.get("formal", "false")
formal = str(formal_arg).lower() == "true"

# Access context (e.g., for logging)
if hasattr(context, "log"):
await context.log.info(f"Generating greeting for {name}")

if formal:
text = f"Good day, {name}. How may I assist you?"
else:
text = f"Hey {name}! What's up?"

return [PromptMessage(role="assistant", content={"type": "text", "text": text})]

return greeting_prompt_with_context

@pytest.fixture
def mock_context(self):
"""Create a mock context."""
mock_server = Mock()
context = Context(mock_server)
# Mock the log interface with async methods
mock_log = Mock()

async def async_info(*args, **kwargs):
pass

mock_log.info = async_info
context._log = mock_log
return context

def test_manager_initialization(self):
"""Test prompt manager initialization."""
manager = PromptManager()
Expand Down Expand Up @@ -239,3 +280,116 @@ async def error_prompt(args: dict[str, str]):

with pytest.raises(PromptError):
await manager.get_prompt("error_prompt", {})

@pytest.mark.asyncio
async def test_prompt_with_context_parameter(
self, prompt_manager, sample_prompt, prompt_function_with_context, mock_context
):
"""Test prompt with new context parameter signature."""
await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context)

result = await prompt_manager.get_prompt(
"greeting", {"name": "Alice", "formal": "true"}, mock_context
)

assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert "Good day, Alice" in result.messages[0].content["text"]

@pytest.mark.asyncio
async def test_prompt_with_context_logging(
self, prompt_manager, sample_prompt, prompt_function_with_context, mock_context
):
"""Test that prompt with context can use logging."""
await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context)

await prompt_manager.get_prompt("greeting", {"name": "Bob"}, mock_context)

# Verify logging was called (if mock was set up properly)
# This would require a more sophisticated mock setup

@pytest.mark.asyncio
async def test_prompt_context_required_but_not_provided(
self, prompt_manager, sample_prompt, prompt_function_with_context
):
"""Test that error is raised when context-requiring prompt doesn't get context."""
await prompt_manager.add_prompt(sample_prompt, prompt_function_with_context)

with pytest.raises(PromptError, match="Handler requires context"):
await prompt_manager.get_prompt("greeting", {"name": "Alice"}, None)

@pytest.mark.asyncio
async def test_backward_compatibility_legacy_signature(
self, prompt_manager, sample_prompt, prompt_function
):
"""Test backward compatibility with legacy signature (no context)."""
await prompt_manager.add_prompt(sample_prompt, prompt_function)

# Should work without context
result = await prompt_manager.get_prompt("greeting", {"name": "Charlie"}, None)

assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert "Hey Charlie!" in result.messages[0].content["text"]

@pytest.mark.asyncio
async def test_mixed_signatures(
self, prompt_manager, prompt_function, prompt_function_with_context, mock_context
):
"""Test that both signatures can coexist."""
prompt1 = Prompt(name="legacy", description="Legacy prompt")
prompt2 = Prompt(name="new", description="New prompt with context")

await prompt_manager.add_prompt(prompt1, prompt_function)
await prompt_manager.add_prompt(prompt2, prompt_function_with_context)

# Legacy prompt works without context
result1 = await prompt_manager.get_prompt(
"legacy", {"name": "Dave", "formal": "false"}, None
)
assert "Hey Dave!" in result1.messages[0].content["text"]

# New prompt works with context
result2 = await prompt_manager.get_prompt(
"new", {"name": "Eve", "formal": "true"}, mock_context
)
assert "Good day, Eve" in result2.messages[0].content["text"]

@pytest.mark.asyncio
async def test_sync_prompt_function_with_context(self, prompt_manager, mock_context):
"""Test synchronous prompt function with context parameter."""
prompt = Prompt(name="sync_prompt", description="Sync prompt with context")

def sync_prompt(context: Context, args: dict[str, str]) -> list[PromptMessage]:
name = args.get("name", "User")
return [
PromptMessage(role="user", content={"type": "text", "text": f"Hello {name}!"})
]

await prompt_manager.add_prompt(prompt, sync_prompt)

result = await prompt_manager.get_prompt("sync_prompt", {"name": "Frank"}, mock_context)

assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert "Hello Frank!" in result.messages[0].content["text"]

@pytest.mark.asyncio
async def test_sync_prompt_function_without_context(self, prompt_manager):
"""Test synchronous prompt function without context (legacy)."""
prompt = Prompt(name="sync_legacy", description="Sync legacy prompt")

def sync_legacy_prompt(args: dict[str, str]) -> list[PromptMessage]:
name = args.get("name", "User")
return [
PromptMessage(role="user", content={"type": "text", "text": f"Hi {name}!"})
]

await prompt_manager.add_prompt(prompt, sync_legacy_prompt)

result = await prompt_manager.get_prompt("sync_legacy", {"name": "Grace"}, None)

assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert "Hi Grace!" in result.messages[0].content["text"]