diff --git a/examples/evals/README.md b/examples/evals/README.md new file mode 100644 index 000000000..cf1371f6a --- /dev/null +++ b/examples/evals/README.md @@ -0,0 +1,392 @@ +# Arcade Evals Examples + +This directory contains user-friendly examples demonstrating how to evaluate tools from different sources using the Arcade evals framework. + +## πŸ“‹ Table of Contents + +- [Quick Start](#quick-start) +- [Example Files](#example-files) +- [CLI Reference](#cli-reference) +- [Common Patterns](#common-patterns) +- [Troubleshooting](#troubleshooting) + +## πŸš€ Quick Start + +### What Makes These Examples Different + +These examples are designed to be: +- **Production-ready**: Include proper error handling and timeouts +- **Copy-paste friendly**: Clear configuration sections you can modify +- **Informative**: Print status messages during loading +- **Focused**: One concept per example, no unnecessary complexity +- **Pattern-based**: Follow consistent structure from real-world evals + +### Installation + +```bash +# Install with evals support +pip install 'arcade-mcp[evals]' + +# Or using uv (recommended) +uv tool install 'arcade-mcp[evals]' +``` + +### Basic Usage + +```bash +# Run an evaluation with OpenAI +arcade evals examples/evals/eval_arcade_gateway.py \ + --api-key openai:YOUR_OPENAI_KEY + +# Compare multiple models +arcade evals examples/evals/eval_stdio_mcp_server.py \ + -p "openai:gpt-4o anthropic:claude-sonnet-4-5-20250929" \ + -k openai:YOUR_OPENAI_KEY \ + -k anthropic:YOUR_ANTHROPIC_KEY + +# Output results to HTML +arcade evals examples/evals/eval_http_mcp_server.py \ + --api-key openai:YOUR_KEY \ + -o results.html -d +``` + +## πŸ“š Example Files + +### Example Structure + +All examples follow a consistent pattern: + +```python +# 1. Configuration section - Update these values +ARCADE_API_KEY = os.environ.get("ARCADE_API_KEY", "YOUR_KEY_HERE") + +# 2. Eval suite with async loading +@tool_eval() +async def eval_my_suite() -> EvalSuite: + suite = EvalSuite(name="...", system_message="...", rubric=...) + + # 3. Load tools with timeout and error handling + try: + await asyncio.wait_for( + suite.add_arcade_gateway(...), + timeout=10.0, + ) + print(" βœ“ Source loaded") + except Exception as e: + print(f" βœ— Source failed: {e}") + return suite + + # 4. Add test cases + suite.add_case(name="...", user_message="...", ...) + + return suite +``` + +This pattern ensures: +- Clear configuration at the top +- Robust error handling +- Informative output during loading +- Graceful degradation if sources fail + +### 1. `eval_arcade_gateway.py` + +Evaluates tools from Arcade Gateway (cloud-hosted toolkits). + +**What it demonstrates:** + +- Async loading from Arcade Gateway with timeout handling +- Error handling for connection failures +- Math toolkit evaluations +- BinaryCritic for parameter validation +- Conversational context with additional_messages + +**Prerequisites:** + +Before running this example, you need to set up an MCP Gateway: + +1. **Get your API key** - [API Keys Setup Guide](https://docs.arcade.dev/en/get-started/setup/api-keys) +2. **Create an MCP Gateway** at [Arcade Portal](https://portal.arcade.dev) +3. **Add toolkits** (e.g., Math, GitHub, Slack) to your gateway +4. **Get your credentials:** + - `ARCADE_API_KEY` - Your Arcade API key + - `ARCADE_USER_ID` - Your user ID (found in portal settings) + +πŸ“š **Full setup guide:** [MCP Gateways Documentation](https://docs.arcade.dev/en/guides/create-tools/mcp-gateways) + +**Requirements:** + +- Arcade API key (get one at [arcade.dev](https://arcade.dev)) +- LLM API key (OpenAI or Anthropic) + +**Run it:** + +```bash +# Set your Arcade API key +export ARCADE_API_KEY=your_arcade_key + +arcade evals examples/evals/eval_arcade_gateway.py \ + --api-key openai:YOUR_OPENAI_KEY +``` + +### 2. `eval_stdio_mcp_server.py` + +Evaluates tools from local MCP servers running via stdio (subprocess). + +**What it demonstrates:** + +- Loading from local stdio MCP servers (subprocesses) +- Using `add_mcp_stdio_server()` method +- Setting environment variables (PYTHONUNBUFFERED) +- Simple echo tool evaluations +- Async loading with timeout and error handling + +**Requirements:** + +- Local MCP server code +- Server dependencies installed +- LLM API key + +**Run it:** + +```bash +arcade evals examples/evals/eval_stdio_mcp_server.py \ + --api-key openai:YOUR_KEY +``` + +### 3. `eval_http_mcp_server.py` + +Evaluates tools from remote MCP servers via HTTP or SSE. + +**What it demonstrates:** + +- Connecting to HTTP MCP endpoints +- Using SSE (Server-Sent Events) transport +- Authentication with Bearer tokens +- Error handling with timeouts + +**Requirements:** + +- Running HTTP/SSE MCP server +- Network connectivity +- LLM API key +- (Optional) Authentication token + +**Run it:** + +```bash +# Update the configuration in the file first, then run: +arcade evals examples/evals/eval_http_mcp_server.py \ + --api-key openai:YOUR_KEY +``` + +### 4. `eval_comprehensive_comparison.py` + +Compares tool performance across multiple sources simultaneously. + +**What it demonstrates:** + +- Comparative evaluation across different tool sources +- Loading from multiple sources (Gateway, stdio, dict) +- Track-based evaluation (comparing same task across sources) +- Conditional test cases based on loaded sources +- Using SimilarityCritic for fuzzy matching + +**Requirements:** + +- Arcade API key (for Gateway) +- LLM API key +- (Optional) Local simple MCP server + +**Run it:** + +```bash +# Set environment variables +export ARCADE_API_KEY=your_key +export ARCADE_USER_ID=your_user_id + +arcade evals examples/evals/eval_comprehensive_comparison.py \ + -p "openai:gpt-4o anthropic:claude-sonnet-4-5-20250929" \ + -k openai:YOUR_KEY \ + -k anthropic:YOUR_KEY \ + -o comparison.html -d +``` + +## 🎯 CLI Reference + +### New v2.0.0 Flags + + +| Flag | Short | Description | Example | +| --------------------- | ------- | -------------------------------------------------- | ------------------------------------------------- | +| `--use-provider` | `-p` | Provider(s) and models (space-separated) | `-p "openai:gpt-4o anthropic:claude-sonnet"` | +| `--api-key` | `-k` | API key in`provider:key` format (repeatable) | `-k openai:sk-... -k anthropic:sk-ant-...` | +| `--output` | `-o` | Output file (auto-detects format from extension) | `-o results.html` or `-o results` (all formats) | +| `--only-failed` | `-f` | Show only failed evaluations | `--only-failed` | +| `--include-context` | | Include system messages and conversation history | `--include-context` | +| `--details` | `-d` | Show detailed output | `-d` | +| `--max-concurrent` | | Max concurrent evaluations | `--max-concurrent 5` | +| `--capture` | | Capture mode (record tool calls without scoring) | `--capture` | + +### Provider & Model Selection + +**Single provider with default model:** + +```bash +arcade evals eval_file.py -p openai -k openai:YOUR_KEY +``` + +**Single provider with specific models:** + +```bash +arcade evals eval_file.py -p "openai:gpt-4o,gpt-4o-mini" -k openai:YOUR_KEY +``` + +**Multiple providers (space-separated):** + +```bash +arcade evals eval_file.py \ + -p "openai:gpt-4o anthropic:claude-sonnet-4-5-20250929" \ + -k openai:YOUR_KEY \ + -k anthropic:YOUR_KEY +``` + +### Output Formats + +**Auto-detect from extension:** + +```bash +-o results.html # HTML output +-o results.json # JSON output +-o results.md # Markdown output +-o results.txt # Text output +``` + +**Multiple formats:** + +```bash +-o results.html -o results.json # Both HTML and JSON +``` + +**All formats:** + +```bash +-o results # Generates results.txt, results.md, results.html, results.json +``` + +## πŸ”§ Common Patterns + +### Pattern 1: Compare OpenAI Models + +```bash +arcade evals examples/evals/eval_arcade_gateway.py \ + -p "openai:gpt-4o,gpt-4o-mini,gpt-3.5-turbo" \ + -k openai:YOUR_KEY \ + -o comparison.html -d +``` + +### Pattern 2: OpenAI vs Anthropic + +```bash +arcade evals examples/evals/eval_stdio_mcp_server.py \ + -p "openai:gpt-4o anthropic:claude-sonnet-4-5-20250929" \ + -k openai:YOUR_OPENAI_KEY \ + -k anthropic:YOUR_ANTHROPIC_KEY \ + -o battle.html -d +``` + +### Pattern 3: Failed Tests Only + +```bash +arcade evals examples/evals/eval_http_mcp_server.py \ + --api-key openai:YOUR_KEY \ + --only-failed -d +``` + +### Pattern 4: Comparative Evaluation + +```bash +# Compare performance across multiple tool sources +arcade evals examples/evals/eval_comprehensive_comparison.py \ + -p "openai:gpt-4o anthropic:claude-sonnet-4-5-20250929" \ + -k openai:YOUR_KEY \ + -k anthropic:YOUR_KEY \ + -o comparison.html -d +``` + +### Pattern 5: Capture Mode (No Scoring) + +```bash +# Record tool calls without evaluation +arcade evals examples/evals/eval_arcade_gateway.py \ + --capture \ + --api-key openai:YOUR_KEY \ + -o captured.json +``` + +### Pattern 6: Full Context Output + +```bash +arcade evals examples/evals/eval_stdio_mcp_server.py \ + --api-key openai:YOUR_KEY \ + --include-context \ + -o full_results.html -d +``` + +## πŸ› Troubleshooting + +### Error: "No module named 'openai'" + +**Solution:** Install evals dependencies: + +```bash +pip install 'arcade-mcp[evals]' +``` + +### Error: "API key not found for provider 'openai'" + +**Solution:** Provide API key via flag or environment variable: + +```bash +# Via flag +arcade evals eval_file.py --api-key openai:YOUR_KEY + +# Via environment variable +export OPENAI_API_KEY=your_key +arcade evals eval_file.py +``` + +### Error: "Connection refused" (HTTP server) + +**Solution:** Ensure your HTTP MCP server is running: + +```bash +# Check if server is running +curl http://localhost:8000/mcp + +# Start your server first +python server.py +``` + +### Error: "Module not found" (stdio server) + +**Solution:** Install server dependencies: + +```bash +cd examples/mcp_servers/simple +uv sync +``` + +### Evals run but all tests fail + +**Possible causes:** + +1. Wrong tool names - check your server's tool definitions +2. Incorrect argument names - verify expected vs actual +3. Server not responding - check server logs +4. API key issues - verify LLM provider keys + +**Debug with verbose output:** + +```bash +arcade evals eval_file.py --api-key openai:YOUR_KEY -d +``` diff --git a/examples/evals/eval_arcade_gateway.py b/examples/evals/eval_arcade_gateway.py new file mode 100644 index 000000000..1126e7aa3 --- /dev/null +++ b/examples/evals/eval_arcade_gateway.py @@ -0,0 +1,135 @@ +"""Arcade Gateway evaluation - Loading tools from cloud-hosted toolkits. + +This example demonstrates loading and evaluating tools from Arcade Gateway, +which provides access to pre-built toolkits (Math, GitHub, Slack, Linear, etc.). + +Prerequisites: + 1. Get your API key: https://docs.arcade.dev/en/get-started/setup/api-keys + 2. Create an MCP Gateway at https://portal.arcade.dev + 3. Add toolkits to your gateway (e.g., Math, GitHub, Slack) + 4. Get your ARCADE_API_KEY and ARCADE_USER_ID from the portal + + Full setup guide: https://docs.arcade.dev/en/guides/create-tools/mcp-gateways + +Run: + # Set environment variables + export ARCADE_API_KEY=your_arcade_key + export ARCADE_USER_ID=your_user_id + + # Run the evaluation + arcade evals examples/evals/eval_arcade_gateway.py \\ + -p openai:gpt-4o \\ + -k openai:YOUR_KEY \\ + -o results.html -d +""" + +import asyncio +import os + +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedMCPToolCall, + tool_eval, +) + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +ARCADE_API_KEY = os.environ.get("ARCADE_API_KEY", "YOUR_ARCADE_API_KEY_HERE") +ARCADE_USER_ID = os.environ.get("ARCADE_USER_ID", "YOUR_USER_ID_HERE") + +default_rubric = EvalRubric( + fail_threshold=0.7, + warn_threshold=0.9, +) + + +# ============================================================================= +# EVAL SUITE +# ============================================================================= + + +@tool_eval() +async def eval_arcade_gateway() -> EvalSuite: + """Evaluate Math toolkit from Arcade Gateway.""" + suite = EvalSuite( + name="Arcade Gateway - Math Toolkit", + system_message="You are a helpful math assistant. Use tools to perform calculations.", + rubric=default_rubric, + ) + + print("\n Loading Arcade Gateway...") + + try: + await asyncio.wait_for( + suite.add_arcade_gateway( + gateway_slug="Math", + arcade_api_key=ARCADE_API_KEY, + arcade_user_id=ARCADE_USER_ID, + ), + timeout=10.0, + ) + print(" βœ“ Arcade Gateway (Math toolkit)") + except asyncio.TimeoutError: + print(" βœ— Arcade Gateway - timeout") + return suite + except Exception as e: + print(f" βœ— Arcade Gateway - {type(e).__name__}: {e}") + return suite + + # Test Case 1: Simple addition + suite.add_case( + name="Simple addition - 10 + 5", + user_message="What is 10 plus 5?", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="Math_Add", + args={"a": 10, "b": 5}, + ) + ], + critics=[ + BinaryCritic(critic_field="a", weight=0.5), + BinaryCritic(critic_field="b", weight=0.5), + ], + ) + + # Test Case 2: Larger numbers + suite.add_case( + name="Addition - 123 + 456", + user_message="Calculate 123 + 456", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="Math_Add", + args={"a": 123, "b": 456}, + ) + ], + critics=[ + BinaryCritic(critic_field="a", weight=0.5), + BinaryCritic(critic_field="b", weight=0.5), + ], + ) + + # Test Case 3: Conversational context + suite.add_case( + name="Addition with context", + user_message="Now add them together", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="Math_Add", + args={"a": 50, "b": 25}, + ) + ], + critics=[ + BinaryCritic(critic_field="a", weight=0.5), + BinaryCritic(critic_field="b", weight=0.5), + ], + additional_messages=[ + {"role": "user", "content": "I have two numbers: 50 and 25"}, + {"role": "assistant", "content": "Great! I'll remember those numbers."}, + ], + ) + + return suite diff --git a/examples/evals/eval_comprehensive_comparison.py b/examples/evals/eval_comprehensive_comparison.py new file mode 100644 index 000000000..6f85380ef --- /dev/null +++ b/examples/evals/eval_comprehensive_comparison.py @@ -0,0 +1,229 @@ +"""Comprehensive comparison across multiple tool sources. + +This example demonstrates comparative evaluations across different sources: +- Arcade Gateway (cloud toolkits) +- Local stdio MCP servers +- Dict-based tool definitions (baseline) + +Run: + arcade evals examples/evals/eval_comprehensive_comparison.py \\ + -p "openai:gpt-4o anthropic:claude-sonnet-4-5-20250929" \\ + -k openai:YOUR_KEY -k anthropic:YOUR_KEY \\ + -o comparison.html -d +""" + +import asyncio +import os + +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedMCPToolCall, + MCPToolDefinition, + SimilarityCritic, + tool_eval, +) + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +ARCADE_API_KEY = os.environ.get("ARCADE_API_KEY", "YOUR_ARCADE_API_KEY_HERE") +ARCADE_USER_ID = os.environ.get("ARCADE_USER_ID", "YOUR_USER_ID_HERE") + +EXAMPLES_DIR = os.path.dirname(os.path.dirname(__file__)) +SIMPLE_SERVER_PATH = os.path.join(EXAMPLES_DIR, "mcp_servers", "simple") + +SIMPLE_SERVER_COMMAND = [ + "uv", + "run", + "--directory", + SIMPLE_SERVER_PATH, + "simple", +] + +# Baseline dict tool (for comparison) +DICT_SEARCH: MCPToolDefinition = { + "name": "search", + "description": "Search for information", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + }, +} + +default_rubric = EvalRubric( + fail_threshold=0.6, + warn_threshold=0.8, + fail_on_tool_selection=False, +) + + +# ============================================================================= +# EVAL SUITE +# ============================================================================= + + +@tool_eval() +async def eval_comprehensive_comparison() -> EvalSuite: + """Compare tool performance across multiple sources.""" + suite = EvalSuite( + name="Multi-Source Comparative Evaluation", + system_message="You are a helpful assistant with various tools available.", + rubric=default_rubric, + ) + + loaded_tracks: list[str] = [] + + # Always add baseline dict tools + suite.add_tool_definitions([DICT_SEARCH], track="dict_baseline") + loaded_tracks.append("dict_baseline") + + print("\n Loading tool sources...") + + # Load from Arcade Gateway + try: + print(" β†’ Loading Arcade Gateway (Math)...") + await asyncio.wait_for( + suite.add_arcade_gateway( + gateway_slug="Math", + arcade_api_key=ARCADE_API_KEY, + arcade_user_id=ARCADE_USER_ID, + track="arcade_gateway", + ), + timeout=10.0, + ) + loaded_tracks.append("arcade_gateway") + print(" βœ“ Arcade Gateway") + except asyncio.TimeoutError: + print(" βœ— Arcade Gateway - timeout") + except Exception as e: + print(f" βœ— Arcade Gateway - {type(e).__name__}: {e}") + + # Load from stdio MCP server + try: + print(" β†’ Loading stdio MCP server (simple)...") + await asyncio.wait_for( + suite.add_mcp_stdio_server( + command=SIMPLE_SERVER_COMMAND, + env={"PYTHONUNBUFFERED": "1"}, + track="stdio_simple", + ), + timeout=15.0, + ) + loaded_tracks.append("stdio_simple") + print(" βœ“ Stdio MCP server") + except asyncio.TimeoutError: + print(" βœ— Stdio MCP server - timeout") + except Exception as e: + print(f" βœ— Stdio MCP server - {type(e).__name__}: {e}") + + print(f"\n Loaded tracks: {loaded_tracks}\n") + + # ========================================================================= + # TEST CASE 1: Math operation (Arcade Gateway vs baseline) + # ========================================================================= + + if "arcade_gateway" in loaded_tracks: + case1 = suite.add_comparative_case( + name="Math addition - Gateway vs Baseline", + user_message="What is 15 plus 27?", + ) + case1.for_track( + "arcade_gateway", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="Math_Add", + args={"a": 15, "b": 27}, + ) + ], + critics=[ + BinaryCritic(critic_field="a", weight=0.5), + BinaryCritic(critic_field="b", weight=0.5), + ], + ) + case1.for_track( + "dict_baseline", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="search", + args={"query": "15 plus 27"}, + ) + ], + critics=[SimilarityCritic(critic_field="query", weight=1.0, similarity_threshold=0.3)], + ) + + # ========================================================================= + # TEST CASE 2: Echo operation (stdio vs baseline) + # ========================================================================= + + if "stdio_simple" in loaded_tracks: + case2 = suite.add_comparative_case( + name="Echo message - Stdio vs Baseline", + user_message="Echo 'Hello World'", + ) + case2.for_track( + "stdio_simple", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="echo", + args={"message": "Hello World"}, + ) + ], + critics=[ + BinaryCritic(critic_field="message", weight=1.0), + ], + ) + case2.for_track( + "dict_baseline", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="search", + args={"query": "Hello World"}, + ) + ], + critics=[SimilarityCritic(critic_field="query", weight=1.0, similarity_threshold=0.5)], + ) + + # ========================================================================= + # TEST CASE 3: Conversational context + # ========================================================================= + + if "arcade_gateway" in loaded_tracks: + case3 = suite.add_comparative_case( + name="Math with context", + user_message="Now add them together", + additional_messages=[ + {"role": "user", "content": "I have two numbers: 50 and 25"}, + {"role": "assistant", "content": "I'll remember those numbers."}, + ], + ) + case3.for_track( + "arcade_gateway", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="Math_Add", + args={"a": 50, "b": 25}, + ) + ], + critics=[ + BinaryCritic(critic_field="a", weight=0.5), + BinaryCritic(critic_field="b", weight=0.5), + ], + ) + case3.for_track( + "dict_baseline", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="search", + args={"query": "50 plus 25"}, + ) + ], + critics=[SimilarityCritic(critic_field="query", weight=1.0, similarity_threshold=0.3)], + ) + + return suite diff --git a/examples/evals/eval_http_mcp_server.py b/examples/evals/eval_http_mcp_server.py new file mode 100644 index 000000000..b70952cc1 --- /dev/null +++ b/examples/evals/eval_http_mcp_server.py @@ -0,0 +1,142 @@ +"""Remote HTTP/SSE MCP server evaluation. + +This example demonstrates loading and evaluating tools from remote MCP servers +accessible via HTTP or Server-Sent Events (SSE). + +NOTE: This requires a running HTTP MCP server. Update the configuration below +with your server details. + +Run: + arcade evals examples/evals/eval_http_mcp_server.py \\ + -p openai:gpt-4o \\ + -k openai:YOUR_KEY \\ + -o results.html -d +""" + +import asyncio +import os + +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedMCPToolCall, + tool_eval, +) + +# ============================================================================= +# CONFIGURATION - Update these for your HTTP MCP server +# ============================================================================= + +# Example: GitHub Copilot MCP (requires GitHub token) +HTTP_MCP_URL = os.environ.get("MCP_SERVER_URL", "https://api.githubcopilot.com/mcp/") +HTTP_MCP_TOKEN = os.environ.get("GITHUB_PAT", "YOUR_GITHUB_TOKEN_HERE") + +# Example: SSE-based MCP server +SSE_MCP_URL = os.environ.get("SSE_MCP_URL", "https://mcp.example.com/sse") + +default_rubric = EvalRubric( + fail_threshold=0.7, + warn_threshold=0.9, +) + + +# ============================================================================= +# EVAL SUITE - HTTP MCP Server +# ============================================================================= + + +@tool_eval() +async def eval_http_mcp_server() -> EvalSuite: + """Evaluate tools from HTTP MCP server.""" + suite = EvalSuite( + name="HTTP MCP Server Evaluation", + system_message="You are a helpful assistant with access to remote tools.", + rubric=default_rubric, + ) + + print("\n Loading HTTP MCP server...") + + try: + await asyncio.wait_for( + suite.add_mcp_server( + url=HTTP_MCP_URL, + headers={"Authorization": f"Bearer {HTTP_MCP_TOKEN}"}, + use_sse=False, # Use HTTP streaming + ), + timeout=15.0, + ) + print(" βœ“ HTTP MCP server") + except asyncio.TimeoutError: + print(" βœ— HTTP MCP server - timeout") + return suite + except Exception as e: + print(f" βœ— HTTP MCP server - {type(e).__name__}: {e}") + return suite + + # Add test cases based on your server's tools + # Example: If your server has an echo tool + suite.add_case( + name="HTTP server tool call", + user_message="Echo 'Hello from HTTP'", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="echo", # Adjust to match your server's tool names + args={"message": "Hello from HTTP"}, + ) + ], + critics=[ + BinaryCritic(critic_field="message", weight=1.0), + ], + ) + + return suite + + +# ============================================================================= +# EVAL SUITE - SSE MCP Server +# ============================================================================= + + +@tool_eval() +async def eval_sse_mcp_server() -> EvalSuite: + """Evaluate tools from SSE MCP server.""" + suite = EvalSuite( + name="SSE MCP Server Evaluation", + system_message="You are a helpful assistant with access to SSE-connected tools.", + rubric=default_rubric, + ) + + print("\n Loading SSE MCP server...") + + try: + await asyncio.wait_for( + suite.add_mcp_server( + url=SSE_MCP_URL, + use_sse=True, # Use SSE transport + headers={"Accept": "text/event-stream"}, + ), + timeout=15.0, + ) + print(" βœ“ SSE MCP server") + except asyncio.TimeoutError: + print(" βœ— SSE MCP server - timeout") + return suite + except Exception as e: + print(f" βœ— SSE MCP server - {type(e).__name__}: {e}") + return suite + + # Add test cases for your SSE server's tools + suite.add_case( + name="SSE server tool call", + user_message="Get status", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="get_status", # Adjust to match your server's tools + args={}, + ) + ], + critics=[], + ) + + return suite diff --git a/examples/evals/eval_stdio_mcp_server.py b/examples/evals/eval_stdio_mcp_server.py new file mode 100644 index 000000000..30ad12c81 --- /dev/null +++ b/examples/evals/eval_stdio_mcp_server.py @@ -0,0 +1,124 @@ +"""Local stdio MCP server evaluation. + +This example demonstrates loading and evaluating tools from a local MCP server +running as a subprocess via stdio (standard input/output). + +Run: + arcade evals examples/evals/eval_stdio_mcp_server.py \\ + -p openai:gpt-4o \\ + -k openai:YOUR_KEY \\ + -o results.html -d +""" + +import asyncio +import os + +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedMCPToolCall, + tool_eval, +) + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +# Path to the simple echo server +EXAMPLES_DIR = os.path.dirname(os.path.dirname(__file__)) +SIMPLE_SERVER_PATH = os.path.join(EXAMPLES_DIR, "mcp_servers", "simple") + +# Stdio server command +SIMPLE_SERVER_COMMAND = [ + "uv", + "run", + "--directory", + SIMPLE_SERVER_PATH, + "simple", +] + +default_rubric = EvalRubric( + fail_threshold=0.7, + warn_threshold=0.9, +) + + +# ============================================================================= +# EVAL SUITE +# ============================================================================= + + +@tool_eval() +async def eval_stdio_simple_server() -> EvalSuite: + """Evaluate simple echo server via stdio.""" + suite = EvalSuite( + name="Stdio MCP Server - Simple Echo", + system_message="You are a helpful assistant that can echo messages.", + rubric=default_rubric, + ) + + print("\n Loading stdio MCP server (simple)...") + + try: + await asyncio.wait_for( + suite.add_mcp_stdio_server( + command=SIMPLE_SERVER_COMMAND, + env={"PYTHONUNBUFFERED": "1"}, + ), + timeout=15.0, + ) + print(" βœ“ Simple MCP server (stdio)") + except asyncio.TimeoutError: + print(" βœ— Simple MCP server (stdio) - timeout") + return suite + except Exception as e: + print(f" βœ— Simple MCP server (stdio) - {type(e).__name__}: {e}") + return suite + + # Test Case 1: Simple echo + suite.add_case( + name="Echo - Hello", + user_message="Echo the word 'Hello'", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="echo", + args={"message": "Hello"}, + ) + ], + critics=[ + BinaryCritic(critic_field="message", weight=1.0), + ], + ) + + # Test Case 2: Echo with punctuation + suite.add_case( + name="Echo - Hello, World!", + user_message="Echo this: Hello, World!", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="echo", + args={"message": "Hello, World!"}, + ) + ], + critics=[ + BinaryCritic(critic_field="message", weight=1.0), + ], + ) + + # Test Case 3: Echo longer phrase + suite.add_case( + name="Echo - Longer phrase", + user_message="Please echo: The quick brown fox", + expected_tool_calls=[ + ExpectedMCPToolCall( + tool_name="echo", + args={"message": "The quick brown fox"}, + ) + ], + critics=[ + BinaryCritic(critic_field="message", weight=1.0), + ], + ) + + return suite diff --git a/libs/arcade-cli/arcade_cli/display.py b/libs/arcade-cli/arcade_cli/display.py index fae74b5e2..bfc760fe1 100644 --- a/libs/arcade-cli/arcade_cli/display.py +++ b/libs/arcade-cli/arcade_cli/display.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Any +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional from arcade_core.schema import ToolDefinition from rich.console import Console @@ -323,14 +324,14 @@ def display_tool_messages(tool_messages: list[dict]) -> None: ) -def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool = False) -> None: - """ - Display evaluation results in a format inspired by pytest's output. - - Args: - results: List of dictionaries containing evaluation results for each model. - show_details: Whether to show detailed results for each case. - """ +def _display_results_to_console( + output_console: Console, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: Optional[tuple[int, int, int, int]] = None, +) -> None: + """Display evaluation results to a Rich console.""" total_passed = 0 total_failed = 0 total_warned = 0 @@ -343,9 +344,9 @@ def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool cases = model_results.get("cases", []) total_cases += len(cases) - console.print(f"[bold]Model:[/bold] [bold magenta]{model}[/bold magenta]") + output_console.print(f"[bold]Model:[/bold] [bold magenta]{model}[/bold magenta]") if show_details: - console.print(f"[bold magenta]{rubric}[/bold magenta]") + output_console.print(f"[bold magenta]{rubric}[/bold magenta]") for case in cases: evaluation = case["evaluation"] @@ -365,24 +366,123 @@ def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool # Display one-line summary for each case with score as a percentage score_percentage = evaluation.score * 100 - console.print(f"{status} {case['name']} -- Score: {score_percentage:.2f}%") + output_console.print(f"{status} {case['name']} -- Score: {score_percentage:.2f}%") if show_details: # Show detailed information for each case - console.print(f"[bold]User Input:[/bold] {case['input']}\n") - console.print("[bold]Details:[/bold]") - console.print(_format_evaluation(evaluation)) - console.print("-" * 80) - - # Summary - summary = ( - f"[bold]Summary -- [/bold]Total: {total_cases} -- [green]Passed: {total_passed}[/green]" - ) - if total_warned > 0: - summary += f" -- [yellow]Warnings: {total_warned}[/yellow]" - if total_failed > 0: - summary += f" -- [red]Failed: {total_failed}[/red]" - console.print(summary + "\n") + output_console.print(f"[bold]User Input:[/bold] {case['input']}\n") + output_console.print("[bold]Details:[/bold]") + output_console.print(_format_evaluation(evaluation)) + output_console.print("-" * 80) + + output_console.print("") + + # Summary - use original counts if filtering, otherwise use current counts + if failed_only and original_counts: + # Unpack original counts + orig_total, orig_passed, orig_failed, orig_warned = original_counts + + # Show disclaimer before summary + output_console.print( + f"[bold yellow]Note: Showing only {total_cases} failed evaluation(s) (--only-failed)[/bold yellow]" + ) + + # Build summary with original counts + summary = ( + f"[bold]Summary -- [/bold]Total: {orig_total} -- [green]Passed: {orig_passed}[/green]" + ) + if orig_warned > 0: + summary += f" -- [yellow]Warnings: {orig_warned}[/yellow]" + if orig_failed > 0: + summary += f" -- [red]Failed: {orig_failed}[/red]" + else: + # Normal summary with current counts + summary = ( + f"[bold]Summary -- [/bold]Total: {total_cases} -- [green]Passed: {total_passed}[/green]" + ) + if total_warned > 0: + summary += f" -- [yellow]Warnings: {total_warned}[/yellow]" + if total_failed > 0: + summary += f" -- [red]Failed: {total_failed}[/red]" + + output_console.print(summary + "\n") + + +def display_eval_results( + results: list[list[dict[str, Any]]], + show_details: bool = False, + output_file: Optional[str] = None, + failed_only: bool = False, + original_counts: Optional[tuple[int, int, int, int]] = None, + output_formats: list[str] | None = None, + include_context: bool = False, +) -> None: + """ + Display evaluation results in a format inspired by pytest's output. + + Args: + results: List of dictionaries containing evaluation results for each model. + show_details: Whether to show detailed results for each case. + output_file: Optional file path to write results to. + failed_only: Whether only failed cases are being displayed (adds disclaimer). + original_counts: Optional tuple of (total_cases, total_passed, total_failed, total_warned) + from before filtering. Used when failed_only is True. + output_formats: List of output formats for file output (e.g., ['txt', 'md', 'html']). + include_context: Whether to include system_message and additional_messages. + """ + # Always display to terminal with Rich formatting + try: + _display_results_to_console(console, results, show_details, failed_only, original_counts) + except Exception as e: + console.print(f"[red]Error displaying results to console: {type(e).__name__}: {e}[/red]") + + # Also write to file(s) if requested using the specified formatter(s) + if output_file and output_formats: + from arcade_cli.formatters import get_formatter + + # Get base path without extension + base_path = Path(output_file) + base_name = base_path.stem + parent_dir = base_path.parent + + try: + parent_dir.mkdir(parents=True, exist_ok=True) + except PermissionError: + console.print(f"[red]Error: Permission denied creating directory {parent_dir}[/red]") + return + except OSError as e: + console.print(f"[red]Error creating directory: {e}[/red]") + return + + for fmt in output_formats: + # Define output_path early so it's available in exception handlers + output_path = parent_dir / f"{base_name}.{fmt}" + try: + formatter = get_formatter(fmt) + formatted_output = formatter.format( + results, + show_details=show_details, + failed_only=failed_only, + original_counts=original_counts, + include_context=include_context, + ) + + # Build output path with proper extension + output_path = parent_dir / f"{base_name}.{formatter.file_extension}" + + with open(output_path, "w", encoding="utf-8") as f: + f.write(formatted_output) + + console.print(f"[green]βœ“ Results written to {output_path}[/green]") + + except PermissionError: + console.print(f"[red]Error: Permission denied writing to {output_path}[/red]") + except OSError as e: + console.print(f"[red]Error writing file: {e}[/red]") + except Exception as e: + console.print( + f"[red]Error formatting results ({fmt}): {type(e).__name__}: {e}[/red]" + ) def _format_evaluation(evaluation: "EvaluationResult") -> str: diff --git a/libs/arcade-cli/arcade_cli/evals_runner.py b/libs/arcade-cli/arcade_cli/evals_runner.py new file mode 100644 index 000000000..4c16d5ec1 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/evals_runner.py @@ -0,0 +1,515 @@ +""" +Evaluation and capture mode execution logic for the CLI. + +This module contains the async execution functions for running evaluations +and capture mode operations. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable + +from rich.console import Console +from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn +from rich.text import Text + +from arcade_cli.display import display_eval_results +from arcade_cli.formatters import get_capture_formatter +from arcade_cli.utils import ModelSpec, filter_failed_evaluations + +if TYPE_CHECKING: + from arcade_evals import CaptureResult + +logger = logging.getLogger(__name__) + + +# All supported output formats +ALL_FORMATS = ["txt", "md", "html", "json"] + + +def parse_output_formats(format_str: str, console: Console | None = None) -> list[str]: + """ + Parse output format string into a list of formats. + + Supports: + - Single format: "md" -> ["md"] + - Comma-separated: "md,html" -> ["md", "html"] + - "all" keyword: "all" -> ["txt", "md", "html", "json"] + + Args: + format_str: The format string from CLI. + console: Optional Rich console for error messages (unused now - raises instead). + + Returns: + List of valid format strings. + + Raises: + ValueError: If any invalid formats are provided. + """ + if format_str.lower() == "all": + return ALL_FORMATS.copy() + + formats = [f.strip().lower() for f in format_str.split(",")] + valid_formats = [f for f in formats if f in ALL_FORMATS] + invalid_formats = [f for f in formats if f and f not in ALL_FORMATS] + + # Fail fast on invalid formats (parse-time validation) + if invalid_formats: + valid_list = ", ".join(ALL_FORMATS) + raise ValueError( + f"Invalid format(s): {', '.join(invalid_formats)}. Valid formats: {valid_list}" + ) + + return valid_formats + + +# --- Result Types for Error Handling --- + + +@dataclass +class EvalTaskResult: + """Result of running a single evaluation task.""" + + suite_name: str + model: str + provider: str + success: bool + result: Any | None = None # EvalResult on success + error: str | None = None + error_type: str | None = None + + @property + def display_name(self) -> str: + """Get display name in format 'provider/model'.""" + return f"{self.provider}/{self.model}" + + @classmethod + def from_success( + cls, suite_name: str, model: str, provider: str, result: Any + ) -> EvalTaskResult: + """Create a successful result.""" + return cls( + suite_name=suite_name, model=model, provider=provider, success=True, result=result + ) + + @classmethod + def from_error( + cls, suite_name: str, model: str, provider: str, error: Exception + ) -> EvalTaskResult: + """Create a failed result from an exception.""" + return cls( + suite_name=suite_name, + model=model, + provider=provider, + success=False, + error=str(error), + error_type=type(error).__name__, + ) + + +@dataclass +class CaptureTaskResult: + """Result of running a single capture task.""" + + suite_name: str + model: str + provider: str + success: bool + result: list[CaptureResult] | None = None # List of CaptureResult on success + error: str | None = None + error_type: str | None = None + + @property + def display_name(self) -> str: + """Get display name in format 'provider/model'.""" + return f"{self.provider}/{self.model}" + + @classmethod + def from_success( + cls, suite_name: str, model: str, provider: str, result: list[CaptureResult] + ) -> CaptureTaskResult: + """Create a successful result.""" + return cls( + suite_name=suite_name, model=model, provider=provider, success=True, result=result + ) + + @classmethod + def from_error( + cls, suite_name: str, model: str, provider: str, error: Exception + ) -> CaptureTaskResult: + """Create a failed result from an exception.""" + return cls( + suite_name=suite_name, + model=model, + provider=provider, + success=False, + error=str(error), + error_type=type(error).__name__, + ) + + +# --- Task Wrappers with Error Handling --- + + +async def _run_eval_task( + suite_func: Callable[..., Any], + model_spec: ModelSpec, + max_concurrent: int, + include_context: bool = False, +) -> EvalTaskResult: + """ + Run a single evaluation task with error handling. + + Returns EvalTaskResult with success/failure info instead of raising. + """ + suite_name = suite_func.__name__ + + try: + result = await suite_func( + provider_api_key=model_spec.api_key, + model=model_spec.model, + max_concurrency=max_concurrent, + provider=model_spec.provider.value, + include_context=include_context, + ) + return EvalTaskResult.from_success( + suite_name, model_spec.model, model_spec.provider.value, result + ) + + except Exception as e: + logger.warning( + "Evaluation task failed: suite=%s, model=%s, provider=%s, error=%s: %s", + suite_name, + model_spec.model, + model_spec.provider.value, + type(e).__name__, + str(e), + exc_info=True, # Include full traceback for debugging + ) + return EvalTaskResult.from_error(suite_name, model_spec.model, model_spec.provider.value, e) + + +async def _run_capture_task( + suite_func: Callable[..., Any], + model_spec: ModelSpec, + max_concurrent: int, + include_context: bool, +) -> CaptureTaskResult: + """ + Run a single capture task with error handling. + + Returns CaptureTaskResult with success/failure info instead of raising. + """ + suite_name = suite_func.__name__ + + try: + result = await suite_func( + provider_api_key=model_spec.api_key, + model=model_spec.model, + max_concurrency=max_concurrent, + provider=model_spec.provider.value, + capture_mode=True, + include_context=include_context, + ) + return CaptureTaskResult.from_success( + suite_name, model_spec.model, model_spec.provider.value, result + ) + + except Exception as e: + logger.warning( + "Capture task failed: suite=%s, model=%s, provider=%s, error=%s: %s", + suite_name, + model_spec.model, + model_spec.provider.value, + type(e).__name__, + str(e), + exc_info=True, # Include full traceback for debugging + ) + return CaptureTaskResult.from_error( + suite_name, model_spec.model, model_spec.provider.value, e + ) + + +# --- Main Runner Functions --- + + +async def run_evaluations( + eval_suites: list[Callable[..., Any]], + model_specs: list[ModelSpec], + max_concurrent: int, + show_details: bool, + output_file: str | None, + output_format: str, + failed_only: bool, + console: Console, + include_context: bool = False, +) -> None: + """ + Run evaluation suites and display results. + + Individual task failures are caught and reported without crashing the entire batch. + + Args: + eval_suites: List of decorated evaluation suite functions. + model_specs: List of ModelSpec objects containing provider, model, and API key. + max_concurrent: Maximum concurrent evaluations. + show_details: Whether to show detailed results. + output_file: Optional file path to write results. + output_format: Format for file output ('txt', 'md'). + failed_only: Whether to show only failed evaluations. + console: Rich console for output. + include_context: Whether to include system_message and additional_messages. + """ + tasks = [] + + for suite_func in eval_suites: + console.print( + Text.assemble( + ("Running evaluations in ", "bold"), + (suite_func.__name__, "bold blue"), + ) + ) + for model_spec in model_specs: + task = asyncio.create_task( + _run_eval_task( + suite_func=suite_func, + model_spec=model_spec, + max_concurrent=max_concurrent, + include_context=include_context, + ) + ) + tasks.append(task) + + # Track progress with Rich progress bar (compatible with Rich console) + # Note: task_results is collected synchronously as each async task completes. + # The append() is atomic in CPython due to the GIL, and we await each future + # sequentially within the for-loop, so this is safe. + task_results: list[EvalTaskResult] = [] + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + console=console, + transient=False, + ) as progress: + task_id = progress.add_task("[cyan]Running evaluations...", total=len(tasks)) + for f in asyncio.as_completed(tasks): + result = await f + task_results.append(result) + # Update progress with completed task info + progress.update( + task_id, + advance=1, + description=f"[cyan]Completed: {result.suite_name} ({result.display_name})", + ) + + # Separate successes and failures + successful = [r for r in task_results if r.success] + failed = [r for r in task_results if not r.success] + + # Report failures + if failed: + console.print(f"\n[bold yellow]⚠️ {len(failed)} evaluation(s) failed:[/bold yellow]") + for fail in failed: + console.print( + f" β€’ {fail.suite_name} ({fail.display_name}): [red]{fail.error_type}[/red] - {fail.error}" + ) + + # Process successful results + # Normalize results structure: ensure each result is a list (for consistent formatting) + # - Regular evals return a single dict -> wrap in list + # - Comparative evals return a list of dicts -> keep as is + all_evaluations: list[list[dict[str, Any]]] = [] + for r in successful: + if r.result is None: + continue + if isinstance(r.result, list): + # Comparative eval: already a list of results (one per track) + all_evaluations.append(r.result) + else: + # Regular eval: single dict, wrap in list for consistent structure + all_evaluations.append([r.result]) + + if not all_evaluations: + console.print("\n[bold red]❌ No evaluations completed successfully.[/bold red]") + return + + # Filter to show only failed evaluations if requested + original_counts = None + if failed_only: + all_evaluations, original_counts = filter_failed_evaluations(all_evaluations) + + # Parse output_format as a list (handles comma-separated and "all") + output_formats = parse_output_formats(output_format, console) + + display_eval_results( + all_evaluations, + show_details=show_details, + output_file=output_file, + failed_only=failed_only, + original_counts=original_counts, + output_formats=output_formats, + include_context=include_context, + ) + + # Summary when there were failures + if failed: + console.print(f"\n[bold]Summary:[/bold] {len(successful)} succeeded, {len(failed)} failed") + + +async def run_capture( + eval_suites: list[Callable[..., Any]], + model_specs: list[ModelSpec], + max_concurrent: int, + include_context: bool, + output_file: str | None, + output_format: str, + console: Console, +) -> None: + """ + Run evaluation suites in capture mode and output results. + + Capture mode records tool calls without scoring them. + Individual task failures are caught and reported without crashing the entire batch. + + Args: + eval_suites: List of decorated evaluation suite functions. + model_specs: List of ModelSpec objects containing provider, model, and API key. + max_concurrent: Maximum concurrent operations. + include_context: Whether to include system_message and additional_messages. + output_file: Optional file path to write results. + output_format: Output format ('json', 'txt', 'md', 'html'). + console: Rich console for output. + """ + tasks = [] + + for suite_func in eval_suites: + console.print( + Text.assemble( + ("Capturing tool calls from ", "bold"), + (suite_func.__name__, "bold cyan"), + ) + ) + for model_spec in model_specs: + task = asyncio.create_task( + _run_capture_task( + suite_func=suite_func, + model_spec=model_spec, + max_concurrent=max_concurrent, + include_context=include_context, + ) + ) + tasks.append(task) + + # Track progress with Rich progress bar (compatible with Rich console) + # Note: task_results is collected synchronously as each async task completes. + # The append() is atomic in CPython due to the GIL, and we await each future + # sequentially within the for-loop, so this is safe. + task_results: list[CaptureTaskResult] = [] + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + console=console, + transient=False, + ) as progress: + task_id = progress.add_task("[cyan]Capturing tool calls...", total=len(tasks)) + for f in asyncio.as_completed(tasks): + result = await f + task_results.append(result) + # Update progress with completed task info + progress.update( + task_id, + advance=1, + description=f"[cyan]Completed: {result.suite_name} ({result.display_name})", + ) + + # Separate successes and failures + successful = [r for r in task_results if r.success] + failed = [r for r in task_results if not r.success] + + # Report failures + if failed: + console.print(f"\n[bold yellow]⚠️ {len(failed)} capture(s) failed:[/bold yellow]") + for fail in failed: + console.print( + f" β€’ {fail.suite_name} ({fail.display_name}): [red]{fail.error_type}[/red] - {fail.error}" + ) + + # Collect successful captures + all_captures: list[CaptureResult] = [] + for r in successful: + if r.result is not None: + all_captures.extend(r.result) + + if not all_captures: + console.print("\n[bold red]❌ No captures completed successfully.[/bold red]") + return + + # Parse output formats (handles comma-separated and "all") + output_formats = parse_output_formats(output_format, console) + + # Output to file(s) or console + if output_file: + # Get base path without extension + base_path = Path(output_file) + base_name = base_path.stem + parent_dir = base_path.parent + + try: + parent_dir.mkdir(parents=True, exist_ok=True) + except PermissionError: + console.print( + f"\n[red]❌ Error: Permission denied creating directory {parent_dir}[/red]" + ) + return + except OSError as e: + console.print(f"\n[red]❌ Error creating directory: {e}[/red]") + return + + for fmt in output_formats: + # Define file_path early so it's available in exception handlers + file_path = parent_dir / f"{base_name}.{fmt}" + try: + formatter = get_capture_formatter(fmt) + formatted_output = formatter.format(all_captures, include_context=include_context) + + # Build output path with proper extension + file_path = parent_dir / f"{base_name}.{formatter.file_extension}" + + with open(file_path, "w", encoding="utf-8") as outfile: + outfile.write(formatted_output) + console.print( + f"\n[green]βœ“ Capture results written to[/green] [bold]{file_path}[/bold]" + ) + + except ValueError as e: + console.print(f"\n[red]❌ {e}[/red]") + except PermissionError: + console.print(f"\n[red]❌ Error: Permission denied writing to {file_path}[/red]") + except OSError as e: + console.print(f"\n[red]❌ Error writing file: {e}[/red]") + else: + # Console output: always use JSON for best copy-paste experience + console.print("\n[bold]Capture Results:[/bold]") + json_formatter = get_capture_formatter("json") + console.print(json_formatter.format(all_captures, include_context=include_context)) + + # Summary + total_cases = sum(len(cap.captured_cases) for cap in all_captures) + total_calls = sum( + sum(len(case.tool_calls) for case in cap.captured_cases) for cap in all_captures + ) + console.print( + f"\n[bold green]Captured {total_calls} tool calls across {total_cases} cases[/bold green]" + ) + + # Summary when there were failures + if failed: + console.print(f"\n[bold]Summary:[/bold] {len(successful)} succeeded, {len(failed)} failed") diff --git a/libs/arcade-cli/arcade_cli/formatters/__init__.py b/libs/arcade-cli/arcade_cli/formatters/__init__.py new file mode 100644 index 000000000..6b23329a3 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/formatters/__init__.py @@ -0,0 +1,102 @@ +"""Formatters for evaluation and capture results output.""" + +from difflib import get_close_matches + +from arcade_cli.formatters.base import CaptureFormatter, EvalResultFormatter +from arcade_cli.formatters.html import CaptureHtmlFormatter, HtmlFormatter +from arcade_cli.formatters.json import CaptureJsonFormatter, JsonFormatter +from arcade_cli.formatters.markdown import CaptureMarkdownFormatter, MarkdownFormatter +from arcade_cli.formatters.text import CaptureTextFormatter, TextFormatter + +# Registry of available formatters for evaluations +FORMATTERS: dict[str, type[EvalResultFormatter]] = { + "txt": TextFormatter, + "md": MarkdownFormatter, + "html": HtmlFormatter, + "json": JsonFormatter, +} + +# Registry of available formatters for capture mode +CAPTURE_FORMATTERS: dict[str, type[CaptureFormatter]] = { + "json": CaptureJsonFormatter, + "txt": CaptureTextFormatter, + "md": CaptureMarkdownFormatter, + "html": CaptureHtmlFormatter, +} + + +def get_formatter(format_name: str) -> EvalResultFormatter: + """ + Get a formatter instance by name. + + Args: + format_name: The format name (e.g., 'txt', 'md'). + + Returns: + An instance of the appropriate formatter. + + Raises: + ValueError: If the format is not supported. Suggests similar format names if available. + """ + formatter_class = FORMATTERS.get(format_name.lower()) + if formatter_class is None: + supported = list(FORMATTERS.keys()) + + # Try to find a close match for better error messages + close_matches = get_close_matches(format_name.lower(), supported, n=1, cutoff=0.6) + + error_msg = f"Unsupported format '{format_name}'." + if close_matches: + error_msg += f" Did you mean '{close_matches[0]}'?" + error_msg += f" Supported formats: {', '.join(supported)}" + + raise ValueError(error_msg) + return formatter_class() + + +def get_capture_formatter(format_name: str) -> CaptureFormatter: + """ + Get a capture formatter instance by name. + + Args: + format_name: The format name (e.g., 'json', 'txt', 'md', 'html'). + + Returns: + An instance of the appropriate formatter. + + Raises: + ValueError: If the format is not supported. Suggests similar format names if available. + """ + formatter_class = CAPTURE_FORMATTERS.get(format_name.lower()) + if formatter_class is None: + supported = list(CAPTURE_FORMATTERS.keys()) + + close_matches = get_close_matches(format_name.lower(), supported, n=1, cutoff=0.6) + + error_msg = f"Unsupported capture format '{format_name}'." + if close_matches: + error_msg += f" Did you mean '{close_matches[0]}'?" + error_msg += f" Supported formats: {', '.join(supported)}" + + raise ValueError(error_msg) + return formatter_class() + + +__all__ = [ + # Eval formatters + "FORMATTERS", + "EvalResultFormatter", + "HtmlFormatter", + "JsonFormatter", + "MarkdownFormatter", + "TextFormatter", + "get_formatter", + # Capture formatters + "CAPTURE_FORMATTERS", + "CaptureFormatter", + "CaptureHtmlFormatter", + "CaptureJsonFormatter", + "CaptureMarkdownFormatter", + "CaptureTextFormatter", + "get_capture_formatter", +] diff --git a/libs/arcade-cli/arcade_cli/formatters/base.py b/libs/arcade-cli/arcade_cli/formatters/base.py new file mode 100644 index 000000000..3b1d61661 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/formatters/base.py @@ -0,0 +1,791 @@ +"""Base formatter for evaluation and capture results.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from arcade_evals import CaptureResult + +# Type alias for capture results +CaptureResults = list["CaptureResult"] + +# --- Type Aliases --- +# The results structure: list of suites, each containing list of model results +EvalResults = list[list[dict[str, Any]]] + +# Model -> Suite -> Cases mapping +ModelSuiteGroups = dict[str, dict[str, list[dict[str, Any]]]] + +# Statistics tuple: (total, passed, failed, warned) +EvalStats = tuple[int, int, int, int] + +# Comparative grouping: model -> base_suite -> case_name -> {input, tracks: {track: case_result}} +ComparativeCaseData = dict[str, Any] # {input, tracks: {track_name: case_result}} +ComparativeSuiteData = dict[str, ComparativeCaseData] # case_name -> ComparativeCaseData +ComparativeGroups = dict[str, dict[str, ComparativeSuiteData]] # model -> suite -> cases + +# --- Constants --- +# Maximum field value length before truncation (for display) +MAX_FIELD_DISPLAY_LENGTH = 60 +TRUNCATION_SUFFIX = "..." + + +def truncate_field_value(value: str, max_length: int = MAX_FIELD_DISPLAY_LENGTH) -> str: + """ + Truncate long field values for display. + + Args: + value: The string value to potentially truncate. + max_length: Maximum allowed length (default: 60). + + Returns: + The original value if within limits, or truncated with "..." suffix. + """ + if len(value) > max_length: + return value[: max_length - len(TRUNCATION_SUFFIX)] + TRUNCATION_SUFFIX + return value + + +def group_results_by_model( + results: EvalResults, +) -> tuple[ModelSuiteGroups, int, int, int, int]: + """ + Group evaluation results by model and suite, collecting statistics. + + This is the shared logic used by all formatters and display functions. + + Args: + results: Nested list of evaluation results by suite and model. + + Returns: + A tuple of: + - model_groups: Dict mapping model -> suite -> list of cases + - total_passed: Count of passed evaluations + - total_failed: Count of failed evaluations + - total_warned: Count of warned evaluations + - total_cases: Total count of all cases + """ + total_passed = 0 + total_failed = 0 + total_warned = 0 + total_cases = 0 + model_groups: ModelSuiteGroups = {} + + for eval_suite in results: + for model_results in eval_suite: + model = model_results.get("model", "Unknown Model") + + # suite_name is always set by EvalSuite.evaluate() + suite_name = model_results.get("suite_name") or "Unnamed Suite" + + cases = model_results.get("cases", []) + total_cases += len(cases) + + if model not in model_groups: + model_groups[model] = {} + + if suite_name not in model_groups[model]: + model_groups[model][suite_name] = [] + + for case in cases: + evaluation = case["evaluation"] + if evaluation.passed: + total_passed += 1 + elif evaluation.warning: + total_warned += 1 + else: + total_failed += 1 + + model_groups[model][suite_name].append(case) + + return model_groups, total_passed, total_failed, total_warned, total_cases + + +def is_comparative_result(results: EvalResults) -> bool: + """ + Check if results contain comparative evaluations. + + Comparative results have a 'track_name' field that indicates they came + from a multi-track comparative evaluation. + + Args: + results: Nested list of evaluation results. + + Returns: + True if any result has a 'track_name' field. + """ + for eval_suite in results: + for model_results in eval_suite: + if model_results.get("track_name"): + return True + return False + + +def _extract_base_suite_name(suite_name: str, track_name: str) -> str: + """ + Extract the base suite name by removing the track suffix. + + Examples: + "My Suite [track_a]" with track "track_a" -> "My Suite" + "Suite Name [some_track]" with track "some_track" -> "Suite Name" + """ + suffix = f" [{track_name}]" + if suite_name.endswith(suffix): + return suite_name[: -len(suffix)] + return suite_name + + +def group_comparative_by_case( + results: EvalResults, +) -> tuple[ComparativeGroups, int, int, int, int, dict[str, list[str]]]: + """ + Group comparative results by model, suite, and case name. + + This allows showing the same case across different tracks side-by-side. + + Args: + results: Nested list of evaluation results (must be comparative). + + Returns: + A tuple of: + - comparative_groups: {model: {base_suite: {case_name: {input, tracks: {track: result}}}}} + - total_passed: Count of passed evaluations + - total_failed: Count of failed evaluations + - total_warned: Count of warned evaluations + - total_cases: Total count of all cases (unique case_name * tracks) + - suite_track_order: Dict mapping base_suite -> list of track names for that suite + """ + total_passed = 0 + total_failed = 0 + total_warned = 0 + total_cases = 0 + + # Track order per suite (different suites can have different tracks) + suite_track_order: dict[str, list[str]] = {} + + # Structure: model -> base_suite -> case_name -> {input, tracks: {track: case_result}} + comparative_groups: ComparativeGroups = {} + + for eval_suite in results: + for model_results in eval_suite: + model = model_results.get("model", "Unknown Model") + suite_name = model_results.get("suite_name") or "Unnamed Suite" + track_name = model_results.get("track_name", "default") + + # Extract base suite name (without track suffix) + base_suite = _extract_base_suite_name(suite_name, track_name) + + # Track the order of tracks per suite + if base_suite not in suite_track_order: + suite_track_order[base_suite] = [] + if track_name not in suite_track_order[base_suite]: + suite_track_order[base_suite].append(track_name) + + cases = model_results.get("cases", []) + total_cases += len(cases) + + if model not in comparative_groups: + comparative_groups[model] = {} + + if base_suite not in comparative_groups[model]: + comparative_groups[model][base_suite] = {} + + for case in cases: + case_name = case["name"] + evaluation = case["evaluation"] + + # Count stats + if evaluation.passed: + total_passed += 1 + elif evaluation.warning: + total_warned += 1 + else: + total_failed += 1 + + # Initialize case entry if needed + if case_name not in comparative_groups[model][base_suite]: + comparative_groups[model][base_suite][case_name] = { + "input": case.get("input", ""), + "system_message": case.get("system_message"), + "additional_messages": case.get("additional_messages"), + "tracks": {}, + } + + # Store this track's result for this case + comparative_groups[model][base_suite][case_name]["tracks"][track_name] = { + "evaluation": evaluation, + "name": case_name, + "input": case.get("input", ""), + } + + return ( + comparative_groups, + total_passed, + total_failed, + total_warned, + total_cases, + suite_track_order, + ) + + +def compute_track_differences( + case_data: ComparativeCaseData, + track_order: list[str], +) -> dict[str, list[str]]: + """ + Compute which fields differ between tracks for a given case. + + Compares each track against the first track (baseline). + + Args: + case_data: The case data with tracks. + track_order: List of track names in order. + + Returns: + Dict mapping track_name -> list of field names that differ from baseline. + """ + differences: dict[str, list[str]] = {} + tracks = case_data.get("tracks", {}) + + if len(tracks) < 2 or not track_order: + return differences + + # First track is baseline + baseline_track = track_order[0] + if baseline_track not in tracks: + return differences + + baseline_result = tracks[baseline_track] + baseline_eval = baseline_result.get("evaluation") + if not baseline_eval or not hasattr(baseline_eval, "results"): + return differences + + # Build baseline field values + baseline_fields: dict[str, Any] = {} + for critic_result in baseline_eval.results: + field = critic_result.get("field", "") + baseline_fields[field] = { + "actual": critic_result.get("actual"), + "match": critic_result.get("match"), + "score": critic_result.get("score"), + } + + # Compare other tracks to baseline + for track_name in track_order[1:]: + if track_name not in tracks: + continue + + track_result = tracks[track_name] + track_eval = track_result.get("evaluation") + if not track_eval or not hasattr(track_eval, "results"): + continue + + diff_fields: list[str] = [] + + for critic_result in track_eval.results: + field = critic_result.get("field", "") + actual = critic_result.get("actual") + match = critic_result.get("match") + + # Check if this field exists in baseline and differs + if field in baseline_fields: + baseline_data = baseline_fields[field] + # Different if actual value differs or match status differs + if actual != baseline_data["actual"] or match != baseline_data["match"]: + diff_fields.append(field) + else: + # Field exists in this track but not baseline + diff_fields.append(field) + + differences[track_name] = diff_fields + + return differences + + +# Type for case-first comparative grouping +# Structure: suite -> case_name -> model -> {input, tracks: {track: result}} +CaseFirstComparativeGroups = dict[str, dict[str, dict[str, dict[str, Any]]]] + + +def is_multi_model_comparative(results: EvalResults) -> bool: + """ + Check if comparative results contain multiple models. + + Args: + results: Nested list of evaluation results. + + Returns: + True if this is a comparative result with more than one unique model. + """ + if not is_comparative_result(results): + return False + + models: set[str] = set() + for eval_suite in results: + for model_results in eval_suite: + model = model_results.get("model", "Unknown") + models.add(model) + if len(models) > 1: + return True + return False + + +def group_comparative_by_case_first( + results: EvalResults, +) -> tuple[CaseFirstComparativeGroups, list[str], dict[str, list[str]], int, int, int, int]: + """ + Group comparative results by suite -> case -> model for case-first comparison. + + When multiple models run the same comparative evaluation, this groups results + so the same case from different models appears together. + + Args: + results: Nested list of comparative evaluation results. + + Returns: + A tuple of: + - case_groups: {suite: {case_name: {model: {input, tracks: {track: result}}}}} + - model_order: List of model names in order of appearance + - suite_track_order: Dict mapping suite -> list of track names + - total_passed, total_failed, total_warned, total_cases + """ + total_passed = 0 + total_failed = 0 + total_warned = 0 + total_cases = 0 + + model_order: list[str] = [] + suite_track_order: dict[str, list[str]] = {} + + # Structure: base_suite -> case_name -> model -> {input, tracks: {track: result}} + case_groups: CaseFirstComparativeGroups = {} + + for eval_suite in results: + for model_results in eval_suite: + model = model_results.get("model", "Unknown Model") + suite_name = model_results.get("suite_name") or "Unnamed Suite" + track_name = model_results.get("track_name", "default") + + # Track model order + if model not in model_order: + model_order.append(model) + + # Extract base suite name (without track suffix) + base_suite = _extract_base_suite_name(suite_name, track_name) + + # Track the order of tracks per suite + if base_suite not in suite_track_order: + suite_track_order[base_suite] = [] + if track_name not in suite_track_order[base_suite]: + suite_track_order[base_suite].append(track_name) + + cases = model_results.get("cases", []) + total_cases += len(cases) + + # Initialize suite + if base_suite not in case_groups: + case_groups[base_suite] = {} + + for case in cases: + case_name = case["name"] + evaluation = case["evaluation"] + + # Count stats + if evaluation.passed: + total_passed += 1 + elif evaluation.warning: + total_warned += 1 + else: + total_failed += 1 + + # Initialize case + if case_name not in case_groups[base_suite]: + case_groups[base_suite][case_name] = {} + + # Initialize model entry for this case + if model not in case_groups[base_suite][case_name]: + case_groups[base_suite][case_name][model] = { + "input": case.get("input", ""), + "system_message": case.get("system_message"), + "additional_messages": case.get("additional_messages"), + "tracks": {}, + } + + # Store this track's result + case_groups[base_suite][case_name][model]["tracks"][track_name] = { + "evaluation": evaluation, + "name": case_name, + "input": case.get("input", ""), + } + + return ( + case_groups, + model_order, + suite_track_order, + total_passed, + total_failed, + total_warned, + total_cases, + ) + + +# ============================================================================= +# MULTI-MODEL HELPERS +# ============================================================================= + + +def is_multi_model_eval(results: EvalResults) -> bool: + """ + Check if evaluation results contain multiple models. + + Args: + results: Nested list of evaluation results. + + Returns: + True if more than one unique model is present. + """ + models: set[str] = set() + for eval_suite in results: + for model_results in eval_suite: + model = model_results.get("model", "Unknown") + models.add(model) + if len(models) > 1: + return True + return False + + +def is_multi_model_capture(captures: CaptureResults) -> bool: + """ + Check if capture results contain multiple models. + + Args: + captures: List of CaptureResult objects. + + Returns: + True if more than one unique model is present. + """ + models = {c.model for c in captures} + return len(models) > 1 + + +# Type for multi-model comparison: suite -> case -> model -> case_result +MultiModelComparisonData = dict[str, dict[str, dict[str, dict[str, Any]]]] + +# Type for per-model stats: model -> {passed, failed, warned, total, pass_rate} +PerModelStats = dict[str, dict[str, Any]] + + +def group_eval_for_comparison( + results: EvalResults, +) -> tuple[MultiModelComparisonData, list[str], PerModelStats]: + """ + Reorganize evaluation results for cross-model comparison. + + Groups results by suite -> case -> model, enabling side-by-side tables. + + Args: + results: Nested list of evaluation results. + + Returns: + A tuple of: + - comparison_data: {suite: {case_name: {model: case_result}}} + - model_order: List of model names in order of appearance + - per_model_stats: {model: {passed, failed, warned, total, pass_rate}} + """ + comparison_data: MultiModelComparisonData = {} + model_order: list[str] = [] + per_model_stats: PerModelStats = {} + + for eval_suite in results: + for model_results in eval_suite: + model = model_results.get("model", "Unknown Model") + suite_name = model_results.get("suite_name") or "Unnamed Suite" + cases = model_results.get("cases", []) + + # Track model order + if model not in model_order: + model_order.append(model) + + # Initialize per-model stats + if model not in per_model_stats: + per_model_stats[model] = { + "passed": 0, + "failed": 0, + "warned": 0, + "total": 0, + } + + # Initialize suite in comparison data + if suite_name not in comparison_data: + comparison_data[suite_name] = {} + + for case in cases: + case_name = case["name"] + evaluation = case["evaluation"] + + # Update per-model stats + per_model_stats[model]["total"] += 1 + if evaluation.passed: + per_model_stats[model]["passed"] += 1 + elif evaluation.warning: + per_model_stats[model]["warned"] += 1 + else: + per_model_stats[model]["failed"] += 1 + + # Initialize case in suite + if case_name not in comparison_data[suite_name]: + comparison_data[suite_name][case_name] = {} + + # Store this model's result for this case + comparison_data[suite_name][case_name][model] = { + "evaluation": evaluation, + "input": case.get("input", ""), + "name": case_name, + } + + # Calculate pass rates + for _model, stats in per_model_stats.items(): + if stats["total"] > 0: + stats["pass_rate"] = (stats["passed"] / stats["total"]) * 100 + else: + stats["pass_rate"] = 0.0 + + return comparison_data, model_order, per_model_stats + + +def find_best_model( + case_models: dict[str, dict[str, Any]], +) -> tuple[str | None, float]: + """ + Find the model with the highest score for a case. + + Args: + case_models: Dict mapping model -> case_result with evaluation. + + Returns: + Tuple of (best_model_name, best_score). Returns (None, 0.0) if no models + or if all evaluations are missing. + Returns ("Tie", score) if multiple models share the highest score. + """ + if not case_models: + return None, 0.0 + + best_model: str | None = None + best_score = -1.0 + tie = False + found_valid_evaluation = False + + for model, case_result in case_models.items(): + evaluation = case_result.get("evaluation") + if not evaluation: + continue + + found_valid_evaluation = True + score = evaluation.score + if score > best_score: + best_score = score + best_model = model + tie = False + elif score == best_score: + tie = True + + # Return 0.0 if no valid evaluations found (not -1.0) + if not found_valid_evaluation: + return None, 0.0 + + if tie: + return "Tie", best_score + + return best_model, best_score + + +# Type for grouped captures: suite -> case_name -> {user_message, models: {model: [tool_calls]}} +GroupedCaptures = dict[str, dict[str, dict[str, Any]]] + + +def group_captures_by_case( + captures: CaptureResults, +) -> tuple[GroupedCaptures, list[str]]: + """ + Group capture results by suite and case for multi-model comparison. + + Args: + captures: List of CaptureResult objects. + + Returns: + A tuple of: + - grouped: {suite: {case_key: {user_message, system_message, track_name, models: {model: captured_case}}}} + - model_order: List of model names in order of appearance + + Note: For comparative captures with tracks, case_key includes the track name + to keep them separate (e.g., "weather_case [track_a]"). + """ + grouped: GroupedCaptures = {} + model_order: list[str] = [] + + for capture in captures: + suite_name = capture.suite_name + model = capture.model + + # Track model order + if model not in model_order: + model_order.append(model) + + # Initialize suite + if suite_name not in grouped: + grouped[suite_name] = {} + + for case in capture.captured_cases: + # Include track_name in the key for comparative captures + track_name = getattr(case, "track_name", None) + case_key = f"{case.case_name} [{track_name}]" if track_name else case.case_name + + # Initialize case + if case_key not in grouped[suite_name]: + grouped[suite_name][case_key] = { + "user_message": case.user_message, + "system_message": case.system_message, + "additional_messages": case.additional_messages, + "track_name": track_name, + "models": {}, + } + + # Store this model's captured case + grouped[suite_name][case_key]["models"][model] = case + + return grouped, model_order + + +def group_captures_by_case_then_track( + captures: CaptureResults, +) -> tuple[dict[str, dict[str, dict[str, Any]]], list[str], list[str | None]]: + """ + Group capture results by suite, case, then track for tab-based display. + + Args: + captures: List of CaptureResult objects. + + Returns: + A tuple of: + - grouped: {suite: {base_case_name: {tracks: {track: {models: {model: case}}}, user_message, ...}}} + - model_order: List of model names in order + - track_order: List of track names in order (None for non-comparative) + """ + grouped: dict[str, dict[str, dict[str, Any]]] = {} + model_order: list[str] = [] + track_order: list[str | None] = [] + + for capture in captures: + suite_name = capture.suite_name + model = capture.model + + if model not in model_order: + model_order.append(model) + + if suite_name not in grouped: + grouped[suite_name] = {} + + for case in capture.captured_cases: + track_name = getattr(case, "track_name", None) + base_case_name = case.case_name + + # Track order + if track_name and track_name not in track_order: + track_order.append(track_name) + + # Initialize case + if base_case_name not in grouped[suite_name]: + grouped[suite_name][base_case_name] = { + "user_message": case.user_message, + "system_message": case.system_message, + "additional_messages": case.additional_messages, + "tracks": {}, # {track_name: {models: {model: case}}} + } + + # Initialize track + track_key = track_name or "_default" + if track_key not in grouped[suite_name][base_case_name]["tracks"]: + grouped[suite_name][base_case_name]["tracks"][track_key] = { + "models": {}, + } + + # Store case under track and model + grouped[suite_name][base_case_name]["tracks"][track_key]["models"][model] = case + + # If no tracks, add None to track_order for consistent handling + if not track_order: + track_order = [None] + + return grouped, model_order, track_order + + +class EvalResultFormatter(ABC): + """ + Abstract base class for evaluation result formatters. + + Implement this class to add new output formats (txt, md, json, html, etc.). + """ + + @property + @abstractmethod + def file_extension(self) -> str: + """Return the default file extension for this format (e.g., 'txt', 'md').""" + ... + + @abstractmethod + def format( + self, + results: EvalResults, + show_details: bool = False, + failed_only: bool = False, + original_counts: EvalStats | None = None, + include_context: bool = False, + ) -> str: + """ + Format evaluation results into a string. + + Args: + results: Nested list of evaluation results by suite and model. + show_details: Whether to show detailed results for each case. + failed_only: Whether only failed cases are being displayed. + original_counts: Optional (total, passed, failed, warned) from before filtering. + include_context: Whether to include system_message and additional_messages. + + Returns: + Formatted string representation of the results. + """ + ... + + +class CaptureFormatter(ABC): + """ + Abstract base class for capture result formatters. + + Implement this class to add new output formats for capture mode. + """ + + @property + @abstractmethod + def file_extension(self) -> str: + """Return the default file extension for this format.""" + ... + + @abstractmethod + def format( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """ + Format capture results into a string. + + Args: + captures: List of CaptureResult objects. + include_context: Whether to include system_message and additional_messages. + + Returns: + Formatted string representation of the capture results. + """ + ... diff --git a/libs/arcade-cli/arcade_cli/formatters/html.py b/libs/arcade-cli/arcade_cli/formatters/html.py new file mode 100644 index 000000000..ab900aa20 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/formatters/html.py @@ -0,0 +1,2878 @@ +"""HTML formatter for evaluation and capture results with full color support.""" + +import json +from datetime import datetime, timezone +from typing import Any + +from arcade_cli.formatters.base import ( + CaptureFormatter, + CaptureResults, + ComparativeCaseData, + EvalResultFormatter, + compute_track_differences, + find_best_model, + group_comparative_by_case, + group_comparative_by_case_first, + group_eval_for_comparison, + group_results_by_model, + is_comparative_result, + is_multi_model_capture, + is_multi_model_comparative, + is_multi_model_eval, + truncate_field_value, +) + + +class HtmlFormatter(EvalResultFormatter): + """ + HTML formatter for evaluation results. + + Produces a styled HTML document with colors matching the terminal output. + + Security Note: All user-controllable data MUST be escaped via _escape_html() + before being inserted into HTML. This includes case names, inputs, model names, + suite names, and any evaluation results or error messages. + """ + + def __init__(self) -> None: + """Initialize formatter with ID tracking for uniqueness.""" + super().__init__() + self._id_cache: dict[tuple[str, str, str], str] = {} + self._used_ids: set[str] = set() + + @property + def file_extension(self) -> str: + return "html" + + def format( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + # Check if this is a comparative evaluation + if is_comparative_result(results): + return self._format_comparative( + results, show_details, failed_only, original_counts, include_context + ) + + # Check if this is a multi-model evaluation + if is_multi_model_eval(results): + return self._format_multi_model( + results, show_details, failed_only, original_counts, include_context + ) + + return self._format_regular( + results, show_details, failed_only, original_counts, include_context + ) + + def _format_regular( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format regular (non-comparative) evaluation results.""" + # Use shared grouping logic + model_groups, total_passed, total_failed, total_warned, total_cases = ( + group_results_by_model(results) + ) + + # Calculate pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + else: + pass_rate = 0 + + # Build HTML + html_parts = [self._get_html_header()] + + # Title and timestamp + html_parts.append('
') + html_parts.append("

🎯 Evaluation Results

") + html_parts.append( + f'

Generated: {datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")}

' + ) + + # Summary section + html_parts.append('
') + html_parts.append("

πŸ“Š Summary

") + + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + html_parts.append( + f'
⚠️ Showing only {total_cases} failed evaluation(s)
' + ) + html_parts.append('
') + html_parts.append( + f'
Total{orig_total}
' + ) + html_parts.append( + f'
Passed{orig_passed}
' + ) + if orig_warned > 0: + html_parts.append( + f'
Warnings{orig_warned}
' + ) + html_parts.append( + f'
Failed{orig_failed}
' + ) + else: + html_parts.append('
') + html_parts.append( + f'
Total{total_cases}
' + ) + html_parts.append( + f'
Passed{total_passed}
' + ) + if total_warned > 0: + html_parts.append( + f'
Warnings{total_warned}
' + ) + if total_failed > 0: + html_parts.append( + f'
Failed{total_failed}
' + ) + + html_parts.append("
") # stats-grid + html_parts.append( + f'
Pass Rate: {pass_rate:.1f}%
' + ) + html_parts.append("
") # summary-section + + # Results by model + html_parts.append("

πŸ“‹ Results by Model

") + + for model, suites in model_groups.items(): + html_parts.append('
') + html_parts.append(f"

πŸ€– {self._escape_html(model)}

") + + for suite_name, cases in suites.items(): + # Show suite/file name + html_parts.append('
') + html_parts.append( + f'

πŸ“ {self._escape_html(suite_name)}

' + ) + + # Show summary table only when NOT showing details (avoid duplication) + if not show_details: + html_parts.append('') + html_parts.append( + "" + ) + html_parts.append("") + + for case in cases: + evaluation = case["evaluation"] + if evaluation.passed: + status_class = "passed" + status_text = "βœ… PASSED" + elif evaluation.warning: + status_class = "warned" + status_text = "⚠️ WARNED" + else: + status_class = "failed" + status_text = "❌ FAILED" + + score_pct = evaluation.score * 100 + case_name = self._escape_html(case["name"]) + + html_parts.append(f'') + html_parts.append(f'') + html_parts.append(f"") + html_parts.append(f'') + html_parts.append("") + + html_parts.append("
StatusCaseScore
{status_text}{case_name}{score_pct:.1f}%
") + + # Detailed results - each case is individually expandable + if show_details: + html_parts.append( + '

πŸ’‘ Click on any case below to expand details

' + ) + for case in cases: + evaluation = case["evaluation"] + if evaluation.passed: + status_class = "passed" + status_badge = 'PASSED' + status_icon = "βœ…" + elif evaluation.warning: + status_class = "warned" + status_badge = 'WARNED' + status_icon = "⚠️" + else: + status_class = "failed" + status_badge = 'FAILED' + status_icon = "❌" + + case_name = self._escape_html(case["name"]) + score_pct = evaluation.score * 100 + + # Each case is a collapsible details element (collapsed by default) + html_parts.append(f'
') + html_parts.append( + f'' + f"{status_icon} {case_name} " + f'{score_pct:.1f}% ' + f"{status_badge}" + f"" + ) + html_parts.append('
') + html_parts.append( + f"

Input: {self._escape_html(case['input'])}

" + ) + + # Context section (if include_context is True) + if include_context: + system_msg = case.get("system_message") + addl_msgs = case.get("additional_messages") + if system_msg or addl_msgs: + html_parts.append('
') + html_parts.append("

πŸ“‹ Context

") + if system_msg: + html_parts.append( + f'
' + f"System Message: " + f"{self._escape_html(system_msg)}" + f"
" + ) + if addl_msgs: + conversation_html = self._format_conversation(addl_msgs) + html_parts.append( + f'
' + f"πŸ’¬ Conversation Context ({len(addl_msgs)} messages)" + f"{conversation_html}" + f"
" + ) + html_parts.append("
") + + # Evaluation details + html_parts.append(self._format_evaluation_details(evaluation)) + html_parts.append("
") + html_parts.append("
") + + html_parts.append("
") # suite-section + + html_parts.append("
") # model-section + + html_parts.append("
") # container + html_parts.append("") + + return "\n".join(html_parts) + + def _format_evaluation_details(self, evaluation: Any) -> str: + """Format evaluation details as HTML table.""" + if evaluation.failure_reason: + return f'
❌ Failure Reason: {self._escape_html(evaluation.failure_reason)}
' + + lines = [''] + lines.append( + "" + ) + lines.append("") + + for critic_result in evaluation.results: + is_criticized = critic_result.get("is_criticized", True) + field = self._escape_html(critic_result["field"]) + score = critic_result["score"] + weight = critic_result["weight"] + expected = self._escape_html(str(critic_result["expected"])) + actual = self._escape_html(str(critic_result["actual"])) + + # Truncate long values for table readability + expected = truncate_field_value(expected) + actual = truncate_field_value(actual) + + if is_criticized: + if critic_result["match"]: + match_cell = 'βœ… Match' + row_class = "match-row" + else: + match_cell = '❌ No Match' + row_class = "nomatch-row" + score_cell = f"{score:.2f}/{weight:.2f}" + else: + match_cell = 'β€” Un-criticized' + row_class = "uncriticized-row" + score_cell = "-" + + lines.append(f'') + lines.append(f'') + lines.append(f"") + lines.append(f'') + lines.append(f"") + lines.append(f"") + lines.append("") + + lines.append("
FieldMatchScoreExpectedActual
{field}{match_cell}{score_cell}{expected}{actual}
") + return "\n".join(lines) + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters.""" + return ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + def _make_safe_id(self, suite_name: str, case_name: str, model_name: str = "") -> str: + """Generate a safe ID for HTML attributes and CSS selectors. + + Removes or replaces characters that could break HTML attributes or + CSS selectors, including quotes, brackets, and special characters. + Ensures uniqueness by appending a counter when duplicates are detected. + + Args: + suite_name: The suite name. + case_name: The case name. + model_name: Optional model name. + + Returns: + A sanitized string safe for use in HTML id/data attributes, guaranteed unique. + """ + import re + + def sanitize(s: str) -> str: + # Replace common separators with underscores + s = s.replace(" ", "_").replace("-", "_") + # Remove brackets and parentheses + s = s.replace("[", "").replace("]", "").replace("(", "").replace(")", "") + # Remove quotes that would break HTML attributes + s = s.replace('"', "").replace("'", "") + # Remove any remaining non-alphanumeric characters except underscores + s = re.sub(r"[^\w]", "", s) + return s + + # Check cache for idempotence - same inputs should return same ID + cache_key = (suite_name, case_name, model_name) + if cache_key in self._id_cache: + return self._id_cache[cache_key] + + suite_id = sanitize(suite_name) + case_id_part = sanitize(case_name) + base_id = f"{suite_id}__{case_id_part}" + + if model_name: + model_id = sanitize(model_name) + base_id = f"{model_id}__{base_id}" + + # Ensure uniqueness by appending a counter if this ID already exists + unique_id = base_id + counter = 1 + while unique_id in self._used_ids: + unique_id = f"{base_id}_{counter}" + counter += 1 + + # Cache the result and mark ID as used + self._id_cache[cache_key] = unique_id + self._used_ids.add(unique_id) + return unique_id + + def _format_conversation(self, messages: list[dict]) -> str: + """Format conversation messages as rich HTML for context display.""" + html_parts = ['
'] + + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content") + tool_calls = msg.get("tool_calls", []) + tool_name = msg.get("name", "") # For tool responses + + role_class = f"msg msg-{role}" + role_label = { + "user": "πŸ‘€ User", + "assistant": "πŸ€– Assistant", + "tool": "πŸ”§ Tool", + "system": "βš™οΈ System", + }.get(role, f"πŸ’¬ {role.title()}") + + # Add tool name to label for tool responses + if role == "tool" and tool_name: + role_label = f"πŸ”§ Tool ({tool_name})" + + html_parts.append(f'
') + html_parts.append(f'
{role_label}
') + + if content: + # For tool responses, try to format JSON nicely + if role == "tool": + try: + parsed_content = json.loads(content) + formatted_content = json.dumps(parsed_content, indent=2) + html_parts.append( + f'
{self._escape_html(formatted_content)}
' + ) + except (json.JSONDecodeError, TypeError): + # Not valid JSON, show as regular content + html_parts.append( + f'
{self._escape_html(str(content))}
' + ) + else: + html_parts.append( + f'
{self._escape_html(str(content))}
' + ) + + # Handle tool calls in assistant messages + if tool_calls: + html_parts.append('
') + for tc in tool_calls: + tc_func = tc.get("function", {}) + tc_name = tc_func.get("name", "unknown") + tc_args = tc_func.get("arguments", "{}") + try: + args_formatted = json.dumps(json.loads(tc_args), indent=2) + except (json.JSONDecodeError, TypeError): + args_formatted = str(tc_args) + html_parts.append( + f'
' + f'πŸ› οΈ {self._escape_html(tc_name)}' + f'
{self._escape_html(args_formatted)}
' + f"
" + ) + html_parts.append("
") + + html_parts.append("
") + + html_parts.append("
") + return "\n".join(html_parts) + + # ========================================================================= + # MULTI-MODEL EVALUATION FORMATTING + # ========================================================================= + + def _format_multi_model( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format multi-model evaluation results with comparison tables.""" + comparison_data, model_order, per_model_stats = group_eval_for_comparison(results) + + # Build HTML + html_parts = [self._get_html_header()] + html_parts.append(self._get_multi_model_styles()) + + # Container + html_parts.append('
') + html_parts.append("

πŸ”„ Multi-Model Evaluation Results

") + html_parts.append( + f'

Generated: {datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")}

' + ) + html_parts.append(f'

Models: {", ".join(model_order)}

') + + # Per-Model Summary Section + html_parts.append('
') + html_parts.append("

πŸ“Š Per-Model Summary

") + html_parts.append('') + html_parts.append("") + html_parts.append( + "" + ) + html_parts.append("") + + best_model = None + best_rate = -1.0 + for model in model_order: + stats = per_model_stats[model] + rate = stats["pass_rate"] + + if rate > best_rate: + best_rate = rate + best_model = model + + row_class = "best-model" if rate == best_rate and best_model == model else "" + html_parts.append(f'') + html_parts.append(f'') + html_parts.append(f'') + html_parts.append(f'') + html_parts.append(f'') + html_parts.append(f"") + html_parts.append(f'') + html_parts.append("") + + html_parts.append("
ModelPassedFailedWarnedTotalPass Rate
{self._escape_html(model)}{stats["passed"]}{stats["failed"]}{stats["warned"]}{stats['total']}{rate:.1f}%
") + + if best_model: + html_parts.append( + f'

πŸ† Best Overall: {self._escape_html(best_model)} ({best_rate:.1f}% pass rate)

' + ) + html_parts.append("
") + + # Cross-Model Comparison Section + html_parts.append('
') + html_parts.append("

βš”οΈ Cross-Model Comparison

") + + for suite_name, cases in comparison_data.items(): + html_parts.append('
') + html_parts.append(f"

Suite: {self._escape_html(suite_name)}

") + + # Comparison table + html_parts.append('') + html_parts.append("") + html_parts.append("") + for model in model_order: + html_parts.append(f"") + html_parts.append("") + html_parts.append("") + + for case_name, case_models in cases.items(): + html_parts.append("") + html_parts.append(f'') + + for model in model_order: + if model in case_models: + evaluation = case_models[model]["evaluation"] + score = evaluation.score * 100 + if evaluation.passed: + cell_class = "passed" + icon = "βœ“" + elif evaluation.warning: + cell_class = "warned" + icon = "⚠" + else: + cell_class = "failed" + icon = "βœ—" + html_parts.append(f'') + else: + html_parts.append('') + + # Best model + best, _ = find_best_model(case_models) + if best == "Tie": + html_parts.append('') + elif best and best != "N/A": + html_parts.append(f'') + else: + html_parts.append('') + + html_parts.append("") + + html_parts.append("
Case{self._escape_html(model)}Best
{self._escape_html(case_name)}{icon} {score:.0f}%-🀝 TieπŸ† {self._escape_html(best)}-
") + html_parts.append("
") + + # Detailed results + if show_details: + html_parts.append('
') + html_parts.append("

Detailed Results

") + + for case_name, case_models in cases.items(): + html_parts.append('
') + html_parts.append(f"
{self._escape_html(case_name)}
") + + for model in model_order: + if model not in case_models: + continue + + case_result = case_models[model] + evaluation = case_result["evaluation"] + + html_parts.append('
') + html_parts.append( + f"{self._escape_html(model)}: Score {evaluation.score * 100:.1f}%" + ) + html_parts.append(self._format_evaluation_details(evaluation)) + html_parts.append("
") + + html_parts.append("
") + + html_parts.append("
") + + html_parts.append("
") + + # Footer + html_parts.append("
") # container + html_parts.append("") + + return "\n".join(html_parts) + + def _get_multi_model_styles(self) -> str: + """Return additional CSS for multi-model views.""" + return """ + + """ + + # ========================================================================= + # COMPARATIVE EVALUATION FORMATTING + # ========================================================================= + + def _format_comparative( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format comparative evaluation results with tabbed track view.""" + # Check if this is multi-model comparative - use case-first grouping + if is_multi_model_comparative(results): + return self._format_comparative_case_first( + results, show_details, failed_only, original_counts, include_context + ) + + return self._format_comparative_single_model( + results, show_details, failed_only, original_counts, include_context + ) + + def _format_comparative_single_model( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format single-model comparative evaluation results.""" + # Use comparative grouping + ( + comparative_groups, + total_passed, + total_failed, + total_warned, + total_cases, + suite_track_order, + ) = group_comparative_by_case(results) + + # Collect all unique tracks for header + all_tracks: list[str] = [] + for tracks in suite_track_order.values(): + for t in tracks: + if t not in all_tracks: + all_tracks.append(t) + + # Calculate pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + else: + pass_rate = 0 + + # Build HTML + html_parts = [self._get_html_header()] + + # Title and timestamp + html_parts.append('
') + html_parts.append("

πŸ“Š Comparative Evaluation Results

") + html_parts.append( + f'

Generated: {datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")}

' + ) + + # Tracks list (only show if there are multiple tracks) + if len(all_tracks) > 1: + html_parts.append('
') + html_parts.append("All Tracks:") + for track in all_tracks: + html_parts.append(f'{self._escape_html(track)}') + html_parts.append("
") + + # Summary section + html_parts.append('
') + html_parts.append("

πŸ“Š Summary

") + + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + html_parts.append( + f'
⚠️ Showing only {total_cases} failed evaluation(s)
' + ) + html_parts.append('
') + html_parts.append( + f'
Total{orig_total}
' + ) + html_parts.append( + f'
Passed{orig_passed}
' + ) + if orig_warned > 0: + html_parts.append( + f'
Warnings{orig_warned}
' + ) + html_parts.append( + f'
Failed{orig_failed}
' + ) + else: + html_parts.append('
') + html_parts.append( + f'
Total{total_cases}
' + ) + html_parts.append( + f'
Passed{total_passed}
' + ) + if total_warned > 0: + html_parts.append( + f'
Warnings{total_warned}
' + ) + if total_failed > 0: + html_parts.append( + f'
Failed{total_failed}
' + ) + + html_parts.append("
") # stats-grid + html_parts.append( + f'
Pass Rate: {pass_rate:.1f}%
' + ) + html_parts.append("
") # summary-section + + # Results by model + html_parts.append("

πŸ“‹ Comparative Results by Model

") + + for model, suites in comparative_groups.items(): + html_parts.append('
') + html_parts.append(f"

πŸ€– {self._escape_html(model)}

") + + for suite_name, cases in suites.items(): + # Get track order for this specific suite + track_order = suite_track_order.get(suite_name, []) + + html_parts.append('
') + # Only show COMPARATIVE badge if there are multiple tracks + badge = ( + 'COMPARATIVE' + if len(track_order) > 1 + else "" + ) + html_parts.append( + f'

πŸ“ {self._escape_html(suite_name)} {badge}

' + ) + + # Show tracks for this suite (only if multiple) + if len(track_order) > 1: + html_parts.append('
') + html_parts.append("Tracks:") + for track in track_order: + html_parts.append( + f'{self._escape_html(track)}' + ) + html_parts.append("
") + + for case_name, case_data in cases.items(): + # Context section (if include_context is True) + if include_context: + system_msg = case_data.get("system_message") + addl_msgs = case_data.get("additional_messages") + if system_msg or addl_msgs: + html_parts.append('
') + html_parts.append("

πŸ“‹ Context

") + if system_msg: + html_parts.append( + f'
' + f"System Message: " + f"{self._escape_html(system_msg)}" + f"
" + ) + if addl_msgs: + conversation_html = self._format_conversation(addl_msgs) + html_parts.append( + f'
' + f"πŸ’¬ Conversation Context ({len(addl_msgs)} messages)" + f"{conversation_html}" + f"
" + ) + html_parts.append("
") + + html_parts.append( + self._format_comparative_case_html( + case_name, case_data, track_order, show_details, suite_name + ) + ) + + html_parts.append("
") # suite-section + + html_parts.append("
") # model-section + + # JavaScript for tab switching + html_parts.append(self._get_tab_script()) + + html_parts.append("
") # container + html_parts.append("") + + return "\n".join(html_parts) + + def _format_comparative_case_first( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format multi-model comparative evaluation grouped by case first.""" + # Get case-first grouping + ( + case_groups, + model_order, + suite_track_order, + total_passed, + total_failed, + total_warned, + total_cases, + ) = group_comparative_by_case_first(results) + + # Collect all unique tracks + all_tracks: list[str] = [] + for tracks in suite_track_order.values(): + for t in tracks: + if t not in all_tracks: + all_tracks.append(t) + + # Calculate pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + else: + pass_rate = 0 + + # Build HTML + html_parts = [self._get_html_header()] + html_parts.append(self._get_multi_model_styles()) + + html_parts.append('
') + html_parts.append("

πŸ“Š Comparative Evaluation Results (Multi-Model)

") + html_parts.append( + f'

Generated: {datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")}

' + ) + + # Models and tracks info + html_parts.append('
') + html_parts.append(f"

Models: {', '.join(model_order)}

") + # Only show tracks list if there are multiple tracks + if len(all_tracks) > 1: + html_parts.append('
') + html_parts.append("Tracks:") + for track in all_tracks: + html_parts.append(f'{self._escape_html(track)}') + html_parts.append("
") + html_parts.append("
") + + # Summary section + html_parts.append('
') + html_parts.append("

πŸ“Š Summary

") + + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + html_parts.append( + f'
⚠️ Showing only {total_cases} failed evaluation(s)
' + ) + html_parts.append('
') + html_parts.append( + f'
Total{orig_total}
' + ) + html_parts.append( + f'
Passed{orig_passed}
' + ) + if orig_warned > 0: + html_parts.append( + f'
Warnings{orig_warned}
' + ) + html_parts.append( + f'
Failed{orig_failed}
' + ) + else: + html_parts.append('
') + html_parts.append( + f'
Total{total_cases}
' + ) + html_parts.append( + f'
Passed{total_passed}
' + ) + if total_warned > 0: + html_parts.append( + f'
Warnings{total_warned}
' + ) + if total_failed > 0: + html_parts.append( + f'
Failed{total_failed}
' + ) + + html_parts.append("
") # stats-grid + html_parts.append( + f'
Pass Rate: {pass_rate:.1f}%
' + ) + html_parts.append("
") # summary-section + + # Results grouped by case + html_parts.append("

πŸ“‹ Results by Case

") + + for suite_name, cases in case_groups.items(): + track_order = suite_track_order.get(suite_name, []) + + html_parts.append('
') + # Only show COMPARATIVE badge if there are multiple tracks + badge = ( + 'COMPARATIVE' if len(track_order) > 1 else "" + ) + html_parts.append( + f'

πŸ“ {self._escape_html(suite_name)} {badge}

' + ) + + # Show tracks for this suite (only if multiple) + if len(track_order) > 1: + html_parts.append('
') + html_parts.append("Tracks:") + for track in track_order: + html_parts.append( + f'{self._escape_html(track)}' + ) + html_parts.append("
") + + for case_name, model_data in cases.items(): + # Case container + html_parts.append('
') + html_parts.append(f"

πŸ“‹ Case: {self._escape_html(case_name)}

") + + # Get input and context from first model + first_model_data = next(iter(model_data.values()), {}) + case_input = first_model_data.get("input", "") + if case_input: + html_parts.append( + f'

Input: {self._escape_html(case_input)}

' + ) + + # Context section (if include_context is True) + if include_context: + system_msg = first_model_data.get("system_message") + addl_msgs = first_model_data.get("additional_messages") + if system_msg or addl_msgs: + html_parts.append('
') + html_parts.append("

πŸ“‹ Context

") + if system_msg: + html_parts.append( + f'
' + f"System Message: " + f"{self._escape_html(system_msg)}" + f"
" + ) + if addl_msgs: + conversation_html = self._format_conversation(addl_msgs) + html_parts.append( + f'
' + f"πŸ’¬ Conversation Context ({len(addl_msgs)} messages)" + f"{conversation_html}" + f"
" + ) + html_parts.append("
") + + # Show each model's results for this case + for model in model_order: + if model not in model_data: + html_parts.append('
') + html_parts.append( + f'
πŸ€– {self._escape_html(model)}
' + ) + html_parts.append('
No data
') + html_parts.append("
") + continue + + model_case_data = model_data[model] + html_parts.append('
') + html_parts.append( + f'
πŸ€– {self._escape_html(model)}
' + ) + + # Show track comparison for this model + html_parts.append( + self._format_comparative_case_html( + case_name, model_case_data, track_order, show_details, suite_name, model + ) + ) + + html_parts.append("
") # model-panel + + html_parts.append("
") # case-group + + html_parts.append("
") # suite-section + + # JavaScript for tab switching + html_parts.append(self._get_tab_script()) + + html_parts.append("
") # container + html_parts.append("") + + return "\n".join(html_parts) + + def _format_comparative_case_html( + self, + case_name: str, + case_data: ComparativeCaseData, + track_order: list[str], + show_details: bool, + suite_name: str = "", + model_name: str = "", + ) -> str: + """Format a single comparative case as HTML with tabbed details.""" + lines: list[str] = [] + tracks = case_data.get("tracks", {}) + + # Compute differences from baseline + differences = compute_track_differences(case_data, track_order) + + # Generate unique ID for this case's tabs - include suite name and model for uniqueness + # Sanitize all parts for use in HTML attributes and CSS selectors + case_id = self._make_safe_id(suite_name, case_name, model_name) + + lines.append('
') + + # Case header + lines.append('
') + lines.append(f"
{self._escape_html(case_name)}
") + lines.append( + f'

Input: ' + f"{self._escape_html(case_data.get('input', 'N/A'))}

" + ) + lines.append("
") + + # Comparison summary table + lines.append('') + lines.append( + "" + ) + lines.append("") + + for i, track_name in enumerate(track_order): + is_baseline = i == 0 + row_class = "baseline" if is_baseline else "" + + if track_name not in tracks: + lines.append(f'') + lines.append(f"") + lines.append('') + lines.append('') + lines.append('') + lines.append("") + continue + + track_result = tracks[track_name] + evaluation = track_result.get("evaluation") + + if not evaluation: + lines.append(f'') + lines.append(f"") + lines.append('') + lines.append('') + lines.append('') + lines.append("") + continue + + # Status + if evaluation.passed: + status_class = "passed" + status_text = "βœ… PASSED" + elif evaluation.warning: + status_class = "warned" + status_text = "⚠️ WARNED" + else: + status_class = "failed" + status_text = "❌ FAILED" + + # Score + score_pct = evaluation.score * 100 + + # Differences + diff_fields = differences.get(track_name, []) + if is_baseline: + diff_html = '(baseline)' + elif diff_fields: + diff_html = " ".join( + f'{self._escape_html(f)}' for f in diff_fields + ) + else: + diff_html = 'β€”' + + lines.append(f'') + lines.append(f"") + lines.append(f'') + lines.append(f'') + lines.append(f"") + lines.append("") + + lines.append("
TrackStatusScoreDifferences
{self._escape_html(track_name)}⚠️ N/Aβ€”No data
{self._escape_html(track_name)}⚠️ N/Aβ€”No evaluation
{self._escape_html(track_name)}{status_text}{score_pct:.1f}%{diff_html}
") + + # Detailed results with tabs (if show_details) + if show_details: + # Find tracks with data for proper active tab handling + tracks_with_data = [ + (i, tn) + for i, tn in enumerate(track_order) + if tn in tracks and tracks[tn].get("evaluation") + ] + + # Tab buttons - show all tracks, style N/A differently but keep clickable + lines.append('
') + first_with_data = tracks_with_data[0][0] if tracks_with_data else 0 + for i, track_name in enumerate(track_order): + has_data = track_name in tracks and tracks[track_name].get("evaluation") + active = "active" if i == first_with_data else "" + na_class = "" if has_data else "na-track" + diff_class = "has-diff" if differences.get(track_name) else "" + lines.append( + f'" + ) + lines.append("
") # track-tabs + + # Tab panels container - include panels for ALL tracks + lines.append('
') + for i, track_name in enumerate(track_order): + has_data = track_name in tracks and tracks[track_name].get("evaluation") + active = "active" if i == first_with_data else "" + + lines.append( + f'
' + ) + + if not has_data: + # Show informative N/A panel + lines.append('
') + lines.append('Viewing track:') + lines.append( + f'{self._escape_html(track_name)}' + ) + lines.append("
") + lines.append('
') + lines.append('
β„Ή
') # noqa: RUF001 + lines.append("

Track Not Configured

") + lines.append( + f"

The {self._escape_html(track_name)} track " + f"was not configured for this test case.

" + ) + lines.append("

") + lines.append( + "This happens when a comparative case uses .for_track() " + "to define expectations only for specific tracks. " + "Tracks without expectations are skipped during evaluation." + ) + lines.append("

") + lines.append('
') + lines.append("To include this track:") + lines.append("
case.for_track(")
+                    lines.append(f'    "{self._escape_html(track_name)}",')
+                    lines.append("    expected_tool_calls=[...],")
+                    lines.append("    critics=[...]")
+                    lines.append(")
") + lines.append("
") + lines.append("
") # na-panel-content + else: + # Show normal evaluation panel + track_result = tracks[track_name] + evaluation = track_result.get("evaluation") + lines.append('
') + lines.append('Viewing track:') + lines.append( + f'{self._escape_html(track_name)}' + ) + lines.append("
") + lines.append(self._format_evaluation_details(evaluation)) + + lines.append("
") # track-panel + lines.append("
") # track-panels-container + + lines.append("
") # comparative-case + + return "\n".join(lines) + + def _get_tab_script(self) -> str: + """Return JavaScript for tab switching.""" + return """ + +""" + + def _get_html_header(self) -> str: + """Return HTML header with embedded CSS for styling.""" + return """ + + + + + Evaluation Results + + + +""" + + +class CaptureHtmlFormatter(CaptureFormatter): + """HTML formatter for capture results.""" + + @property + def file_extension(self) -> str: + return "html" + + def format( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format capture results as HTML.""" + # Check for multi-model captures + if is_multi_model_capture(captures): + return self._format_multi_model(captures, include_context) + + return self._format_single_model(captures, include_context) + + def _format_single_model( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format single-model capture results as HTML.""" + total_cases = 0 + total_calls = 0 + + # Build captures HTML + captures_html = [] + for capture in captures: + cases_html = [] + for case in capture.captured_cases: + total_cases += 1 + tool_calls_html = [] + + for tc in case.tool_calls: + total_calls += 1 + args_html = "" + if tc.args: + args_json = json.dumps(tc.args, indent=2) + args_html = f'
{self._escape_html(args_json)}
' + tool_calls_html.append( + f'
' + f'{self._escape_html(tc.name)}' + f"{args_html}" + f"
" + ) + + if not tool_calls_html: + tool_calls_html.append('
No tool calls captured
') + + context_html = "" + if include_context: + context_parts = [] + if case.system_message: + context_parts.append( + f'
' + f"System Message: " + f"{self._escape_html(case.system_message)}" + f"
" + ) + if case.additional_messages: + conversation_html = self._format_conversation(case.additional_messages) + context_parts.append( + f'
' + f"πŸ’¬ Conversation Context ({len(case.additional_messages)} messages)" + f"{conversation_html}" + f"
" + ) + if context_parts: + context_html = f'
{"".join(context_parts)}
' + + # track_name is set for comparative cases + track_name = getattr(case, "track_name", None) + track_html = "" + if track_name: + track_html = f'{self._escape_html(track_name)}' + + cases_html.append( + f'
' + f'

{self._escape_html(case.case_name)} {track_html}

' + f'
' + f"User: {self._escape_html(case.user_message)}" + f"
" + f"{context_html}" + f'

Tool Calls

{"".join(tool_calls_html)}
' + f"
" + ) + + captures_html.append( + f'
' + f'

{self._escape_html(capture.suite_name)}

' + f'
' + f"Model: {self._escape_html(capture.model)}" + f"Provider: {self._escape_html(capture.provider)}" + f"
" + f'
{"".join(cases_html)}
' + f"
" + ) + + return self._get_capture_html(captures_html, total_cases, total_calls) + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters.""" + return ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + def _format_conversation(self, messages: list[dict]) -> str: + """Format conversation messages as a rich HTML conversation view.""" + html_parts = ['
'] + + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", []) + name = msg.get("name", "") + + # Role-specific styling + role_class = f"msg-{role}" + role_icon = { + "user": "πŸ‘€", + "assistant": "πŸ€–", + "tool": "πŸ”§", + "system": "βš™οΈ", + }.get(role, "πŸ’¬") + role_label = role.capitalize() + + html_parts.append(f'
') + html_parts.append( + f'
' + f'{role_icon}' + f'{role_label}' + ) + + # Show tool name for tool responses + if role == "tool" and name: + html_parts.append(f'({self._escape_html(name)})') + + html_parts.append("
") # Close msg-header + + # Message content + if content: + # For tool responses, try to format JSON nicely + if role == "tool": + try: + parsed_content = json.loads(content) + formatted_content = json.dumps(parsed_content, indent=2) + html_parts.append( + f'
{self._escape_html(formatted_content)}
' + ) + except (json.JSONDecodeError, TypeError): + # Not valid JSON, show as regular content + html_parts.append( + f'
{self._escape_html(str(content))}
' + ) + else: + html_parts.append( + f'
{self._escape_html(str(content))}
' + ) + + # Tool calls (for assistant messages) + if tool_calls: + html_parts.append('
') + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + + # Parse and pretty-print arguments + try: + args_dict = json.loads(tc_args) if isinstance(tc_args, str) else tc_args + args_formatted = json.dumps(args_dict, indent=2) + except (json.JSONDecodeError, TypeError): + args_formatted = str(tc_args) + + html_parts.append( + f'
' + f'πŸ“ž {self._escape_html(tc_name)}' + f'
{self._escape_html(args_formatted)}
' + f"
" + ) + html_parts.append("
") + + html_parts.append("
") # Close msg + + html_parts.append("
") # Close conversation + return "\n".join(html_parts) + + def _format_multi_model( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format multi-model capture results with track tabs.""" + from arcade_cli.formatters.base import group_captures_by_case_then_track + + grouped_data, model_order, track_order = group_captures_by_case_then_track(captures) + + html_parts: list[str] = [] + + # HTML head with track tab styles + html_parts.append(""" + + + + + Multi-Model Capture Results + + + +""") + + html_parts.append("

πŸ”„ Multi-Model Capture Results

") + html_parts.append( + f'

Models: {", ".join(self._escape_html(m) for m in model_order)}

' + ) + + total_cases = 0 + total_calls = 0 + case_idx = 0 + + for suite_name, cases in grouped_data.items(): + html_parts.append('
') + html_parts.append(f"

{self._escape_html(suite_name)}

") + + for case_name, case_data in cases.items(): + total_cases += 1 + case_idx += 1 + case_id = f"case_{case_idx}" + html_parts.append('
') + + user_msg = case_data.get("user_message", "") + tracks_data = case_data.get("tracks", {}) + + html_parts.append('
') + html_parts.append(f"

{self._escape_html(case_name)}

") + if user_msg: + html_parts.append( + f"

User: {self._escape_html(user_msg)}

" + ) + html_parts.append("
") + + # Check if we have multiple tracks + track_keys = list(tracks_data.keys()) + has_multiple_tracks = len(track_keys) > 1 or ( + len(track_keys) == 1 and track_keys[0] != "_default" + ) + + if has_multiple_tracks: + # Render track tabs + html_parts.append('
') + for i, track_key in enumerate(track_keys): + active = "active" if i == 0 else "" + display_name = track_key if track_key != "_default" else "Default" + html_parts.append( + f'" + ) + html_parts.append("
") + + # Render track panels + html_parts.append('
') + for i, track_key in enumerate(track_keys): + active = "active" if i == 0 else "" + track_data = tracks_data[track_key] + html_parts.append( + f'
' + ) + + display_name = track_key if track_key != "_default" else "Default" + html_parts.append( + f'
🏷️ {self._escape_html(display_name)}
' + ) + + # Render model panels within track + models_dict = track_data.get("models", {}) + for model in model_order: + if model not in models_dict: + html_parts.append('
') + html_parts.append( + f'
{self._escape_html(model)}
' + ) + html_parts.append('
No data
') + html_parts.append("
") + continue + + captured_case = models_dict[model] + html_parts.append('
') + html_parts.append( + f'
{self._escape_html(model)}
' + ) + + if captured_case.tool_calls: + for tc in captured_case.tool_calls: + total_calls += 1 + args_html = "" + if tc.args: + args_json = json.dumps(tc.args, indent=2) + args_html = f'
{self._escape_html(args_json)}
' + html_parts.append( + f'
' + f'{self._escape_html(tc.name)}' + f"{args_html}
" + ) + else: + html_parts.append('
No tool calls
') + + html_parts.append("
") # model-panel + + html_parts.append("
") # track-panel + html_parts.append("
") # track-panels + else: + # No tracks - render models directly + track_key = track_keys[0] if track_keys else "_default" + track_data = tracks_data.get(track_key, {}) + models_dict = track_data.get("models", {}) + + for model in model_order: + if model not in models_dict: + html_parts.append('
') + html_parts.append( + f'
{self._escape_html(model)}
' + ) + html_parts.append('
No data
') + html_parts.append("
") + continue + + captured_case = models_dict[model] + html_parts.append('
') + html_parts.append( + f'
{self._escape_html(model)}
' + ) + + if captured_case.tool_calls: + for tc in captured_case.tool_calls: + total_calls += 1 + args_html = "" + if tc.args: + args_json = json.dumps(tc.args, indent=2) + args_html = ( + f'
{self._escape_html(args_json)}
' + ) + html_parts.append( + f'
' + f'{self._escape_html(tc.name)}' + f"{args_html}
" + ) + else: + html_parts.append('
No tool calls
') + + html_parts.append("
") + + # Context section + system_msg = case_data.get("system_message") + addl_msgs = case_data.get("additional_messages") + if include_context and (system_msg or addl_msgs): + html_parts.append('
') + html_parts.append("

Context

") + if system_msg: + html_parts.append( + f"

System: {self._escape_html(system_msg)}

" + ) + if addl_msgs: + html_parts.append(self._format_conversation(addl_msgs)) + html_parts.append("
") + + html_parts.append("
") # case-group + + html_parts.append("
") # suite-section + + # Summary + total_suites = len(grouped_data) + html_parts.append(f""" +
+

Summary

+

Suites: {total_suites} | Cases: {total_cases} | Models: {len(model_order)} | Tool Calls: {total_calls}

+
+ + + + +""") + + return "\n".join(html_parts) + + def _get_capture_html( + self, captures_html: list[str], total_cases: int, total_calls: int + ) -> str: + """Return complete HTML document for capture results.""" + return f""" + + + + + Capture Results + + + +

🎯 Capture Results

+ {"".join(captures_html)} +
+

Summary

+
+
+
{total_cases}
+
Total Cases
+
+
+
{total_calls}
+
Tool Calls
+
+
+
+ +""" diff --git a/libs/arcade-cli/arcade_cli/formatters/json.py b/libs/arcade-cli/arcade_cli/formatters/json.py new file mode 100644 index 000000000..361974b2a --- /dev/null +++ b/libs/arcade-cli/arcade_cli/formatters/json.py @@ -0,0 +1,690 @@ +"""JSON formatter for evaluation and capture results.""" + +import json +from datetime import datetime, timezone +from typing import Any + +from arcade_cli.formatters.base import ( + CaptureFormatter, + CaptureResults, + EvalResultFormatter, + EvalResults, + EvalStats, + find_best_model, + group_comparative_by_case, + group_comparative_by_case_first, + group_eval_for_comparison, + group_results_by_model, + is_comparative_result, + is_multi_model_capture, + is_multi_model_comparative, + is_multi_model_eval, +) + + +class JsonFormatter(EvalResultFormatter): + """ + JSON formatter for evaluation results. + + Produces a structured JSON document containing all evaluation data, + suitable for programmatic processing, dashboards, or further analysis. + """ + + @property + def file_extension(self) -> str: + return "json" + + def format( + self, + results: EvalResults, + show_details: bool = False, + failed_only: bool = False, + original_counts: EvalStats | None = None, + include_context: bool = False, + ) -> str: + """Format evaluation results as JSON.""" + # Check if this is a comparative evaluation + if is_comparative_result(results): + output = self._format_comparative( + results, show_details, failed_only, original_counts, include_context + ) + elif is_multi_model_eval(results): + output = self._format_multi_model( + results, show_details, failed_only, original_counts, include_context + ) + else: + output = self._format_regular( + results, show_details, failed_only, original_counts, include_context + ) + + return json.dumps(output, indent=2, default=str) + + def _format_regular( + self, + results: EvalResults, + show_details: bool = False, + failed_only: bool = False, + original_counts: EvalStats | None = None, + include_context: bool = False, + ) -> dict[str, Any]: + """Format regular (non-comparative) evaluation results.""" + model_groups, total_passed, total_failed, total_warned, total_cases = ( + group_results_by_model(results) + ) + + # Calculate pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + else: + pass_rate = 0 + + output: dict[str, Any] = { + "type": "evaluation", + "generated_at": datetime.now(timezone.utc).isoformat(), + "summary": { + "total_cases": total_cases, + "passed": total_passed, + "failed": total_failed, + "warned": total_warned, + "pass_rate": round(pass_rate, 2), + }, + "models": {}, + } + + if failed_only and original_counts: + output["summary"]["original_counts"] = { + "total": original_counts[0], + "passed": original_counts[1], + "failed": original_counts[2], + "warned": original_counts[3], + } + output["summary"]["filtered"] = True + + # Build model results + for model, suites in model_groups.items(): + output["models"][model] = {"suites": {}} + + for suite_name, cases in suites.items(): + suite_data: dict[str, Any] = { + "case_count": len(cases), + "cases": [], + } + + for case in cases: + case_data = self._serialize_case(case, show_details, include_context) + suite_data["cases"].append(case_data) + + output["models"][model]["suites"][suite_name] = suite_data + + return output + + def _format_comparative( + self, + results: EvalResults, + show_details: bool = False, + failed_only: bool = False, + original_counts: EvalStats | None = None, + include_context: bool = False, + ) -> dict[str, Any]: + """Format comparative evaluation results.""" + # Check if this is multi-model comparative - use case-first grouping + if is_multi_model_comparative(results): + return self._format_comparative_case_first( + results, show_details, failed_only, original_counts, include_context + ) + + return self._format_comparative_single_model( + results, show_details, failed_only, original_counts, include_context + ) + + def _format_comparative_single_model( + self, + results: EvalResults, + show_details: bool = False, + failed_only: bool = False, + original_counts: EvalStats | None = None, + include_context: bool = False, + ) -> dict[str, Any]: + """Format single-model comparative evaluation results.""" + ( + comparative_groups, + total_passed, + total_failed, + total_warned, + total_cases, + suite_track_order, + ) = group_comparative_by_case(results) + + # Collect all unique tracks + all_tracks: list[str] = [] + for tracks in suite_track_order.values(): + for t in tracks: + if t not in all_tracks: + all_tracks.append(t) + + # Calculate pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + else: + pass_rate = 0 + + output: dict[str, Any] = { + "type": "comparative_evaluation", + "generated_at": datetime.now(timezone.utc).isoformat(), + "tracks": all_tracks, + "summary": { + "total_cases": total_cases, + "passed": total_passed, + "failed": total_failed, + "warned": total_warned, + "pass_rate": round(pass_rate, 2), + }, + "models": {}, + } + + if failed_only and original_counts: + output["summary"]["original_counts"] = { + "total": original_counts[0], + "passed": original_counts[1], + "failed": original_counts[2], + "warned": original_counts[3], + } + output["summary"]["filtered"] = True + + # Build model results + for model, suites in comparative_groups.items(): + output["models"][model] = {"suites": {}} + + for suite_name, cases in suites.items(): + track_order = suite_track_order.get(suite_name, []) + + suite_data: dict[str, Any] = { + "tracks": track_order, + "case_count": len(cases), + "cases": {}, + } + + for case_name, case_data in cases.items(): + tracks_data = case_data.get("tracks", {}) + + case_output: dict[str, Any] = { + "input": case_data.get("input", ""), + "tracks": {}, + } + + # Add context if requested + if include_context: + system_msg = case_data.get("system_message") + addl_msgs = case_data.get("additional_messages") + if system_msg: + case_output["system_message"] = system_msg + if addl_msgs: + case_output["additional_messages"] = addl_msgs + + for track_name in track_order: + if track_name not in tracks_data: + case_output["tracks"][track_name] = {"status": "missing"} + continue + + track_result = tracks_data[track_name] + evaluation = track_result.get("evaluation") + + if not evaluation: + case_output["tracks"][track_name] = {"status": "no_evaluation"} + continue + + track_data: dict[str, Any] = { + "status": self._get_status(evaluation), + "score": round(evaluation.score * 100, 2), + "passed": evaluation.passed, + "warning": evaluation.warning, + } + + if evaluation.failure_reason: + track_data["failure_reason"] = evaluation.failure_reason + + if show_details and evaluation.results: + track_data["details"] = self._serialize_critic_results( + evaluation.results + ) + + case_output["tracks"][track_name] = track_data + + suite_data["cases"][case_name] = case_output + + output["models"][model]["suites"][suite_name] = suite_data + + return output + + def _format_comparative_case_first( + self, + results: EvalResults, + show_details: bool = False, + failed_only: bool = False, + original_counts: EvalStats | None = None, + include_context: bool = False, + ) -> dict[str, Any]: + """Format multi-model comparative evaluation grouped by case first.""" + # Get case-first grouping + ( + case_groups, + model_order, + suite_track_order, + total_passed, + total_failed, + total_warned, + total_cases, + ) = group_comparative_by_case_first(results) + + # Collect all unique tracks + all_tracks: list[str] = [] + for tracks in suite_track_order.values(): + for t in tracks: + if t not in all_tracks: + all_tracks.append(t) + + # Calculate pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + else: + pass_rate = 0 + + output: dict[str, Any] = { + "type": "multi_model_comparative_evaluation", + "generated_at": datetime.now(timezone.utc).isoformat(), + "models": model_order, + "tracks": all_tracks, + "summary": { + "total_cases": total_cases, + "passed": total_passed, + "failed": total_failed, + "warned": total_warned, + "pass_rate": round(pass_rate, 2), + }, + "grouped_by_case": {}, + } + + if failed_only and original_counts: + output["summary"]["original_counts"] = { + "total": original_counts[0], + "passed": original_counts[1], + "failed": original_counts[2], + "warned": original_counts[3], + } + output["summary"]["filtered"] = True + + # Build case-first structure + for suite_name, cases in case_groups.items(): + track_order = suite_track_order.get(suite_name, []) + output["grouped_by_case"][suite_name] = {"tracks": track_order, "cases": {}} + + for case_name, model_data in cases.items(): + first_model_data = next(iter(model_data.values()), {}) + case_output: dict[str, Any] = { + "input": first_model_data.get("input", ""), + "models": {}, + } + + # Add context if requested + if include_context: + system_msg = first_model_data.get("system_message") + addl_msgs = first_model_data.get("additional_messages") + if system_msg: + case_output["system_message"] = system_msg + if addl_msgs: + case_output["additional_messages"] = addl_msgs + + for model in model_order: + if model not in model_data: + case_output["models"][model] = {"status": "missing"} + continue + + model_case_data = model_data[model] + tracks_data = model_case_data.get("tracks", {}) + + model_output: dict[str, Any] = {"tracks": {}} + + for track_name in track_order: + if track_name not in tracks_data: + model_output["tracks"][track_name] = {"status": "missing"} + continue + + track_result = tracks_data[track_name] + evaluation = track_result.get("evaluation") + + if not evaluation: + model_output["tracks"][track_name] = {"status": "no_evaluation"} + continue + + track_data: dict[str, Any] = { + "status": self._get_status(evaluation), + "score": round(evaluation.score * 100, 2), + "passed": evaluation.passed, + "warning": evaluation.warning, + } + + if evaluation.failure_reason: + track_data["failure_reason"] = evaluation.failure_reason + + if show_details and evaluation.results: + track_data["details"] = self._serialize_critic_results( + evaluation.results + ) + + model_output["tracks"][track_name] = track_data + + case_output["models"][model] = model_output + + output["grouped_by_case"][suite_name]["cases"][case_name] = case_output + + return output + + def _format_multi_model( + self, + results: EvalResults, + show_details: bool = False, + failed_only: bool = False, + original_counts: EvalStats | None = None, + include_context: bool = False, + ) -> dict[str, Any]: + """Format multi-model evaluation results with comparison structure.""" + comparison_data, model_order, per_model_stats = group_eval_for_comparison(results) + + # Calculate totals + total_passed = sum(s["passed"] for s in per_model_stats.values()) + total_failed = sum(s["failed"] for s in per_model_stats.values()) + total_warned = sum(s["warned"] for s in per_model_stats.values()) + total_cases = sum(s["total"] for s in per_model_stats.values()) + + # Calculate pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + else: + pass_rate = 0 + + output: dict[str, Any] = { + "type": "multi_model_evaluation", + "generated_at": datetime.now(timezone.utc).isoformat(), + "models": model_order, + "summary": { + "total_evaluations": total_cases, + "unique_cases": sum(len(cases) for cases in comparison_data.values()), + "passed": total_passed, + "failed": total_failed, + "warned": total_warned, + "pass_rate": round(pass_rate, 2), + }, + "per_model_stats": {}, + "comparison": {}, + } + + if failed_only and original_counts: + output["summary"]["original_counts"] = { + "total": original_counts[0], + "passed": original_counts[1], + "failed": original_counts[2], + "warned": original_counts[3], + } + output["summary"]["filtered"] = True + + # Per-model statistics + best_model = None + best_rate = -1.0 + for model in model_order: + stats = per_model_stats[model] + output["per_model_stats"][model] = { + "total": stats["total"], + "passed": stats["passed"], + "failed": stats["failed"], + "warned": stats["warned"], + "pass_rate": round(stats["pass_rate"], 2), + } + if stats["pass_rate"] > best_rate: + best_rate = stats["pass_rate"] + best_model = model + + if best_model: + output["summary"]["best_model"] = best_model + output["summary"]["best_pass_rate"] = round(best_rate, 2) + + # Build comparison structure + for suite_name, cases in comparison_data.items(): + output["comparison"][suite_name] = {} + + for case_name, case_models in cases.items(): + case_output: dict[str, Any] = { + "results_by_model": {}, + } + + # Add context from first model if requested + if include_context: + first_model_result = next(iter(case_models.values()), {}) + system_msg = first_model_result.get("system_message") + addl_msgs = first_model_result.get("additional_messages") + if system_msg: + case_output["system_message"] = system_msg + if addl_msgs: + case_output["additional_messages"] = addl_msgs + + for model in model_order: + if model not in case_models: + case_output["results_by_model"][model] = {"status": "missing"} + continue + + case_result = case_models[model] + evaluation = case_result["evaluation"] + + model_data: dict[str, Any] = { + "status": self._get_status(evaluation), + "score": round(evaluation.score * 100, 2), + "passed": evaluation.passed, + "warning": evaluation.warning, + } + + if evaluation.failure_reason: + model_data["failure_reason"] = evaluation.failure_reason + + if show_details and evaluation.results: + model_data["details"] = self._serialize_critic_results(evaluation.results) + + case_output["results_by_model"][model] = model_data + + # Find best model for this case + best, best_score = find_best_model(case_models) + case_output["best_model"] = best + case_output["best_score"] = round(best_score * 100, 2) + + output["comparison"][suite_name][case_name] = case_output + + return output + + def _serialize_case( + self, case: dict[str, Any], show_details: bool, include_context: bool = False + ) -> dict[str, Any]: + """Serialize a single evaluation case.""" + evaluation = case["evaluation"] + + case_data: dict[str, Any] = { + "name": case["name"], + "input": case.get("input", ""), + "status": self._get_status(evaluation), + "score": round(evaluation.score * 100, 2), + "passed": evaluation.passed, + "warning": evaluation.warning, + } + + # Add context if requested + if include_context: + system_msg = case.get("system_message") + addl_msgs = case.get("additional_messages") + if system_msg: + case_data["system_message"] = system_msg + if addl_msgs: + case_data["additional_messages"] = addl_msgs + + if evaluation.failure_reason: + case_data["failure_reason"] = evaluation.failure_reason + + if show_details and evaluation.results: + case_data["details"] = self._serialize_critic_results(evaluation.results) + + return case_data + + def _serialize_critic_results(self, results: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Serialize critic results for detailed output.""" + serialized = [] + for critic_result in results: + item: dict[str, Any] = { + "field": critic_result["field"], + "match": critic_result["match"], + "score": critic_result["score"], + "weight": critic_result["weight"], + "expected": critic_result["expected"], + "actual": critic_result["actual"], + } + + if "is_criticized" in critic_result: + item["is_criticized"] = critic_result["is_criticized"] + + serialized.append(item) + + return serialized + + def _get_status(self, evaluation: Any) -> str: + """Get status string from evaluation.""" + if evaluation.passed: + return "passed" + elif evaluation.warning: + return "warned" + else: + return "failed" + + +class CaptureJsonFormatter(CaptureFormatter): + """JSON formatter for capture results.""" + + @property + def file_extension(self) -> str: + return "json" + + def format( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format capture results as JSON.""" + # Check for multi-model captures + if is_multi_model_capture(captures): + output_data = self._format_multi_model(captures, include_context) + else: + output_data = { + "type": "capture", + "captures": [cap.to_dict(include_context=include_context) for cap in captures], + } + return json.dumps(output_data, indent=2) + + def _format_multi_model( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> dict[str, Any]: + """Format multi-model capture results with track-aware structure.""" + from arcade_cli.formatters.base import group_captures_by_case_then_track + + grouped_data, model_order, track_order = group_captures_by_case_then_track(captures) + has_tracks = len(track_order) > 1 or (track_order and track_order[0] is not None) + + track_names = [t for t in track_order if t is not None] if has_tracks else [] + + output: dict[str, Any] = { + "type": "multi_model_capture", + "generated_at": datetime.now(timezone.utc).isoformat(), + "models": model_order, + "tracks": track_names if track_names else None, + "summary": { + "total_suites": len(grouped_data), + "total_cases": sum(len(cases) for cases in grouped_data.values()), + "models_count": len(model_order), + "tracks_count": len(track_names) if track_names else 0, + }, + "grouped_by_case": {}, + } + + for suite_name, cases in grouped_data.items(): + output["grouped_by_case"][suite_name] = {} + + for case_name, case_data in cases.items(): + case_output: dict[str, Any] = { + "user_message": case_data.get("user_message", ""), + } + + if include_context: + if case_data.get("system_message"): + case_output["system_message"] = case_data["system_message"] + if case_data.get("additional_messages"): + case_output["additional_messages"] = case_data["additional_messages"] + + tracks_data = case_data.get("tracks", {}) + track_keys = list(tracks_data.keys()) + has_multiple_tracks = len(track_keys) > 1 or ( + len(track_keys) == 1 and track_keys[0] != "_default" + ) + + if has_multiple_tracks: + # Structure with tracks + case_output["tracks"] = {} + for track_key in track_keys: + track_display = track_key if track_key != "_default" else "default" + track_data = tracks_data[track_key] + models_dict = track_data.get("models", {}) + + track_output: dict[str, Any] = {"models": {}} + for model in model_order: + if model not in models_dict: + track_output["models"][model] = {"status": "missing"} + continue + + captured_case = models_dict[model] + track_output["models"][model] = { + "tool_calls": [ + {"name": tc.name, "args": tc.args} + for tc in captured_case.tool_calls + ], + } + + case_output["tracks"][track_display] = track_output + else: + # No tracks - flat structure + track_key = track_keys[0] if track_keys else "_default" + track_data = tracks_data.get(track_key, {}) + models_dict = track_data.get("models", {}) + + case_output["models"] = {} + for model in model_order: + if model not in models_dict: + case_output["models"][model] = {"status": "missing"} + continue + + captured_case = models_dict[model] + case_output["models"][model] = { + "tool_calls": [ + {"name": tc.name, "args": tc.args} + for tc in captured_case.tool_calls + ], + } + + output["grouped_by_case"][suite_name][case_name] = case_output + + return output diff --git a/libs/arcade-cli/arcade_cli/formatters/markdown.py b/libs/arcade-cli/arcade_cli/formatters/markdown.py new file mode 100644 index 000000000..ea5b34860 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/formatters/markdown.py @@ -0,0 +1,1284 @@ +"""Markdown formatter for evaluation and capture results.""" + +import json +from datetime import datetime, timezone +from typing import Any + +from arcade_cli.formatters.base import ( + CaptureFormatter, + CaptureResults, + ComparativeCaseData, + EvalResultFormatter, + compute_track_differences, + find_best_model, + group_comparative_by_case, + group_comparative_by_case_first, + group_eval_for_comparison, + group_results_by_model, + is_comparative_result, + is_multi_model_capture, + is_multi_model_comparative, + is_multi_model_eval, + truncate_field_value, +) + +# Markdown-specific truncation length (slightly shorter for table readability) +MD_MAX_FIELD_LENGTH = 50 + + +class MarkdownFormatter(EvalResultFormatter): + """ + Markdown formatter for evaluation results. + + Produces a well-structured Markdown document with tables and collapsible sections. + """ + + @property + def file_extension(self) -> str: + return "md" + + def format( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + # Check if this is a comparative evaluation + if is_comparative_result(results): + return self._format_comparative( + results, show_details, failed_only, original_counts, include_context + ) + + # Check if this is a multi-model evaluation + if is_multi_model_eval(results): + return self._format_multi_model( + results, show_details, failed_only, original_counts, include_context + ) + + return self._format_regular( + results, show_details, failed_only, original_counts, include_context + ) + + def _format_regular( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format regular (non-comparative) evaluation results.""" + lines: list[str] = [] + + # Header + lines.append("# Evaluation Results") + lines.append("") + lines.append( + f"**Generated:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}" + ) + lines.append("") + + # Use shared grouping logic + model_groups, total_passed, total_failed, total_warned, total_cases = ( + group_results_by_model(results) + ) + + # Summary section + lines.append("## Summary") + lines.append("") + + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + lines.append(f"> ⚠️ **Note:** Showing only {total_cases} failed evaluation(s)") + lines.append("") + lines.append("| Metric | Count |") + lines.append("|--------|-------|") + lines.append(f"| **Total** | {orig_total} |") + lines.append(f"| βœ… Passed | {orig_passed} |") + if orig_warned > 0: + lines.append(f"| ⚠️ Warnings | {orig_warned} |") + lines.append(f"| ❌ Failed | {orig_failed} |") + else: + lines.append("| Metric | Count |") + lines.append("|--------|-------|") + lines.append(f"| **Total** | {total_cases} |") + lines.append(f"| βœ… Passed | {total_passed} |") + if total_warned > 0: + lines.append(f"| ⚠️ Warnings | {total_warned} |") + if total_failed > 0: + lines.append(f"| ❌ Failed | {total_failed} |") + + # Pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + lines.append("") + lines.append(f"**Pass Rate:** {pass_rate:.1f}%") + + lines.append("") + + # Results by model + lines.append("## Results by Model") + lines.append("") + + for model, suites in model_groups.items(): + lines.append(f"### πŸ€– {model}") + lines.append("") + + for suite_name, cases in suites.items(): + lines.append(f"#### πŸ“ {suite_name}") + lines.append("") + + # Results table + lines.append("| Status | Case | Score |") + lines.append("|--------|------|-------|") + + for case in cases: + evaluation = case["evaluation"] + if evaluation.passed: + status = "βœ…" + elif evaluation.warning: + status = "⚠️" + else: + status = "❌" + + score_pct = evaluation.score * 100 + case_name = case["name"].replace("|", "\\|") + lines.append(f"| {status} | {case_name} | {score_pct:.1f}% |") + + lines.append("") + + # Detailed results if requested + if show_details: + lines.append("
") + lines.append("Detailed Results") + lines.append("") + + for case in cases: + evaluation = case["evaluation"] + if evaluation.passed: + status_text = "βœ… PASSED" + elif evaluation.warning: + status_text = "⚠️ WARNED" + else: + status_text = "❌ FAILED" + + lines.append(f"##### {case['name']}") + lines.append("") + lines.append(f"**Status:** {status_text} ") + lines.append(f"**Score:** {evaluation.score * 100:.2f}%") + lines.append("") + lines.append(f"**Input:** `{case['input']}`") + lines.append("") + + # Context section (if include_context is True) + if include_context: + system_msg = case.get("system_message") + addl_msgs = case.get("additional_messages") + if system_msg or addl_msgs: + lines.append("**πŸ“‹ Context:**") + lines.append("") + if system_msg: + lines.append(f"> **System:** {system_msg}") + lines.append("") + if addl_msgs: + lines.append( + f"
πŸ’¬ Conversation ({len(addl_msgs)} messages)" + ) + lines.append("") + lines.extend(self._format_conversation_md(addl_msgs)) + lines.append("
") + lines.append("") + + # Evaluation details + lines.append(self._format_evaluation_details(evaluation)) + lines.append("") + lines.append("---") + lines.append("") + + lines.append("
") + lines.append("") + + return "\n".join(lines) + + def _format_evaluation_details(self, evaluation: Any) -> str: + """Format evaluation details as markdown.""" + lines: list[str] = [] + + if evaluation.failure_reason: + lines.append(f"**Failure Reason:** {evaluation.failure_reason}") + else: + lines.append("| Field | Match | Score | Expected | Actual |") + lines.append("|-------|-------|-------|----------|--------|") + + for critic_result in evaluation.results: + is_criticized = critic_result.get("is_criticized", True) + field = critic_result["field"] + score = critic_result["score"] + weight = critic_result["weight"] + expected = str(critic_result["expected"]).replace("|", "\\|") + actual = str(critic_result["actual"]).replace("|", "\\|") + + # Truncate long values for table readability + expected = truncate_field_value(expected, MD_MAX_FIELD_LENGTH) + actual = truncate_field_value(actual, MD_MAX_FIELD_LENGTH) + + if is_criticized: + match_icon = "βœ…" if critic_result["match"] else "❌" + lines.append( + f"| {field} | {match_icon} | {score:.2f}/{weight:.2f} | `{expected}` | `{actual}` |" + ) + else: + lines.append(f"| {field} | β€” | - | `{expected}` | `{actual}` |") + + return "\n".join(lines) + + # ========================================================================= + # MULTI-MODEL EVALUATION FORMATTING + # ========================================================================= + + def _format_multi_model( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format evaluation results with multi-model comparison tables.""" + lines: list[str] = [] + + # Header + lines.append("# Multi-Model Evaluation Results") + lines.append("") + lines.append( + f"**Generated:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}" + ) + lines.append("") + + # Get comparison data + comparison_data, model_order, per_model_stats = group_eval_for_comparison(results) + + # Calculate totals + total_cases = sum(s["total"] for s in per_model_stats.values()) + total_passed = sum(s["passed"] for s in per_model_stats.values()) + total_failed = sum(s["failed"] for s in per_model_stats.values()) + total_warned = sum(s["warned"] for s in per_model_stats.values()) + + # Models being compared + lines.append(f"**Models Compared:** {', '.join(f'`{m}`' for m in model_order)}") + lines.append("") + + # Per-Model Summary Table + lines.append("## Per-Model Summary") + lines.append("") + lines.append("| Model | Passed | Failed | Warned | Total | Pass Rate |") + lines.append("|-------|--------|--------|--------|-------|-----------|") + + best_model = None + best_rate = -1.0 + for model in model_order: + stats = per_model_stats[model] + rate = stats["pass_rate"] + rate_str = f"{rate:.1f}%" + + # Track best model + if rate > best_rate: + best_rate = rate + best_model = model + + lines.append( + f"| `{model}` | {stats['passed']} | {stats['failed']} | " + f"{stats['warned']} | {stats['total']} | {rate_str} |" + ) + + lines.append("") + if best_model: + lines.append(f"**πŸ† Best Overall:** `{best_model}` ({best_rate:.1f}% pass rate)") + lines.append("") + + # Cross-Model Comparison by Suite + lines.append("## Cross-Model Comparison") + lines.append("") + + for suite_name, cases in comparison_data.items(): + lines.append(f"### πŸ“ {suite_name}") + lines.append("") + + # Build comparison table header + header = "| Case |" + separator = "|------|" + for model in model_order: + header += f" {model} |" + separator += "--------|" + header += " Best |" + separator += "------|" + + lines.append(header) + lines.append(separator) + + # Build rows for each case + for case_name, case_models in cases.items(): + row = f"| {case_name} |" + + for model in model_order: + if model in case_models: + evaluation = case_models[model]["evaluation"] + score = evaluation.score * 100 + if evaluation.passed: + cell = f"βœ… {score:.0f}%" + elif evaluation.warning: + cell = f"⚠️ {score:.0f}%" + else: + cell = f"❌ {score:.0f}%" + else: + cell = "β€”" + row += f" {cell} |" + + # Find best model for this case + best, best_score = find_best_model(case_models) + if best == "Tie": + row += " Tie |" + elif best: + row += f" `{best}` |" + else: + row += " β€” |" + + lines.append(row) + + lines.append("") + + # Detailed results per case (if requested) + if show_details: + lines.append("
") + lines.append("πŸ“‹ Detailed Results") + lines.append("") + + for case_name, case_models in cases.items(): + lines.append(f"#### {case_name}") + lines.append("") + + for model in model_order: + if model not in case_models: + continue + + case_result = case_models[model] + evaluation = case_result["evaluation"] + + lines.append(f"**{model}:** Score {evaluation.score * 100:.1f}%") + lines.append("") + lines.append(self._format_evaluation_details(evaluation)) + lines.append("") + + lines.append("---") + lines.append("") + + lines.append("
") + lines.append("") + + # Overall summary + lines.append("## Overall Summary") + lines.append("") + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + lines.append("> ⚠️ Showing only failed evaluations") + lines.append("") + lines.append(f"- **Total Cases:** {orig_total}") + lines.append(f"- **Passed:** {orig_passed}") + lines.append(f"- **Failed:** {orig_failed}") + if orig_warned > 0: + lines.append(f"- **Warned:** {orig_warned}") + else: + # Note: total_cases counts each model's run of each case separately + unique_cases = sum(len(cases) for cases in comparison_data.values()) + lines.append(f"- **Unique Cases:** {unique_cases}") + lines.append(f"- **Total Evaluations:** {total_cases} ({len(model_order)} models)") + lines.append(f"- **Passed:** {total_passed}") + lines.append(f"- **Failed:** {total_failed}") + if total_warned > 0: + lines.append(f"- **Warned:** {total_warned}") + + lines.append("") + return "\n".join(lines) + + # ========================================================================= + # COMPARATIVE EVALUATION FORMATTING + # ========================================================================= + + def _format_comparative( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format comparative evaluation results showing tracks side-by-side.""" + # Check if this is multi-model comparative - use case-first grouping + if is_multi_model_comparative(results): + return self._format_comparative_case_first( + results, show_details, failed_only, original_counts, include_context + ) + + # Single model comparative - use original model-first grouping + return self._format_comparative_single_model( + results, show_details, failed_only, original_counts, include_context + ) + + def _format_comparative_single_model( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format single-model comparative evaluation results.""" + lines: list[str] = [] + + # Header + lines.append("# Comparative Evaluation Results") + lines.append("") + lines.append( + f"**Generated:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}" + ) + lines.append("") + + # Use comparative grouping + ( + comparative_groups, + total_passed, + total_failed, + total_warned, + total_cases, + suite_track_order, + ) = group_comparative_by_case(results) + + # Collect all unique tracks for summary + all_tracks: list[str] = [] + for tracks in suite_track_order.values(): + for t in tracks: + if t not in all_tracks: + all_tracks.append(t) + + # Summary section + lines.append("## Summary") + lines.append("") + lines.append(f"**Tracks compared:** {', '.join(f'`{t}`' for t in all_tracks)}") + lines.append("") + + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + lines.append(f"> ⚠️ **Note:** Showing only {total_cases} failed evaluation(s)") + lines.append("") + lines.append("| Metric | Count |") + lines.append("|--------|-------|") + lines.append(f"| **Total** | {orig_total} |") + lines.append(f"| βœ… Passed | {orig_passed} |") + if orig_warned > 0: + lines.append(f"| ⚠️ Warnings | {orig_warned} |") + lines.append(f"| ❌ Failed | {orig_failed} |") + else: + lines.append("| Metric | Count |") + lines.append("|--------|-------|") + lines.append(f"| **Total** | {total_cases} |") + lines.append(f"| βœ… Passed | {total_passed} |") + if total_warned > 0: + lines.append(f"| ⚠️ Warnings | {total_warned} |") + if total_failed > 0: + lines.append(f"| ❌ Failed | {total_failed} |") + + # Pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + lines.append("") + lines.append(f"**Pass Rate:** {pass_rate:.1f}%") + + lines.append("") + + # Results by model + lines.append("## Results by Model") + lines.append("") + + for model, suites in comparative_groups.items(): + lines.append(f"### πŸ€– {model}") + lines.append("") + + for suite_name, cases in suites.items(): + # Get track order for this specific suite + track_order = suite_track_order.get(suite_name, []) + + lines.append(f"#### πŸ“Š {suite_name} (Comparative)") + lines.append("") + lines.append(f"**Tracks:** {', '.join(f'`{t}`' for t in track_order)}") + lines.append("") + + # List all cases with summary comparison + for case_name, case_data in cases.items(): + # Context section (if include_context is True) + if include_context: + system_msg = case_data.get("system_message") + addl_msgs = case_data.get("additional_messages") + if system_msg or addl_msgs: + lines.append("
") + lines.append("πŸ“‹ Context") + lines.append("") + if system_msg: + lines.append(f"**System Message:** {system_msg}") + lines.append("") + if addl_msgs: + lines.append(f"**πŸ’¬ Conversation ({len(addl_msgs)} messages):**") + lines.append("") + for msg in addl_msgs: + role = msg.get("role", "unknown") + content = msg.get("content", "") + name = msg.get("name", "") + role_icons = { + "user": "πŸ‘€", + "assistant": "πŸ€–", + "tool": "πŸ”§", + "system": "βš™οΈ", + } + icon = role_icons.get(role, "πŸ’¬") + label = ( + f"{icon} **{role.title()}**" + if not name + else f"{icon} **{role.title()}** (`{name}`)" + ) + lines.append(f"> {label}") + if content: + if role == "tool": + try: + import json + + parsed = json.loads(content) + formatted = json.dumps(parsed, indent=2) + lines.append("> ```json") + for json_line in formatted.split("\n"): + lines.append(f"> {json_line}") + lines.append("> ```") + except (json.JSONDecodeError, TypeError): + lines.append(f"> {content}") + else: + lines.append(f"> {content}") + tool_calls = msg.get("tool_calls", []) + if tool_calls: + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + lines.append(f"> πŸ”§ **{tc_name}**") + try: + import json + + args_dict = ( + json.loads(tc_args) + if isinstance(tc_args, str) + else tc_args + ) + formatted = json.dumps(args_dict, indent=2) + lines.append("> ```json") + for arg_line in formatted.split("\n"): + lines.append(f"> {arg_line}") + lines.append("> ```") + except (json.JSONDecodeError, TypeError): + lines.append(f"> `{tc_args}`") + lines.append(">") + lines.append("
") + lines.append("") + + lines.extend( + self._format_comparative_case( + case_name, case_data, track_order, show_details + ) + ) + + return "\n".join(lines) + + def _format_comparative_case_first( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format multi-model comparative evaluation grouped by case first.""" + lines: list[str] = [] + + # Get case-first grouping + ( + case_groups, + model_order, + suite_track_order, + total_passed, + total_failed, + total_warned, + total_cases, + ) = group_comparative_by_case_first(results) + + # Collect all unique tracks + all_tracks: list[str] = [] + for tracks in suite_track_order.values(): + for t in tracks: + if t not in all_tracks: + all_tracks.append(t) + + # Header + lines.append("# Comparative Evaluation Results (Multi-Model)") + lines.append("") + lines.append( + f"**Generated:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}" + ) + lines.append("") + lines.append(f"**Models:** {', '.join(f'`{m}`' for m in model_order)}") + lines.append("") + lines.append(f"**Tracks:** {', '.join(f'`{t}`' for t in all_tracks)}") + lines.append("") + + # Summary section + lines.append("## Summary") + lines.append("") + + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + lines.append(f"> ⚠️ **Note:** Showing only {total_cases} failed evaluation(s)") + lines.append("") + lines.append("| Metric | Count |") + lines.append("|--------|-------|") + lines.append(f"| **Total** | {orig_total} |") + lines.append(f"| βœ… Passed | {orig_passed} |") + if orig_warned > 0: + lines.append(f"| ⚠️ Warnings | {orig_warned} |") + lines.append(f"| ❌ Failed | {orig_failed} |") + else: + lines.append("| Metric | Count |") + lines.append("|--------|-------|") + lines.append(f"| **Total** | {total_cases} |") + lines.append(f"| βœ… Passed | {total_passed} |") + if total_warned > 0: + lines.append(f"| ⚠️ Warnings | {total_warned} |") + if total_failed > 0: + lines.append(f"| ❌ Failed | {total_failed} |") + + # Pass rate + if total_cases > 0: + if failed_only and original_counts and original_counts[0] > 0: + pass_rate = (original_counts[1] / original_counts[0]) * 100 + else: + pass_rate = (total_passed / total_cases) * 100 + lines.append("") + lines.append(f"**Pass Rate:** {pass_rate:.1f}%") + + lines.append("") + + # Results grouped by case + lines.append("## Results by Case") + lines.append("") + + for suite_name, cases in case_groups.items(): + track_order = suite_track_order.get(suite_name, []) + + lines.append(f"### πŸ“Š {suite_name}") + lines.append("") + lines.append(f"**Tracks:** {', '.join(f'`{t}`' for t in track_order)}") + lines.append("") + + for case_name, model_data in cases.items(): + # Case header + lines.append(f"#### πŸ“‹ Case: {case_name}") + lines.append("") + + # Get input and context from first model + first_model_data = next(iter(model_data.values()), {}) + case_input = first_model_data.get("input", "") + if case_input: + lines.append(f"**Input:** `{case_input}`") + lines.append("") + + # Context section (if include_context is True) + if include_context: + system_msg = first_model_data.get("system_message") + addl_msgs = first_model_data.get("additional_messages") + if system_msg or addl_msgs: + lines.append("
") + lines.append("πŸ“‹ Context") + lines.append("") + if system_msg: + lines.append(f"**System Message:** {system_msg}") + lines.append("") + if addl_msgs: + lines.append(f"**πŸ’¬ Conversation ({len(addl_msgs)} messages):**") + lines.append("") + for msg in addl_msgs: + role = msg.get("role", "unknown") + content = msg.get("content", "") + name = msg.get("name", "") + role_icons = { + "user": "πŸ‘€", + "assistant": "πŸ€–", + "tool": "πŸ”§", + "system": "βš™οΈ", + } + icon = role_icons.get(role, "πŸ’¬") + label = ( + f"{icon} **{role.title()}**" + if not name + else f"{icon} **{role.title()}** (`{name}`)" + ) + lines.append(f"> {label}") + if content: + # For tool responses, format as JSON code block + if role == "tool": + try: + import json + + parsed = json.loads(content) + formatted = json.dumps(parsed, indent=2) + lines.append("> ```json") + for json_line in formatted.split("\n"): + lines.append(f"> {json_line}") + lines.append("> ```") + except (json.JSONDecodeError, TypeError): + lines.append(f"> {content}") + else: + lines.append(f"> {content}") + # Handle tool calls + tool_calls = msg.get("tool_calls", []) + if tool_calls: + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + lines.append(f"> πŸ”§ **{tc_name}**") + try: + import json + + args_dict = ( + json.loads(tc_args) + if isinstance(tc_args, str) + else tc_args + ) + formatted = json.dumps(args_dict, indent=2) + lines.append("> ```json") + for arg_line in formatted.split("\n"): + lines.append(f"> {arg_line}") + lines.append("> ```") + except (json.JSONDecodeError, TypeError): + lines.append(f"> `{tc_args}`") + lines.append(">") + lines.append("
") + lines.append("") + + # Show each model's results for this case + for model in model_order: + if model not in model_data: + lines.append(f"##### πŸ€– {model}") + lines.append("") + lines.append("*(No data)*") + lines.append("") + continue + + model_case_data = model_data[model] + lines.append(f"##### πŸ€– {model}") + lines.append("") + + # Show track comparison for this model + lines.extend( + self._format_comparative_case( + case_name, model_case_data, track_order, show_details + ) + ) + + lines.append("---") + lines.append("") + + return "\n".join(lines) + + def _format_comparative_case( + self, + case_name: str, + case_data: ComparativeCaseData, + track_order: list[str], + show_details: bool, + ) -> list[str]: + """Format a single comparative case showing all tracks.""" + lines: list[str] = [] + tracks = case_data.get("tracks", {}) + + lines.append(f"##### Case: {case_name}") + lines.append("") + lines.append(f"**Input:** `{case_data.get('input', 'N/A')}`") + lines.append("") + + # Compute differences from baseline + differences = compute_track_differences(case_data, track_order) + + # Summary comparison table + lines.append("| Track | Status | Score | Differences |") + lines.append("|-------|--------|-------|-------------|") + + for track_name in track_order: + if track_name not in tracks: + lines.append(f"| `{track_name}` | ⚠️ | N/A | *No data* |") + continue + + track_result = tracks[track_name] + evaluation = track_result.get("evaluation") + + if not evaluation: + lines.append(f"| `{track_name}` | ⚠️ | N/A | *No evaluation* |") + continue + + # Status + if evaluation.passed: + status = "βœ…" + elif evaluation.warning: + status = "⚠️" + else: + status = "❌" + + # Score + score_pct = evaluation.score * 100 + + # Differences from baseline + diff_fields = differences.get(track_name, []) + if track_name == track_order[0]: + diff_text = "*(baseline)*" + elif diff_fields: + diff_text = ", ".join(f"`{f}`" for f in diff_fields) + else: + diff_text = "β€”" + + lines.append(f"| `{track_name}` | {status} | {score_pct:.1f}% | {diff_text} |") + + lines.append("") + + # Detailed results per track (collapsible) + if show_details: + for track_name in track_order: + if track_name not in tracks: + continue + + track_result = tracks[track_name] + evaluation = track_result.get("evaluation") + + if not evaluation: + continue + + lines.append("
") + lines.append(f"πŸ“‹ {track_name} β€” Detailed Results") + lines.append("") + lines.append(self._format_evaluation_details(evaluation)) + lines.append("") + lines.append("
") + lines.append("") + + lines.append("---") + lines.append("") + + return lines + + def _format_conversation_md(self, messages: list[dict]) -> list[str]: + """Format conversation messages as Markdown for context display.""" + lines: list[str] = [] + + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", []) + name = msg.get("name", "") + + role_icons = {"user": "πŸ‘€", "assistant": "πŸ€–", "tool": "πŸ”§", "system": "βš™οΈ"} + icon = role_icons.get(role, "πŸ’¬") + label = ( + f"{icon} **{role.title()}**" + if not name + else f"{icon} **{role.title()}** (`{name}`)" + ) + + lines.append(f"> {label}") + + if content: + # For tool responses, try to format JSON nicely + if role == "tool": + try: + parsed = json.loads(content) + formatted = json.dumps(parsed, indent=2) + lines.append("> ```json") + for json_line in formatted.split("\n"): + lines.append(f"> {json_line}") + lines.append("> ```") + except (json.JSONDecodeError, TypeError): + lines.append(f"> {content}") + else: + lines.append(f"> {content}") + + # Handle tool calls in assistant messages + if tool_calls: + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + lines.append(f"> πŸ”§ **{tc_name}**") + try: + args_dict = json.loads(tc_args) if isinstance(tc_args, str) else tc_args + formatted = json.dumps(args_dict, indent=2) + lines.append("> ```json") + for arg_line in formatted.split("\n"): + lines.append(f"> {arg_line}") + lines.append("> ```") + except (json.JSONDecodeError, TypeError): + lines.append(f"> `{tc_args}`") + + lines.append(">") + + return lines + + +class CaptureMarkdownFormatter(CaptureFormatter): + """Markdown formatter for capture results.""" + + @property + def file_extension(self) -> str: + return "md" + + def format( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format capture results as Markdown.""" + # Check for multi-model captures + if is_multi_model_capture(captures): + return self._format_multi_model(captures, include_context) + + return self._format_single_model(captures, include_context) + + def _format_single_model( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format single-model capture results.""" + lines: list[str] = [] + lines.append("# Capture Results") + lines.append("") + + total_cases = 0 + total_calls = 0 + + for capture in captures: + lines.append(f"## {capture.suite_name}") + lines.append("") + lines.append(f"- **Model:** {capture.model}") + lines.append(f"- **Provider:** {capture.provider}") + lines.append("") + + for case in capture.captured_cases: + total_cases += 1 + lines.append(f"### Case: {case.case_name}") + lines.append("") + + # track_name is set for comparative cases + track_name = getattr(case, "track_name", None) + if track_name: + lines.append(f"**Track:** `{track_name}`") + lines.append("") + + lines.append(f"**User Message:** {case.user_message}") + lines.append("") + + if include_context and case.system_message: + lines.append(f"**System Message:** {case.system_message}") + lines.append("") + + lines.append("#### Tool Calls") + lines.append("") + + if case.tool_calls: + for tc in case.tool_calls: + total_calls += 1 + lines.append(f"**`{tc.name}`**") + if tc.args: + lines.append("") + lines.append("```json") + lines.append(json.dumps(tc.args, indent=2)) + lines.append("```") + lines.append("") + else: + lines.append("*No tool calls captured*") + lines.append("") + + if include_context and case.additional_messages: + lines.append("
") + lines.append( + f"πŸ’¬ Conversation Context ({len(case.additional_messages)} messages)" + ) + lines.append("") + lines.extend(self._format_conversation_md(case.additional_messages)) + lines.append("
") + lines.append("") + + lines.append("---") + lines.append("") + + lines.append("## Summary") + lines.append("") + lines.append(f"- **Total Cases:** {total_cases}") + lines.append(f"- **Total Tool Calls:** {total_calls}") + lines.append("") + + return "\n".join(lines) + + def _format_multi_model( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format multi-model capture results with track sections.""" + from arcade_cli.formatters.base import group_captures_by_case_then_track + + grouped, model_order, track_order = group_captures_by_case_then_track(captures) + has_tracks = len(track_order) > 1 or (track_order and track_order[0] is not None) + + lines: list[str] = [] + lines.append("# Multi-Model Capture Results") + lines.append("") + + # Show models being compared + lines.append(f"**Models Compared:** {', '.join(f'`{m}`' for m in model_order)}") + if has_tracks: + track_names = [t for t in track_order if t is not None] + lines.append(f"**Tracks:** {' | '.join(f'`{t}`' for t in track_names)}") + lines.append("") + + total_cases = 0 + total_calls = 0 + + for suite_name, cases in grouped.items(): + lines.append(f"## {suite_name}") + lines.append("") + + for case_name, case_data in cases.items(): + total_cases += 1 + lines.append(f"### Case: {case_name}") + lines.append("") + lines.append(f"**User Message:** {case_data.get('user_message', 'N/A')}") + lines.append("") + + if include_context and case_data.get("system_message"): + lines.append(f"**System Message:** {case_data['system_message']}") + lines.append("") + + tracks_data = case_data.get("tracks", {}) + track_keys = list(tracks_data.keys()) + has_multiple_tracks = len(track_keys) > 1 or ( + len(track_keys) == 1 and track_keys[0] != "_default" + ) + + if has_multiple_tracks: + # Show tool calls by track with clear sections + for track_key in track_keys: + track_display = track_key if track_key != "_default" else "Default" + lines.append(f"#### Track: `{track_display}`") + lines.append("") + + track_data = tracks_data[track_key] + models_dict = track_data.get("models", {}) + + # Model comparison table within track + lines.append("| Model | Tools Called |") + lines.append("|-------|-------------|") + + for model in model_order: + if model not in models_dict: + lines.append(f"| `{model}` | *(no data)* |") + continue + + captured_case = models_dict[model] + if captured_case.tool_calls: + tool_names = ", ".join( + f"`{tc.name}`" for tc in captured_case.tool_calls + ) + total_calls += len(captured_case.tool_calls) + else: + tool_names = "*(none)*" + lines.append(f"| `{model}` | {tool_names} |") + + lines.append("") + + # Detailed tool calls per model + for model in model_order: + if model not in models_dict: + continue + + captured_case = models_dict[model] + if not captured_case.tool_calls: + continue + + lines.append("
") + lines.append(f"πŸ€– {model} - Details") + lines.append("") + + for tc in captured_case.tool_calls: + lines.append(f"**`{tc.name}`**") + if tc.args: + lines.append("") + lines.append("```json") + lines.append(json.dumps(tc.args, indent=2)) + lines.append("```") + lines.append("") + + lines.append("
") + lines.append("") + + lines.append("---") + lines.append("") + else: + # No tracks - show models directly + lines.append("#### Tool Calls by Model") + lines.append("") + + track_key = track_keys[0] if track_keys else "_default" + track_data = tracks_data.get(track_key, {}) + models_dict = track_data.get("models", {}) + + lines.append("| Model | Tools Called |") + lines.append("|-------|-------------|") + + for model in model_order: + if model not in models_dict: + lines.append(f"| `{model}` | *(no data)* |") + continue + + captured_case = models_dict[model] + if captured_case.tool_calls: + tool_names = ", ".join( + f"`{tc.name}`" for tc in captured_case.tool_calls + ) + total_calls += len(captured_case.tool_calls) + else: + tool_names = "*(none)*" + lines.append(f"| `{model}` | {tool_names} |") + + lines.append("") + + # Detailed tool calls per model (collapsible) + for model in model_order: + if model not in models_dict: + continue + + captured_case = models_dict[model] + if not captured_case.tool_calls: + continue + + lines.append("
") + lines.append(f"πŸ€– {model} - Tool Call Details") + lines.append("") + + for tc in captured_case.tool_calls: + lines.append(f"**`{tc.name}`**") + if tc.args: + lines.append("") + lines.append("```json") + lines.append(json.dumps(tc.args, indent=2)) + lines.append("```") + lines.append("") + + lines.append("
") + lines.append("") + + # Context (shared, show once) + if include_context and case_data.get("additional_messages"): + lines.append("
") + lines.append( + f"πŸ’¬ Conversation Context " + f"({len(case_data['additional_messages'])} messages)" + ) + lines.append("") + lines.extend(self._format_conversation_md(case_data["additional_messages"])) + lines.append("
") + lines.append("") + + lines.append("---") + lines.append("") + + # Summary + lines.append("## Summary") + lines.append("") + lines.append(f"- **Models:** {len(model_order)}") + lines.append(f"- **Unique Cases:** {total_cases}") + lines.append(f"- **Total Tool Calls:** {total_calls}") + lines.append("") + + return "\n".join(lines) + + def _format_conversation_md(self, messages: list[dict]) -> list[str]: + """Format conversation messages as rich Markdown.""" + lines: list[str] = [] + + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", []) + name = msg.get("name", "") + + # Role-specific icons and formatting + role_info = { + "user": ("πŸ‘€", "**User**"), + "assistant": ("πŸ€–", "**Assistant**"), + "tool": ("πŸ”§", "**Tool**"), + "system": ("βš™οΈ", "**System**"), + }.get(role, ("πŸ’¬", f"**{role.capitalize()}**")) + + icon, label = role_info + + # Header line + if role == "tool" and name: + lines.append(f"> {icon} {label} (`{name}`)") + else: + lines.append(f"> {icon} {label}") + + lines.append(">") + + # Content + if content: + # Indent content and handle multi-line + for line in content.split("\n"): + lines.append(f"> {line}") + elif role == "assistant" and not content and tool_calls: + lines.append("> *(calling tools...)*") + + # Tool calls for assistant messages + if tool_calls: + lines.append(">") + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + + # Parse and pretty-print arguments + try: + args_dict = json.loads(tc_args) if isinstance(tc_args, str) else tc_args + args_formatted = json.dumps(args_dict, indent=2) + except (json.JSONDecodeError, TypeError): + args_formatted = str(tc_args) + + lines.append(f"> πŸ“ž **`{tc_name}`**") + lines.append(">") + lines.append("> ```json") + for arg_line in args_formatted.split("\n"): + lines.append(f"> {arg_line}") + lines.append("> ```") + + lines.append("") # Blank line between messages + + return lines diff --git a/libs/arcade-cli/arcade_cli/formatters/text.py b/libs/arcade-cli/arcade_cli/formatters/text.py new file mode 100644 index 000000000..88bf3bc06 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/formatters/text.py @@ -0,0 +1,1086 @@ +"""Plain text formatter for evaluation and capture results.""" + +import json +from typing import Any + +from arcade_cli.formatters.base import ( + CaptureFormatter, + CaptureResults, + ComparativeCaseData, + EvalResultFormatter, + compute_track_differences, + find_best_model, + group_comparative_by_case, + group_comparative_by_case_first, + group_eval_for_comparison, + group_results_by_model, + is_comparative_result, + is_multi_model_capture, + is_multi_model_comparative, + is_multi_model_eval, +) + + +class TextFormatter(EvalResultFormatter): + """ + Plain text formatter for evaluation results. + + Produces output similar to pytest's format with simple ASCII formatting. + """ + + @property + def file_extension(self) -> str: + return "txt" + + def format( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + # Check if this is a comparative evaluation + if is_comparative_result(results): + return self._format_comparative( + results, show_details, failed_only, original_counts, include_context + ) + + # Check if this is a multi-model evaluation + if is_multi_model_eval(results): + return self._format_multi_model( + results, show_details, failed_only, original_counts, include_context + ) + + return self._format_regular( + results, show_details, failed_only, original_counts, include_context + ) + + def _format_regular( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format regular (non-comparative) evaluation results.""" + lines: list[str] = [] + + # Use shared grouping logic + model_groups, total_passed, total_failed, total_warned, total_cases = ( + group_results_by_model(results) + ) + + # Output grouped results + for model, suites in model_groups.items(): + lines.append(f"Model: {model}") + lines.append("=" * 60) + + for suite_name, cases in suites.items(): + lines.append(f" Suite: {suite_name}") + lines.append(" " + "-" * 56) + + for case in cases: + evaluation = case["evaluation"] + if evaluation.passed: + status = "PASSED" + elif evaluation.warning: + status = "WARNED" + else: + status = "FAILED" + + score_percentage = evaluation.score * 100 + lines.append(f" {status} {case['name']} -- Score: {score_percentage:.2f}%") + + if show_details: + lines.append(f" User Input: {case['input']}") + lines.append("") + + # Context section (if include_context is True) + if include_context: + system_msg = case.get("system_message") + addl_msgs = case.get("additional_messages") + if system_msg or addl_msgs: + lines.append(" Context:") + if system_msg: + lines.append(f" System: {system_msg}") + if addl_msgs: + lines.append(f" Conversation ({len(addl_msgs)} messages):") + for conv_line in self._format_conversation_text(addl_msgs): + lines.append(f" {conv_line}") + lines.append("") + + lines.append(" Details:") + for detail_line in self._format_evaluation(evaluation).split("\n"): + lines.append(f" {detail_line}") + lines.append(" " + "-" * 52) + + lines.append("") + + lines.append("") + + # Summary + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + lines.append(f"Note: Showing only {total_cases} failed evaluation(s) (--only-failed)") + summary = f"Summary -- Total: {orig_total} -- Passed: {orig_passed}" + if orig_warned > 0: + summary += f" -- Warnings: {orig_warned}" + if orig_failed > 0: + summary += f" -- Failed: {orig_failed}" + else: + summary = f"Summary -- Total: {total_cases} -- Passed: {total_passed}" + if total_warned > 0: + summary += f" -- Warnings: {total_warned}" + if total_failed > 0: + summary += f" -- Failed: {total_failed}" + + lines.append(summary) + lines.append("") + + return "\n".join(lines) + + def _format_evaluation(self, evaluation: Any) -> str: + """Format evaluation details.""" + result_lines = [] + if evaluation.failure_reason: + result_lines.append(f"Failure Reason: {evaluation.failure_reason}") + else: + for critic_result in evaluation.results: + is_criticized = critic_result.get("is_criticized", True) + field = critic_result["field"] + score = critic_result["score"] + weight = critic_result["weight"] + expected = critic_result["expected"] + actual = critic_result["actual"] + + if is_criticized: + match_str = "Match" if critic_result["match"] else "No Match" + result_lines.append( + f"{field}: {match_str}\n" + f" Score: {score:.2f}/{weight:.2f}\n" + f" Expected: {expected}\n" + f" Actual: {actual}" + ) + else: + result_lines.append( + f"{field}: Un-criticized\n Expected: {expected}\n Actual: {actual}" + ) + return "\n".join(result_lines) + + # ========================================================================= + # MULTI-MODEL EVALUATION FORMATTING + # ========================================================================= + + def _format_multi_model( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format multi-model evaluation results with comparison tables.""" + lines: list[str] = [] + + # Get comparison data + comparison_data, model_order, per_model_stats = group_eval_for_comparison(results) + + # Header + lines.append("=" * 78) + lines.append("MULTI-MODEL EVALUATION RESULTS") + lines.append("=" * 78) + lines.append("") + lines.append(f"Models: {', '.join(model_order)}") + lines.append("") + + # Per-Model Summary Table + lines.append("-" * 78) + lines.append("PER-MODEL SUMMARY") + lines.append("-" * 78) + lines.append("") + + # Build header row + header = f"{'Model':<20} {'Passed':>8} {'Failed':>8} {'Warned':>8} {'Total':>8} {'Pass Rate':>10}" + lines.append(header) + lines.append("-" * len(header)) + + best_model = None + best_rate = -1.0 + for model in model_order: + stats = per_model_stats[model] + rate = stats["pass_rate"] + + if rate > best_rate: + best_rate = rate + best_model = model + + lines.append( + f"{model:<20} {stats['passed']:>8} {stats['failed']:>8} " + f"{stats['warned']:>8} {stats['total']:>8} {rate:>9.1f}%" + ) + + lines.append("") + if best_model: + lines.append(f"Best Overall: {best_model} ({best_rate:.1f}% pass rate)") + lines.append("") + + # Cross-Model Comparison by Suite + lines.append("-" * 78) + lines.append("CROSS-MODEL COMPARISON") + lines.append("-" * 78) + lines.append("") + + for suite_name, cases in comparison_data.items(): + lines.append(f"Suite: {suite_name}") + lines.append("") + + # Build comparison table header - dynamic based on model count + # Calculate column widths + case_col_width = 30 + model_col_width = 12 + best_col_width = 15 + + header_parts = [f"{'Case':<{case_col_width}}"] + for model in model_order: + # Truncate model name if too long + display_name = ( + model[: model_col_width - 1] if len(model) > model_col_width - 1 else model + ) + header_parts.append(f"{display_name:>{model_col_width}}") + header_parts.append(f"{'Best':>{best_col_width}}") + + header_line = " ".join(header_parts) + lines.append(header_line) + lines.append("-" * len(header_line)) + + # Build rows for each case + for case_name, case_models in cases.items(): + # Truncate case name if needed + display_case = ( + case_name[: case_col_width - 1] + if len(case_name) > case_col_width - 1 + else case_name + ) + row_parts = [f"{display_case:<{case_col_width}}"] + + for model in model_order: + if model in case_models: + evaluation = case_models[model]["evaluation"] + score = evaluation.score * 100 + if evaluation.passed: + cell = f"OK {score:.0f}%" + elif evaluation.warning: + cell = f"WN {score:.0f}%" + else: + cell = f"FL {score:.0f}%" + else: + cell = "-" + row_parts.append(f"{cell:>{model_col_width}}") + + # Find best model for this case + best, _ = find_best_model(case_models) + if best == "Tie": + best_cell = "Tie" + elif best: + best_cell = ( + best[: best_col_width - 1] if len(best) > best_col_width - 1 else best + ) + else: + best_cell = "-" + row_parts.append(f"{best_cell:>{best_col_width}}") + + lines.append(" ".join(row_parts)) + + lines.append("") + + # Detailed results per case (if requested) + if show_details: + lines.append(" Detailed Results:") + lines.append(" " + "-" * 70) + + for case_name, case_models in cases.items(): + lines.append(f" Case: {case_name}") + + for model in model_order: + if model not in case_models: + continue + + case_result = case_models[model] + evaluation = case_result["evaluation"] + + lines.append(f" [{model}] Score: {evaluation.score * 100:.1f}%") + + # Show evaluation details indented + eval_details = self._format_evaluation(evaluation) + for line in eval_details.split("\n"): + lines.append(f" {line}") + + lines.append("") + + lines.append("") + + # Overall summary + total_cases = sum(s["total"] for s in per_model_stats.values()) + total_passed = sum(s["passed"] for s in per_model_stats.values()) + total_failed = sum(s["failed"] for s in per_model_stats.values()) + total_warned = sum(s["warned"] for s in per_model_stats.values()) + + lines.append("=" * 78) + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + lines.append("Note: Showing only failed evaluations (--only-failed)") + lines.append( + f"Summary -- Total: {orig_total} -- Passed: {orig_passed} -- " + f"Failed: {orig_failed} -- Warned: {orig_warned}" + ) + else: + unique_cases = sum(len(cases) for cases in comparison_data.values()) + lines.append( + f"Summary -- Unique Cases: {unique_cases} -- " + f"Total Evaluations: {total_cases} ({len(model_order)} models)" + ) + lines.append( + f" Passed: {total_passed} -- Failed: {total_failed} -- Warned: {total_warned}" + ) + lines.append("") + + return "\n".join(lines) + + # ========================================================================= + # COMPARATIVE EVALUATION FORMATTING + # ========================================================================= + + def _format_comparative( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format comparative evaluation results showing tracks side-by-side.""" + # Check if this is multi-model comparative - use case-first grouping + if is_multi_model_comparative(results): + return self._format_comparative_case_first( + results, show_details, failed_only, original_counts, include_context + ) + + return self._format_comparative_single_model( + results, show_details, failed_only, original_counts, include_context + ) + + def _format_comparative_single_model( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format single-model comparative evaluation results.""" + lines: list[str] = [] + + # Use comparative grouping + ( + comparative_groups, + total_passed, + total_failed, + total_warned, + total_cases, + suite_track_order, + ) = group_comparative_by_case(results) + + # Collect all unique tracks for header + all_tracks: list[str] = [] + for tracks in suite_track_order.values(): + for t in tracks: + if t not in all_tracks: + all_tracks.append(t) + + lines.append("=" * 76) + lines.append("COMPARATIVE EVALUATION RESULTS") + lines.append("=" * 76) + lines.append("") + lines.append(f"All Tracks: {' vs '.join(all_tracks)}") + lines.append("") + + # Output grouped results + for model, suites in comparative_groups.items(): + lines.append(f"Model: {model}") + lines.append("=" * 76) + + for suite_name, cases in suites.items(): + # Get track order for this specific suite + track_order = suite_track_order.get(suite_name, []) + + lines.append(f" Suite: {suite_name} (Comparative)") + lines.append(f" Tracks: {' vs '.join(track_order)}") + lines.append(" " + "-" * 72) + + for case_name, case_data in cases.items(): + # Context section (if include_context is True) + if include_context: + system_msg = case_data.get("system_message") + addl_msgs = case_data.get("additional_messages") + if system_msg or addl_msgs: + lines.append(" " + "-" * 40) + lines.append(" πŸ“‹ CONTEXT") + lines.append(" " + "-" * 40) + if system_msg: + lines.append(f" System Message: {system_msg}") + if addl_msgs: + lines.append(f" πŸ’¬ Conversation ({len(addl_msgs)} messages):") + for msg in addl_msgs: + role = msg.get("role", "unknown").upper() + content = msg.get("content", "") + name = msg.get("name", "") + role_label = f"[{role}]" if not name else f"[{role}: {name}]" + lines.append(f" {role_label}") + if content: + # For tool responses, try to format JSON + if role.lower() == "tool": + try: + import json + + parsed = json.loads(content) + formatted = json.dumps(parsed, indent=2) + for json_line in formatted.split("\n"): + lines.append(f" {json_line}") + except (json.JSONDecodeError, TypeError): + lines.append(f" {content}") + else: + lines.append(f" {content}") + # Handle tool calls + tool_calls = msg.get("tool_calls", []) + if tool_calls: + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + lines.append(f" πŸ”§ {tc_name}") + try: + import json + + args_dict = ( + json.loads(tc_args) + if isinstance(tc_args, str) + else tc_args + ) + formatted = json.dumps(args_dict, indent=2) + for arg_line in formatted.split("\n"): + lines.append(f" {arg_line}") + except (json.JSONDecodeError, TypeError): + lines.append(f" {tc_args}") + lines.append(" " + "-" * 40) + + lines.extend( + self._format_comparative_case_text( + case_name, case_data, track_order, show_details + ) + ) + + lines.append("") + + # Summary + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + lines.append(f"Note: Showing only {total_cases} failed evaluation(s) (--only-failed)") + summary = f"Summary -- Total: {orig_total} -- Passed: {orig_passed}" + if orig_warned > 0: + summary += f" -- Warnings: {orig_warned}" + if orig_failed > 0: + summary += f" -- Failed: {orig_failed}" + else: + summary = f"Summary -- Total: {total_cases} -- Passed: {total_passed}" + if total_warned > 0: + summary += f" -- Warnings: {total_warned}" + if total_failed > 0: + summary += f" -- Failed: {total_failed}" + + lines.append(summary) + lines.append("") + + return "\n".join(lines) + + def _format_comparative_case_first( + self, + results: list[list[dict[str, Any]]], + show_details: bool = False, + failed_only: bool = False, + original_counts: tuple[int, int, int, int] | None = None, + include_context: bool = False, + ) -> str: + """Format multi-model comparative evaluation grouped by case first.""" + lines: list[str] = [] + + # Get case-first grouping + ( + case_groups, + model_order, + suite_track_order, + total_passed, + total_failed, + total_warned, + total_cases, + ) = group_comparative_by_case_first(results) + + # Collect all unique tracks + all_tracks: list[str] = [] + for tracks in suite_track_order.values(): + for t in tracks: + if t not in all_tracks: + all_tracks.append(t) + + lines.append("=" * 78) + lines.append("COMPARATIVE EVALUATION RESULTS (MULTI-MODEL)") + lines.append("=" * 78) + lines.append("") + lines.append(f"Models: {', '.join(model_order)}") + lines.append(f"Tracks: {', '.join(all_tracks)}") + lines.append("") + + # Results grouped by case + for suite_name, cases in case_groups.items(): + track_order = suite_track_order.get(suite_name, []) + + lines.append("-" * 78) + lines.append(f"SUITE: {suite_name}") + lines.append(f"Tracks: {' vs '.join(track_order)}") + lines.append("-" * 78) + lines.append("") + + for case_name, model_data in cases.items(): + # Case header + lines.append(" " + "=" * 72) + lines.append(f" CASE: {case_name}") + lines.append(" " + "=" * 72) + + # Get input and context from first model + first_model_data = next(iter(model_data.values()), {}) + case_input = first_model_data.get("input", "") + if case_input: + lines.append(f" Input: {case_input}") + + # Context section (if include_context is True) + if include_context: + system_msg = first_model_data.get("system_message") + addl_msgs = first_model_data.get("additional_messages") + if system_msg or addl_msgs: + lines.append("") + lines.append(" " + "-" * 40) + lines.append(" πŸ“‹ CONTEXT") + lines.append(" " + "-" * 40) + if system_msg: + lines.append(f" System Message: {system_msg}") + if addl_msgs: + lines.append(f" πŸ’¬ Conversation ({len(addl_msgs)} messages):") + for msg in addl_msgs: + role = msg.get("role", "unknown").upper() + content = msg.get("content", "") + name = msg.get("name", "") + role_label = f"[{role}]" if not name else f"[{role}: {name}]" + lines.append(f" {role_label}") + if content: + # For tool responses, try to format JSON + if role.lower() == "tool": + try: + import json + + parsed = json.loads(content) + formatted = json.dumps(parsed, indent=2) + for json_line in formatted.split("\n"): + lines.append(f" {json_line}") + except (json.JSONDecodeError, TypeError): + lines.append(f" {content}") + else: + lines.append(f" {content}") + # Handle tool calls in assistant messages + tool_calls = msg.get("tool_calls", []) + if tool_calls: + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + lines.append(f" πŸ”§ {tc_name}") + try: + import json + + args_dict = ( + json.loads(tc_args) + if isinstance(tc_args, str) + else tc_args + ) + formatted = json.dumps(args_dict, indent=2) + for arg_line in formatted.split("\n"): + lines.append(f" {arg_line}") + except (json.JSONDecodeError, TypeError): + lines.append(f" {tc_args}") + lines.append(" " + "-" * 40) + + lines.append("") + + # Show each model's results for this case + for model in model_order: + if model not in model_data: + lines.append(f" [{model}] (no data)") + lines.append("") + continue + + model_case_data = model_data[model] + lines.append(f" [{model}]") + + # Show track comparison for this model + case_lines = self._format_comparative_case_text( + case_name, model_case_data, track_order, show_details + ) + # Indent the case lines + for line in case_lines: + lines.append(" " + line) + + lines.append("") + + # Summary + lines.append("=" * 78) + if failed_only and original_counts: + orig_total, orig_passed, orig_failed, orig_warned = original_counts + lines.append(f"Note: Showing only {total_cases} failed evaluation(s) (--only-failed)") + summary = f"Summary -- Total: {orig_total} -- Passed: {orig_passed}" + if orig_warned > 0: + summary += f" -- Warnings: {orig_warned}" + if orig_failed > 0: + summary += f" -- Failed: {orig_failed}" + else: + summary = f"Summary -- Total: {total_cases} -- Passed: {total_passed}" + if total_warned > 0: + summary += f" -- Warnings: {total_warned}" + if total_failed > 0: + summary += f" -- Failed: {total_failed}" + + lines.append(summary) + lines.append("") + + return "\n".join(lines) + + def _format_comparative_case_text( + self, + case_name: str, + case_data: ComparativeCaseData, + track_order: list[str], + show_details: bool, + ) -> list[str]: + """Format a single comparative case in text format.""" + lines: list[str] = [] + tracks = case_data.get("tracks", {}) + + lines.append("") + lines.append(" " + "─" * 68) + lines.append(f" CASE: {case_name}") + lines.append(" " + "─" * 68) + lines.append(f" Input: {case_data.get('input', 'N/A')}") + lines.append("") + + # Compute differences from baseline + differences = compute_track_differences(case_data, track_order) + + # Build comparison table header + lines.append(" β”Œβ”€ COMPARISON ─────────────────────────────────────────────────────┐") + lines.append( + " β”‚ {:20s} β”‚ {:8s} β”‚ {:8s} β”‚ {:24s} β”‚".format( + "Track", "Status", "Score", "Differences" + ) + ) + lines.append(" β”œ" + "─" * 22 + "β”Ό" + "─" * 10 + "β”Ό" + "─" * 10 + "β”Ό" + "─" * 26 + "─") + + for track_name in track_order: + if track_name not in tracks: + lines.append( + " β”‚ {:20s} β”‚ {:8s} β”‚ {:8s} β”‚ {:24s} β”‚".format( + track_name[:20], "N/A", "N/A", "No data" + ) + ) + continue + + track_result = tracks[track_name] + evaluation = track_result.get("evaluation") + + if not evaluation: + lines.append( + " β”‚ {:20s} β”‚ {:8s} β”‚ {:8s} β”‚ {:24s} β”‚".format( + track_name[:20], "N/A", "N/A", "No evaluation" + ) + ) + continue + + # Status + if evaluation.passed: + status = "PASSED" + elif evaluation.warning: + status = "WARNED" + else: + status = "FAILED" + + # Score + score_str = f"{evaluation.score * 100:.1f}%" + + # Differences from baseline + diff_fields = differences.get(track_name, []) + if track_name == track_order[0]: + diff_text = "(baseline)" + elif diff_fields: + diff_text = ", ".join(diff_fields)[:24] + else: + diff_text = "β€”" + + lines.append( + f" β”‚ {track_name[:20]:20s} β”‚ {status:8s} β”‚ {score_str:8s} β”‚ {diff_text[:24]:24s} β”‚" + ) + + lines.append(" β””" + "─" * 22 + "β”΄" + "─" * 10 + "β”΄" + "─" * 10 + "β”΄" + "─" * 26 + "β”˜") + lines.append("") + + # Detailed results per track + if show_details: + for track_name in track_order: + if track_name not in tracks: + continue + + track_result = tracks[track_name] + evaluation = track_result.get("evaluation") + + if not evaluation: + continue + + lines.append(f" [{track_name}] Details:") + for detail_line in self._format_evaluation(evaluation).split("\n"): + lines.append(f" {detail_line}") + lines.append("") + + return lines + + def _format_conversation_text(self, messages: list[dict]) -> list[str]: + """Format conversation messages as plain text for context display.""" + lines: list[str] = [] + + for msg in messages: + role = msg.get("role", "unknown").upper() + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", []) + name = msg.get("name", "") + + role_label = f"[{role}]" if not name else f"[{role}: {name}]" + lines.append(f" {role_label}") + + if content: + # For tool responses, try to format JSON nicely + if role.lower() == "tool": + try: + parsed = json.loads(content) + formatted = json.dumps(parsed, indent=2) + for json_line in formatted.split("\n"): + lines.append(f" {json_line}") + except (json.JSONDecodeError, TypeError): + lines.append(f" {content}") + else: + lines.append(f" {content}") + + # Handle tool calls in assistant messages + if tool_calls: + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + lines.append(f" πŸ”§ {tc_name}") + try: + args_dict = json.loads(tc_args) if isinstance(tc_args, str) else tc_args + formatted = json.dumps(args_dict, indent=2) + for arg_line in formatted.split("\n"): + lines.append(f" {arg_line}") + except (json.JSONDecodeError, TypeError): + lines.append(f" {tc_args}") + + return lines + + +class CaptureTextFormatter(CaptureFormatter): + """Plain text formatter for capture results.""" + + @property + def file_extension(self) -> str: + return "txt" + + def format( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format capture results as plain text.""" + # Check for multi-model captures + if is_multi_model_capture(captures): + return self._format_multi_model(captures, include_context) + + return self._format_single_model(captures, include_context) + + def _format_single_model( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format single-model capture results.""" + lines: list[str] = [] + lines.append("=" * 70) + lines.append("CAPTURE RESULTS") + lines.append("=" * 70) + lines.append("") + + total_cases = 0 + total_calls = 0 + + for capture in captures: + lines.append(f"Suite: {capture.suite_name}") + lines.append(f"Model: {capture.model}") + lines.append(f"Provider: {capture.provider}") + lines.append("-" * 70) + + for case in capture.captured_cases: + total_cases += 1 + lines.append("") + lines.append(f" Case: {case.case_name}") + # track_name is set for comparative cases + track_name = getattr(case, "track_name", None) + if track_name: + lines.append(f" Track: {track_name}") + lines.append(f" User Message: {case.user_message}") + + if include_context and case.system_message: + lines.append(f" System Message: {case.system_message}") + + lines.append("") + lines.append(" Tool Calls:") + if case.tool_calls: + for tc in case.tool_calls: + total_calls += 1 + lines.append(f" - {tc.name}") + if tc.args: + for key, value in tc.args.items(): + lines.append(f" {key}: {self._format_value(value)}") + else: + lines.append(" (no tool calls)") + + if include_context and case.additional_messages: + lines.append("") + lines.append( + f" Conversation Context ({len(case.additional_messages)} messages):" + ) + lines.extend(self._format_conversation_text(case.additional_messages)) + + lines.append("") + + lines.append("") + + lines.append("=" * 70) + lines.append(f"Summary: {total_calls} tool calls across {total_cases} cases") + lines.append("") + + return "\n".join(lines) + + def _format_multi_model( + self, + captures: CaptureResults, + include_context: bool = False, + ) -> str: + """Format multi-model capture results with track sections.""" + from arcade_cli.formatters.base import group_captures_by_case_then_track + + grouped_data, model_order, track_order = group_captures_by_case_then_track(captures) + has_tracks = len(track_order) > 1 or (track_order and track_order[0] is not None) + + lines: list[str] = [] + + lines.append("=" * 78) + lines.append("MULTI-MODEL CAPTURE RESULTS") + lines.append("=" * 78) + lines.append("") + lines.append(f"Models: {', '.join(model_order)}") + if has_tracks: + track_names = [t for t in track_order if t is not None] + lines.append(f"Tracks: {' | '.join(track_names)}") + lines.append("") + + for suite_name, cases in grouped_data.items(): + lines.append("-" * 78) + lines.append(f"SUITE: {suite_name}") + lines.append("-" * 78) + lines.append("") + + for case_name, case_data in cases.items(): + lines.append(" " + "=" * 72) + lines.append(f" CASE: {case_name}") + lines.append(" " + "=" * 72) + + user_msg = case_data.get("user_message", "") + if user_msg: + lines.append(f" User Message: {user_msg}") + lines.append("") + + tracks_data = case_data.get("tracks", {}) + track_keys = list(tracks_data.keys()) + has_multiple_tracks = len(track_keys) > 1 or ( + len(track_keys) == 1 and track_keys[0] != "_default" + ) + + if has_multiple_tracks: + # Show track sections + for track_key in track_keys: + track_display = track_key if track_key != "_default" else "Default" + lines.append(" " + "β”Œ" + "─" * 70 + "┐") + lines.append(f" β”‚ 🏷️ TRACK: {track_display:<57s} β”‚") + lines.append(" " + "β”œ" + "─" * 70 + "─") + + track_data = tracks_data[track_key] + models_dict = track_data.get("models", {}) + + for model in model_order: + if model not in models_dict: + lines.append(f" β”‚ [{model}] (no data)") + continue + + captured_case = models_dict[model] + lines.append(f" β”‚ [{model}]") + + if captured_case.tool_calls: + for tc in captured_case.tool_calls: + lines.append(f" β”‚ - {tc.name}") + if tc.args: + for key, value in tc.args.items(): + lines.append( + f" β”‚ {key}: {self._format_value(value)}" + ) + else: + lines.append(" β”‚ (no tool calls)") + lines.append(" β”‚") + + lines.append(" " + "β””" + "─" * 70 + "β”˜") + lines.append("") + else: + # No tracks - render models directly + track_key = track_keys[0] if track_keys else "_default" + track_data = tracks_data.get(track_key, {}) + models_dict = track_data.get("models", {}) + + lines.append(" Tool Calls by Model:") + lines.append(" " + "-" * 70) + + for model in model_order: + if model not in models_dict: + lines.append(f" [{model}] (no data)") + continue + + captured_case = models_dict[model] + lines.append(f" [{model}]") + + if captured_case.tool_calls: + for tc in captured_case.tool_calls: + lines.append(f" - {tc.name}") + if tc.args: + for key, value in tc.args.items(): + lines.append( + f" {key}: {self._format_value(value)}" + ) + else: + lines.append(" (no tool calls)") + lines.append("") + + # Context section + system_msg = case_data.get("system_message") + addl_msgs = case_data.get("additional_messages") + if include_context and (system_msg or addl_msgs): + lines.append(" πŸ“‹ Context:") + if system_msg: + lines.append(f" System: {system_msg}") + if addl_msgs: + lines.append(f" Conversation ({len(addl_msgs)} messages):") + lines.extend(self._format_conversation_text(addl_msgs)) + lines.append("") + + lines.append("") + + # Summary + total_models = len(model_order) + total_suites = len(grouped_data) + total_cases = sum(len(cases) for cases in grouped_data.values()) + track_info = f", {len([t for t in track_order if t])} track(s)" if has_tracks else "" + + lines.append("=" * 78) + lines.append( + f"Summary: {total_cases} cases across {total_suites} suite(s), " + f"{total_models} model(s){track_info}" + ) + lines.append("") + + return "\n".join(lines) + + def _format_conversation_text(self, messages: list[dict]) -> list[str]: + """Format conversation messages as plain text.""" + lines: list[str] = [] + + for i, msg in enumerate(messages): + role = msg.get("role", "unknown") + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", []) + name = msg.get("name", "") + + # Role indicators + role_prefix = { + "user": " [USER]", + "assistant": " [ASSISTANT]", + "tool": " [TOOL]", + "system": " [SYSTEM]", + }.get(role, f" [{role.upper()}]") + + # Add separator between messages + if i > 0: + lines.append(" " + "-" * 50) + + # Header + if role == "tool" and name: + lines.append(f"{role_prefix} ({name})") + else: + lines.append(role_prefix) + + # Content + if content: + # Indent content lines + for line in content.split("\n"): + if line.strip(): + lines.append(f" {line}") + elif role == "assistant" and not content and tool_calls: + lines.append(" (calling tools...)") + + # Tool calls for assistant messages + if tool_calls: + for tc in tool_calls: + func = tc.get("function", {}) + tc_name = func.get("name", "unknown") + tc_args = func.get("arguments", "{}") + + lines.append(f" -> {tc_name}") + + # Parse and format arguments + try: + args_dict = json.loads(tc_args) if isinstance(tc_args, str) else tc_args + args_formatted = json.dumps(args_dict, indent=2) + for arg_line in args_formatted.split("\n"): + lines.append(f" {arg_line}") + except (json.JSONDecodeError, TypeError): + lines.append(f" {tc_args}") + + return lines + + def _format_value(self, value: Any) -> str: + """Format a value for display, truncating if too long.""" + str_value = str(value) + if len(str_value) > 60: + return str_value[:57] + "..." + return str_value diff --git a/libs/arcade-cli/arcade_cli/main.py b/libs/arcade-cli/arcade_cli/main.py index d9ad89135..78b52b3a7 100644 --- a/libs/arcade-cli/arcade_cli/main.py +++ b/libs/arcade-cli/arcade_cli/main.py @@ -11,8 +11,6 @@ from arcade_core.constants import CREDENTIALS_FILE_PATH, PROD_COORDINATOR_HOST, PROD_ENGINE_HOST from arcadepy import Arcade from rich.console import Console -from rich.text import Text -from tqdm import tqdm from arcade_cli.authn import ( OAuthLoginError, @@ -22,9 +20,7 @@ perform_oauth_login, save_credentials_from_whoami, ) -from arcade_cli.display import ( - display_eval_results, -) +from arcade_cli.evals_runner import run_capture, run_evaluations from arcade_cli.org import app as org_app from arcade_cli.project import app as project_app from arcade_cli.secret import app as secret_app @@ -32,14 +28,19 @@ from arcade_cli.show import show_logic from arcade_cli.usage.command_tracker import TrackedTyper, TrackedTyperGroup from arcade_cli.utils import ( + ModelSpec, Provider, compute_base_url, + expand_provider_configs, + get_default_model, get_eval_files, handle_cli_error, load_eval_suites, log_engine_health, + parse_output_paths, + parse_provider_spec, require_dependency, - resolve_provider_api_key, + resolve_provider_api_keys, version_callback, ) @@ -404,23 +405,54 @@ def evals( "-c", help="Maximum number of concurrent evaluations (default: 1)", ), - models: str = typer.Option( - "gpt-4o", - "--models", - "-m", - help="The models to use for evaluation (default: gpt-4o). Use commas to separate multiple models. All models must belong to the same provider.", - ), - provider: Provider = typer.Option( - Provider.OPENAI, - "--provider", + use_provider: Optional[str] = typer.Option( + None, + "--use-provider", "-p", - help="The provider of the models to use for evaluation.", + help="Provider(s) and models to use. Format: 'provider' or 'provider:model1,model2'. " + "Multiple providers: separate with spaces. " + "Examples: 'openai' or 'openai:gpt-4o anthropic:claude-sonnet-4-5-20250929'", ), - provider_api_key: str = typer.Option( + api_key: Optional[list[str]] = typer.Option( None, - "--provider-api-key", + "--api-key", "-k", - help="The model provider API key. If not provided, will look for the appropriate environment variable based on the provider (e.g., OPENAI_API_KEY for openai provider), first in the current environment, then in the current working directory's .env file.", + help="API key(s) for provider(s). Format: 'provider:key'. " + "Can be repeated. Examples: --api-key openai:sk-... --api-key anthropic:sk-ant-...", + ), + only_failed: bool = typer.Option( + False, + "--only-failed", + "-f", + help="Show only failed evaluations", + ), + output: Optional[list[str]] = typer.Option( + None, + "--output", + "-o", + help="Output file(s) with auto-detected format from extension. " + "Examples: -o results.json, -o results.md -o results.html, -o results (all formats). " + "Can be repeated for multiple formats.", + ), + capture: bool = typer.Option( + False, + "--capture", + help="Run in capture mode - record tool calls without evaluation scoring", + ), + include_context: bool = typer.Option( + False, + "--include-context", + help="Include system_message and additional_messages in output (works for both eval and capture modes)", + ), + host: Optional[str] = typer.Option( + None, + "--host", + help="Arcade API host for gateway connections (e.g., 'api.bosslevel.dev')", + ), + port: Optional[int] = typer.Option( + None, + "--port", + help="Arcade API port for gateway connections (default: 443 for HTTPS)", ), debug: bool = typer.Option(False, "--debug", help="Show debug information"), ) -> None: @@ -444,27 +476,87 @@ def evals( pip_install_command=r"pip install arcade-tdk", ) - models_list = models.split(",") # Use 'models_list' to avoid shadowing + # --- Build model specs from flags --- + model_specs: list[ModelSpec] = [] + + # Resolve API keys from --api-key flags and environment + api_keys = resolve_provider_api_keys(api_keys_specs=api_key) + + if use_provider: + # Parse provider specs - supports space-separated values + # e.g., "openai:gpt-4o anthropic:claude" + provider_specs = use_provider.split() + try: + provider_configs = [parse_provider_spec(spec) for spec in provider_specs] + except ValueError as e: + handle_cli_error(str(e), should_exit=True) + return # For type checker + + # Expand to model specs + try: + model_specs = expand_provider_configs(provider_configs, api_keys) + except ValueError as e: + handle_cli_error(str(e), should_exit=True) + return # For type checker + else: + # Default: OpenAI with default model + if not api_keys.get(Provider.OPENAI): + handle_cli_error( + "API key not found for provider 'openai'. " + "Please provide it via --api-key openai:KEY, set the OPENAI_API_KEY environment variable, " + "or add it to a .env file in the current directory.\n\n" + "Tip: Use --use-provider to specify a different provider (e.g., --use-provider anthropic)", + should_exit=True, + ) + return # For type checker - # Resolve the API key for the provider - resolved_api_key = resolve_provider_api_key(provider, provider_api_key) - if not resolved_api_key: - provider_env_vars = { - Provider.OPENAI: "OPENAI_API_KEY", - } - env_var_name = provider_env_vars.get(provider, f"{provider.upper()}_API_KEY") - handle_cli_error( - f"API key not found for provider '{provider.value}'. " - f"Please provide it via --provider-api-key,-k argument, set the {env_var_name} environment variable, " - f"or add it to a .env file in the current directory.", - should_exit=True, - ) + model_specs = [ + ModelSpec( + provider=Provider.OPENAI, + model=get_default_model(Provider.OPENAI), + api_key=api_keys[Provider.OPENAI], # type: ignore[arg-type] + ) + ] + + if not model_specs: + handle_cli_error("No models specified. Use --use-provider to specify models.") + return eval_files = get_eval_files(directory) if not eval_files: return - console.print("\nRunning evaluations", style="bold") + # Warn about incompatible flag combinations + if capture: + console.print("\nRunning in capture mode", style="bold cyan") + if only_failed: + console.print("[yellow]⚠️ --only-failed is ignored in capture mode[/yellow]") + if show_details: + console.print("[yellow]⚠️ --details is ignored in capture mode[/yellow]") + else: + console.print("\nRunning evaluations", style="bold") + + # Show which models will be used + unique_providers = {spec.provider.value for spec in model_specs} + if len(unique_providers) > 1: + console.print( + f"[bold cyan]Using {len(model_specs)} model(s) across {len(unique_providers)} providers[/bold cyan]" + ) + for spec in model_specs: + console.print(f" β€’ {spec.display_name}", style="dim") + + # Set arcade URL override BEFORE loading suites (so MCP connections use it) + if host or port: + # Build URL from --host and --port + if not host: + handle_cli_error("--port requires --host to be specified", should_exit=True) + return + + # Default to HTTPS on port 443 + scheme = "https" + port_str = f":{port}" if port and port != 443 else "" + constructed_url = f"{scheme}://{host}{port_str}" + os.environ["ARCADE_API_BASE_URL"] = constructed_url # Use the new function to load eval suites eval_suites = load_eval_suites(eval_files) @@ -480,39 +572,44 @@ def evals( style="bold", ) - async def run_evaluations() -> None: - all_evaluations = [] - tasks = [] - for suite_func in eval_suites: - console.print( - Text.assemble( - ("Running evaluations in ", "bold"), - (suite_func.__name__, "bold blue"), - ) - ) - for model in models_list: - task = asyncio.create_task( - suite_func( - provider_api_key=resolved_api_key, - model=model, - max_concurrency=max_concurrent, - ) - ) - tasks.append(task) - - # Track progress and results as suite functions complete - with tqdm(total=len(tasks), desc="Evaluations Progress") as pbar: - results = [] - for f in asyncio.as_completed(tasks): - results.append(await f) - pbar.update(1) + # Parse output paths with smart format detection + final_output_file: str | None = None + final_output_formats: list[str] = [] - # TODO error handling on each eval - all_evaluations.extend(results) - display_eval_results(all_evaluations, show_details=show_details) + if output: + try: + final_output_file, final_output_formats = parse_output_paths(output) + except ValueError as e: + handle_cli_error(str(e), should_exit=True) + return try: - asyncio.run(run_evaluations()) + if capture: + asyncio.run( + run_capture( + eval_suites=eval_suites, + model_specs=model_specs, + max_concurrent=max_concurrent, + include_context=include_context, + output_file=final_output_file, + output_format=",".join(final_output_formats) if final_output_formats else "txt", + console=console, + ) + ) + else: + asyncio.run( + run_evaluations( + eval_suites=eval_suites, + model_specs=model_specs, + max_concurrent=max_concurrent, + show_details=show_details, + output_file=final_output_file, + output_format=",".join(final_output_formats) if final_output_formats else "txt", + failed_only=only_failed, + include_context=include_context, + console=console, + ) + ) except Exception as e: handle_cli_error("Failed to run evaluations", e, debug) diff --git a/libs/arcade-cli/arcade_cli/utils.py b/libs/arcade-cli/arcade_cli/utils.py index 256be8e5e..c58180699 100644 --- a/libs/arcade-cli/arcade_cli/utils.py +++ b/libs/arcade-cli/arcade_cli/utils.py @@ -4,14 +4,13 @@ import shlex import sys import traceback -import webbrowser from dataclasses import dataclass from datetime import datetime from enum import Enum from importlib import metadata from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Union, cast +from typing import Any, Callable, cast from urllib.parse import urlparse import idna @@ -35,18 +34,9 @@ Arcade, ) from arcadepy.types import AuthorizationResponse -from openai import OpenAI, Stream -from openai.types.chat.chat_completion import Choice as ChatCompletionChoice -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from openai.types.chat.chat_completion_chunk import ( - Choice as ChatCompletionChunkChoice, -) from pydantic import ValidationError from rich.console import Console -from rich.live import Live -from rich.markdown import Markdown from rich.markup import escape -from rich.text import Text from typer.core import TyperGroup from typer.models import Context @@ -77,6 +67,302 @@ class Provider(str, Enum): """Supported model providers for evaluations.""" OPENAI = "openai" + ANTHROPIC = "anthropic" + + +# ============================================================================ +# Default Models Configuration +# ============================================================================ +# Edit these values to change the default models used by the CLI. +# These are used when --models is not specified. +# +# Note: Anthropic models include date suffixes (e.g., -20250929) which may need +# periodic updates. Check https://docs.anthropic.com/en/docs/about-claude/models +# for the latest model identifiers. + +DEFAULT_MODELS: dict[Provider, str] = { + Provider.OPENAI: "gpt-4o", + Provider.ANTHROPIC: "claude-sonnet-4-5-20250929", +} + + +def get_default_model(provider: Provider) -> str: + """Get the default model for a provider. + + Args: + provider: The provider to get the default model for. + + Returns: + The default model name for the provider. + """ + return DEFAULT_MODELS.get(provider, "gpt-4o") + + +# ============================================================================ +# Output Format Detection +# ============================================================================ + +ALL_OUTPUT_FORMATS = ["txt", "md", "html", "json"] + + +def parse_output_paths(output_paths: list[str] | None) -> tuple[str | None, list[str]]: + """Parse --output/-o paths into base path and format list. + + Supports: + - Single file with extension: "results.json" β†’ ("results", ["json"]) + - Multiple files: ["results.md", "results.html"] β†’ ("results", ["md", "html"]) + - No extension: "results" β†’ ("results", ["txt", "md", "html", "json"]) + + Args: + output_paths: List of output paths from --output/-o flag. + + Returns: + Tuple of (base_path, formats). Returns (None, []) if no paths. + + Raises: + ValueError: If paths have inconsistent base names or invalid extensions. + """ + if not output_paths: + return None, [] + + # Extract base path and formats + base_path = None + formats: list[str] = [] + + for path_str in output_paths: + path = Path(path_str) + stem = path.stem + ext = path.suffix.lstrip(".") + + # Determine base path (all paths should have same base) + if base_path is None: + base_path = str(Path(path.parent) / stem) + elif str(Path(path.parent) / stem) != base_path: + raise ValueError( + f"Output paths have different base names: '{base_path}' vs '{Path(path.parent) / stem}'. " + "All outputs must use the same base path." + ) + + # No extension means all formats + if not ext: + formats = ALL_OUTPUT_FORMATS.copy() + break + + # Validate extension + if ext not in ALL_OUTPUT_FORMATS: + valid = ", ".join(ALL_OUTPUT_FORMATS) + raise ValueError(f"Invalid output format '.{ext}'. Valid extensions: {valid}") + + if ext not in formats: + formats.append(ext) + + return base_path, formats + + +def parse_api_key_spec(spec: str) -> tuple[Provider, str]: + """Parse --api-key value into (provider, key). + + Args: + spec: API key spec string. Format: "provider:key" + Examples: "openai:sk-...", "anthropic:sk-ant-..." + + Returns: + Tuple of (Provider, api_key_string). + + Raises: + ValueError: If format is invalid or provider is unknown. + """ + if ":" not in spec: + raise ValueError( + f"Invalid --api-key format: '{spec}'. " + "Expected format: 'provider:key' (e.g., 'openai:sk-...')" + ) + + provider_str, key = spec.split(":", 1) + provider_str = provider_str.strip().lower() + key = key.strip() + + if not key: + raise ValueError(f"Empty API key for provider '{provider_str}'") + + try: + provider = Provider(provider_str) + except ValueError: + valid_providers = [p.value for p in Provider] + raise ValueError( + f"Invalid provider '{provider_str}' in --api-key. " + f"Valid providers: {', '.join(valid_providers)}" + ) + + return provider, key + + +# ============================================================================ +# Multi-Provider Model Specification +# ============================================================================ + + +@dataclass +class ProviderConfig: + """Configuration for a single provider from CLI input. + + Parsed from --use-provider flag values like: + - "openai" -> provider=OPENAI, models=[] (use default) + - "openai:gpt-4o,gpt-4o-mini" -> provider=OPENAI, models=["gpt-4o", "gpt-4o-mini"] + """ + + provider: Provider + models: list[str] # Empty list means use default model + + def get_models(self) -> list[str]: + """Get models, using default if none specified.""" + if self.models: + return self.models + return [get_default_model(self.provider)] + + +@dataclass +class ModelSpec: + """A specific model to run evaluations against. + + This is the expanded form used by the runner - one ModelSpec per + (provider, model, api_key) combination. + """ + + provider: Provider + model: str + api_key: str + + @property + def display_name(self) -> str: + """Get display name in format 'provider/model'.""" + return f"{self.provider.value}/{self.model}" + + +def parse_provider_spec(spec: str) -> ProviderConfig: + """Parse a --use-provider value into a ProviderConfig. + + Args: + spec: Provider spec string. Examples: + - "openai" -> use OpenAI with default model + - "openai:gpt-4o" -> use OpenAI with gpt-4o + - "anthropic:claude-sonnet-4-5-20250929,claude-3-haiku-20240307" + + Returns: + ProviderConfig with parsed provider and models. + + Raises: + ValueError: If provider name is invalid. + + Examples: + >>> parse_provider_spec("openai") + ProviderConfig(provider=Provider.OPENAI, models=[]) + >>> parse_provider_spec("openai:gpt-4o,gpt-4o-mini") + ProviderConfig(provider=Provider.OPENAI, models=['gpt-4o', 'gpt-4o-mini']) + """ + if ":" in spec: + provider_str, models_str = spec.split(":", 1) + models = [m.strip() for m in models_str.split(",") if m.strip()] + else: + provider_str = spec.strip() + models = [] + + # Validate provider + provider_str_lower = provider_str.lower() + try: + provider = Provider(provider_str_lower) + except ValueError: + valid_providers = [p.value for p in Provider] + raise ValueError( + f"Invalid provider '{provider_str}'. Valid providers: {', '.join(valid_providers)}" + ) + + return ProviderConfig(provider=provider, models=models) + + +def expand_provider_configs( + configs: list[ProviderConfig], + api_keys: dict[Provider, str | None], +) -> list[ModelSpec]: + """Expand provider configs into individual ModelSpecs with resolved API keys. + + Args: + configs: List of ProviderConfig from parsed --use-provider flags. + api_keys: Dict mapping Provider to API key (from flags or env vars). + + Returns: + List of ModelSpec, one per (provider, model) combination. + + Raises: + ValueError: If API key is missing for any provider. + """ + model_specs: list[ModelSpec] = [] + + for config in configs: + api_key = api_keys.get(config.provider) + if not api_key: + env_var = f"{config.provider.value.upper()}_API_KEY" + raise ValueError( + f"API key required for provider '{config.provider.value}'. " + f"Provide via --{config.provider.value}-key or set {env_var} environment variable." + ) + + for model in config.get_models(): + model_specs.append(ModelSpec(provider=config.provider, model=model, api_key=api_key)) + + return model_specs + + +def resolve_provider_api_keys( + api_keys_specs: list[str] | None = None, +) -> dict[Provider, str | None]: + """Resolve API keys for all providers from flags and environment. + + Priority: --api-key flag > environment variable > .env file + + Args: + api_keys_specs: List of provider:key specs from --api-key flags. + + Returns: + Dict mapping Provider to resolved API key (or None if not found). + """ + from dotenv import dotenv_values + + # Load .env file + env_values = dotenv_values(".env") + + # Start with empty dict + keys: dict[Provider, str | None] = { + Provider.OPENAI: None, + Provider.ANTHROPIC: None, + } + + # Parse --api-key provider:key specs (highest priority) + if api_keys_specs: + for spec in api_keys_specs: + try: + provider, key = parse_api_key_spec(spec) + keys[provider] = key + except ValueError as e: + # Re-raise to let CLI handle error + raise ValueError(str(e)) from e + + # Fallback to environment variables and .env file + def resolve_key_from_env(env_var: str) -> str | None: + # Check current environment + key = os.environ.get(env_var) + if key: + return key + # Check .env file + return env_values.get(env_var) + + # Set from environment if not already set by --api-key + if keys[Provider.OPENAI] is None: + keys[Provider.OPENAI] = resolve_key_from_env("OPENAI_API_KEY") + if keys[Provider.ANTHROPIC] is None: + keys[Provider.ANTHROPIC] = resolve_key_from_env("ANTHROPIC_API_KEY") + + return keys class CLIError(Exception): @@ -319,77 +605,6 @@ def get_tools_from_engine( return tools -def get_tool_messages(choice: dict) -> list[dict]: - if hasattr(choice, "tool_messages") and choice.tool_messages: - return choice.tool_messages # type: ignore[no-any-return] - return [] - - -@dataclass -class StreamingResult: - role: str - full_message: str - tool_messages: list - tool_authorization: dict | None - - -def handle_streaming_content(stream: Stream[ChatCompletionChunk], model: str) -> StreamingResult: - """ - Display the streamed markdown chunks as a single line. - """ - from rich.live import Live - - full_message = "" - tool_messages = [] - tool_authorization = None - role = "" - printed_role: bool = False - - with Live(console=console, refresh_per_second=10) as live: - for chunk in stream: - choice = chunk.choices[0] - role = choice.delta.role or role - - # Display and get tool messages if they exist - tool_messages += get_tool_messages(choice) # type: ignore[arg-type] - tool_authorization = get_tool_authorization(choice) - - chunk_message = choice.delta.content - - if role == "assistant" and tool_authorization: - continue # Skip the message if it's an auth request (handled later in handle_tool_authorization) - - if role == "assistant" and not printed_role: - console.print(f"\n[blue][bold]Assistant[/bold] ({model}):[/blue] ") - printed_role = True - - if chunk_message: - full_message += chunk_message - markdown_chunk = Markdown(full_message) - live.update(markdown_chunk) - - # Markdownify URLs in the final message if applicable - if role == "assistant": - full_message = markdownify_urls(full_message) - live.update(Markdown(full_message)) - - return StreamingResult(role, full_message, tool_messages, tool_authorization) - - -def markdownify_urls(message: str) -> str: - """ - Convert URLs in the message to markdown links. - """ - import re - - # This regex will match URLs that are not already formatted as markdown links: - # [Link text](https://example.com) - url_pattern = r"(?)", message) - - def validate_and_get_config( validate_api: bool = True, validate_user: bool = True, @@ -555,106 +770,6 @@ def log_engine_health(client: Arcade) -> None: ) -@dataclass -class ChatInteractionResult: - history: list[dict] - tool_messages: list[dict] - tool_authorization: dict | None - - -def handle_chat_interaction( - client: OpenAI, - model: str, - history: list[dict], - user_email: str | None, - stream: bool = False, -) -> ChatInteractionResult: - """ - Handle a single chat-request/chat-response interaction for both streamed and non-streamed responses. - Handling the chat response includes: - - Streaming the response if the stream flag is set - - Displaying the response in the console - - Getting the tool messages and tool authorization from the response - - Updating the history with the response, tool calls, and tool responses - """ - if stream: - # TODO Fix this in the client so users don't deal with these - # typing issues - response = client.chat.completions.create( # type: ignore[call-overload] - model=model, - messages=history, - tool_choice="generate", - user=user_email, - stream=True, - ) - streaming_result = handle_streaming_content(response, model) - role, message_content = streaming_result.role, streaming_result.full_message - tool_messages, tool_authorization = ( - streaming_result.tool_messages, - streaming_result.tool_authorization, - ) - else: - response = client.chat.completions.create( # type: ignore[call-overload] - model=model, - messages=history, - tool_choice="generate", - user=user_email, - stream=False, - ) - message_content = response.choices[0].message.content or "" - - # Get extra fields from the response - tool_messages = get_tool_messages(response.choices[0]) - tool_authorization = get_tool_authorization(response.choices[0]) - - role = response.choices[0].message.role - - if role == "assistant" and tool_authorization: - pass # Skip the message if it's an auth request (handled later in handle_tool_authorization) - elif role == "assistant": - message_content = markdownify_urls(message_content) - console.print( - f"\n[blue][bold]Assistant[/bold] ({model}):[/blue] ", - Markdown(message_content), - ) - else: - console.print(f"\n[bold]{role}:[/bold] {message_content}") - - history += tool_messages - history.append({"role": role, "content": message_content}) - - return ChatInteractionResult(history, tool_messages, tool_authorization) - - -def handle_tool_authorization( - arcade_client: Arcade, - tool_authorization: AuthorizationResponse, - history: list[dict[str, Any]], - openai_client: OpenAI, - model: str, - user_email: str | None, - stream: bool, -) -> ChatInteractionResult: - with Live(console=console, refresh_per_second=4) as live: - if tool_authorization.url: - authorization_url = str(tool_authorization.url) - webbrowser.open(authorization_url) - message = ( - "You'll need to authorize this action in your browser.\n\n" - f"If a browser doesn't open automatically, click [this link]({authorization_url}) " - f"or copy this URL and paste it into your browser:\n\n{authorization_url}" - ) - live.update(Markdown(message, style="dim")) - - wait_for_authorization_completion(arcade_client, tool_authorization) - - message = "Thanks for authorizing the action! Sending your request..." - live.update(Text(message, style="dim")) - - history.pop() - return handle_chat_interaction(openai_client, model, history, user_email, stream) - - def wait_for_authorization_completion( client: Arcade, tool_authorization: AuthorizationResponse | None ) -> None: @@ -677,28 +792,6 @@ def wait_for_authorization_completion( continue -def get_tool_authorization( - choice: Union[ChatCompletionChoice, ChatCompletionChunkChoice], -) -> dict | None: - """ - Get the tool authorization from a chat response's choice. - """ - if hasattr(choice, "tool_authorizations") and choice.tool_authorizations: - return choice.tool_authorizations[0] # type: ignore[no-any-return] - return None - - -def is_authorization_pending(tool_authorization: dict | None) -> bool: - """ - Check if the authorization for a tool call is pending. - Expects a chat response's choice.tool_authorizations as input. - """ - is_auth_pending = ( - tool_authorization is not None and tool_authorization.get("status", "") == "pending" - ) - return is_auth_pending - - def get_eval_files(directory: str) -> list[Path]: """ Get a list of evaluation files starting with 'eval_' and ending with '.py' in the given directory. @@ -1020,6 +1113,7 @@ def resolve_provider_api_key(provider: Provider, provider_api_key: str | None = # Map providers to their environment variable names provider_env_vars = { Provider.OPENAI: "OPENAI_API_KEY", + Provider.ANTHROPIC: "ANTHROPIC_API_KEY", } env_var_name = provider_env_vars.get(provider) @@ -1042,6 +1136,65 @@ def resolve_provider_api_key(provider: Provider, provider_api_key: str | None = return None +def filter_failed_evaluations( + all_evaluations: list[list[dict[str, Any]]], +) -> tuple[list[list[dict[str, Any]]], tuple[int, int, int, int]]: + """ + Filter evaluation results to show only failed cases and calculate original counts. + + Args: + all_evaluations: List of evaluation results with structure: + [[{model: str, rubric: str, cases: [{name, input, evaluation}]}]] + + Returns: + Tuple of (filtered_evaluations, original_counts) where original_counts is + (total_cases, total_passed, total_failed, total_warned) + """ + original_total_cases = 0 + original_total_passed = 0 + original_total_failed = 0 + original_total_warned = 0 + + # Calculate original counts before filtering + for eval_suite in all_evaluations: + for model_results in eval_suite: + for case in model_results.get("cases", []): + evaluation = case["evaluation"] + original_total_cases += 1 + if evaluation.passed: + original_total_passed += 1 + elif evaluation.warning: + original_total_warned += 1 + else: + original_total_failed += 1 + + # Filter to show only failed evaluations + filtered_evaluations = [] + for eval_suite in all_evaluations: + filtered_suite = [] + for model_results in eval_suite: + filtered_cases = [ + case + for case in model_results.get("cases", []) + if not case["evaluation"].passed and not case["evaluation"].warning + ] + if filtered_cases: # Only include model results with failed cases + filtered_model_results = model_results.copy() + filtered_model_results["cases"] = filtered_cases + filtered_suite.append(filtered_model_results) + if filtered_suite: + filtered_evaluations.append(filtered_suite) + + original_counts = ( + original_total_cases, + original_total_passed, + original_total_failed, + original_total_warned, + ) + + return filtered_evaluations, original_counts + + def require_dependency( package_name: str, command_name: str, diff --git a/libs/arcade-core/arcade_core/converters/__init__.py b/libs/arcade-core/arcade_core/converters/__init__.py new file mode 100644 index 000000000..0204f5af2 --- /dev/null +++ b/libs/arcade-core/arcade_core/converters/__init__.py @@ -0,0 +1,34 @@ +"""Converters for transforming tool definitions between formats.""" + +from .anthropic import ( + AnthropicInputSchema, + AnthropicInputSchemaProperty, + AnthropicToolList, + AnthropicToolSchema, + to_anthropic, +) +from .openai import ( + OpenAIFunctionParameterProperty, + OpenAIFunctionParameters, + OpenAIFunctionSchema, + OpenAIToolList, + OpenAIToolSchema, + to_openai, +) +from .utils import denormalize_tool_name, normalize_tool_name + +__all__ = [ + "AnthropicInputSchema", + "AnthropicInputSchemaProperty", + "AnthropicToolList", + "AnthropicToolSchema", + "OpenAIFunctionParameterProperty", + "OpenAIFunctionParameters", + "OpenAIFunctionSchema", + "OpenAIToolList", + "OpenAIToolSchema", + "denormalize_tool_name", + "normalize_tool_name", + "to_anthropic", + "to_openai", +] diff --git a/libs/arcade-core/arcade_core/converters/anthropic.py b/libs/arcade-core/arcade_core/converters/anthropic.py new file mode 100644 index 000000000..1c87378f2 --- /dev/null +++ b/libs/arcade-core/arcade_core/converters/anthropic.py @@ -0,0 +1,194 @@ +"""Converter for converting Arcade ToolDefinition to Anthropic tool schema.""" + +from typing import Any, TypedDict + +from arcade_core.catalog import MaterializedTool +from arcade_core.converters.utils import normalize_tool_name +from arcade_core.schema import InputParameter, ValueSchema + +# ---------------------------------------------------------------------------- +# Type definitions for JSON tool schemas used by Anthropic APIs. +# Defines the proper types for tool schemas to ensure +# compatibility with Anthropic's Messages API tool use feature. +# Reference: https://docs.anthropic.com/en/docs/build-with-claude/tool-use +# ---------------------------------------------------------------------------- + + +class AnthropicInputSchemaProperty(TypedDict, total=False): + """Type definition for a property within Anthropic input schema.""" + + type: str + """The JSON Schema type for this property.""" + + description: str + """Description of the property.""" + + enum: list[Any] + """Allowed values for enum properties.""" + + items: dict[str, Any] + """Schema for array items when type is 'array'.""" + + properties: dict[str, "AnthropicInputSchemaProperty"] + """Nested properties when type is 'object'.""" + + required: list[str] + """Required fields for nested objects.""" + + +class AnthropicInputSchema(TypedDict, total=False): + """Type definition for Anthropic tool input schema.""" + + type: str + """Must be 'object' for tool input schemas.""" + + properties: dict[str, AnthropicInputSchemaProperty] + """The properties of the tool input parameters.""" + + required: list[str] + """List of required parameter names.""" + + +class AnthropicToolSchema(TypedDict, total=False): + """ + Schema for a tool definition passed to Anthropic's `tools` parameter. + + Unlike OpenAI, Anthropic uses a flat structure without a wrapper object. + The schema uses `input_schema` instead of `parameters`. + """ + + name: str + """The name of the tool.""" + + description: str + """Description of what the tool does.""" + + input_schema: AnthropicInputSchema + """JSON Schema describing the tool's input parameters.""" + + +# Type alias for a list of Anthropic tool schemas +AnthropicToolList = list[AnthropicToolSchema] + + +# ---------------------------------------------------------------------------- +# Converters +# ---------------------------------------------------------------------------- +def to_anthropic(tool: MaterializedTool) -> AnthropicToolSchema: + """Convert a MaterializedTool to Anthropic tool schema format. + + Args: + tool: The MaterializedTool to convert + + Returns: + The Anthropic tool schema format (what is passed to the Anthropic API) + """ + name = normalize_tool_name(tool.definition.fully_qualified_name) + description = tool.description + input_schema = _convert_input_parameters_to_json_schema(tool.definition.input.parameters) + + return _create_tool_schema(name, description, input_schema) + + +def _create_tool_schema( + name: str, description: str, input_schema: AnthropicInputSchema +) -> AnthropicToolSchema: + """Create a properly typed Anthropic tool schema. + + Args: + name: The name of the tool + description: Description of what the tool does + input_schema: JSON schema for the tool input parameters + + Returns: + A properly typed AnthropicToolSchema + """ + tool: AnthropicToolSchema = { + "name": name, + "description": description, + "input_schema": input_schema, + } + + return tool + + +def _convert_value_schema_to_json_schema( + value_schema: ValueSchema, +) -> AnthropicInputSchemaProperty: + """Convert Arcade ValueSchema to JSON Schema format for Anthropic.""" + type_mapping = { + "string": "string", + "integer": "integer", + "number": "number", + "boolean": "boolean", + "json": "object", + "array": "array", + } + + schema: AnthropicInputSchemaProperty = {"type": type_mapping[value_schema.val_type]} + + if value_schema.val_type == "array" and value_schema.inner_val_type: + items_schema: dict[str, Any] = {"type": type_mapping[value_schema.inner_val_type]} + + # For arrays, enum should be applied to the items, not the array itself + if value_schema.enum: + items_schema["enum"] = value_schema.enum + + schema["items"] = items_schema + else: + # Handle enum for non-array types + if value_schema.enum: + schema["enum"] = value_schema.enum + + # Handle object properties + if value_schema.val_type == "json" and value_schema.properties: + schema["properties"] = { + name: _convert_value_schema_to_json_schema(nested_schema) + for name, nested_schema in value_schema.properties.items() + } + + return schema + + +def _convert_input_parameters_to_json_schema( + parameters: list[InputParameter], +) -> AnthropicInputSchema: + """Convert list of InputParameter to JSON schema parameters object. + + Unlike OpenAI's strict mode, Anthropic uses standard JSON Schema: + - Only actually required parameters are listed in 'required' + - No need to add 'null' to optional parameter types + - No 'additionalProperties: false' requirement + """ + if not parameters: + # Minimal JSON schema for a tool with no input parameters + return { + "type": "object", + "properties": {}, + } + + properties: dict[str, AnthropicInputSchemaProperty] = {} + required: list[str] = [] + + for parameter in parameters: + param_schema = _convert_value_schema_to_json_schema(parameter.value_schema) + + if parameter.description: + param_schema["description"] = parameter.description + + properties[parameter.name] = param_schema + + # Only add actually required parameters to the required list + if parameter.required: + required.append(parameter.name) + + json_schema: AnthropicInputSchema = { + "type": "object", + "properties": properties, + } + + # Only include 'required' if there are required parameters + if required: + json_schema["required"] = required + + return json_schema diff --git a/libs/arcade-core/arcade_core/converters/openai.py b/libs/arcade-core/arcade_core/converters/openai.py index c4fb317cf..38e3860f6 100644 --- a/libs/arcade-core/arcade_core/converters/openai.py +++ b/libs/arcade-core/arcade_core/converters/openai.py @@ -3,6 +3,7 @@ from typing import Any, Literal, TypedDict from arcade_core.catalog import MaterializedTool +from arcade_core.converters.utils import normalize_tool_name from arcade_core.schema import InputParameter, ValueSchema # ---------------------------------------------------------------------------- @@ -101,7 +102,7 @@ def to_openai(tool: MaterializedTool) -> OpenAIToolSchema: Returns: The OpenAI JsonToolSchema format (what is passed to the OpenAI API) """ - name = tool.definition.fully_qualified_name.replace(".", "_") + name = normalize_tool_name(tool.definition.fully_qualified_name) description = tool.description parameters_schema = _convert_input_parameters_to_json_schema(tool.definition.input.parameters) return _create_tool_schema(name, description, parameters_schema) diff --git a/libs/arcade-core/arcade_core/converters/utils.py b/libs/arcade-core/arcade_core/converters/utils.py new file mode 100644 index 000000000..71325b8b0 --- /dev/null +++ b/libs/arcade-core/arcade_core/converters/utils.py @@ -0,0 +1,54 @@ +"""Shared utilities for tool name conversion across providers. + +This module contains common utilities used by both OpenAI and Anthropic converters. +""" + + +def normalize_tool_name(name: str) -> str: + """ + Normalize a tool name for LLM provider compatibility. + + Both OpenAI and Anthropic have restrictions on tool names: + - OpenAI: allows alphanumeric, hyphens, underscores (max 64 chars) + - Anthropic: allows alphanumeric and underscores only (no dots) + + Arcade uses dot notation for fully qualified names (e.g., "Google.Search"), + so we normalize by replacing dots with underscores. + + Args: + name: The original tool name (e.g., "Google.Search") + + Returns: + The normalized tool name (e.g., "Google_Search") + + Examples: + >>> normalize_tool_name("Google.Search") + 'Google_Search' + >>> normalize_tool_name("MyTool") + 'MyTool' + >>> normalize_tool_name("Namespace.Sub.Tool") + 'Namespace_Sub_Tool' + """ + return name.replace(".", "_") + + +def denormalize_tool_name(normalized_name: str, separator: str = ".") -> str: + """ + Reverse the normalization of a tool name. + + This converts provider-format names back to Arcade's dot notation. + Note: This is a best-effort reversal and may not be accurate if the original + name contained underscores. + + Args: + normalized_name: The normalized tool name (e.g., "Google_Search") + separator: The separator to use (default: ".") + + Returns: + The denormalized tool name (e.g., "Google.Search") + + Examples: + >>> denormalize_tool_name("Google_Search") + 'Google.Search' + """ + return normalized_name.replace("_", separator) diff --git a/libs/arcade-core/pyproject.toml b/libs/arcade-core/pyproject.toml index 33fecfcb8..47751bde4 100644 --- a/libs/arcade-core/pyproject.toml +++ b/libs/arcade-core/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-core" -version = "4.1.0" +version = "4.2.0" description = "Arcade Core - Core library for Arcade platform" readme = "README.md" license = { text = "MIT" } diff --git a/libs/arcade-evals/arcade_evals/__init__.py b/libs/arcade-evals/arcade_evals/__init__.py index 8a081573c..83d1c0925 100644 --- a/libs/arcade-evals/arcade_evals/__init__.py +++ b/libs/arcade-evals/arcade_evals/__init__.py @@ -1,15 +1,49 @@ +from ._evalsuite._providers import ProviderName +from ._evalsuite._tool_registry import MCPToolDefinition +from .capture import CapturedCase, CapturedToolCall, CaptureResult from .critic import BinaryCritic, DatetimeCritic, NoneCritic, NumericCritic, SimilarityCritic -from .eval import EvalRubric, EvalSuite, ExpectedToolCall, NamedExpectedToolCall, tool_eval +from .eval import ( + AnyExpectedToolCall, + EvalRubric, + EvalSuite, + ExpectedMCPToolCall, + ExpectedToolCall, + NamedExpectedToolCall, + tool_eval, +) +from .loaders import ( + clear_tools_cache, + load_arcade_mcp_gateway_async, + load_from_stdio_async, + load_mcp_remote_async, + load_stdio_arcade_async, +) +from .weights import FuzzyWeight, Weight, validate_and_normalize_critic_weights __all__ = [ + "AnyExpectedToolCall", "BinaryCritic", + "CaptureResult", + "CapturedCase", + "CapturedToolCall", "DatetimeCritic", "EvalRubric", "EvalSuite", + "ExpectedMCPToolCall", "ExpectedToolCall", + "FuzzyWeight", + "MCPToolDefinition", "NamedExpectedToolCall", "NoneCritic", "NumericCritic", + "ProviderName", "SimilarityCritic", + "Weight", + "clear_tools_cache", + "load_arcade_mcp_gateway_async", + "load_mcp_remote_async", + "load_from_stdio_async", + "load_stdio_arcade_async", "tool_eval", + "validate_and_normalize_critic_weights", ] diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/__init__.py b/libs/arcade-evals/arcade_evals/_evalsuite/__init__.py new file mode 100644 index 000000000..681d1b922 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/__init__.py @@ -0,0 +1 @@ +"""Internal implementation details for EvalSuite""" diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_anthropic_schema.py b/libs/arcade-evals/arcade_evals/_evalsuite/_anthropic_schema.py new file mode 100644 index 000000000..c66141259 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_anthropic_schema.py @@ -0,0 +1,57 @@ +"""Anthropic tool schema conversion (internal). + +Converts MCP-style tool schemas to Anthropic's tool format. + +Anthropic uses standard JSON Schema, so conversion is minimal: +- Rename inputSchema -> input_schema (camelCase to snake_case) +- Normalize tool names (dots to underscores, as Anthropic doesn't allow dots) +- No strict mode transformations needed +- Standard JSON Schema constraints are preserved +""" + +from __future__ import annotations + +from typing import Any + +from arcade_core.converters.utils import normalize_tool_name as _normalize_tool_name + + +def convert_mcp_to_anthropic_tool(mcp_tool: dict[str, Any]) -> dict[str, Any]: + """ + Convert an MCP tool definition to Anthropic tool format. + + This is a minimal conversion since Anthropic accepts standard JSON Schema. + Changes: + - Rename `inputSchema` to `input_schema` + - Normalize tool name (dots to underscores) + + Args: + mcp_tool: MCP-style tool definition with keys: + - name (required) + - description (optional) + - inputSchema (optional, JSON Schema) + + Returns: + Anthropic tool definition with keys: + - name + - description + - input_schema + """ + return { + "name": _normalize_tool_name(mcp_tool["name"]), + "description": mcp_tool.get("description", ""), + "input_schema": mcp_tool.get("inputSchema", {"type": "object", "properties": {}}), + } + + +def convert_mcp_tools_to_anthropic(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Convert a list of MCP tool definitions to Anthropic tool format. + + Args: + mcp_tools: List of MCP-style tool definitions. + + Returns: + List of Anthropic tool definitions. + """ + return [convert_mcp_to_anthropic_tool(tool) for tool in mcp_tools] diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_capture.py b/libs/arcade-evals/arcade_evals/_evalsuite/_capture.py new file mode 100644 index 000000000..711f9e8eb --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_capture.py @@ -0,0 +1,180 @@ +"""Capture mode mixin for EvalSuite. + +This module provides the capture functionality as a mixin class, +keeping it separate from the main evaluation logic in eval.py. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any + +from arcade_evals.capture import CapturedCase, CapturedToolCall, CaptureResult + +if TYPE_CHECKING: + from arcade_evals._evalsuite._comparative import ComparativeCaseBuilder + from arcade_evals._evalsuite._providers import ProviderName + from arcade_evals._evalsuite._tool_registry import EvalSuiteToolRegistry + from arcade_evals._evalsuite._tracks import TrackManager + from arcade_evals._evalsuite._types import EvalRubric + from arcade_evals.eval import EvalCase + + +class _EvalSuiteCaptureMixin: + """Mixin providing capture mode functionality for EvalSuite.""" + + # These attributes are defined in EvalSuite + name: str + cases: list[EvalCase] + max_concurrent: int + rubric: EvalRubric + _internal_registry: EvalSuiteToolRegistry | None + _comparative_case_builders: list[ComparativeCaseBuilder] + _track_manager: TrackManager + + # These methods are defined in EvalSuite + async def _run_openai( + self, + client: Any, + model: str, + case: EvalCase, + registry: EvalSuiteToolRegistry | None = None, + ) -> list[tuple[str, dict[str, Any]]]: + raise NotImplementedError # Implemented in EvalSuite + + async def _run_anthropic( + self, + client: Any, + model: str, + case: EvalCase, + registry: EvalSuiteToolRegistry | None = None, + ) -> list[tuple[str, dict[str, Any]]]: + raise NotImplementedError # Implemented in EvalSuite + + def _process_tool_calls( + self, + tool_calls: list[tuple[str, dict[str, Any]]], + registry: EvalSuiteToolRegistry | None = None, + ) -> list[tuple[str, dict[str, Any]]]: + raise NotImplementedError # Implemented in EvalSuite + + def _create_eval_case(self, *args: Any, **kwargs: Any) -> EvalCase: + raise NotImplementedError # Implemented in EvalSuite + + async def capture( + self, + client: Any, # AsyncOpenAI | AsyncAnthropic + model: str, + provider: ProviderName = "openai", + include_context: bool = False, + ) -> CaptureResult: + """ + Run the evaluation suite in capture mode - records tool calls without scoring. + + Capture mode runs each case and records the tool calls made by the model, + without evaluating or scoring them. This is useful for: + - Generating expected tool calls for new test cases + - Debugging model behavior + - Creating baseline recordings + + Handles both regular cases and comparative cases. For comparative cases, + each track is captured separately with its own tool registry. + + Args: + client: The LLM client instance (AsyncOpenAI or AsyncAnthropic). + model: The model to use. + provider: The provider name ("openai" or "anthropic"). + include_context: Whether to include system_message and additional_messages + in the output. + + Returns: + A CaptureResult containing all captured tool calls. + """ + all_captured: list[CapturedCase] = [] + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def capture_case( + case: EvalCase, + registry: EvalSuiteToolRegistry | None = None, + track: str | None = None, + ) -> CapturedCase: + """Capture a case using the specified registry.""" + async with semaphore: + use_registry = registry or self._internal_registry + if use_registry is None or use_registry.tool_count() == 0: + raise ValueError( + "No tools registered. Use add_* convenience methods or pass catalog=ToolCatalog." + ) + + # Get tool calls based on provider + if provider == "anthropic": + predicted_args = await self._run_anthropic( + client, model, case, registry=use_registry + ) + else: + predicted_args = await self._run_openai( + client, model, case, registry=use_registry + ) + + # Process tool calls (resolve names, fill defaults) + filled_actual_tool_calls = self._process_tool_calls( + predicted_args, registry=use_registry + ) + + # Convert to CapturedToolCall objects + tool_calls = [ + CapturedToolCall(name=name, args=args) + for name, args in filled_actual_tool_calls + ] + + return CapturedCase( + case_name=case.name, + user_message=case.user_message, + tool_calls=tool_calls, + system_message=case.system_message if include_context else None, + additional_messages=case.additional_messages if include_context else None, + track_name=track, + ) + + # Capture regular cases (using default registry) + if self.cases: + tasks = [capture_case(case) for case in self.cases] + regular_captured = await asyncio.gather(*tasks) + all_captured.extend(regular_captured) + + # Capture comparative cases (each track separately) + if self._comparative_case_builders: + for builder in self._comparative_case_builders: + comp_case = builder.build() + + # For each track configured in this comparative case + for track_name in comp_case.track_configs: + if not self._track_manager.has_track(track_name): + continue # Skip missing tracks + + track_registry = self._track_manager.get_registry(track_name) + + # Create an EvalCase from the comparative case + # Use case-specific rubric if defined, otherwise use suite default + case_rubric = comp_case.rubric or self.rubric + eval_case = self._create_eval_case( + name=comp_case.name, # Don't embed track in name - use track_name field + user_message=comp_case.user_message, + system_message=comp_case.system_message, + additional_messages=comp_case.additional_messages, + expected_tool_calls=[], # Not needed for capture + rubric=case_rubric, + critics=[], # Not needed for capture + ) + + captured = await capture_case( + eval_case, registry=track_registry, track=track_name + ) + all_captured.append(captured) + + return CaptureResult( + suite_name=self.name, + model=model, + provider=provider, + captured_cases=list(all_captured), + ) diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_comparative.py b/libs/arcade-evals/arcade_evals/_evalsuite/_comparative.py new file mode 100644 index 000000000..fc2027e1c --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_comparative.py @@ -0,0 +1,132 @@ +"""Comparative case builder for multi-track evaluations. + +Provides a fluent API for defining evaluation cases that run against +multiple tool tracks with track-specific expected results and critics. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from arcade_evals._evalsuite._types import ( + ComparativeCase, + EvalRubric, + ExpectedMCPToolCall, + ExpectedToolCall, +) + +if TYPE_CHECKING: + from arcade_evals.critic import Critic + + +class ComparativeCaseBuilder: + """Fluent builder for creating comparative cases. + + Example: + builder = ComparativeCaseBuilder( + suite=suite, + name="weather_query", + user_message="What's the weather?", + ) + builder.for_track( + "Google Weather", + expected_tool_calls=[...], + critics=[...], + ).for_track( + "OpenWeather", + expected_tool_calls=[...], + critics=[...], + ) + """ + + def __init__( + self, + suite: Any, # EvalSuite - avoid circular import + name: str, + user_message: str, + system_message: str = "", + additional_messages: list[dict[str, str]] | None = None, + rubric: EvalRubric | None = None, + ) -> None: + """Initialize the builder. + + Args: + suite: The parent EvalSuite. + name: Unique case name. + user_message: User message (shared across tracks). + system_message: System message (shared across tracks). + additional_messages: Additional context (shared). + rubric: Default rubric (shared, can be overridden). + """ + self._suite = suite + self._case = ComparativeCase( + name=name, + user_message=user_message, + system_message=system_message, + additional_messages=additional_messages or [], + rubric=rubric, + ) + + def for_track( + self, + track_name: str, + expected_tool_calls: list[ExpectedToolCall | ExpectedMCPToolCall], + critics: list[Critic] | None = None, + ) -> ComparativeCaseBuilder: + """Add track-specific configuration. + + Args: + track_name: The track name (must be registered via add_*_tools). + expected_tool_calls: Expected tool calls for this track. + critics: Critics for this track. + + Returns: + Self for method chaining. + + Raises: + ValueError: If track doesn't exist. + """ + # Validate track exists + if not self._suite._track_manager.has_track(track_name): + available = self._suite._track_manager.get_track_names() + raise ValueError( + f"Track '{track_name}' not found. " + f"Available tracks: {available}. " + f"Register tracks first using add_*_tools(track=...)." + ) + + self._case.add_track_config( + track_name=track_name, + expected_tool_calls=expected_tool_calls, + critics=critics, + ) + return self + + def build(self) -> ComparativeCase: + """Build and return the comparative case. + + Returns: + The configured ComparativeCase. + + Raises: + ValueError: If no tracks configured. + """ + if not self._case.track_configs: + raise ValueError( + f"No tracks configured for comparative case '{self._case.name}'. " + f"Use .for_track() to add at least one track configuration." + ) + return self._case + + @property + def case(self) -> ComparativeCase: + """Access the underlying case for inspection. + + Note: This is primarily for testing. The case may be incomplete + if tracks haven't been configured yet. Use build() to validate + and finalize the case. + + Returns: + The ComparativeCase (may be incomplete). + """ + return self._case diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_comparative_execution.py b/libs/arcade-evals/arcade_evals/_evalsuite/_comparative_execution.py new file mode 100644 index 000000000..a0e69251a --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_comparative_execution.py @@ -0,0 +1,233 @@ +"""Comparative evaluation execution mixin for EvalSuite. + +This module provides the execution logic for comparative evaluations, +allowing the same cases to be run against multiple tool tracks. +""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING, Any + +from arcade_evals._evalsuite._comparative import ComparativeCaseBuilder +from arcade_evals._evalsuite._types import ComparativeCase, EvalRubric + +if TYPE_CHECKING: + from arcade_evals._evalsuite._providers import ProviderName + from arcade_evals._evalsuite._tool_registry import EvalSuiteToolRegistry + from arcade_evals._evalsuite._tracks import TrackManager + + +class _EvalSuiteComparativeMixin: + """Mixin providing comparative evaluation execution methods.""" + + # Type hints for attributes from EvalSuite + name: str + system_message: str + rubric: EvalRubric # EvalSuite always has a rubric (default_factory) + max_concurrent: int + _comparative_case_builders: list[ComparativeCaseBuilder] + _track_manager: TrackManager + _create_eval_case: Any # Method from EvalSuite to create EvalCase + _convert_to_named_expected_tool_call: Any # Method from EvalSuite + _add_none_critics: Any # Method from EvalSuite + _process_tool_calls: Any # Method from EvalSuite + _run_openai: Any # Method from EvalSuite + _run_anthropic: Any # Method from EvalSuite + + def add_comparative_case( + self, + name: str, + user_message: str, + system_message: str | None = None, + additional_messages: list[dict[str, str]] | None = None, + rubric: EvalRubric | None = None, + ) -> ComparativeCaseBuilder: + """Create a comparative case that runs against multiple tool tracks. + + Use .for_track() on the returned builder to configure track-specific + expected tool calls and critics. + + Args: + name: Unique case name. + user_message: User message (shared across all tracks). + system_message: System message (shared, defaults to suite's system_message). + additional_messages: Additional context messages (shared). + rubric: Evaluation rubric (shared, defaults to suite's rubric). + + Returns: + A ComparativeCaseBuilder for fluent track configuration. + + Example: + suite.add_comparative_case( + name="weather_query", + user_message="What's the weather in NYC?", + ).for_track( + "Google Weather", + expected_tool_calls=[ExpectedMCPToolCall("Google_GetWeather", city="NYC")], + critics=[RangeCritic(field="temperature", min_val=0, max_val=100)], + ).for_track( + "OpenWeather", + expected_tool_calls=[ExpectedMCPToolCall("get_current", location="NYC")], + critics=[RangeCritic(field="main.temp", min_val=273, max_val=373)], + ) + """ + builder = ComparativeCaseBuilder( + suite=self, + name=name, + user_message=user_message, + system_message=system_message or self.system_message, + additional_messages=additional_messages, + rubric=rubric or self.rubric, + ) + # Store the builder (validated at execution time to allow fluent configuration) + self._comparative_case_builders.append(builder) + return builder + + async def run_comparative( + self, + client: Any, + model: str, + provider: ProviderName = "openai", + ) -> dict[str, dict[str, Any]]: + """Run comparative cases across all configured tracks. + + Args: + client: The LLM client instance. + model: The model to evaluate. + provider: The provider name. + + Returns: + Dictionary mapping track names to their results. + Each track result contains: + - model: The model name + - suite_name: The suite name + - track_name: The track name + - cases: List of case results + + Example: + results = await suite.run_comparative(client, "gpt-4o") + # results["Google Weather"]["cases"][0] -> first case result + # results["OpenWeather"]["cases"][0] -> same case, different track + """ + if not self._comparative_case_builders: + raise ValueError( + "No comparative cases defined. Use add_comparative_case() to add cases." + ) + + # Build and validate all cases upfront + comparative_cases: list[ComparativeCase] = [] + all_required_tracks: set[str] = set() + for builder in self._comparative_case_builders: + comp_case = builder.build() # Validates that tracks are configured + comparative_cases.append(comp_case) + all_required_tracks.update(comp_case.track_configs.keys()) + + # Validate all required tracks exist upfront (fail fast) + missing_tracks = [t for t in all_required_tracks if not self._track_manager.has_track(t)] + if missing_tracks: + available = self._track_manager.get_track_names() + raise ValueError( + f"Missing track registries: {missing_tracks}. " + f"Available tracks: {available}. " + f"Ensure you registered tools with track=''." + ) + + # Initialize track results structure + track_results: dict[str, dict[str, Any]] = {} + for track_name in all_required_tracks: + track_results[track_name] = { + "model": model, + "suite_name": self.name, + "track_name": track_name, + "rubric": self.rubric, + "cases": [], + } + + # Prepare all async tasks for parallel execution + semaphore = asyncio.Semaphore(self.max_concurrent) + tasks: list[tuple[str, Any]] = [] # (track_name, task) + + for comp_case in comparative_cases: + for track_name, track_config in comp_case.track_configs.items(): + registry = self._track_manager.get_registry(track_name) + # We validated above that all registries exist, so this should never be None + if registry is None: + raise RuntimeError( + f"Registry for '{track_name}' unexpectedly None after validation" + ) + + # Create EvalCase from comparative case + track config + expected_tool_calls = [ + self._convert_to_named_expected_tool_call(tc) + for tc in track_config.expected_tool_calls + ] + critics = self._add_none_critics(expected_tool_calls, track_config.critics or []) + + eval_case = self._create_eval_case( + name=comp_case.name, + system_message=comp_case.system_message, + user_message=comp_case.user_message, + expected_tool_calls=expected_tool_calls, + rubric=comp_case.rubric or self.rubric, + critics=critics, + additional_messages=comp_case.additional_messages, + ) + + # Create task for this case+track combination + async def run_track_case( + _case: Any, # EvalCase + _reg: EvalSuiteToolRegistry, + _t_name: str, + ) -> dict[str, Any]: + async with semaphore: + start = time.time() + print(f" [TASK START] {_case.name} @ {_t_name}", flush=True) + if provider == "anthropic": + predicted_args = await self._run_anthropic( + client, model, _case, registry=_reg + ) + else: + predicted_args = await self._run_openai( + client, model, _case, registry=_reg + ) + elapsed = time.time() - start + print( + f" [TASK DONE] {_case.name} @ {_t_name} ({elapsed:.1f}s)", + flush=True, + ) + + filled_actual_tool_calls = self._process_tool_calls( + predicted_args, registry=_reg + ) + evaluation = _case.evaluate(filled_actual_tool_calls) + + return { + "name": _case.name, + "track": _t_name, + "input": _case.user_message, + "system_message": _case.system_message, + "additional_messages": _case.additional_messages, + "expected_tool_calls": [ + {"name": tc.name, "args": tc.args} + for tc in _case.expected_tool_calls + ], + "predicted_tool_calls": [ + {"name": name, "args": args} + for name, args in filled_actual_tool_calls + ], + "evaluation": evaluation, + } + + task = run_track_case(eval_case, registry, track_name) + tasks.append((track_name, task)) + + # Execute all tasks in parallel (respecting max_concurrent via semaphore) + results = await asyncio.gather(*[task for _, task in tasks]) + + # Organize results by track + for (track_name, _), result in zip(tasks, results): + track_results[track_name]["cases"].append(result) + + return track_results diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_convenience.py b/libs/arcade-evals/arcade_evals/_evalsuite/_convenience.py new file mode 100644 index 000000000..d6ae963a9 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_convenience.py @@ -0,0 +1,265 @@ +"""EvalSuite convenience methods (internal-only). + +This module contains only the functionality introduced in this PR: +- tool registration convenience methods +- unified internal registry plumbing helpers +- track-based tool registration for comparative evaluations + +It is intentionally not exported from `arcade_evals.__init__`. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, Callable + +from arcade_evals._evalsuite._tool_registry import EvalSuiteToolRegistry, MCPToolDefinition +from arcade_evals._evalsuite._tracks import TrackManager +from arcade_evals.loaders import ( + load_arcade_mcp_gateway_async, + load_from_stdio_async, + load_mcp_remote_async, +) + +if TYPE_CHECKING: + from arcade_core import ToolCatalog + + +class _EvalSuiteConvenienceMixin: + """Mixin providing convenience tool registration methods.""" + + _internal_registry: EvalSuiteToolRegistry | None + _track_manager: TrackManager + _python_tool_func_map: dict[str, Callable] + _python_func_to_tool_name: dict[Callable, str] + strict_mode: bool # Attribute from EvalSuite dataclass + + def _get_registry(self, track: str | None = None) -> EvalSuiteToolRegistry: + """Get the registry for a track or the default internal registry. + + Args: + track: Optional track name. If provided, gets or creates the track registry. + If None, uses the default internal registry. + + Returns: + The appropriate EvalSuiteToolRegistry. + + Raises: + RuntimeError: If internal registry not initialized. + """ + if track is not None: + # Get existing track registry or create new one + registry = self._track_manager.get_registry(track) + if registry is None: + # Create new registry for this track + registry = EvalSuiteToolRegistry(strict_mode=self.strict_mode) + self._track_manager.create_track(track, registry) + return registry + + # Default: use internal registry + if self._internal_registry is None: + raise RuntimeError("Internal registry not initialized. This should not happen.") + return self._internal_registry + + def get_tracks(self) -> list[str]: + """Get all registered track names. + + Returns: + List of track names in registration order. + """ + return self._track_manager.get_track_names() + + def add_tool_definitions( + self, + tools: list[MCPToolDefinition], + *, + track: str | None = None, + ) -> Any: + """Add tool definitions directly from MCP-style dictionaries. + + Args: + tools: List of tool definitions. Each must have: + - name (str): Required. The unique tool name. + - description (str): Optional. Defaults to "". + - inputSchema (dict): Optional. JSON Schema for parameters. + Defaults to {"type": "object", "properties": {}}. + track: Optional track name. If provided, tools are added to that track's + isolated registry. Use for comparative evaluations. + + Returns: + Self for method chaining. + + Raises: + TypeError: If a tool definition is not a dictionary. + ValueError: If a tool definition is missing 'name' or the name is already registered. + """ + registry = self._get_registry(track) + for tool in tools: + if not isinstance(tool, dict): + raise TypeError("Tool definitions must be dictionaries") + if "name" not in tool: + raise ValueError("Tool definition must include 'name'") + # Copy to avoid mutating input dict + tool_copy = dict(tool) + tool_copy.setdefault("description", "") + tool_copy.setdefault("inputSchema", {"type": "object", "properties": {}}) + registry.add_tool(tool_copy) + return self + + async def add_mcp_server( + self, + url: str, + *, + headers: dict[str, str] | None = None, + timeout: int = 10, + track: str | None = None, + use_sse: bool = False, + ) -> Any: + """Add tools from an MCP HTTP server. + + Args: + url: The MCP server URL. + headers: Optional HTTP headers. + timeout: Connection timeout in seconds. + track: Optional track name for comparative evaluations. + use_sse: If True, use Server-Sent Events (SSE) transport. + + Returns: + Self for method chaining. + """ + registry = self._get_registry(track) + tools = await load_mcp_remote_async(url, timeout=timeout, headers=headers, use_sse=use_sse) + if not tools: + warnings.warn( + f"No tools loaded from {url}. Server may be unavailable.", + UserWarning, + stacklevel=2, + ) + return self + registry.add_tools(tools) + return self + + async def add_mcp_stdio_server( + self, + command: list[str], + *, + env: dict[str, str] | None = None, + timeout: int = 10, + track: str | None = None, + ) -> Any: + """Add tools from an MCP stdio server. + + Args: + command: Command to start the MCP server. + env: Optional environment variables. + timeout: Connection timeout in seconds. + track: Optional track name for comparative evaluations. + + Returns: + Self for method chaining. + """ + registry = self._get_registry(track) + tools = await load_from_stdio_async(command, timeout=timeout, env=env) + if not tools: + warnings.warn( + f"No tools loaded from stdio command: {' '.join(command)}", + UserWarning, + stacklevel=2, + ) + return self + registry.add_tools(tools) + return self + + async def add_arcade_gateway( + self, + gateway_slug: str, + *, + arcade_api_key: str | None = None, + arcade_user_id: str | None = None, + base_url: str | None = None, + timeout: int = 10, + track: str | None = None, + ) -> Any: + """Add tools from an Arcade MCP gateway. + + Args: + gateway_slug: The Arcade gateway slug. + arcade_api_key: Optional API key. + arcade_user_id: Optional user ID. + base_url: Optional base URL. + timeout: Connection timeout in seconds. + track: Optional track name for comparative evaluations. + + Returns: + Self for method chaining. + """ + registry = self._get_registry(track) + + tools = await load_arcade_mcp_gateway_async( + gateway_slug, + arcade_api_key=arcade_api_key, + arcade_user_id=arcade_user_id, + base_url=base_url, # Let loader handle default/env var + timeout=timeout, + ) + + if not tools: + warnings.warn( + f"No tools loaded from Arcade gateway: {gateway_slug}", + UserWarning, + stacklevel=2, + ) + return self + registry.add_tools(tools) + return self + + def add_tool_catalog( + self, + catalog: ToolCatalog, + *, + track: str | None = None, + ) -> Any: + """Add tools from a ToolCatalog to the internal registry. + + Args: + catalog: A ToolCatalog containing registered Python tools. + track: Optional track name for comparative evaluations. + + Returns: + Self for method chaining. + """ + # Delegate to the shared helper method defined in EvalSuite + self._register_catalog_tools(catalog, track=track) # type: ignore[attr-defined] + return self + + def get_tool_count(self, track: str | None = None) -> int: + """Get the number of registered tools. + + Args: + track: Optional track name. If provided, counts tools in that track. + + Returns: + Number of tools. + """ + if track is not None: + registry = self._track_manager.get_registry(track) + return registry.tool_count() if registry else 0 + if self._internal_registry is None: + return 0 + return self._internal_registry.tool_count() + + def list_tool_names(self, track: str | None = None) -> list[str]: + """List all registered tool names. + + Args: + track: Optional track name. If provided, lists tools in that track. + + Returns: + List of tool names. + """ + if track is not None: + registry = self._track_manager.get_registry(track) + return registry.tool_names() if registry else [] + if self._internal_registry is None: + return [] + return self._internal_registry.tool_names() diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_openai_schema.py b/libs/arcade-evals/arcade_evals/_evalsuite/_openai_schema.py new file mode 100644 index 000000000..4b92cbce2 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_openai_schema.py @@ -0,0 +1,149 @@ +"""OpenAI tool schema conversion (internal). + +Converts MCP-style tool schemas to OpenAI's tool format with strict mode support. + +OpenAI strict mode requirements: +- additionalProperties: false at all object levels +- properties and required present on all object schemas +- required includes ALL properties (optional params use null union types) +- Unsupported JSON Schema keywords are stripped +""" + +from __future__ import annotations + +import copy +from typing import Any + +# Maximum recursion depth to prevent infinite loops in circular schema references +_MAX_SCHEMA_DEPTH = 50 + +# Keywords not supported by OpenAI strict mode that should be stripped +_UNSUPPORTED_STRICT_MODE_KEYWORDS = frozenset({ + "minimum", + "maximum", + "exclusiveMinimum", + "exclusiveMaximum", + "minLength", + "maxLength", + "pattern", + "format", + "default", + "nullable", + "minItems", + "maxItems", + "uniqueItems", + "minProperties", + "maxProperties", +}) + + +class SchemaConversionError(Exception): + """Raised when schema conversion fails.""" + + +def convert_to_strict_mode_schema(parameters: dict[str, Any]) -> dict[str, Any]: + """ + Convert an input JSON schema (MCP `inputSchema`) to OpenAI strict mode format. + + OpenAI strict mode requires: + - additionalProperties: false at all object levels + - properties and required present on all object schemas + - required includes ALL properties + - optional params become union types with null (e.g., ["string", "null"]) + - unsupported JSON Schema keywords stripped + """ + result = copy.deepcopy(parameters) + strict_schema = _apply_strict_mode_recursive(result, depth=0) + return { + "type": "object", + "properties": strict_schema.get("properties", {}), + "required": strict_schema.get("required", []), + "additionalProperties": False, + } + + +def _apply_strict_mode_recursive(schema: dict[str, Any], *, depth: int = 0) -> dict[str, Any]: + if depth > _MAX_SCHEMA_DEPTH: + raise SchemaConversionError( + f"Schema nesting exceeds maximum depth of {_MAX_SCHEMA_DEPTH}. " + "This may indicate a circular reference in the schema." + ) + + # Strip unsupported keywords that OpenAI strict mode doesn't allow + for keyword in _UNSUPPORTED_STRICT_MODE_KEYWORDS: + schema.pop(keyword, None) + + # OpenAI strict mode enum handling: + # 1. OpenAI requires enum values to be strings + # 2. OpenAI validates that enum values match the declared type + # 3. When we convert enum values to strings, we must also change the type to "string" + # + # Example: {"type": "integer", "enum": [0, 1, 2]} becomes {"type": "string", "enum": ["0", "1", "2"]} + # Example: {"type": ["integer", "null"], "enum": [0, 1]} becomes {"type": ["string", "null"], "enum": ["0", "1"]} + # + # Without this fix, OpenAI returns: "enum value 0 does not validate against {'type': ['integer', 'null']}" + if "enum" in schema: + schema["enum"] = [str(v) for v in schema["enum"]] + # Change type to string to match the stringified enum values + current_type = schema.get("type") + if current_type and current_type != "string": + if isinstance(current_type, str): + schema["type"] = "string" + elif isinstance(current_type, list) and "string" not in current_type: + # Replace non-string types with string, preserve null if present + has_null = "null" in current_type + if has_null: + schema["type"] = ["string", "null"] + else: + # Single type without null should be simplified to string + schema["type"] = "string" + + schema_type = schema.get("type") + + if schema_type == "object": + schema["additionalProperties"] = False + schema.setdefault("properties", {}) + + properties = schema.get("properties", {}) + required = set(schema.get("required", [])) + + new_properties: dict[str, Any] = {} + all_param_names: list[str] = [] + + for param_name, param_schema in properties.items(): + if not isinstance(param_schema, dict): + new_properties[param_name] = param_schema + all_param_names.append(param_name) + continue + + processed_schema = _apply_strict_mode_recursive(param_schema, depth=depth + 1) + + # Optional param: add null to type union + if param_name not in required: + param_type = processed_schema.get("type") + if isinstance(param_type, str): + processed_schema["type"] = [param_type, "null"] + elif isinstance(param_type, list) and "null" not in param_type: + processed_schema["type"] = [*param_type, "null"] + + new_properties[param_name] = processed_schema + all_param_names.append(param_name) + + schema["properties"] = new_properties + schema["required"] = all_param_names + + elif schema_type == "array": + items = schema.get("items") + if isinstance(items, dict): + schema["items"] = _apply_strict_mode_recursive(items, depth=depth + 1) + + for combiner in ("anyOf", "oneOf", "allOf"): + if combiner in schema: + schema[combiner] = [ + _apply_strict_mode_recursive(option, depth=depth + 1) + if isinstance(option, dict) + else option + for option in schema[combiner] + ] + + return schema diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_providers.py b/libs/arcade-evals/arcade_evals/_evalsuite/_providers.py new file mode 100644 index 000000000..c5752ebda --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_providers.py @@ -0,0 +1,151 @@ +"""Provider abstractions and message conversion utilities. + +This module contains: +- ProviderName type for supported LLM providers +- Message conversion utilities for different provider formats + +Anthropic has different message format requirements than OpenAI: +- Only "user" and "assistant" roles (system is a separate parameter) +- tool_use/tool_result content blocks instead of tool_calls/tool role +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Literal + +logger = logging.getLogger(__name__) + +# Supported LLM providers for evaluations +ProviderName = Literal["openai", "anthropic"] + + +def convert_messages_to_anthropic(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Convert OpenAI-format messages to Anthropic format. + + Anthropic only supports "user" and "assistant" roles (system is a separate parameter). + + Key differences handled: + - "system" -> skipped (handled separately in Anthropic API) + - "user" -> "user" (pass through) + - "assistant" -> "assistant" (pass through) + - "assistant" with "tool_calls" -> "assistant" with tool_use content blocks + - "tool" -> "user" with tool_result content block + - "function" (legacy) -> "user" with tool_result content block + + Args: + messages: List of OpenAI-format messages + + Returns: + List of Anthropic-format messages + """ + anthropic_messages: list[dict[str, Any]] = [] + + for msg in messages: + role = msg.get("role", "") + + if role == "system": + # Skip system messages - Anthropic API takes system as a separate parameter. + # In _run_anthropic(), we pass system=case.system_message to client.messages.create(). + # This is the correct approach per Anthropic's API design. + continue + + elif role == "user": + # User messages convert directly + content = msg.get("content", "") + if content: + anthropic_messages.append({"role": "user", "content": content}) + + elif role == "assistant": + if "tool_calls" in msg and msg.get("tool_calls"): + # Convert OpenAI tool_calls to Anthropic tool_use blocks + # Anthropic supports mixed content: text blocks + tool_use blocks + content_blocks: list[dict[str, Any]] = [] + + # Include text content if present (assistant can say something before using tools) + text_content = msg.get("content") + if text_content: + content_blocks.append({"type": "text", "text": text_content}) + + # Add tool_use blocks + for tool_call in msg.get("tool_calls", []): + function = tool_call.get("function") + if not function: + continue # Skip malformed tool calls + + # Parse arguments JSON + arguments_str = function.get("arguments", "{}") + try: + arguments = json.loads(arguments_str) if arguments_str else {} + except json.JSONDecodeError as e: + logger.warning( + "Failed to parse tool arguments JSON for '%s': %s. Using empty dict.", + function.get("name", "unknown"), + e, + ) + arguments = {} + + content_blocks.append({ + "type": "tool_use", + "id": tool_call.get("id", ""), + "name": function.get("name", ""), + "input": arguments, + }) + + if content_blocks: + anthropic_messages.append({"role": "assistant", "content": content_blocks}) + else: + # Regular assistant message (no tool calls) + content = msg.get("content", "") + if content: + anthropic_messages.append({"role": "assistant", "content": content}) + + elif role == "tool": + # Convert OpenAI tool response to Anthropic tool_result block + tool_result_block = { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": msg.get("content", ""), + } + # Batch consecutive tool results into the last user message + if anthropic_messages and anthropic_messages[-1]["role"] == "user": + # Add to existing user message's content array + last_content = anthropic_messages[-1]["content"] + if isinstance(last_content, list): + last_content.append(tool_result_block) + else: + # Convert string content to array with both blocks + anthropic_messages[-1]["content"] = [ + {"type": "text", "text": last_content}, + tool_result_block, + ] + else: + # Start new user message with tool result + anthropic_messages.append({"role": "user", "content": [tool_result_block]}) + + elif role == "function": + # Legacy OpenAI function role (deprecated) - same as tool + tool_result_block = { + "type": "tool_result", + "tool_use_id": msg.get("name", ""), # function uses "name" not "tool_call_id" + "content": msg.get("content", ""), + } + # Batch consecutive tool results into the last user message + if anthropic_messages and anthropic_messages[-1]["role"] == "user": + # Add to existing user message's content array + last_content = anthropic_messages[-1]["content"] + if isinstance(last_content, list): + last_content.append(tool_result_block) + else: + # Convert string content to array with both blocks + anthropic_messages[-1]["content"] = [ + {"type": "text", "text": last_content}, + tool_result_block, + ] + else: + # Start new user message with tool result + anthropic_messages.append({"role": "user", "content": [tool_result_block]}) + + return anthropic_messages diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_tool_registry.py b/libs/arcade-evals/arcade_evals/_evalsuite/_tool_registry.py new file mode 100644 index 000000000..fef2b51d4 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_tool_registry.py @@ -0,0 +1,283 @@ +"""EvalSuite internal unified tool registry (not part of the public API).""" + +from __future__ import annotations + +from typing import Any, Literal, TypedDict + +from arcade_core.converters.anthropic import to_anthropic +from arcade_core.converters.utils import normalize_tool_name + +from arcade_evals._evalsuite._anthropic_schema import convert_mcp_to_anthropic_tool +from arcade_evals._evalsuite._openai_schema import convert_to_strict_mode_schema + +ToolFormat = Literal["openai", "anthropic"] + + +class _MCPToolDefinitionRequired(TypedDict): + """Required fields for MCP-style tool definition.""" + + name: str + + +class MCPToolDefinition(_MCPToolDefinitionRequired, total=False): + """MCP-style tool definition structure. + + This is the format expected by `add_tool_definitions()` and used internally + by the EvalSuiteToolRegistry. + + Required: + name: The unique tool name. + + Optional: + description: Human-readable description (defaults to ""). + inputSchema: JSON Schema for input parameters + (defaults to {"type": "object", "properties": {}}). + """ + + description: str + inputSchema: dict[str, Any] + + +class EvalSuiteToolRegistry: + """ + A minimal internal registry that stores tools in MCP-style descriptors: + + { + "name": "...", + "description": "...", + "inputSchema": { ... JSON Schema ... } + } + + EvalSuite converts Python tools into this format too, so there is only one + runtime path for OpenAI tool formatting. + + Note: Tools are stored with their original names (e.g., "Google.Search"), + but Anthropic requires underscores (e.g., "Google_Search"). The registry + maintains a mapping to look up tools by either format. + """ + + def __init__(self, *, strict_mode: bool = True) -> None: + self._tools: dict[str, dict[str, Any]] = {} + self._strict_mode = strict_mode + # Mapping from normalized names (underscores) to original names (dots) + # e.g., {"Google_Search": "Google.Search"} + self._normalized_to_original: dict[str, str] = {} + # Store original MaterializedTool objects for direct Anthropic conversion (Python tools only) + self._materialized_tools: dict[str, Any] = {} + + @property + def strict_mode(self) -> bool: + return self._strict_mode + + @strict_mode.setter + def strict_mode(self, value: bool) -> None: + self._strict_mode = value + + def add_tool( + self, + tool_descriptor: MCPToolDefinition | dict[str, Any], + materialized_tool: Any = None, + ) -> None: + """Add a tool to the registry. + + Args: + tool_descriptor: MCP-style tool definition. + materialized_tool: Optional MaterializedTool for direct Anthropic conversion (Python tools only). + """ + if "name" not in tool_descriptor: + raise ValueError("Tool descriptor must have a 'name' field") + name = tool_descriptor["name"] + if name in self._tools: + raise ValueError( + f"Tool '{name}' already registered. " + "Each tool name must be unique across all sources (MCP servers, gateways, catalogs)." + ) + self._tools[name] = dict(tool_descriptor) + + # Store MaterializedTool if provided (for direct Anthropic conversion) + if materialized_tool is not None: + self._materialized_tools[name] = materialized_tool + + # Build normalized name mapping for Anthropic/OpenAI lookups + # e.g., "Google.Search" -> normalized key "Google_Search" + normalized_name = normalize_tool_name(name) + if normalized_name != name: + # Check for collision: if the normalized name already exists as a direct tool + # (e.g., registering "Google.Search" when "Google_Search" already exists), + # the normalized lookup would be ambiguous + if normalized_name in self._tools: + # The underscore version is registered directly, so normalized lookups + # should prefer that. Don't add to mapping to avoid ambiguity. + pass + elif normalized_name in self._normalized_to_original: + # Another dotted tool already maps to this normalized name + # e.g., "A.B" and "A_B" (as "A.B") would both normalize to "A_B" + # Keep the first mapping to avoid silent overwrites + pass + else: + self._normalized_to_original[normalized_name] = name + + def add_tools(self, tools: list[MCPToolDefinition] | list[dict[str, Any]]) -> None: + for tool in tools: + self.add_tool(tool) + + def list_tools_for_model(self, tool_format: ToolFormat = "openai") -> list[dict[str, Any]]: + if tool_format == "openai": + return self._to_openai_format() + elif tool_format == "anthropic": + return self._to_anthropic_format() + else: + raise ValueError(f"Tool format '{tool_format}' is not supported") + + def _to_openai_format(self) -> list[dict[str, Any]]: + """Convert stored MCP tools to OpenAI function calling format. + + Note: Tool names are normalized (dots replaced with underscores) because + OpenAI function names don't allow dots. + """ + openai_tools: list[dict[str, Any]] = [] + for tool in self._tools.values(): + parameters = tool.get("inputSchema", {"type": "object", "properties": {}}) + if self._strict_mode and isinstance(parameters, dict): + parameters = convert_to_strict_mode_schema(parameters) + + # Normalize tool name for OpenAI (e.g., "Google.Search" -> "Google_Search") + # OpenAI function names don't allow dots + tool_name = normalize_tool_name(tool["name"]) + + openai_tool: dict[str, Any] = { + "type": "function", + "function": { + "name": tool_name, + "description": tool.get("description", ""), + "parameters": parameters, + }, + } + if self._strict_mode: + openai_tool["function"]["strict"] = True + openai_tools.append(openai_tool) + + return openai_tools + + def _to_anthropic_format(self) -> list[dict[str, Any]]: + """Convert stored tools to Anthropic format. + + Uses direct to_anthropic() from arcade-core for Python tools (when MaterializedTool available), + falls back to convert_mcp_to_anthropic_tool() for MCP/remote tools (JSON descriptors only). + """ + anthropic_tools: list[dict[str, Any]] = [] + for tool_name, tool_descriptor in self._tools.items(): + # Python tools: use direct converter (we have MaterializedTool) + if tool_name in self._materialized_tools: + anthropic_tool = to_anthropic(self._materialized_tools[tool_name]) + anthropic_tools.append(dict(anthropic_tool)) + else: + # MCP/remote tools: convert from JSON descriptor (no MaterializedTool available) + # Used for tools from: load_mcp_remote_async(), load_from_stdio_async(), + # load_arcade_mcp_gateway_async(), or add_tool_definitions() + anthropic_tools.append(convert_mcp_to_anthropic_tool(tool_descriptor)) + + return anthropic_tools + + def _resolve_tool_name(self, tool_name: str) -> str | None: + """Resolve a tool name to its original registry key. + + Handles both original names (e.g., "Google.Search") and + normalized names (e.g., "Google_Search" from Anthropic). + + Args: + tool_name: The tool name to resolve. + + Returns: + The original tool name if found, None otherwise. + """ + # First, try direct lookup + if tool_name in self._tools: + return tool_name + # Then, check if it's a normalized name (from Anthropic) + original_name = self._normalized_to_original.get(tool_name) + if original_name and original_name in self._tools: + return original_name + return None + + def normalize_args(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + """Apply schema defaults to arguments. + + Fills in default values from the tool schema for: + - Missing parameters (key not in args) + - Null parameters (value is None), which OpenAI strict mode sends for optional params + + This ensures that optional parameters with defaults are properly filled + even when the model sends null values. + """ + resolved_name = self._resolve_tool_name(tool_name) + tool = self._tools.get(resolved_name) if resolved_name else None + if not tool: + return args + + schema = tool.get("inputSchema", {}) + if not isinstance(schema, dict): + return args + + properties = schema.get("properties", {}) + if not isinstance(properties, dict): + return args + + normalized = dict(args) + for prop_name, prop_schema in properties.items(): + # Apply default if parameter is missing OR if it's null (None) + # OpenAI strict mode sends null for optional parameters that weren't provided + should_apply_default = ( + isinstance(prop_schema, dict) + and "default" in prop_schema + and (prop_name not in normalized or normalized[prop_name] is None) + ) + if should_apply_default: + normalized[prop_name] = prop_schema["default"] + return normalized + + def get_tool_schema(self, tool_name: str) -> dict[str, Any] | None: + resolved_name = self._resolve_tool_name(tool_name) + return self._tools.get(resolved_name) if resolved_name else None + + def has_tool(self, tool_name: str) -> bool: + return self._resolve_tool_name(tool_name) is not None + + def resolve_tool_name(self, tool_name: str) -> str | None: + """Public method to resolve a tool name to its original registry key. + + This is useful for callers that need to look up tools by names + returned from providers (e.g., Anthropic returns underscore names). + + Args: + tool_name: The tool name to resolve. + + Returns: + The original tool name if found, None otherwise. + """ + return self._resolve_tool_name(tool_name) + + def process_tool_call(self, tool_name: str, args: dict[str, Any]) -> tuple[str, dict[str, Any]]: + """Resolve tool name and apply schema defaults in one step. + + This combines name resolution (for Anthropic underscore -> dot conversion) + with schema default application. + + Args: + tool_name: The tool name (may be in provider format like "Google_Search"). + args: The arguments from the tool call. + + Returns: + Tuple of (resolved_name, args_with_defaults). + resolved_name will be the original registered name (e.g., "Google.Search") + or the input name if not found in registry. + """ + resolved_name = self._resolve_tool_name(tool_name) or tool_name + args_with_defaults = self.normalize_args(tool_name, args) + return resolved_name, args_with_defaults + + def tool_names(self) -> list[str]: + return list(self._tools.keys()) + + def tool_count(self) -> int: + return len(self._tools) diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_tracks.py b/libs/arcade-evals/arcade_evals/_evalsuite/_tracks.py new file mode 100644 index 000000000..c6db0fc99 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_tracks.py @@ -0,0 +1,97 @@ +"""Track management for comparative evaluations. + +A track represents an isolated tool registry with a unique name. +This enables running the same evaluation cases against different +tool sources (e.g., different MCP servers) for comparison. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from arcade_evals._evalsuite._tool_registry import EvalSuiteToolRegistry + + +class TrackManager: + """Manages named tracks, each with its own isolated tool registry. + + Tracks enable comparative evaluations where the same cases are run + against different tool sources. + + Example: + manager = TrackManager() + manager.create_track("Google Weather", registry1) + manager.create_track("OpenWeather", registry2) + + for track_name in manager.get_track_names(): + registry = manager.get_registry(track_name) + # Run cases against this registry + """ + + def __init__(self) -> None: + self._tracks: dict[str, EvalSuiteToolRegistry] = {} + + def create_track(self, name: str, registry: EvalSuiteToolRegistry) -> str: + """Create a new track with an isolated registry. + + Args: + name: Unique track name. + registry: The tool registry for this track. + + Returns: + The track name (for use as track ID). + + Raises: + ValueError: If track name already exists. + """ + if name in self._tracks: + raise ValueError(f"Track '{name}' already exists. Use a unique track name.") + self._tracks[name] = registry + return name + + def get_registry(self, track_name: str) -> EvalSuiteToolRegistry | None: + """Get the registry for a track. + + Args: + track_name: The track name. + + Returns: + The registry if found, None otherwise. + """ + return self._tracks.get(track_name) + + def get_track_names(self) -> list[str]: + """Get all registered track names. + + Returns: + List of track names in registration order. + """ + return list(self._tracks.keys()) + + def has_track(self, name: str) -> bool: + """Check if a track exists. + + Args: + name: The track name. + + Returns: + True if track exists, False otherwise. + """ + return name in self._tracks + + def track_count(self) -> int: + """Get number of registered tracks. + + Returns: + Number of tracks. + """ + return len(self._tracks) + + def get_all_registries(self) -> dict[str, EvalSuiteToolRegistry]: + """Get all registries by track name. + + Returns: + Dictionary mapping track names to registries. + """ + return dict(self._tracks) diff --git a/libs/arcade-evals/arcade_evals/_evalsuite/_types.py b/libs/arcade-evals/arcade_evals/_evalsuite/_types.py new file mode 100644 index 000000000..a43063d52 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/_evalsuite/_types.py @@ -0,0 +1,176 @@ +"""Shared types for eval suite modules. + +This module contains dataclasses and types that are shared between +eval.py and the _evalsuite submodules, avoiding circular imports. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from arcade_evals.critic import Critic + + +@dataclass +class ExpectedToolCall: + """ + Represents an expected tool call for a Python tool (registered via ToolCatalog). + + Use this for Python functions decorated with @tool. + + Attributes: + func: The Python function itself. + args: A dictionary containing the expected arguments for the tool. + + Example: + ExpectedToolCall(func=my_tool_function, args={"x": 1, "y": 2}) + ExpectedToolCall(my_tool_function, {"x": 1}) # Positional args supported + """ + + func: Callable + args: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ExpectedMCPToolCall: + """ + Represents an expected tool call identified by tool name (string). + + Use this for: + - Tools loaded from MCP servers (local stdio or remote HTTP) + - Tools loaded from Arcade Gateways + - Manual tool definitions (dictionaries with name/description/inputSchema) + + Attributes: + tool_name: The name of the tool (e.g., "Weather_GetCurrent"). + args: A dictionary containing the expected arguments for the tool. + + Example: + ExpectedMCPToolCall(tool_name="Calculator_Add", args={"a": 5, "b": 3}) + ExpectedMCPToolCall("Calculator_Add", {"a": 5}) # Positional args supported + """ + + tool_name: str + args: dict[str, Any] = field(default_factory=dict) + + +# Type alias for mixed usage (Python tools + MCP tools in same test case) +AnyExpectedToolCall = ExpectedToolCall | ExpectedMCPToolCall + + +@dataclass +class NamedExpectedToolCall: + """ + Represents a tool call with its name and arguments. + + Attributes: + name: The name of the tool. + args: A dictionary containing the expected arguments for the tool. + """ + + name: str + args: dict[str, Any] + + +@dataclass +class EvalRubric: + """ + Defines the rubric for evaluating an AI model's performance on a task. + + Attributes: + fail_threshold: The minimum score required to pass the evaluation (between 0.0 and 1.0). + warn_threshold: The score threshold for issuing a warning (between 0.0 and 1.0). + fail_on_tool_selection: Whether to fail the evaluation if the tool selection is incorrect. + fail_on_tool_call_quantity: Whether to fail the evaluation if the number of tool calls is incorrect. + tool_selection_weight: The weight assigned to the tool selection score (between 0.0 and 1.0). + """ + + fail_threshold: float = 0.8 + warn_threshold: float = 0.9 + fail_on_tool_selection: bool = True + fail_on_tool_call_quantity: bool = True + tool_selection_weight: float = 1.0 + + def __str__(self) -> str: + """Return a complete string representation of the rubric configuration.""" + return ( + f"EvalRubric(fail_threshold={self.fail_threshold}, " + f"warn_threshold={self.warn_threshold}, " + f"fail_on_tool_selection={self.fail_on_tool_selection}, " + f"fail_on_tool_call_quantity={self.fail_on_tool_call_quantity}, " + f"tool_selection_weight={self.tool_selection_weight})" + ) + + def __repr__(self) -> str: + """Return the same string representation for repr.""" + return self.__str__() + + +@dataclass +class TrackConfig: + """Configuration for a single track within a comparative case. + + Attributes: + expected_tool_calls: Expected tool calls for this track. + critics: Critics to evaluate tool arguments for this track. + """ + + expected_tool_calls: list[ExpectedToolCall | ExpectedMCPToolCall] + critics: list[Critic] = field(default_factory=list) + + +@dataclass +class ComparativeCase: + """A case that runs against multiple tracks for comparison. + + Shared context (messages) is defined once, while each track has + its own expected tool calls and critics. + + Attributes: + name: Unique case name. + user_message: User message (shared across tracks). + system_message: System message (shared across tracks). + additional_messages: Additional context messages (shared). + rubric: Evaluation rubric (shared, can be overridden per track). + track_configs: Track-specific configurations. + """ + + name: str + user_message: str + system_message: str = "" + additional_messages: list[dict[str, str]] = field(default_factory=list) + rubric: EvalRubric | None = None + track_configs: dict[str, TrackConfig] = field(default_factory=dict) + + def add_track_config( + self, + track_name: str, + expected_tool_calls: list[ExpectedToolCall | ExpectedMCPToolCall], + critics: list[Critic] | None = None, + ) -> None: + """Add configuration for a track. + + Args: + track_name: The track name. + expected_tool_calls: Expected tool calls for this track. + critics: Critics for this track. + + Raises: + ValueError: If track already configured. + """ + if track_name in self.track_configs: + raise ValueError(f"Track '{track_name}' already configured for case '{self.name}'.") + self.track_configs[track_name] = TrackConfig( + expected_tool_calls=expected_tool_calls, + critics=critics or [], + ) + + def get_configured_tracks(self) -> list[str]: + """Get list of tracks configured for this case. + + Returns: + List of track names. + """ + return list(self.track_configs.keys()) diff --git a/libs/arcade-evals/arcade_evals/capture.py b/libs/arcade-evals/arcade_evals/capture.py new file mode 100644 index 000000000..d5ad4aeb9 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/capture.py @@ -0,0 +1,186 @@ +""" +Capture mode for EvalSuite. + +Capture mode runs evaluation cases and records tool calls from the model +without scoring or evaluating them. This is useful for: +- Generating expected tool calls for new test cases +- Debugging model behavior +- Creating baseline recordings +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from openai import AsyncOpenAI + +if TYPE_CHECKING: + from arcade_evals.eval import EvalSuite + + +@dataclass +class CapturedToolCall: + """ + A captured tool call from the model during capture mode. + + Attributes: + name: The name of the tool that was called. + args: The arguments passed to the tool. + """ + + name: str + args: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return {"name": self.name, "args": self.args} + + +@dataclass +class CapturedCase: + """ + Result of running a single case in capture mode. + + Attributes: + case_name: The name of the evaluation case. + user_message: The user message that triggered the tool calls. + tool_calls: List of tool calls made by the model. + system_message: The system message (included if include_context is True). + additional_messages: Additional messages (included if include_context is True). + track_name: The track name for comparative captures (None for regular cases). + """ + + case_name: str + user_message: str + tool_calls: list[CapturedToolCall] = field(default_factory=list) + system_message: str | None = None + additional_messages: list[dict[str, Any]] | None = None + track_name: str | None = None + + @staticmethod + def _try_parse_json(value: str) -> Any: + """Try to parse a JSON string, returning the original string if parsing fails.""" + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + @staticmethod + def _normalize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Normalize additional_messages by parsing JSON strings into proper objects. + + OpenAI returns: + - Tool call arguments as JSON strings in assistant messages + - Tool response content as JSON strings in tool messages + + For cleaner output, we parse these into proper objects. + """ + normalized = [] + for msg in messages: + msg_copy = dict(msg) + + # Parse tool call arguments in assistant messages + if "tool_calls" in msg_copy and isinstance(msg_copy["tool_calls"], list): + normalized_tool_calls = [] + for tc in msg_copy["tool_calls"]: + tc_copy = dict(tc) + if "function" in tc_copy and isinstance(tc_copy["function"], dict): + func = dict(tc_copy["function"]) + if "arguments" in func and isinstance(func["arguments"], str): + func["arguments"] = CapturedCase._try_parse_json(func["arguments"]) + tc_copy["function"] = func + normalized_tool_calls.append(tc_copy) + msg_copy["tool_calls"] = normalized_tool_calls + + # Parse content in tool response messages + if msg_copy.get("role") == "tool" and isinstance(msg_copy.get("content"), str): + msg_copy["content"] = CapturedCase._try_parse_json(msg_copy["content"]) + + normalized.append(msg_copy) + return normalized + + def to_dict(self, include_context: bool = False) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: dict[str, Any] = { + "case_name": self.case_name, + "user_message": self.user_message, + "tool_calls": [tc.to_dict() for tc in self.tool_calls], + } + if self.track_name: + result["track_name"] = self.track_name + if include_context: + result["system_message"] = self.system_message + # Normalize additional_messages to parse JSON string arguments + raw_messages = self.additional_messages or [] + result["additional_messages"] = self._normalize_messages(raw_messages) + return result + + +@dataclass +class CaptureResult: + """ + Result of running an EvalSuite in capture mode. + + Attributes: + suite_name: The name of the evaluation suite. + model: The model used for capture. + provider: The provider used (openai, anthropic). + captured_cases: List of captured cases with tool calls. + """ + + suite_name: str + model: str + provider: str + captured_cases: list[CapturedCase] = field(default_factory=list) + + def to_dict(self, include_context: bool = False) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "suite_name": self.suite_name, + "model": self.model, + "provider": self.provider, + "captured_cases": [c.to_dict(include_context) for c in self.captured_cases], + } + + def to_json(self, include_context: bool = False, indent: int = 2) -> str: + """Convert to JSON string.""" + return json.dumps(self.to_dict(include_context), indent=indent) + + def write_to_file(self, file_path: str, include_context: bool = False, indent: int = 2) -> None: + """Write capture results to a JSON file.""" + with open(file_path, "w") as f: + f.write(self.to_json(include_context, indent)) + + +# --- Helper functions for running capture mode --- + + +async def _capture_with_openai( + suite: EvalSuite, api_key: str, model: str, include_context: bool = False +) -> CaptureResult: + """Run capture mode with OpenAI client.""" + async with AsyncOpenAI(api_key=api_key) as client: + return await suite.capture( + client, model, provider="openai", include_context=include_context + ) + + +async def _capture_with_anthropic( + suite: EvalSuite, api_key: str, model: str, include_context: bool = False +) -> CaptureResult: + """Run capture mode with Anthropic client.""" + try: + from anthropic import AsyncAnthropic + except ImportError as e: + raise ImportError( + "The 'anthropic' package is required for Anthropic provider. " + "Install it with: pip install anthropic" + ) from e + + async with AsyncAnthropic(api_key=api_key) as client: + return await suite.capture( + client, model, provider="anthropic", include_context=include_context + ) diff --git a/libs/arcade-evals/arcade_evals/critic.py b/libs/arcade-evals/arcade_evals/critic.py index b4f5c1d64..7f78952e1 100644 --- a/libs/arcade-evals/arcade_evals/critic.py +++ b/libs/arcade-evals/arcade_evals/critic.py @@ -7,16 +7,34 @@ from dateutil import parser from arcade_evals.errors import WeightError +from arcade_evals.weights import FuzzyWeight, Weight, resolve_weight @dataclass class Critic(ABC): + """ + Base class for all critics. + + Attributes: + critic_field: The field name this critic evaluates. + weight: The weight for this critic. Can be a float (0.0-1.0) or FuzzyWeight enum. + When using FuzzyWeight, weights are auto-normalized to sum to 1.0. + """ + critic_field: str - weight: float + weight: Weight def __post_init__(self) -> None: - if self.weight < 0 or self.weight > 1: - raise WeightError(f"Critic weight must be between 0 and 1, got {self.weight}") + if isinstance(self.weight, FuzzyWeight): + return + + if self.weight < 0: + raise WeightError(f"Critic weight must be non-negative, got {self.weight}") + + @property + def resolved_weight(self) -> float: + """Get the weight as a float value.""" + return resolve_weight(self.weight) @abstractmethod def evaluate(self, expected: Any, actual: Any) -> dict[str, Any]: @@ -32,6 +50,10 @@ class NoneCritic(Critic): a NoneCritic is used to indicate that the field was not criticized. """ + # Marker attribute to identify placeholder critics without isinstance checks + # (avoids circular imports in weights.py) + _is_placeholder: ClassVar[bool] = True + weight: float = 0.0 def __post_init__(self) -> None: @@ -108,7 +130,7 @@ def evaluate(self, expected: Any, actual: Any) -> dict[str, float | bool]: actual_casted = actual match = expected == actual_casted - return {"match": match, "score": self.weight if match else 0.0} + return {"match": match, "score": self.resolved_weight if match else 0.0} @dataclass @@ -158,7 +180,10 @@ def evaluate(self, expected: Any, actual: Any) -> dict[str, Any]: normalized_expected = float((float(expected) - min_val) / (max_val - min_val)) normalized_actual = float((float(actual) - min_val) / (max_val - min_val)) score = float(1 - abs(normalized_expected - normalized_actual)) - return {"match": bool(score >= self.match_threshold), "score": float(score * self.weight)} + return { + "match": bool(score >= self.match_threshold), + "score": float(score * self.resolved_weight), + } @dataclass @@ -207,7 +232,23 @@ def __init__( self.similarity_threshold = similarity_threshold self.metric = metric - def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]: + def evaluate(self, expected: Any, actual: Any) -> dict[str, float | bool]: + # IMPORTANT: Convert non-string values to strings before TF-IDF comparison. + # sklearn's TfidfVectorizer calls .lower() on inputs, which fails on lists/dicts. + # This commonly occurs when SimilarityCritic is used for tool arguments that are + # lists (e.g., teams_to_add=["Engineering", "Platform"]) instead of strings. + # Lists are joined with spaces to create comparable text representations. + if not isinstance(expected, str): + expected = ( + " ".join(str(item) for item in expected) + if isinstance(expected, list) + else str(expected) + ) + if not isinstance(actual, str): + actual = ( + " ".join(str(item) for item in actual) if isinstance(actual, list) else str(actual) + ) + if self.metric == "cosine": try: from sklearn.feature_extraction.text import TfidfVectorizer @@ -216,14 +257,35 @@ def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]: raise ImportError( "Use `pip install 'arcade-evals` to install the required dependencies for similarity metrics." ) - vectorizer = TfidfVectorizer() - tfidf_matrix = vectorizer.fit_transform([expected, actual]) - similarity = cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0] + + # Handle edge case: empty strings or strings with no valid tokens + # TfidfVectorizer fails with "empty vocabulary" for such inputs + if not expected.strip() or not actual.strip(): + # Both empty = match, one empty = no match + is_match = expected.strip() == actual.strip() + return { + "match": is_match, + "score": self.resolved_weight if is_match else 0.0, + } + + try: + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform([expected, actual]) + similarity = float(cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]) + except ValueError: + # TfidfVectorizer raises ValueError for empty vocabulary + # (e.g., only numbers/punctuation which get filtered as stop words) + # Fall back to exact string match + is_match = expected == actual + return { + "match": is_match, + "score": self.resolved_weight if is_match else 0.0, + } else: raise ValueError(f"Unsupported similarity metric: {self.metric}") return { "match": similarity >= self.similarity_threshold, - "score": min(similarity * self.weight, self.weight), + "score": min(similarity * self.resolved_weight, self.resolved_weight), } @@ -278,7 +340,7 @@ def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]: if time_diff_seconds <= tolerance_seconds: # Full score if within tolerance - return {"match": True, "score": self.weight} + return {"match": True, "score": self.resolved_weight} elif time_diff_seconds >= max_difference_seconds: # No score if beyond max_difference return {"match": False, "score": 0.0} @@ -287,5 +349,5 @@ def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]: ratio = 1 - (time_diff_seconds / max_difference_seconds) # Ensure ratio is not negative ratio = max(ratio, 0) - score = self.weight * ratio + score = self.resolved_weight * ratio return {"match": False, "score": score} diff --git a/libs/arcade-evals/arcade_evals/eval.py b/libs/arcade-evals/arcade_evals/eval.py index 02bc00d00..27d926a69 100644 --- a/libs/arcade-evals/arcade_evals/eval.py +++ b/libs/arcade-evals/arcade_evals/eval.py @@ -2,6 +2,7 @@ import functools import inspect import json +import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable @@ -11,64 +12,46 @@ from openai import AsyncOpenAI from scipy.optimize import linear_sum_assignment +from arcade_evals._evalsuite._capture import _EvalSuiteCaptureMixin +from arcade_evals._evalsuite._comparative_execution import _EvalSuiteComparativeMixin +from arcade_evals._evalsuite._convenience import _EvalSuiteConvenienceMixin +from arcade_evals._evalsuite._providers import ( + ProviderName, + convert_messages_to_anthropic, +) +from arcade_evals._evalsuite._tool_registry import EvalSuiteToolRegistry +from arcade_evals._evalsuite._tracks import TrackManager + +# Import shared types from _types module (breaks circular dependencies) +from arcade_evals._evalsuite._types import ( + AnyExpectedToolCall, + EvalRubric, + ExpectedMCPToolCall, + ExpectedToolCall, + NamedExpectedToolCall, +) from arcade_evals.critic import NoneCritic -from arcade_evals.errors import WeightError +from arcade_evals.weights import validate_and_normalize_critic_weights if TYPE_CHECKING: from arcade_core import ToolCatalog + from arcade_evals._evalsuite._comparative import ComparativeCaseBuilder from arcade_evals.critic import Critic +logger = logging.getLogger(__name__) -@dataclass -class ExpectedToolCall: - """ - Represents an expected tool call with the function itself and arguments. - - Attributes: - func: The function itself. - args: A dictionary containing the expected arguments for the tool. - """ - - func: Callable - args: dict[str, Any] - - -@dataclass -class NamedExpectedToolCall: - """ - Represents a tool call with its name and arguments. - - Attributes: - name: The name of the tool. - args: A dictionary containing the expected arguments for the tool. - """ - - name: str - args: dict[str, Any] - - -@dataclass -class EvalRubric: - """ - Defines the rubric for evaluating an AI model's performance on a task. - - Attributes: - fail_threshold: The minimum score required to pass the evaluation (between 0.0 and 1.0). - warn_threshold: The score threshold for issuing a warning (between 0.0 and 1.0). - fail_on_tool_selection: Whether to fail the evaluation if the tool selection is incorrect. - fail_on_tool_call_quantity: Whether to fail the evaluation if the number of tool calls is incorrect. - tool_selection_weight: The weight assigned to the tool selection score (between 0.0 and 1.0). - """ - - fail_threshold: float = 0.8 - warn_threshold: float = 0.9 - fail_on_tool_selection: bool = True - fail_on_tool_call_quantity: bool = True - tool_selection_weight: float = 1.0 - - def __str__(self) -> str: - return f"Fail threshold: {self.fail_threshold}\nWarn threshold: {self.warn_threshold}\n" +# Re-export for backwards compatibility (these are now defined in _types.py) +__all__ = [ + "AnyExpectedToolCall", + "EvalCase", + "EvalRubric", + "EvalSuite", + "EvaluationResult", + "ExpectedMCPToolCall", + "ExpectedToolCall", + "NamedExpectedToolCall", +] @dataclass @@ -93,8 +76,14 @@ class EvaluationResult: @property def fail(self) -> bool: + """Returns True if the evaluation failed (excluding warnings).""" return not self.passed and not self.warning + @property + def warn(self) -> bool: + """Returns True if the evaluation is in warning state.""" + return self.warning + def add( self, field: str, @@ -151,6 +140,13 @@ def compute_final_score(self, total_weight: float) -> None: self.score = total_score / total_weight if total_weight > 0 else 0.0 +# Import capture mode helpers (defined in capture.py to keep this file focused) +from arcade_evals.capture import ( # noqa: E402 + _capture_with_anthropic, + _capture_with_openai, +) + + @dataclass class EvalCase: """ @@ -176,29 +172,11 @@ class EvalCase: def __post_init__(self) -> None: if self.critics is not None: - self._validate_critics() + validate_and_normalize_critic_weights(self.critics) else: # if no critics are provided, set to empty list self.critics = [] - def _validate_critics(self) -> None: - """ - Validate the sum of critic weights. - - Raises: - WeightError: If the sum of critic weights exceeds 1.0. - """ - if not self.critics: - return - - total_weight = sum(critic.weight for critic in self.critics) - if total_weight > 1.0: - raise WeightError(f"Sum of critic weights must not exceed 1.0, got {total_weight}") - - for critic in self.critics: - if critic.weight < 0.1 and not isinstance(critic, NoneCritic): - raise WeightError(f"Critic weights should be at least 0.1, got {critic.weight}") - def check_tool_selection_failure(self, actual_tools: list[str]) -> bool: """ Check if tool selection failure should occur. @@ -306,21 +284,25 @@ def evaluate( try: result = critic.evaluate(expected_value, actual_value) total_score += result["score"] - total_weight += critic.weight + total_weight += critic.resolved_weight evaluation_result.add( critic.critic_field, result, - critic.weight, + critic.resolved_weight, expected_value, actual_value, ) except Exception as e: - # TODO: log or console - print(f"Critic evaluation failed for field '{critic.critic_field}': {e}") + logger.warning( + "Critic evaluation failed for field '%s': %s", + critic.critic_field, + e, + exc_info=True, + ) evaluation_result.add( critic.critic_field, {"match": False, "score": 0.0}, - critic.weight, + critic.resolved_weight, expected_value, actual_value, ) @@ -378,8 +360,10 @@ def _create_cost_matrix( result = critic.evaluate(expected_value, actual_value) score += result.get("score", 0.0) except Exception as e: - print( - f"Critic evaluation failed for field '{critic.critic_field}': {e}" + logger.warning( + "Critic evaluation failed for field '%s': %s", + critic.critic_field, + e, ) cost_matrix[i, j] = score @@ -387,7 +371,7 @@ def _create_cost_matrix( @dataclass -class EvalSuite: +class EvalSuite(_EvalSuiteCaptureMixin, _EvalSuiteConvenienceMixin, _EvalSuiteComparativeMixin): """ A suite for evaluating AI model performance on specific tasks or scenarios. @@ -397,46 +381,166 @@ class EvalSuite: Attributes: name: The name of the evaluation suite. system_message: The system message to be used for all cases in this suite. - catalog: A ToolCatalog object containing registered tools. + catalog: A ToolCatalog containing registered Python tools. cases: A list of EvalCase objects representing individual test scenarios. rubric: The evaluation rubric for this case. max_concurrent: Maximum number of concurrent evaluations. + strict_mode: Whether to enable strict-mode schema conversion for MCP-style tools. """ name: str system_message: str - catalog: "ToolCatalog" + catalog: "ToolCatalog | None" = None cases: list[EvalCase] = field(default_factory=list) rubric: EvalRubric = field(default_factory=EvalRubric) max_concurrent: int = 1 + strict_mode: bool = True + + # Internal unified registry for MCP-style tools added via convenience methods. + _internal_registry: EvalSuiteToolRegistry | None = field(default=None, init=False, repr=False) + + # Track manager for comparative evaluations (isolated registries per track). + _track_manager: TrackManager = field(default_factory=TrackManager, init=False, repr=False) + + # Comparative case builders for multi-track evaluations (validated at execution time). + _comparative_case_builders: list["ComparativeCaseBuilder"] = field( + default_factory=list, init=False, repr=False + ) + + # Python tool helpers (used when Python tools are added via add_tool_catalog()). + _python_tool_func_map: dict[str, Callable] = field(default_factory=dict, init=False, repr=False) + _python_func_to_tool_name: dict[Callable, str] = field( + default_factory=dict, init=False, repr=False + ) + + def __post_init__(self) -> None: + """Initialize internal registry and auto-convert catalog if provided.""" + # Always create the internal registry + self._internal_registry = EvalSuiteToolRegistry(strict_mode=self.strict_mode) + + # If catalog was passed, convert those tools to the internal registry + if self.catalog is not None: + self._register_catalog_tools(self.catalog) + + def _register_catalog_tools(self, catalog: "ToolCatalog", *, track: str | None = None) -> None: + """Convert and register tools from a ToolCatalog to the internal registry. + + This helper is used by both __post_init__ (for catalog= parameter) and + add_tool_catalog() (for post-init registration). + + Args: + catalog: The ToolCatalog to register. + track: Optional track name for comparative evaluations. + """ + registry = self._get_registry(track) + + # Convert Python tools from ToolCatalog and store in unified registry format. + # We use to_openai() to extract the normalized tool schema, then pass the + # original MaterializedTool to the registry. This allows: + # - OpenAI: Uses the extracted MCP-style schema (stored in registry) + # - Anthropic: Uses direct to_anthropic() converter (via stored MaterializedTool) + # This avoids double-conversion overhead while maintaining unified storage. + for tool in catalog: + # Use OpenAI converter to get the tool name and base schema + openai_tool = to_openai(tool) + func_schema = openai_tool.get("function", {}) + tool_name = func_schema.get("name") + if not tool_name: + continue + + description = func_schema.get("description") or "" + parameters = func_schema.get("parameters") or {"type": "object", "properties": {}} + registry.add_tool( + { + "name": tool_name, + "description": description, + "inputSchema": dict(parameters), + }, + materialized_tool=tool, # Pass for direct Anthropic conversion + ) + + # Keep track of Python function for defaults + python_func = getattr(tool, "tool", None) + if callable(python_func): + self._python_tool_func_map[tool_name] = python_func + self._python_func_to_tool_name[python_func] = tool_name def _convert_to_named_expected_tool_call( - self, tc: ExpectedToolCall | tuple[Callable, dict[str, Any]] + self, tc: AnyExpectedToolCall | tuple[Callable, dict[str, Any]] ) -> NamedExpectedToolCall: """ - Convert an ExpectedToolCall or a tuple to a NamedExpectedToolCall + Convert an ExpectedToolCall, ExpectedMCPToolCall, or tuple to a NamedExpectedToolCall with default arguments populated. Args: - tc: The tool call, either as an ExpectedToolCall or a tuple. + tc: The tool call - ExpectedToolCall (Python), ExpectedMCPToolCall (MCP), or tuple. Returns: A NamedExpectedToolCall instance. """ + # Handle MCP tools (ExpectedMCPToolCall) + if isinstance(tc, ExpectedMCPToolCall): + return self._convert_mcp_tool_call(tc.tool_name, tc.args) + + # Handle Python tools (ExpectedToolCall or tuple) if isinstance(tc, tuple): func, args = tc else: + # ExpectedToolCall func = tc.func args = tc.args + args_with_defaults = self._fill_args_with_defaults(func, args) - tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()) + # Try convenience method registration first, then fall back to catalog + tool_name = self._python_func_to_tool_name.get(func) + if not tool_name: + if self.catalog is not None: + tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()) + else: + raise ValueError( + "Python tool callables require ToolCatalog or add_tool_catalog() registration." + ) return NamedExpectedToolCall(name=tool_name, args=args_with_defaults) + def _convert_mcp_tool_call(self, tool_name: str, args: dict[str, Any]) -> NamedExpectedToolCall: + """Convert an MCP tool reference to a NamedExpectedToolCall (NEW in this PR).""" + args_with_defaults = dict(args) + # Apply schema defaults from internal registry + if self._internal_registry is not None and self._internal_registry.has_tool(tool_name): + args_with_defaults = self._internal_registry.normalize_args( + tool_name, args_with_defaults + ) + return NamedExpectedToolCall(name=tool_name, args=args_with_defaults) + + def _create_eval_case( + self, + name: str, + system_message: str, + user_message: str, + expected_tool_calls: list[NamedExpectedToolCall], + rubric: EvalRubric, + critics: list["Critic"], + additional_messages: list[dict[str, str]], + ) -> "EvalCase": + """Factory method to create EvalCase instances. + + Used by the comparative mixin to create EvalCase without circular imports. + """ + return EvalCase( + name=name, + system_message=system_message, + user_message=user_message, + expected_tool_calls=expected_tool_calls, + rubric=rubric, + critics=critics, + additional_messages=additional_messages, + ) + def add_case( self, name: str, user_message: str, - expected_tool_calls: list[ExpectedToolCall] | list[tuple[Callable, dict[str, Any]]], + expected_tool_calls: list[AnyExpectedToolCall] | list[tuple[Callable, dict[str, Any]]], critics: list["Critic"] | None = None, system_message: str | None = None, rubric: EvalRubric | None = None, @@ -448,7 +552,7 @@ def add_case( Args: name: The name of the evaluation case. user_message: The user's input message. - expected_tool_calls: A list of expected tool calls as ExpectedToolCall instances. + expected_tool_calls: A list of expected tool calls (ExpectedToolCall, ExpectedMCPToolCall, or tuples). critics: List of critics to evaluate the tool arguments. system_message: The system message to be used. rubric: The evaluation rubric for this case. @@ -606,50 +710,83 @@ def extend_case( ) self.cases.append(new_case) - async def run(self, client: AsyncOpenAI, model: str) -> dict[str, Any]: + def _process_tool_calls( + self, + tool_calls: list[tuple[str, dict[str, Any]]], + registry: EvalSuiteToolRegistry | None = None, + ) -> list[tuple[str, dict[str, Any]]]: + """ + Process tool calls by resolving names and applying defaults. + + Args: + tool_calls: List of (tool_name, args) tuples. + registry: Optional registry to use. If None, uses _internal_registry. + + Returns: + List of processed (tool_name, args_with_defaults) tuples. + """ + effective_registry = registry or self._internal_registry + if effective_registry is None: + return tool_calls + + processed_calls = [] + for tool_name, args in tool_calls: + # Resolve name and apply schema defaults (handles Anthropic "Google_Search" -> "Google.Search") + resolved_name, args_with_defaults = effective_registry.process_tool_call( + tool_name, args + ) + + # Apply Python function defaults if available + if resolved_name in self._python_tool_func_map: + args_with_defaults = self._fill_args_with_defaults( + self._python_tool_func_map[resolved_name], args_with_defaults + ) + + processed_calls.append((resolved_name, args_with_defaults)) + return processed_calls + + async def run( + self, + client: Any, # AsyncOpenAI | AsyncAnthropic - use Any to avoid import dependency + model: str, + provider: ProviderName = "openai", + ) -> dict[str, Any]: """ Run the evaluation suite. Args: - client: The AsyncOpenAI client instance. + client: The LLM client instance (AsyncOpenAI or AsyncAnthropic). model: The model to evaluate. + provider: The provider name ("openai" or "anthropic"). + Returns: A dictionary containing the evaluation results. """ - results: dict[str, Any] = {"model": model, "rubric": self.rubric, "cases": []} + results: dict[str, Any] = { + "model": model, + "suite_name": self.name, + "rubric": self.rubric, + "cases": [], + } semaphore = asyncio.Semaphore(self.max_concurrent) async def sem_task(case: EvalCase) -> dict[str, Any]: async with semaphore: - # Prepare messages - messages = [{"role": "system", "content": case.system_message}] - messages.extend(case.additional_messages) - messages.append({"role": "user", "content": case.user_message}) - - tools = get_formatted_tools(self.catalog, tool_format="openai") - - # Get the model response - response = await client.chat.completions.create( # type: ignore[call-overload] - model=model, - messages=messages, - tool_choice="auto", - tools=tools, - user="eval_user", - seed=42, - stream=False, - ) + # All tools are in internal registry (unified container) + if self._internal_registry is None or self._internal_registry.tool_count() == 0: + raise ValueError( + "No tools registered. Use add_* convenience methods or pass catalog=ToolCatalog." + ) + + # Get tool calls based on provider + if provider == "anthropic": + predicted_args = await self._run_anthropic(client, model, case) + else: + predicted_args = await self._run_openai(client, model, case) - # Extract and fill default arguments for actual tool calls - predicted_args = get_tool_args(response) - filled_actual_tool_calls = [] - for tool_name, args in predicted_args: - tool = self.catalog.get_tool_by_name(tool_name) - if tool is None: - raise ValueError(f"Tool '{tool_name}' not found in catalog.") - func = tool.tool - args_with_defaults = self._fill_args_with_defaults(func, args) - filled_actual_tool_calls.append((tool_name, args_with_defaults)) + # Process tool calls (resolve names, fill defaults) + filled_actual_tool_calls = self._process_tool_calls(predicted_args) # Evaluate the case evaluation = case.evaluate(filled_actual_tool_calls) @@ -658,6 +795,8 @@ async def sem_task(case: EvalCase) -> dict[str, Any]: result = { "name": case.name, "input": case.user_message, + "system_message": case.system_message, + "additional_messages": case.additional_messages, "expected_tool_calls": [ {"name": tc.name, "args": tc.args} for tc in case.expected_tool_calls ], @@ -674,6 +813,93 @@ async def sem_task(case: EvalCase) -> dict[str, Any]: results["cases"] = case_results return results + async def _run_openai( + self, + client: AsyncOpenAI, + model: str, + case: "EvalCase", + registry: EvalSuiteToolRegistry | None = None, + ) -> list[tuple[str, dict[str, Any]]]: + """Run evaluation using OpenAI client. + + Args: + client: The OpenAI client. + model: The model name. + case: The evaluation case. + registry: Optional registry to use. If None, uses _internal_registry. + + Returns: + List of tool calls. + """ + effective_registry = registry or self._internal_registry + if effective_registry is None: + raise RuntimeError("No registry available") + + # Prepare messages + messages: list[dict[str, Any]] = [{"role": "system", "content": case.system_message}] + messages.extend(case.additional_messages) + messages.append({"role": "user", "content": case.user_message}) + + tools = effective_registry.list_tools_for_model(tool_format="openai") + + # Get the model response + response = await client.chat.completions.create( # type: ignore[arg-type] + model=model, + messages=messages, + tool_choice="auto", + tools=tools, + user="eval_user", + seed=42, + stream=False, + ) + + return get_tool_args(response, normalize_names=False) + + async def _run_anthropic( + self, + client: Any, # AsyncAnthropic + model: str, + case: "EvalCase", + registry: EvalSuiteToolRegistry | None = None, + ) -> list[tuple[str, dict[str, Any]]]: + """Run evaluation using Anthropic client. + + Args: + client: The Anthropic client. + model: The model name. + case: The evaluation case. + registry: Optional registry to use. If None, uses _internal_registry. + + Returns: + List of tool calls. + """ + effective_registry = registry or self._internal_registry + if effective_registry is None: + raise RuntimeError("No registry available") + + # Convert OpenAI-format messages to Anthropic format + anthropic_messages = convert_messages_to_anthropic(case.additional_messages) + anthropic_messages.append({"role": "user", "content": case.user_message}) + + tools = effective_registry.list_tools_for_model(tool_format="anthropic") + + # Get the model response + response = await client.messages.create( + model=model, + max_tokens=4096, + system=case.system_message, + messages=anthropic_messages, + tools=tools, + ) + + # Extract tool calls from Anthropic response + tool_calls: list[tuple[str, dict[str, Any]]] = [] + for block in response.content: + if block.type == "tool_use": + tool_calls.append((block.name, block.input)) + + return tool_calls + def get_formatted_tools(catalog: "ToolCatalog", tool_format: str = "openai") -> OpenAIToolList: """Get the formatted tools from the catalog. @@ -692,12 +918,16 @@ def get_formatted_tools(catalog: "ToolCatalog", tool_format: str = "openai") -> raise ValueError(f"Tool format for '{tool_format}' is not supported") -def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]: +def get_tool_args( + chat_completion: Any, normalize_names: bool = True +) -> list[tuple[str, dict[str, Any]]]: """ Returns the tool arguments from the chat completion object. Args: chat_completion: The chat completion object. + normalize_names: Whether to normalize tool names (convert _ to .). + Set to False for MCP tools that use underscores. Returns: A list of tuples containing the tool name and arguments. @@ -706,8 +936,11 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]: message = chat_completion.choices[0].message if message.tool_calls: for tool_call in message.tool_calls: + tool_name = tool_call.function.name + if normalize_names: + tool_name = normalize_name(tool_name) tool_args_list.append(( - normalize_name(tool_call.function.name), + tool_name, json.loads(tool_call.function.arguments), )) return tool_args_list @@ -749,20 +982,110 @@ async def wrapper( provider_api_key: str, model: str, max_concurrency: int = 1, - ) -> list[dict[str, Any]]: - suite = func() + provider: ProviderName = "openai", + capture_mode: bool = False, + include_context: bool = False, + ) -> list[Any]: + """ + Run evaluation or capture mode. + + Returns: + In evaluation mode: list[dict[str, Any]] with evaluation results. + In capture mode: list[CaptureResult] with captured tool calls. + """ + # Support both sync and async suite creation functions + import asyncio + import inspect + + if inspect.iscoroutinefunction(func): + suite = await func() + else: + result = func() + # Handle case where sync func returns a coroutine + if asyncio.iscoroutine(result): + suite = await result + else: + suite = result + if not isinstance(suite, EvalSuite): raise TypeError("Eval function must return an EvalSuite") suite.max_concurrent = max_concurrency - results = [] - async with AsyncOpenAI( - api_key=provider_api_key, - ) as client: - result = await suite.run(client, model) - results.append(result) - return results + + if capture_mode: + # Run in capture mode + if provider == "anthropic": + capture_result = await _capture_with_anthropic( + suite, provider_api_key, model, include_context + ) + else: + capture_result = await _capture_with_openai( + suite, provider_api_key, model, include_context + ) + return [capture_result] + else: + # Run in evaluation mode + if provider == "anthropic": + eval_result = await _run_with_anthropic(suite, provider_api_key, model) + else: + eval_result = await _run_with_openai(suite, provider_api_key, model) + + # For comparative evaluations, eval_result is already a list of track results + # For regular evaluations, it's a single dict that needs wrapping + if isinstance(eval_result, list): + return eval_result + return [eval_result] wrapper.__tool_eval__ = True # type: ignore[attr-defined] return wrapper return decorator + + +async def _run_with_openai( + suite: "EvalSuite", api_key: str, model: str +) -> dict[str, Any] | list[dict[str, Any]]: + """Run evaluation suite with OpenAI client. + + Returns: + For regular evaluations: A single result dict. + For comparative evaluations: A list of result dicts (one per track). + """ + async with AsyncOpenAI(api_key=api_key) as client: + # Check if this suite has comparative cases + if suite._comparative_case_builders: + # Run comparative evaluation - returns dict[track_name, result] + track_results = await suite.run_comparative(client, model, provider="openai") + # Convert to list of results for consistent handling + return list(track_results.values()) + else: + # Run regular evaluation + return await suite.run(client, model, provider="openai") + + +async def _run_with_anthropic( + suite: "EvalSuite", api_key: str, model: str +) -> dict[str, Any] | list[dict[str, Any]]: + """Run evaluation suite with Anthropic client. + + Returns: + For regular evaluations: A single result dict. + For comparative evaluations: A list of result dicts (one per track). + """ + try: + from anthropic import AsyncAnthropic + except ImportError as e: + raise ImportError( + "The 'anthropic' package is required for Anthropic provider. " + "Install it with: pip install anthropic" + ) from e + + async with AsyncAnthropic(api_key=api_key) as client: + # Check if this suite has comparative cases + if suite._comparative_case_builders: + # Run comparative evaluation - returns dict[track_name, result] + track_results = await suite.run_comparative(client, model, provider="anthropic") + # Convert to list of results for consistent handling + return list(track_results.values()) + else: + # Run regular evaluation + return await suite.run(client, model, provider="anthropic") diff --git a/libs/arcade-evals/arcade_evals/loaders.py b/libs/arcade-evals/arcade_evals/loaders.py new file mode 100644 index 000000000..882023d39 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/loaders.py @@ -0,0 +1,440 @@ +""" +MCP Server Tool Loaders. + +Public API (async-only): +- `load_from_stdio_async` +- `load_mcp_remote_async` +- `load_arcade_mcp_gateway_async` +- `load_stdio_arcade_async` + +Requires the MCP SDK: pip install mcp +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Any +from urllib.parse import urlsplit, urlunsplit + +logger = logging.getLogger(__name__) + + +class MCPSessionFilter(logging.Filter): + """Filter to suppress/rewrite misleading MCP SDK session termination messages. + + The MCP SDK logs "Session termination failed: 202" when sessions close gracefully. + HTTP 202 (Accepted) is the correct response for MCP notifications per spec, + not an error. This filter suppresses the misleading error message. + """ + + def filter(self, record: logging.LogRecord) -> bool: + """Return False to suppress log record, True to allow it.""" + message = record.getMessage() + + # Suppress the misleading "Session termination failed: 202" message + # HTTP 202 is the correct response for MCP session close notifications + is_termination_message = "Session termination failed" in message + has_202_code = "202" in message + + return not (is_termination_message and has_202_code) + + +# Apply filter to MCP SDK loggers to suppress misleading session messages +def _configure_mcp_logging() -> None: + """Configure MCP SDK logging to suppress misleading messages.""" + mcp_loggers = [ + "mcp", + "mcp.client", + "mcp.client.session", + "mcp.client.sse", + "mcp.client.stdio", + "mcp.client.streamable_http", + ] + + session_filter = MCPSessionFilter() + for logger_name in mcp_loggers: + mcp_logger = logging.getLogger(logger_name) + mcp_logger.addFilter(session_filter) + + +# Configure MCP logging on module import +_configure_mcp_logging() + +# ============================================================================= +# CONFIGURATION CONSTANTS +# ============================================================================= + +# Default Arcade API base URL (production) +ARCADE_API_BASE_URL = "https://api.arcade.dev" + +# ============================================================================= +# TOOL CACHE - Prevents redundant connections to the same MCP source +# Uses asyncio locks to prevent concurrent loads to the same MCP source +# ============================================================================= + +# Cache for loaded tools: key is (url, headers_hash), value is list of tools +_tools_cache: dict[str, list[dict[str, Any]]] = {} + +# Per-key asyncio locks to prevent concurrent loads to same source +_cache_locks: dict[str, asyncio.Lock] = {} + + +def _make_cache_key(url: str, headers: dict[str, str] | None) -> str: + """Create a cache key from URL and headers.""" + headers_str = str(sorted((headers or {}).items())) + return f"{url}|{headers_str}" + + +# Lock acquisition timeout (seconds) - prevents indefinite hangs +LOCK_TIMEOUT_SECONDS = 60 + + +async def _get_cache_lock(cache_key: str) -> asyncio.Lock: + """Get or create an asyncio lock for the given cache key.""" + if cache_key not in _cache_locks: + _cache_locks[cache_key] = asyncio.Lock() + return _cache_locks[cache_key] + + +async def _acquire_lock_with_timeout( + lock: asyncio.Lock, timeout: float = LOCK_TIMEOUT_SECONDS +) -> bool: + """Acquire a lock with timeout. Returns True if acquired, False on timeout.""" + try: + await asyncio.wait_for(lock.acquire(), timeout=timeout) + except asyncio.TimeoutError: + return False + else: + return True + + +def clear_tools_cache() -> None: + """Clear the tools cache. Useful for testing or forcing fresh connections.""" + _tools_cache.clear() + _cache_locks.clear() + + +def _get_arcade_base_url() -> str: + """Get the Arcade API base URL, checking env var at runtime.""" + return os.environ.get("ARCADE_API_BASE_URL", ARCADE_API_BASE_URL) + + +# ============================================================================= +# MCP SDK IMPORT +# ============================================================================= + + +def _require_mcp() -> tuple[Any, Any, Any, Any, Any]: + """ + Import MCP SDK with a helpful error message. + + Returns: + (ClientSession, StdioServerParameters, stdio_client, sse_client, streamablehttp_client) + """ + try: + import mcp + import mcp.client.sse as mcp_client_sse + import mcp.client.stdio as mcp_client_stdio + import mcp.client.streamable_http as mcp_client_http + + ClientSession = mcp.ClientSession + StdioServerParameters = mcp.StdioServerParameters + stdio_client = mcp_client_stdio.stdio_client + sse_client = mcp_client_sse.sse_client + streamablehttp_client = mcp_client_http.streamablehttp_client + + except ImportError as e: + raise ImportError( + "MCP SDK is required for arcade-evals. " + "Install with: pip install 'arcade-mcp[evals]' or pip install mcp" + ) from e + + return ClientSession, StdioServerParameters, stdio_client, sse_client, streamablehttp_client + + +# ============================================================================= +# UTILITIES +# ============================================================================= + + +def _tool_to_dict(tool: Any) -> dict[str, Any]: + """Convert an MCP Tool object to the MCP-style dict format used by EvalSuite.""" + return { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.inputSchema, + } + + +def _ensure_mcp_path(url: str) -> str: + """Ensure the URL path ends with '/mcp' (without duplicating). + + Preserves query strings and fragments. + """ + parts = urlsplit(url) + path = (parts.path or "").rstrip("/") + + # If any path segment is already "mcp" (e.g. "/mcp" or "/mcp/{slug}" or "/foo/mcp"), + # treat it as already pointing at an MCP endpoint. + segments = [seg for seg in path.split("/") if seg] + if "mcp" in segments: + normalized_path = "/" + "/".join(segments) if segments else "" + return urlunsplit(( + parts.scheme, + parts.netloc, + normalized_path, + parts.query, + parts.fragment, + )) + + new_path = (f"{path}/mcp" if path else "/mcp") if path != "" else "/mcp" + return urlunsplit((parts.scheme, parts.netloc, new_path, parts.query, parts.fragment)) + + +def _build_arcade_mcp_url(gateway_slug: str | None, base_url: str) -> str: + """Build the Arcade MCP gateway URL.""" + base = base_url.rstrip("/") + if gateway_slug: + return f"{base}/mcp/{gateway_slug}" + return f"{base}/mcp" + + +# ============================================================================= +# PUBLIC API (async-only) +# ============================================================================= + + +async def load_from_stdio_async( + command: list[str], + *, + env: dict[str, str] | None = None, + timeout: int = 10, +) -> list[dict[str, Any]]: + """ + Load tools from an MCP server via stdio. + + Results are cached by command to avoid starting multiple subprocesses + for the same server. Concurrent requests for the same command will wait + for the first request to complete and share the result. + + Args: + command: Command to run the MCP server (e.g., ["python", "server.py"]). + env: Additional environment variables to pass to the server. + timeout: Timeout in seconds (not used by MCP SDK, kept for API compatibility). + + Returns: + List of tool definitions in MCP format. + """ + if not command: + return [] + + del timeout # MCP SDK manages timeouts internally + + cache_key = f"stdio|{' '.join(command)}|{sorted((env or {}).items())!s}" + + # Fast path: check cache without lock (no locking overhead for cache hits) + if cache_key in _tools_cache: + logger.debug(f"Using cached tools for stdio: {command[0]}") + return _tools_cache[cache_key].copy() + + # Cache miss - acquire lock and check again (double-checked locking) + lock = await _get_cache_lock(cache_key) + if not await _acquire_lock_with_timeout(lock): + raise TimeoutError(f"Timeout waiting for lock on stdio: {command[0]}") + + try: + # Re-check cache (another request may have populated it while we waited) + if cache_key in _tools_cache: + logger.debug(f"Using cached tools for stdio: {command[0]}") + return _tools_cache[cache_key].copy() + + ClientSession, StdioServerParameters, stdio_client, _, _ = _require_mcp() + + process_env = os.environ.copy() + if env: + process_env.update(env) + + server_params = StdioServerParameters( + command=command[0], + args=command[1:] if len(command) > 1 else [], + env=process_env, + ) + async with ( + stdio_client(server_params) as (read, write), + ClientSession(read, write) as session, + ): + await session.initialize() + result = await session.list_tools() + tools = [_tool_to_dict(tool) for tool in result.tools] + + # Cache the result + _tools_cache[cache_key] = tools.copy() + return tools + finally: + lock.release() + + +async def load_mcp_remote_async( + url: str, + *, + headers: dict[str, str] | None = None, + timeout: int = 10, + use_sse: bool = False, +) -> list[dict[str, Any]]: + """ + Load tools from a remote MCP server via URL (HTTP or SSE transport). + + Results are cached to avoid redundant connections when multiple models + load the same MCP source. Concurrent requests for the same URL will wait + for the first request to complete and share the result. + + Args: + url: URL of the MCP server. + headers: Additional headers to send with the request. + timeout: Timeout in seconds (not used by MCP SDK, kept for API compatibility). + use_sse: Whether to use SSE transport. If False, uses streamable-http. + + Returns: + List of tool definitions in MCP format. + """ + del timeout # MCP SDK manages timeout internally + + url = _ensure_mcp_path(url) + cache_key = _make_cache_key(url, headers) + + # Fast path: check cache without lock (no locking overhead for cache hits) + if cache_key in _tools_cache: + logger.debug(f"Using cached tools for {url}") + return _tools_cache[cache_key].copy() + + # Cache miss - acquire lock and check again (double-checked locking) + lock = await _get_cache_lock(cache_key) + if not await _acquire_lock_with_timeout(lock): + raise TimeoutError(f"Timeout waiting for lock on HTTP: {url}") + + try: + # Re-check cache (another request may have populated it while we waited) + if cache_key in _tools_cache: + logger.debug(f"Using cached tools for {url}") + return _tools_cache[cache_key].copy() + + # Load MCP SDK (deferred import) + ClientSession, _, _, sse_client, streamablehttp_client = _require_mcp() + + # Load from MCP server + tools: list[dict[str, Any]] = [] + + if use_sse: + async with ( + sse_client(url, headers=headers) as (read, write), + ClientSession(read, write) as session, + ): + await session.initialize() + result = await session.list_tools() + tools = [_tool_to_dict(tool) for tool in result.tools] + else: + async with ( + streamablehttp_client(url, headers=headers) as (read, write, _), + ClientSession(read, write) as session, + ): + await session.initialize() + result = await session.list_tools() + tools = [_tool_to_dict(tool) for tool in result.tools] + + # Cache the result + _tools_cache[cache_key] = tools.copy() + return tools + finally: + lock.release() + + +async def load_arcade_mcp_gateway_async( + gateway_slug: str | None = None, + *, + arcade_api_key: str | None = None, + arcade_user_id: str | None = None, + base_url: str | None = None, + timeout: int = 10, +) -> list[dict[str, Any]]: + """ + Load tools from an Arcade MCP gateway. + + Args: + gateway_slug: Optional gateway slug (if None, connects to base MCP endpoint). + arcade_api_key: Arcade API key (defaults to ARCADE_API_KEY env var). + arcade_user_id: Arcade user ID (defaults to ARCADE_USER_ID env var). + base_url: Arcade API base URL (defaults to ARCADE_API_BASE_URL env var or production). + timeout: Timeout in seconds (not used by MCP SDK, kept for API compatibility). + + Returns: + List of tool definitions in MCP format (deduplicated by name). + """ + api_key = arcade_api_key or os.environ.get("ARCADE_API_KEY") + user_id = arcade_user_id or os.environ.get("ARCADE_USER_ID") + + headers: dict[str, str] = {} + if api_key: + # Arcade Gateway expects "Bearer " format + if api_key.startswith("Bearer "): + headers["Authorization"] = api_key + else: + headers["Authorization"] = f"Bearer {api_key}" + if user_id: + # Note: Header is "Arcade-User-Id" (not "Arcade-User-ID") + headers["Arcade-User-Id"] = user_id + + # Use provided base_url or check env var at runtime + effective_base_url = base_url or _get_arcade_base_url() + url = _build_arcade_mcp_url(gateway_slug, effective_base_url) + tools = await load_mcp_remote_async(url, headers=headers, timeout=timeout) + + # Deduplicate tools by name (gateway may return duplicates) + seen: dict[str, dict[str, Any]] = {} + for tool in tools: + name = tool.get("name") + if name and name not in seen: + seen[name] = tool + return list(seen.values()) + + +async def load_stdio_arcade_async( + command: list[str], + *, + arcade_api_key: str | None = None, + arcade_user_id: str | None = None, + tool_secrets: dict[str, str] | None = None, + timeout: int = 10, +) -> list[dict[str, Any]]: + """ + Load tools from an Arcade MCP server via stdio. + + Convenience wrapper that sets Arcade env vars and delegates to `load_from_stdio_async`. + + Args: + command: Command to run the MCP server (e.g., ["python", "server.py"]). + arcade_api_key: Arcade API key (defaults to ARCADE_API_KEY env var). + arcade_user_id: Arcade user ID (defaults to ARCADE_USER_ID env var). + tool_secrets: Additional secrets to pass as environment variables. + timeout: Timeout in seconds. + + Returns: + List of tool definitions in MCP format. + """ + env: dict[str, str] = {} + + if arcade_api_key: + env["ARCADE_API_KEY"] = arcade_api_key + elif "ARCADE_API_KEY" in os.environ: + env["ARCADE_API_KEY"] = os.environ["ARCADE_API_KEY"] + + if arcade_user_id: + env["ARCADE_USER_ID"] = arcade_user_id + elif "ARCADE_USER_ID" in os.environ: + env["ARCADE_USER_ID"] = os.environ["ARCADE_USER_ID"] + + if tool_secrets: + env.update(tool_secrets) + + return await load_from_stdio_async(command, timeout=timeout, env=env if env else None) diff --git a/libs/arcade-evals/arcade_evals/weights.py b/libs/arcade-evals/arcade_evals/weights.py new file mode 100644 index 000000000..33bb979d8 --- /dev/null +++ b/libs/arcade-evals/arcade_evals/weights.py @@ -0,0 +1,221 @@ +""" +Weight definitions and normalization for arcade-evals. + +This module contains: +- FuzzyWeight enum for qualitative weight assignment +- Weight type alias (float | FuzzyWeight) +- Normalization functions for critic weights +- Validation utilities for weight constraints +""" + +from enum import Enum +from typing import TYPE_CHECKING + +from arcade_evals.errors import WeightError + +if TYPE_CHECKING: + from arcade_evals.critic import Critic + + +def _is_placeholder_critic(critic: "Critic") -> bool: + """ + Check if a critic is a placeholder (like NoneCritic). + + Uses duck typing via the _is_placeholder class attribute to avoid + circular imports between weights.py and critic.py. + """ + return getattr(critic, "_is_placeholder", False) + + +class FuzzyWeight(Enum): + """ + Qualitative weight buckets for critic importance. + + Instead of manually calculating float weights, use these qualitative + buckets to express relative importance. Weights are auto-normalized + using Softmax-inspired scaling. + + Example: + >>> critics = [ + ... BinaryCritic(critic_field="owner", weight=FuzzyWeight.HIGH), + ... BinaryCritic(critic_field="state", weight=FuzzyWeight.LOW), + ... ] + # HIGH (5) gets 62.5% weight, LOW (3) gets 37.5% weight + + Weight Buckets (linear scale, uniform increment of 1): + - MINIMAL: 1 - Almost negligible, rarely affects outcome + - VERY_LOW: 2 - Rarely important, edge case checking + - LOW: 3 - Minor importance + - MEDIUM: 4 - Standard importance (default) + - HIGH: 5 - Important parameter + - VERY_HIGH: 6 - Critical, must-match parameter + - CRITICAL: 7 - Absolutely essential, highest priority + """ + + MINIMAL = 1 + VERY_LOW = 2 + LOW = 3 + MEDIUM = 4 + HIGH = 5 + VERY_HIGH = 6 + CRITICAL = 7 + + +# Type alias for weight parameter +Weight = float | FuzzyWeight + + +def normalize_fuzzy_weights(critics: list["Critic"]) -> list[float]: + """ + Normalize a list of critic weights to sum to 1.0. + + Uses Softmax-inspired normalization: each weight is divided by the + sum of all weights, ensuring: + 1. All weights sum to exactly 1.0 + 2. Relative proportions are preserved + + Args: + critics: List of critics with weight attributes. + Weights can be float or FuzzyWeight. + + Returns: + List of normalized float weights in the same order as input critics. + + Example: + >>> from arcade_evals.critic import BinaryCritic + >>> critics = [ + ... BinaryCritic("a", FuzzyWeight.HIGH), + ... BinaryCritic("b", FuzzyWeight.LOW), + ... ] + >>> normalize_fuzzy_weights(critics) + [0.625, 0.375] # HIGH (5) / (5 + 3), LOW (3) / (5 + 3) + """ + if not critics: + return [] + + # Extract raw weight values (convert FuzzyWeight to float) + raw_weights: list[float] = [] + for critic in critics: + if isinstance(critic.weight, FuzzyWeight): + raw_weights.append(float(critic.weight.value)) + else: + raw_weights.append(float(critic.weight)) + + # Calculate total for normalization + total = sum(raw_weights) + if total <= 0: + # Edge case: all weights are zero or negative + # Return zeros to indicate no scoring should occur + return [0.0] * len(critics) + + # Normalize weights (simple division by sum) + return [w / total for w in raw_weights] + + +def resolve_weight(weight: Weight) -> float: + """ + Resolve a Weight value to a float. + + Used when a single weight needs to be resolved without full normalization. + + Args: + weight: Either a float or FuzzyWeight enum. + + Returns: + Float weight value. + """ + if isinstance(weight, FuzzyWeight): + return weight.value + return float(weight) + + +# ============================================================================= +# Critic Weight Validation and Normalization +# ============================================================================= + + +def validate_and_normalize_critic_weights(critics: list["Critic"]) -> None: + """ + Validate and normalize critic weights in-place. + + If any critic uses FuzzyWeight, all weights are normalized using + Softmax-inspired scaling to sum to 1.0. Otherwise, validates that + all float weights are non-negative. + + This function modifies critics in-place, setting their `weight` attribute + to the normalized float value. The original weight is preserved in + `_original_weight` for FuzzyWeight critics. + + Args: + critics: List of critics to validate and normalize. + + Raises: + WeightError: If any float weight is negative. + + Example: + >>> critics = [ + ... BinaryCritic(critic_field="a", weight=FuzzyWeight.HIGH), + ... BinaryCritic(critic_field="b", weight=FuzzyWeight.LOW), + ... ] + >>> validate_and_normalize_critic_weights(critics) + >>> critics[0].weight # Now normalized float + 0.625 + """ + if not critics: + return + + # Check if any critic uses FuzzyWeight + has_fuzzy = any(isinstance(c.weight, FuzzyWeight) for c in critics) + + if has_fuzzy: + _normalize_fuzzy_critic_weights(critics) + else: + _validate_float_critic_weights(critics) + + +def _normalize_fuzzy_critic_weights(critics: list["Critic"]) -> None: + """ + Normalize critic weights when FuzzyWeight is used. + + Filters out placeholder critics (like NoneCritic, which always has weight=0) + and normalizes the remaining critics' weights to sum to 1.0. + + Args: + critics: List of critics to normalize (modified in-place). + """ + # Filter out placeholder critics for normalization (they keep weight=0) + non_placeholder_critics = [c for c in critics if not _is_placeholder_critic(c)] + + if not non_placeholder_critics: + return + + normalized = normalize_fuzzy_weights(non_placeholder_critics) + + for critic, norm_weight in zip(non_placeholder_critics, normalized): + # Store original weight for reference + critic._original_weight = critic.weight # type: ignore[attr-defined] + # Set normalized weight for evaluation + critic.weight = norm_weight + + +def _validate_float_critic_weights(critics: list["Critic"]) -> None: + """ + Validate that all float critic weights are non-negative. + + This is the legacy validation path used when no FuzzyWeight is present. + Float weights are allowed to be any non-negative value; normalization + happens implicitly through the scoring calculation. + + Args: + critics: List of critics to validate. + + Raises: + WeightError: If any weight is negative. + """ + for critic in critics: + if _is_placeholder_critic(critic): + continue + + weight = resolve_weight(critic.weight) + if weight < 0: + raise WeightError(f"Critic weight must be non-negative, got {weight}") diff --git a/libs/tests/arcade_evals/__init__.py b/libs/tests/arcade_evals/__init__.py new file mode 100644 index 000000000..7bd405114 --- /dev/null +++ b/libs/tests/arcade_evals/__init__.py @@ -0,0 +1 @@ +"""Make arcade_evals tests a package to avoid pytest module name collisions.""" diff --git a/libs/tests/arcade_evals/test_capture_execution.py b/libs/tests/arcade_evals/test_capture_execution.py new file mode 100644 index 000000000..7bd820655 --- /dev/null +++ b/libs/tests/arcade_evals/test_capture_execution.py @@ -0,0 +1,105 @@ +"""Tests for capture mode execution.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from arcade_evals import EvalSuite + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + + +class TestCaptureMode: + """Tests for EvalSuite.capture() method.""" + + @pytest.mark.asyncio + async def test_capture_records_tool_calls_without_scoring(self) -> None: + """Test that capture mode records tool calls without evaluation.""" + suite = EvalSuite(name="test", system_message="test") + suite.add_tool_definitions([ + {"name": "search", "description": "Search", "inputSchema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"] + }} + ]) + suite.add_case(name="test case", user_message="search for cats", expected_tool_calls=[]) + + mock_client = AsyncMock() + mock_tool_call = MagicMock() + mock_tool_call.id = "call_123" + mock_tool_call.function.name = "search" + mock_tool_call.function.arguments = '{"query": "cats"}' + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.tool_calls = [mock_tool_call] + mock_client.chat.completions.create.return_value = mock_response + + result = await suite.capture(mock_client, "gpt-4o", provider="openai") + + # Should return CaptureResult with captured_cases + assert result.suite_name == "test" + assert result.model == "gpt-4o" + assert result.provider == "openai" + assert len(result.captured_cases) == 1 + + captured = result.captured_cases[0] + # Should have recorded the tool call + assert len(captured.tool_calls) == 1 + assert captured.tool_calls[0].name == "search" + assert captured.tool_calls[0].args == {"query": "cats"} + + @pytest.mark.asyncio + async def test_capture_raises_without_tools(self) -> None: + """Test that capture mode raises error when no tools registered.""" + suite = EvalSuite(name="test", system_message="test") + suite.add_case(name="test", user_message="test", expected_tool_calls=[]) + + mock_client = AsyncMock() + + with pytest.raises(ValueError, match="No tools registered"): + await suite.capture(mock_client, "gpt-4o", provider="openai") + + @pytest.mark.asyncio + async def test_capture_works_with_anthropic_provider(self) -> None: + """Test capture mode works with Anthropic provider.""" + suite = EvalSuite(name="test", system_message="test") + suite.add_tool_definitions([{"name": "search", "description": "Search", "inputSchema": {}}]) + suite.add_case(name="test", user_message="test", expected_tool_calls=[]) + + mock_client = AsyncMock() + mock_tool_block = MagicMock() + mock_tool_block.type = "tool_use" + mock_tool_block.name = "search" + mock_tool_block.input = {"query": "test"} + + mock_response = MagicMock() + mock_response.content = [mock_tool_block] + mock_client.messages.create.return_value = mock_response + + result = await suite.capture(mock_client, "claude-3", provider="anthropic") + + assert len(result.captured_cases) == 1 + assert len(result.captured_cases[0].tool_calls) == 1 + + @pytest.mark.asyncio + async def test_capture_respects_max_concurrent(self) -> None: + """Test that capture mode respects max_concurrent setting.""" + suite = EvalSuite(name="test", system_message="test", max_concurrent=2) + suite.add_tool_definitions([{"name": "tool1", "description": "Test", "inputSchema": {}}]) + + # Add 3 cases + for i in range(3): + suite.add_case(name=f"case{i}", user_message=f"test{i}", expected_tool_calls=[]) + + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.tool_calls = None + mock_client.chat.completions.create.return_value = mock_response + + result = await suite.capture(mock_client, "gpt-4o", provider="openai") + + # All 3 cases should be captured + assert len(result.captured_cases) == 3 diff --git a/libs/tests/arcade_evals/test_comparative.py b/libs/tests/arcade_evals/test_comparative.py new file mode 100644 index 000000000..6fa297b4e --- /dev/null +++ b/libs/tests/arcade_evals/test_comparative.py @@ -0,0 +1,594 @@ +"""Tests for comparative evaluation cases.""" + +import pytest +from arcade_evals import EvalSuite, ExpectedMCPToolCall +from arcade_evals._evalsuite._comparative import ComparativeCaseBuilder +from arcade_evals._evalsuite._types import ( + ComparativeCase, + ExpectedToolCall, + TrackConfig, +) + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + + +class TestTrackConfig: + """Tests for TrackConfig dataclass.""" + + def test_create_track_config(self) -> None: + """Test creating a TrackConfig.""" + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ + ExpectedMCPToolCall("TestTool", args={"arg1": "value1"}) + ] + config = TrackConfig(expected_tool_calls=expected) + + assert config.expected_tool_calls == expected + assert config.critics == [] + + def test_create_track_config_with_critics(self) -> None: + """Test creating a TrackConfig with critics.""" + from arcade_evals.critic import Critic, SimilarityCritic + + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ + ExpectedMCPToolCall("TestTool", args={"arg1": "value1"}) + ] + critics: list[Critic] = [SimilarityCritic(critic_field="arg1", weight=1.0)] + config = TrackConfig(expected_tool_calls=expected, critics=critics) + + assert config.expected_tool_calls == expected + assert config.critics == critics + + +class TestComparativeCase: + """Tests for ComparativeCase dataclass.""" + + def test_create_comparative_case(self) -> None: + """Test creating a ComparativeCase.""" + case = ComparativeCase( + name="test_case", + user_message="What's the weather?", + system_message="You are helpful.", + ) + + assert case.name == "test_case" + assert case.user_message == "What's the weather?" + assert case.system_message == "You are helpful." + assert case.additional_messages == [] + assert case.track_configs == {} + + def test_add_track_config(self) -> None: + """Test adding track configuration.""" + case = ComparativeCase( + name="test_case", + user_message="What's the weather?", + ) + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ + ExpectedMCPToolCall("GetWeather", args={"city": "NYC"}) + ] + + case.add_track_config("Track1", expected) + + assert "Track1" in case.track_configs + assert case.track_configs["Track1"].expected_tool_calls == expected + + def test_add_duplicate_track_config_raises(self) -> None: + """Test that adding duplicate track config raises.""" + case = ComparativeCase( + name="test_case", + user_message="What's the weather?", + ) + expected1: list[ExpectedToolCall | ExpectedMCPToolCall] = [ + ExpectedMCPToolCall("Tool1", args={"arg": "v1"}) + ] + expected2: list[ExpectedToolCall | ExpectedMCPToolCall] = [ + ExpectedMCPToolCall("Tool2", args={"arg": "v2"}) + ] + + case.add_track_config("Track1", expected1) + + with pytest.raises(ValueError, match="already configured"): + case.add_track_config("Track1", expected2) + + def test_get_configured_tracks(self) -> None: + """Test getting list of configured tracks.""" + case = ComparativeCase( + name="test_case", + user_message="What's the weather?", + ) + track1: list[ExpectedToolCall | ExpectedMCPToolCall] = [ExpectedMCPToolCall("Tool1")] + track2: list[ExpectedToolCall | ExpectedMCPToolCall] = [ExpectedMCPToolCall("Tool2")] + case.add_track_config("Track1", track1) + case.add_track_config("Track2", track2) + + tracks = case.get_configured_tracks() + + assert tracks == ["Track1", "Track2"] + + +class TestComparativeCaseBuilder: + """Tests for ComparativeCaseBuilder fluent API.""" + + def test_builder_creates_case(self) -> None: + """Test builder creates a comparative case.""" + suite = EvalSuite(name="Test Suite", system_message="Test") + # Register a track first + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + + builder = ComparativeCaseBuilder( + suite=suite, + name="test_case", + user_message="Test message", + system_message="System message", + ) + + assert builder.case.name == "test_case" + assert builder.case.user_message == "Test message" + assert builder.case.system_message == "System message" + + def test_builder_for_track(self) -> None: + """Test builder for_track method.""" + suite = EvalSuite(name="Test Suite", system_message="Test") + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + + builder = ComparativeCaseBuilder( + suite=suite, + name="test_case", + user_message="Test message", + ) + + result = builder.for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("Tool1", args={"arg": "value"})], + ) + + assert result is builder # Returns self for chaining + assert "Track1" in builder.case.track_configs + + def test_builder_for_track_nonexistent_raises(self) -> None: + """Test for_track raises for nonexistent track.""" + suite = EvalSuite(name="Test Suite", system_message="Test") + + builder = ComparativeCaseBuilder( + suite=suite, + name="test_case", + user_message="Test message", + ) + + with pytest.raises(ValueError, match="not found"): + builder.for_track( + "NonexistentTrack", + expected_tool_calls=[ExpectedMCPToolCall("Tool1")], + ) + + def test_builder_chaining(self) -> None: + """Test builder supports method chaining.""" + suite = EvalSuite(name="Test Suite", system_message="Test") + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + suite.add_tool_definitions([{"name": "Tool2"}], track="Track2") + + builder = ComparativeCaseBuilder( + suite=suite, + name="test_case", + user_message="Test message", + ) + + builder.for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("Tool1")], + ).for_track( + "Track2", + expected_tool_calls=[ExpectedMCPToolCall("Tool2")], + ) + + assert len(builder.case.track_configs) == 2 + assert "Track1" in builder.case.track_configs + assert "Track2" in builder.case.track_configs + + def test_builder_build_empty_raises(self) -> None: + """Test build raises when no tracks configured.""" + suite = EvalSuite(name="Test Suite", system_message="Test") + + builder = ComparativeCaseBuilder( + suite=suite, + name="test_case", + user_message="Test message", + ) + + with pytest.raises(ValueError, match="No tracks configured"): + builder.build() + + +class TestEvalSuiteTrackIntegration: + """Tests for EvalSuite track integration.""" + + def test_add_tool_definitions_with_track(self) -> None: + """Test adding tool definitions to a specific track.""" + suite = EvalSuite(name="Test", system_message="Test") + + suite.add_tool_definitions( + [{"name": "TestTool", "description": "A test"}], + track="MyTrack", + ) + + tracks = suite.get_tracks() + assert "MyTrack" in tracks + assert suite.get_tool_count(track="MyTrack") == 1 + assert suite.list_tool_names(track="MyTrack") == ["TestTool"] + + def test_add_tool_definitions_multiple_tracks(self) -> None: + """Test adding tools to multiple tracks.""" + suite = EvalSuite(name="Test", system_message="Test") + + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + suite.add_tool_definitions([{"name": "Tool2"}], track="Track2") + + assert len(suite.get_tracks()) == 2 + assert suite.list_tool_names(track="Track1") == ["Tool1"] + assert suite.list_tool_names(track="Track2") == ["Tool2"] + + def test_tracks_are_isolated(self) -> None: + """Test that tracks have isolated tool registries.""" + suite = EvalSuite(name="Test", system_message="Test") + + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + suite.add_tool_definitions([{"name": "Tool2"}], track="Track2") + + # Each track only sees its own tools + track1_tools = suite.list_tool_names(track="Track1") + track2_tools = suite.list_tool_names(track="Track2") + + assert "Tool1" in track1_tools + assert "Tool2" not in track1_tools + assert "Tool2" in track2_tools + assert "Tool1" not in track2_tools + + def test_default_registry_separate_from_tracks(self) -> None: + """Test that default registry is separate from tracks.""" + suite = EvalSuite(name="Test", system_message="Test") + + # Add to default registry + suite.add_tool_definitions([{"name": "DefaultTool"}]) + # Add to track + suite.add_tool_definitions([{"name": "TrackTool"}], track="MyTrack") + + # Default registry + assert suite.get_tool_count() == 1 + assert suite.list_tool_names() == ["DefaultTool"] + + # Track registry + assert suite.get_tool_count(track="MyTrack") == 1 + assert suite.list_tool_names(track="MyTrack") == ["TrackTool"] + + def test_add_comparative_case(self) -> None: + """Test add_comparative_case method.""" + suite = EvalSuite(name="Test", system_message="Test") + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + + builder = suite.add_comparative_case( + name="weather_query", + user_message="What's the weather in NYC?", + ) + + assert builder is not None + assert builder.case.name == "weather_query" + + # Configure track and verify + builder.for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("Tool1", args={"city": "NYC"})], + ) + + assert "Track1" in builder.case.track_configs + + def test_add_comparative_case_uses_suite_defaults(self) -> None: + """Test add_comparative_case uses suite defaults.""" + from arcade_evals import EvalRubric + + rubric = EvalRubric(fail_threshold=0.9) + suite = EvalSuite( + name="Test", + system_message="Default system message", + rubric=rubric, + ) + + builder = suite.add_comparative_case( + name="test", + user_message="Test message", + ) + + assert builder.case.system_message == "Default system message" + assert builder.case.rubric == rubric + + def test_get_tracks_empty(self) -> None: + """Test get_tracks when no tracks registered.""" + suite = EvalSuite(name="Test", system_message="Test") + + assert suite.get_tracks() == [] + + def test_method_chaining_still_works(self) -> None: + """Test that method chaining still works with track parameter.""" + suite = EvalSuite(name="Test", system_message="Test") + + # Chaining should still work + result = suite.add_tool_definitions( + [{"name": "Tool1"}], + track="Track1", + ).add_tool_definitions( + [{"name": "Tool2"}], + track="Track2", + ) + + assert result is suite + assert len(suite.get_tracks()) == 2 + + +class TestRunComparative: + """Tests for EvalSuite.run_comparative method.""" + + @pytest.mark.asyncio + async def test_run_comparative_no_cases_raises(self) -> None: + """Test run_comparative raises when no cases defined.""" + from unittest.mock import AsyncMock + + suite = EvalSuite(name="Test", system_message="Test") + client = AsyncMock() + + with pytest.raises(ValueError, match="No comparative cases defined"): + await suite.run_comparative(client, "gpt-4o") + + @pytest.mark.asyncio + async def test_run_comparative_missing_track_raises(self) -> None: + """Test builder raises when track doesn't exist (fail-fast validation).""" + suite = EvalSuite(name="Test", system_message="Test") + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + + # Builder validates tracks exist at configuration time (fail-fast) + builder = suite.add_comparative_case( + name="test_case", + user_message="Test", + ).for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("Tool1")], + ) + + # Attempting to add non-existent track should raise immediately + with pytest.raises(ValueError, match="Track 'NonExistentTrack' not found"): + builder.for_track( + "NonExistentTrack", + expected_tool_calls=[ExpectedMCPToolCall("Tool2")], + ) + + @pytest.mark.asyncio + async def test_run_comparative_no_tracks_configured_raises(self) -> None: + """Test run_comparative raises when builder has no tracks.""" + from unittest.mock import AsyncMock + + suite = EvalSuite(name="Test", system_message="Test") + # Add case but don't configure any tracks + suite.add_comparative_case( + name="test_case", + user_message="Test", + ) + + client = AsyncMock() + with pytest.raises(ValueError, match="No tracks configured"): + await suite.run_comparative(client, "gpt-4o") + + @pytest.mark.asyncio + async def test_run_comparative_basic_execution(self) -> None: + """Test run_comparative executes cases across tracks.""" + from unittest.mock import AsyncMock, MagicMock + + suite = EvalSuite(name="Test Suite", system_message="You are helpful") + + # Register tools for two tracks + suite.add_tool_definitions( + [{"name": "GetWeather", "description": "Get weather"}], + track="Track1", + ) + suite.add_tool_definitions( + [{"name": "FetchWeather", "description": "Fetch weather"}], + track="Track2", + ) + + # Add comparative case + suite.add_comparative_case( + name="weather_query", + user_message="What's the weather?", + ).for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("GetWeather", args={"city": "NYC"})], + ).for_track( + "Track2", + expected_tool_calls=[ExpectedMCPToolCall("FetchWeather", args={"city": "NYC"})], + ) + + # Mock OpenAI client + client = AsyncMock() + mock_response = MagicMock() + mock_message = MagicMock() + mock_tool_call = MagicMock() + mock_tool_call.function.name = "GetWeather" + mock_tool_call.function.arguments = '{"city": "NYC"}' + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [MagicMock(message=mock_message)] + client.chat.completions.create.return_value = mock_response + + # Run comparative evaluation + results = await suite.run_comparative(client, "gpt-4o", provider="openai") + + # Verify structure + assert "Track1" in results + assert "Track2" in results + assert results["Track1"]["model"] == "gpt-4o" + assert results["Track1"]["suite_name"] == "Test Suite" + assert results["Track1"]["track_name"] == "Track1" + assert len(results["Track1"]["cases"]) == 1 + assert len(results["Track2"]["cases"]) == 1 + + # Verify case results + track1_case = results["Track1"]["cases"][0] + assert track1_case["name"] == "weather_query" + assert track1_case["track"] == "Track1" + assert track1_case["input"] == "What's the weather?" + assert "evaluation" in track1_case + + @pytest.mark.asyncio + async def test_run_comparative_multiple_cases(self) -> None: + """Test run_comparative with multiple comparative cases.""" + from unittest.mock import AsyncMock, MagicMock + + suite = EvalSuite(name="Test", system_message="Test") + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + suite.add_tool_definitions([{"name": "Tool2"}], track="Track2") + + # Add two comparative cases + suite.add_comparative_case( + name="case1", + user_message="Query 1", + ).for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("Tool1")], + ).for_track( + "Track2", + expected_tool_calls=[ExpectedMCPToolCall("Tool2")], + ) + + suite.add_comparative_case( + name="case2", + user_message="Query 2", + ).for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("Tool1")], + ) + + # Mock client + client = AsyncMock() + mock_response = MagicMock() + mock_message = MagicMock() + mock_message.tool_calls = [] + mock_response.choices = [MagicMock(message=mock_message)] + client.chat.completions.create.return_value = mock_response + + results = await suite.run_comparative(client, "gpt-4o") + + # Verify both tracks present + assert "Track1" in results + assert "Track2" in results + + # Track1 should have 2 cases, Track2 should have 1 case + assert len(results["Track1"]["cases"]) == 2 + assert len(results["Track2"]["cases"]) == 1 + + # Verify case names + track1_names = {case["name"] for case in results["Track1"]["cases"]} + assert track1_names == {"case1", "case2"} + track2_names = {case["name"] for case in results["Track2"]["cases"]} + assert track2_names == {"case1"} + + @pytest.mark.asyncio + async def test_run_comparative_anthropic_provider(self) -> None: + """Test run_comparative with Anthropic provider.""" + from unittest.mock import AsyncMock, MagicMock + + suite = EvalSuite(name="Test", system_message="Test") + suite.add_tool_definitions([{"name": "TestTool"}], track="Track1") + + suite.add_comparative_case( + name="test", + user_message="Test query", + ).for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("TestTool")], + ) + + # Mock Anthropic client + client = AsyncMock() + mock_response = MagicMock() + mock_response.content = [] + client.messages.create.return_value = mock_response + + results = await suite.run_comparative(client, "claude-3-5-sonnet", provider="anthropic") + + assert "Track1" in results + assert len(results["Track1"]["cases"]) == 1 + # Verify Anthropic client was called + assert client.messages.create.called + + @pytest.mark.asyncio + async def test_run_comparative_track_deleted_after_config(self) -> None: + """Test run_comparative when track is deleted after case configuration. + + This tests the execution-time validation that ensures tracks still exist + when run_comparative is called (edge case for programmatic track deletion). + """ + from unittest.mock import AsyncMock + + suite = EvalSuite(name="Test", system_message="Test") + + # Register track and configure case + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + suite.add_comparative_case( + name="test_case", + user_message="Test", + ).for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("Tool1")], + ) + + # Simulate track being removed (edge case - programmatic deletion) + # This bypasses builder validation but triggers run_comparative validation + suite._track_manager._tracks.clear() + + client = AsyncMock() + + # Should raise at execution time with helpful error + with pytest.raises(ValueError, match="Missing track registries.*Track1"): + await suite.run_comparative(client, "gpt-4o") + + @pytest.mark.asyncio + async def test_run_comparative_registry_none_defensive_check(self) -> None: + """Test the defensive RuntimeError if registry is None after validation. + + This tests the defensive programming check that should never trigger + in normal operation but protects against race conditions or bugs. + """ + from unittest.mock import AsyncMock + + suite = EvalSuite(name="Test", system_message="Test") + suite.add_tool_definitions([{"name": "Tool1"}], track="Track1") + + suite.add_comparative_case( + name="test_case", + user_message="Test", + ).for_track( + "Track1", + expected_tool_calls=[ExpectedMCPToolCall("Tool1")], + ) + + client = AsyncMock() + + # Patch get_registry to return None during execution loop + # has_track() will pass validation, but get_registry() will return None + # This simulates a race condition where track is deleted between validation and execution + original_has_track = suite._track_manager.has_track + + def patched_get_registry(track_name: str) -> None: + # Return None to trigger the defensive check + return None + + def patched_has_track(track_name: str) -> bool: + # Return True to pass validation + return original_has_track(track_name) + + # Apply patches using patch.object to satisfy mypy + from unittest.mock import patch + + with ( + patch.object(suite._track_manager, "get_registry", patched_get_registry), + patch.object(suite._track_manager, "has_track", patched_has_track), + ): + # Should raise RuntimeError (defensive check) + with pytest.raises(RuntimeError, match="Registry.*unexpectedly None after validation"): + await suite.run_comparative(client, "gpt-4o") diff --git a/libs/tests/arcade_evals/test_comparative_execution.py b/libs/tests/arcade_evals/test_comparative_execution.py new file mode 100644 index 000000000..d349a2fc4 --- /dev/null +++ b/libs/tests/arcade_evals/test_comparative_execution.py @@ -0,0 +1,249 @@ +"""Tests for comparative evaluation execution logic.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedMCPToolCall, +) + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + + +class TestRunComparative: + """Tests for EvalSuite.run_comparative() method.""" + + @pytest.mark.asyncio + async def test_for_track_validates_track_exists(self) -> None: + """Test that for_track raises error if track doesn't exist.""" + suite = EvalSuite(name="test", system_message="test") + + # Add tools to track1 only + suite.add_tool_definitions([{"name": "tool1", "description": "Test", "inputSchema": {}}], track="track1") + + # Try to add comparative case with track2 (doesn't exist) + case = suite.add_comparative_case(name="test", user_message="test") + case.for_track("track1", expected_tool_calls=[ExpectedMCPToolCall("tool1", args={})]) + + # for_track validates immediately + with pytest.raises(ValueError, match="Track.*not found"): + case.for_track("track2", expected_tool_calls=[ExpectedMCPToolCall("tool2", args={})]) + + @pytest.mark.asyncio + async def test_run_comparative_returns_track_results(self) -> None: + """Test that run_comparative returns dict with track results.""" + suite = EvalSuite(name="test", system_message="test") + + # Add tools to two tracks + suite.add_tool_definitions([{"name": "tool1", "description": "Test", "inputSchema": {}}], track="track1") + suite.add_tool_definitions([{"name": "tool2", "description": "Test", "inputSchema": {}}], track="track2") + + # Add comparative case + case = suite.add_comparative_case(name="case1", user_message="test") + case.for_track("track1", expected_tool_calls=[ExpectedMCPToolCall("tool1", args={})]) + case.for_track("track2", expected_tool_calls=[ExpectedMCPToolCall("tool2", args={})]) + + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.tool_calls = None + mock_client.chat.completions.create.return_value = mock_response + + result = await suite.run_comparative(mock_client, "gpt-4o", provider="openai") + + # Should return dict with track names as keys + assert isinstance(result, dict) + assert "track1" in result + assert "track2" in result + + # Each track should have model, suite_name, track_name, cases + assert result["track1"]["model"] == "gpt-4o" + assert result["track1"]["suite_name"] == "test" + assert result["track1"]["track_name"] == "track1" + assert "cases" in result["track1"] + assert len(result["track1"]["cases"]) == 1 + + @pytest.mark.asyncio + async def test_run_comparative_raises_without_comparative_cases(self) -> None: + """Test that run_comparative raises error when no comparative cases defined.""" + suite = EvalSuite(name="test", system_message="test") + suite.add_tool_definitions([{"name": "tool1", "description": "Test", "inputSchema": {}}]) + + mock_client = AsyncMock() + + with pytest.raises(ValueError, match="No comparative cases defined"): + await suite.run_comparative(mock_client, "gpt-4o", provider="openai") + + @pytest.mark.asyncio + async def test_run_comparative_respects_max_concurrent(self) -> None: + """Test that run_comparative respects max_concurrent setting.""" + suite = EvalSuite(name="test", system_message="test", max_concurrent=2) + + # Add tools + suite.add_tool_definitions([{"name": "tool1", "description": "Test", "inputSchema": {}}], track="track1") + + # Add 3 cases + for i in range(3): + case = suite.add_comparative_case(name=f"case{i}", user_message=f"test{i}") + case.for_track("track1", expected_tool_calls=[]) + + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.tool_calls = None + mock_client.chat.completions.create.return_value = mock_response + + # Semaphore with max_concurrent=2 will be used + result = await suite.run_comparative(mock_client, "gpt-4o", provider="openai") + + # All cases should complete + assert len(result["track1"]["cases"]) == 3 + + @pytest.mark.asyncio + async def test_run_comparative_with_anthropic_provider(self) -> None: + """Test run_comparative works with Anthropic provider.""" + suite = EvalSuite(name="test", system_message="test") + + suite.add_tool_definitions([{"name": "search", "description": "Search", "inputSchema": {}}], track="track1") + + case = suite.add_comparative_case(name="test", user_message="search for cats") + case.for_track("track1", expected_tool_calls=[ExpectedMCPToolCall("search", args={"query": "cats"})]) + + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.content = [] + mock_client.messages.create.return_value = mock_response + + result = await suite.run_comparative(mock_client, "claude-3", provider="anthropic") + + assert "track1" in result + assert result["track1"]["model"] == "claude-3" + + +class TestComparativeCaseBuilder: + """Tests for ComparativeCaseBuilder fluent API.""" + + def test_for_track_returns_builder_for_chaining(self) -> None: + """Test that for_track returns builder for method chaining.""" + suite = EvalSuite(name="test", system_message="test") + suite.add_tool_definitions([{"name": "t1", "description": "Test", "inputSchema": {}}], track="track1") + suite.add_tool_definitions([{"name": "t2", "description": "Test", "inputSchema": {}}], track="track2") + + builder = suite.add_comparative_case(name="test", user_message="test") + result1 = builder.for_track("track1", expected_tool_calls=[]) + result2 = result1.for_track("track2", expected_tool_calls=[]) + + # Should return same builder for chaining + assert result1 is builder + assert result2 is builder + + def test_comparative_case_with_custom_rubric(self) -> None: + """Test that comparative cases can have custom rubrics.""" + suite = EvalSuite(name="test", system_message="test") + suite.add_tool_definitions([{"name": "t1", "description": "Test", "inputSchema": {}}], track="track1") + + strict_rubric = EvalRubric(fail_threshold=0.7, warn_threshold=0.9) + + # Rubric is set on the case, not per track + builder = suite.add_comparative_case(name="test", user_message="test", rubric=strict_rubric) + builder.for_track("track1", expected_tool_calls=[]) + + # Build and verify rubric is stored on the case + comp_case = builder.build() + assert comp_case.rubric == strict_rubric + + def test_for_track_with_track_specific_critics(self) -> None: + """Test that tracks can have specific critics.""" + suite = EvalSuite(name="test", system_message="test") + suite.add_tool_definitions([{"name": "t1", "description": "Test", "inputSchema": {}}], track="track1") + + critics = [BinaryCritic(critic_field="query", weight=1.0)] + + builder = suite.add_comparative_case(name="test", user_message="test") + builder.for_track("track1", expected_tool_calls=[], critics=critics) + + comp_case = builder.build() + assert comp_case.track_configs["track1"].critics == critics + + def test_build_raises_if_no_tracks_configured(self) -> None: + """Test that build() raises error if no tracks are configured.""" + suite = EvalSuite(name="test", system_message="test") + builder = suite.add_comparative_case(name="test", user_message="test") + + with pytest.raises(ValueError, match="No tracks configured"): + builder.build() + + +class TestComparativeTrackValidation: + """Tests for track validation in comparative evaluations.""" + + def test_for_track_validates_track_exists(self) -> None: + """Test that for_track validates track exists immediately.""" + suite = EvalSuite(name="test", system_message="test") + + # Register only track1 + suite.add_tool_definitions([{"name": "t1", "description": "Test", "inputSchema": {}}], track="track1") + + # Try to use nonexistent_track + case = suite.add_comparative_case(name="test", user_message="test") + case.for_track("track1", expected_tool_calls=[]) + + # for_track validates immediately + with pytest.raises(ValueError, match="Track.*not found"): + case.for_track("nonexistent_track", expected_tool_calls=[]) + + def test_for_track_error_lists_available_tracks(self) -> None: + """Test that error message lists available tracks.""" + suite = EvalSuite(name="test", system_message="test") + + suite.add_tool_definitions([{"name": "t1", "description": "Test", "inputSchema": {}}], track="available_track") + + case = suite.add_comparative_case(name="test", user_message="test") + + with pytest.raises(ValueError) as exc_info: + case.for_track("missing_track", expected_tool_calls=[]) + + error_msg = str(exc_info.value) + assert "missing_track" in error_msg + assert "available_track" in error_msg + + +class TestComparativeConcurrencyControl: + """Tests for concurrency control in comparative execution.""" + + @pytest.mark.asyncio + async def test_semaphore_limits_concurrent_tasks(self) -> None: + """Test that semaphore properly limits concurrent API calls.""" + suite = EvalSuite(name="test", system_message="test", max_concurrent=1) + + suite.add_tool_definitions([{"name": "t1", "description": "Test", "inputSchema": {}}], track="track1") + + # Add 3 cases - with max_concurrent=1, they should run sequentially + for i in range(3): + case = suite.add_comparative_case(name=f"case{i}", user_message="test") + case.for_track("track1", expected_tool_calls=[]) + + call_count = 0 + + async def mock_create(**kwargs): + nonlocal call_count + call_count += 1 + # Simulate some delay + import asyncio + await asyncio.sleep(0.01) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.tool_calls = None + return mock_response + + mock_client = AsyncMock() + mock_client.chat.completions.create = mock_create + + await suite.run_comparative(mock_client, "gpt-4o", provider="openai") + + # All 3 cases should have been called + assert call_count == 3 diff --git a/libs/tests/arcade_evals/test_convenience_async.py b/libs/tests/arcade_evals/test_convenience_async.py new file mode 100644 index 000000000..968ee1937 --- /dev/null +++ b/libs/tests/arcade_evals/test_convenience_async.py @@ -0,0 +1,293 @@ +"""Tests for async MCP convenience methods in EvalSuite.""" + +from unittest.mock import patch + +import pytest +from arcade_evals import EvalSuite + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + + +class TestAddMcpServer: + """Tests for add_mcp_server async convenience method.""" + + @pytest.mark.asyncio + async def test_add_mcp_server_loads_and_registers_tools(self) -> None: + """Test that add_mcp_server loads tools and adds them to registry.""" + suite = EvalSuite(name="test", system_message="test") + + mock_tools = [ + {"name": "tool1", "description": "Test tool 1", "inputSchema": {}}, + {"name": "tool2", "description": "Test tool 2", "inputSchema": {}}, + ] + + with patch("arcade_evals._evalsuite._convenience.load_mcp_remote_async") as mock_load: + mock_load.return_value = mock_tools + + result = await suite.add_mcp_server( + "http://localhost:8000", + headers={"Authorization": "Bearer token"}, + timeout=15, + use_sse=True, + ) + + # Verify loader was called with correct args + mock_load.assert_called_once_with( + "http://localhost:8000", + timeout=15, + headers={"Authorization": "Bearer token"}, + use_sse=True, + ) + + # Verify tools were registered + tools = suite._internal_registry.list_tools_for_model("openai") + assert len(tools) == 2 + + # Verify returns self for chaining + assert result is suite + + @pytest.mark.asyncio + async def test_add_mcp_server_with_track(self) -> None: + """Test add_mcp_server with track parameter.""" + suite = EvalSuite(name="test", system_message="test") + + mock_tools = [{"name": "tool1", "description": "Test", "inputSchema": {}}] + + with patch("arcade_evals._evalsuite._convenience.load_mcp_remote_async") as mock_load: + mock_load.return_value = mock_tools + + await suite.add_mcp_server("http://localhost:8000", track="github") + + # Verify track was created + assert suite._track_manager.has_track("github") + + # Verify tool is in track registry + track_registry = suite._track_manager.get_registry("github") + assert track_registry is not None + track_tools = track_registry.list_tools_for_model("openai") + assert len(track_tools) == 1 + assert track_tools[0]["function"]["name"] == "tool1" + + @pytest.mark.asyncio + async def test_add_mcp_server_warns_on_empty_tools(self) -> None: + """Test that add_mcp_server warns when no tools are loaded.""" + suite = EvalSuite(name="test", system_message="test") + + with patch("arcade_evals._evalsuite._convenience.load_mcp_remote_async") as mock_load: + mock_load.return_value = [] # Empty tools + + with pytest.warns(UserWarning, match="No tools loaded from"): + await suite.add_mcp_server("http://localhost:8000") + + @pytest.mark.asyncio + async def test_add_mcp_server_handles_loader_exception(self) -> None: + """Test that add_mcp_server propagates loader exceptions.""" + suite = EvalSuite(name="test", system_message="test") + + with patch("arcade_evals._evalsuite._convenience.load_mcp_remote_async") as mock_load: + mock_load.side_effect = TimeoutError("Connection timeout") + + with pytest.raises(TimeoutError, match="Connection timeout"): + await suite.add_mcp_server("http://localhost:8000") + + +class TestAddMcpStdioServer: + """Tests for add_mcp_stdio_server async convenience method.""" + + @pytest.mark.asyncio + async def test_add_mcp_stdio_server_loads_and_registers_tools(self) -> None: + """Test that add_mcp_stdio_server loads tools and adds them to registry.""" + suite = EvalSuite(name="test", system_message="test") + + mock_tools = [ + {"name": "linear_search", "description": "Search", "inputSchema": {}}, + {"name": "linear_create", "description": "Create", "inputSchema": {}}, + ] + + with patch("arcade_evals._evalsuite._convenience.load_from_stdio_async") as mock_load: + mock_load.return_value = mock_tools + + command = ["python", "-m", "arcade_mcp_server", "stdio"] + env = {"ARCADE_API_KEY": "test_key"} + + result = await suite.add_mcp_stdio_server(command, env=env, timeout=20) + + # Verify loader was called with correct args + mock_load.assert_called_once_with(command, timeout=20, env=env) + + # Verify tools were registered + tools = suite._internal_registry.list_tools_for_model("openai") + assert len(tools) == 2 + + # Verify returns self for chaining + assert result is suite + + @pytest.mark.asyncio + async def test_add_mcp_stdio_server_with_track(self) -> None: + """Test add_mcp_stdio_server with track parameter.""" + suite = EvalSuite(name="test", system_message="test") + + mock_tools = [{"name": "tool1", "description": "Test", "inputSchema": {}}] + + with patch("arcade_evals._evalsuite._convenience.load_from_stdio_async") as mock_load: + mock_load.return_value = mock_tools + + await suite.add_mcp_stdio_server(["python", "server.py"], track="linear") + + # Verify track was created + assert suite._track_manager.has_track("linear") + + @pytest.mark.asyncio + async def test_add_mcp_stdio_server_warns_on_empty_tools(self) -> None: + """Test that add_mcp_stdio_server warns when no tools are loaded.""" + suite = EvalSuite(name="test", system_message="test") + + with patch("arcade_evals._evalsuite._convenience.load_from_stdio_async") as mock_load: + mock_load.return_value = [] + + with pytest.warns(UserWarning, match="No tools loaded from stdio"): + await suite.add_mcp_stdio_server(["python", "server.py"]) + + @pytest.mark.asyncio + async def test_add_mcp_stdio_server_handles_loader_exception(self) -> None: + """Test that add_mcp_stdio_server propagates loader exceptions.""" + suite = EvalSuite(name="test", system_message="test") + + with patch("arcade_evals._evalsuite._convenience.load_from_stdio_async") as mock_load: + mock_load.side_effect = TimeoutError("Stdio timeout") + + with pytest.raises(TimeoutError, match="Stdio timeout"): + await suite.add_mcp_stdio_server(["python", "server.py"]) + + +class TestAddArcadeGateway: + """Tests for add_arcade_gateway async convenience method.""" + + @pytest.mark.asyncio + async def test_add_arcade_gateway_loads_and_registers_tools(self) -> None: + """Test that add_arcade_gateway loads tools and adds them to registry.""" + suite = EvalSuite(name="test", system_message="test") + + mock_tools = [ + {"name": "Github_CreateIssue", "description": "Create issue", "inputSchema": {}}, + {"name": "Github_GetIssue", "description": "Get issue", "inputSchema": {}}, + ] + + with patch( + "arcade_evals._evalsuite._convenience.load_arcade_mcp_gateway_async" + ) as mock_load: + mock_load.return_value = mock_tools + + result = await suite.add_arcade_gateway( + "my-gateway", + arcade_api_key="test_key", + arcade_user_id="test@example.com", + base_url="https://api.arcade.dev", + timeout=10, + ) + + # Verify loader was called with correct args + mock_load.assert_called_once_with( + "my-gateway", + arcade_api_key="test_key", + arcade_user_id="test@example.com", + base_url="https://api.arcade.dev", + timeout=10, + ) + + # Verify tools were registered + tools = suite._internal_registry.list_tools_for_model("openai") + assert len(tools) == 2 + + # Verify returns self for chaining + assert result is suite + + @pytest.mark.asyncio + async def test_add_arcade_gateway_with_track(self) -> None: + """Test add_arcade_gateway with track parameter.""" + suite = EvalSuite(name="test", system_message="test") + + mock_tools = [{"name": "tool1", "description": "Test", "inputSchema": {}}] + + with patch( + "arcade_evals._evalsuite._convenience.load_arcade_mcp_gateway_async" + ) as mock_load: + mock_load.return_value = mock_tools + + await suite.add_arcade_gateway("my-gateway", track="arcade") + + # Verify track was created + assert suite._track_manager.has_track("arcade") + + @pytest.mark.asyncio + async def test_add_arcade_gateway_warns_on_empty_tools(self) -> None: + """Test that add_arcade_gateway warns when no tools are loaded.""" + suite = EvalSuite(name="test", system_message="test") + + with patch( + "arcade_evals._evalsuite._convenience.load_arcade_mcp_gateway_async" + ) as mock_load: + mock_load.return_value = [] + + with pytest.warns(UserWarning, match="No tools loaded from Arcade gateway"): + await suite.add_arcade_gateway("my-gateway") + + @pytest.mark.asyncio + async def test_add_arcade_gateway_handles_loader_exception(self) -> None: + """Test that add_arcade_gateway propagates loader exceptions.""" + suite = EvalSuite(name="test", system_message="test") + + with patch( + "arcade_evals._evalsuite._convenience.load_arcade_mcp_gateway_async" + ) as mock_load: + mock_load.side_effect = Exception("Gateway connection failed") + + with pytest.raises(Exception, match="Gateway connection failed"): + await suite.add_arcade_gateway("my-gateway") + + +class TestAsyncConvenienceMethodChaining: + """Tests for method chaining with async MCP methods.""" + + @pytest.mark.asyncio + async def test_chaining_multiple_mcp_sources(self) -> None: + """Test that async methods can be chained together.""" + suite = EvalSuite(name="test", system_message="test") + + mock_http_tools = [{"name": "http_tool", "description": "HTTP", "inputSchema": {}}] + mock_stdio_tools = [{"name": "stdio_tool", "description": "Stdio", "inputSchema": {}}] + mock_gateway_tools = [ + {"name": "gateway_tool", "description": "Gateway", "inputSchema": {}} + ] + + with ( + patch( + "arcade_evals._evalsuite._convenience.load_mcp_remote_async" + ) as mock_http, + patch( + "arcade_evals._evalsuite._convenience.load_from_stdio_async" + ) as mock_stdio, + patch( + "arcade_evals._evalsuite._convenience.load_arcade_mcp_gateway_async" + ) as mock_gateway, + ): + mock_http.return_value = mock_http_tools + mock_stdio.return_value = mock_stdio_tools + mock_gateway.return_value = mock_gateway_tools + + # Chain all three methods + result = await suite.add_mcp_server("http://localhost:8000") + result = await result.add_mcp_stdio_server(["python", "server.py"]) + result = await result.add_arcade_gateway("my-gateway") + + # Verify all tools were registered + tools = suite._internal_registry.list_tools_for_model("openai") + assert len(tools) == 3 + tool_names = [t["function"]["name"] for t in tools] + assert "http_tool" in tool_names + assert "stdio_tool" in tool_names + assert "gateway_tool" in tool_names + + # Verify final result is still the suite + assert result is suite diff --git a/libs/tests/arcade_evals/test_critics.py b/libs/tests/arcade_evals/test_critics.py new file mode 100644 index 000000000..bcd37e382 --- /dev/null +++ b/libs/tests/arcade_evals/test_critics.py @@ -0,0 +1,386 @@ +"""Tests for critic evaluation logic.""" + +import pytest +from arcade_evals.critic import ( + BinaryCritic, + NoneCritic, + NumericCritic, + SimilarityCritic, +) +from arcade_evals.errors import WeightError +from arcade_evals.weights import FuzzyWeight + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + + +class TestNoneCritic: + """Tests for NoneCritic placeholder.""" + + def test_none_critic_always_returns_zero_score(self) -> None: + """Test that NoneCritic always returns score 0.""" + critic = NoneCritic(critic_field="test", weight=0.0) + result = critic.evaluate("expected", "actual") + + assert result["score"] == 0.0 + assert result["match"] is None + assert result["is_criticized"] is False + + def test_none_critic_has_marker_attribute(self) -> None: + """Test that NoneCritic has _is_placeholder marker.""" + critic = NoneCritic(critic_field="test", weight=0.0) + assert hasattr(critic, "_is_placeholder") + assert critic._is_placeholder is True + + +class TestBinaryCritic: + """Tests for BinaryCritic exact equality comparisons.""" + + def test_binary_critic_exact_match_returns_full_weight(self) -> None: + """Test that exact match returns full weight as score.""" + critic = BinaryCritic(critic_field="name", weight=1.0) + result = critic.evaluate("Alice", "Alice") + + assert result["match"] is True + assert result["score"] == 1.0 + + def test_binary_critic_mismatch_returns_zero_score(self) -> None: + """Test that mismatch returns score 0.""" + critic = BinaryCritic(critic_field="name", weight=1.0) + result = critic.evaluate("Alice", "Bob") + + assert result["match"] is False + assert result["score"] == 0.0 + + def test_binary_critic_partial_weight(self) -> None: + """Test that partial weight is respected.""" + critic = BinaryCritic(critic_field="name", weight=0.5) + result = critic.evaluate("Alice", "Alice") + + assert result["match"] is True + assert result["score"] == 0.5 + + def test_binary_critic_cast_actual_to_expected_type(self) -> None: + """Test that actual value is cast to expected type.""" + critic = BinaryCritic(critic_field="count", weight=1.0) + # Expect int, get string + result = critic.evaluate(42, "42") + + assert result["match"] is True + assert result["score"] == 1.0 + + def test_binary_critic_none_handling(self) -> None: + """Test None value handling.""" + critic = BinaryCritic(critic_field="optional", weight=1.0) + + # None == None + result = critic.evaluate(None, None) + assert result["match"] is True + + # None != value + result = critic.evaluate(None, "value") + assert result["match"] is False + + # String "None" is cast to None + result = critic.evaluate(None, "None") + assert result["match"] is True + + +class TestNumericCritic: + """Tests for NumericCritic fuzzy numeric comparisons.""" + + def test_numeric_critic_exact_match_returns_full_score(self) -> None: + """Test that exact match returns full weight as score.""" + critic = NumericCritic( + critic_field="temperature", weight=1.0, value_range=(0.0, 100.0) + ) + result = critic.evaluate(50.0, 50.0) + + assert result["match"] is True + assert result["score"] == 1.0 + + def test_numeric_critic_close_values_high_score(self) -> None: + """Test that close values get high scores.""" + critic = NumericCritic( + critic_field="temperature", + weight=1.0, + value_range=(0.0, 100.0), + match_threshold=0.9, + ) + # Within 10% of range + result = critic.evaluate(50.0, 55.0) + + assert result["score"] >= 0.9 + assert result["match"] is True + + def test_numeric_critic_far_values_low_score(self) -> None: + """Test that far values get low scores.""" + critic = NumericCritic( + critic_field="temperature", weight=1.0, value_range=(0.0, 100.0) + ) + # Far apart + result = critic.evaluate(10.0, 90.0) + + assert result["score"] < 0.3 + assert result["match"] is False + + def test_numeric_critic_respects_match_threshold(self) -> None: + """Test that match_threshold correctly determines match status.""" + critic = NumericCritic( + critic_field="value", + weight=1.0, + value_range=(0.0, 100.0), + match_threshold=0.95, + ) + # Score is 0.9 (within 10% of range) - below 0.95 threshold + result = critic.evaluate(50.0, 60.0) + + assert result["score"] == 0.9 + assert result["match"] is False # Below threshold + + def test_numeric_critic_at_range_boundaries(self) -> None: + """Test evaluation at range boundaries.""" + critic = NumericCritic(critic_field="value", weight=1.0, value_range=(0.0, 100.0)) + + # At min boundary + result = critic.evaluate(0.0, 0.0) + assert result["match"] is True + assert result["score"] == 1.0 + + # At max boundary + result = critic.evaluate(100.0, 100.0) + assert result["match"] is True + assert result["score"] == 1.0 + + def test_numeric_critic_outside_range_handled(self) -> None: + """Test that values outside range are handled (extrapolation).""" + critic = NumericCritic(critic_field="value", weight=1.0, value_range=(0.0, 100.0)) + + # Actual is outside range + result = critic.evaluate(50.0, 150.0) + # Normalized difference will be large, score will be low or negative + assert result["score"] <= 0.0 + + def test_numeric_critic_partial_weight(self) -> None: + """Test that partial weight is respected.""" + critic = NumericCritic(critic_field="value", weight=0.5, value_range=(0.0, 100.0)) + result = critic.evaluate(50.0, 50.0) + + assert result["score"] == 0.5 # Perfect match * 0.5 weight + + def test_numeric_critic_invalid_range_raises_error(self) -> None: + """Test that invalid range (min >= max) raises ValueError.""" + with pytest.raises(ValueError, match="Invalid value_range"): + NumericCritic(critic_field="value", weight=1.0, value_range=(100.0, 0.0)) + + with pytest.raises(ValueError, match="Invalid value_range"): + NumericCritic(critic_field="value", weight=1.0, value_range=(50.0, 50.0)) + + +class TestSimilarityCritic: + """Tests for SimilarityCritic text similarity comparisons.""" + + def test_similarity_critic_exact_match_returns_full_score(self) -> None: + """Test that exact string match returns full weight as score.""" + critic = SimilarityCritic(critic_field="query", weight=1.0) + result = critic.evaluate("search for cats", "search for cats") + + assert result["match"] is True + assert result["score"] == 1.0 + + def test_similarity_critic_very_similar_strings_high_score(self) -> None: + """Test that very similar strings get high scores.""" + critic = SimilarityCritic( + critic_field="query", weight=1.0, similarity_threshold=0.5 + ) + result = critic.evaluate("search for cats", "search for cat") + + # Very similar (just plural difference) + assert result["score"] >= 0.5 + assert result["match"] is True + + def test_similarity_critic_different_strings_low_score(self) -> None: + """Test that different strings get low scores.""" + critic = SimilarityCritic(critic_field="query", weight=1.0) + result = critic.evaluate("search for cats", "weather in Paris") + + assert result["score"] < 0.3 + assert result["match"] is False + + def test_similarity_critic_respects_threshold(self) -> None: + """Test that similarity_threshold correctly determines match status.""" + critic = SimilarityCritic( + critic_field="query", weight=1.0, similarity_threshold=0.9 + ) + result = critic.evaluate("hello world", "hello there") + + # Similarity might be ~0.6-0.7 - below 0.9 threshold + assert result["match"] is False + + def test_similarity_critic_partial_weight(self) -> None: + """Test that partial weight is respected.""" + critic = SimilarityCritic(critic_field="query", weight=0.5) + result = critic.evaluate("test", "test") + + assert result["score"] == 0.5 # Perfect match * 0.5 weight + + def test_similarity_critic_handles_empty_strings(self) -> None: + """Test handling of empty strings.""" + critic = SimilarityCritic(critic_field="query", weight=1.0) + + # Empty == Empty + result = critic.evaluate("", "") + # TF-IDF can't compute similarity for empty strings - should handle gracefully + assert "score" in result + assert "match" in result + + def test_similarity_critic_converts_lists_to_strings(self) -> None: + """Test that lists are converted to space-separated strings.""" + critic = SimilarityCritic(critic_field="tags", weight=1.0) + + # Lists should be joined with spaces + result = critic.evaluate( + ["python", "security"], ["python", "security", "best-practices"] + ) + + # Should be comparing "python security" vs "python security best-practices" + assert "score" in result + assert result["score"] > 0.5 # Should have some similarity + + def test_similarity_critic_converts_non_strings(self) -> None: + """Test that non-string values are converted to strings.""" + critic = SimilarityCritic(critic_field="value", weight=1.0) + + # Numbers to strings + result = critic.evaluate(12345, 12345) + assert result["match"] is True + assert result["score"] == 1.0 + + # Dict to string + result = critic.evaluate({"key": "value"}, {"key": "value"}) + assert result["score"] > 0.8 # Should match after stringification + + def test_similarity_critic_unsupported_metric_raises_error(self) -> None: + """Test that unsupported metric raises ValueError.""" + with pytest.raises(ValueError, match="Unsupported similarity metric"): + SimilarityCritic(critic_field="query", weight=1.0, metric="hamming") + + def test_similarity_critic_requires_sklearn(self) -> None: + """Test that SimilarityCritic raises ImportError without sklearn.""" + from unittest.mock import patch + + critic = SimilarityCritic(critic_field="query", weight=1.0) + + # Patch the import inside evaluate() to simulate missing sklearn + with patch.dict("sys.modules", {"sklearn.feature_extraction.text": None}): + with pytest.raises(ImportError, match="pip install.*arcade-evals"): + critic.evaluate("test", "test2") + + +class TestCriticWeights: + """Tests for critic weight validation and FuzzyWeight support.""" + + def test_negative_weight_raises_error(self) -> None: + """Test that negative weights raise WeightError.""" + with pytest.raises(WeightError, match="non-negative"): + BinaryCritic(critic_field="test", weight=-0.5) + + def test_fuzzy_weight_skips_validation(self) -> None: + """Test that FuzzyWeight skips validation (normalized later).""" + # Should not raise even though FuzzyWeight.CRITICAL might be > 1 + critic = BinaryCritic(critic_field="test", weight=FuzzyWeight.CRITICAL) + assert critic.weight == FuzzyWeight.CRITICAL + + def test_zero_weight_allowed(self) -> None: + """Test that zero weight is allowed.""" + critic = BinaryCritic(critic_field="test", weight=0.0) + assert critic.weight == 0.0 + + def test_large_weight_allowed(self) -> None: + """Test that weights > 1.0 are allowed (softmax normalization handles).""" + critic = BinaryCritic(critic_field="test", weight=5.0) + assert critic.weight == 5.0 + + def test_resolved_weight_returns_float(self) -> None: + """Test that resolved_weight property returns float.""" + critic = BinaryCritic(critic_field="test", weight=0.8) + assert isinstance(critic.resolved_weight, float) + assert critic.resolved_weight == 0.8 + + def test_resolved_weight_with_fuzzy_weight(self) -> None: + """Test resolved_weight with FuzzyWeight enum.""" + critic = BinaryCritic(critic_field="test", weight=FuzzyWeight.HIGH) + # FuzzyWeight.HIGH has value 5 (int) + assert isinstance(critic.resolved_weight, (int, float)) + assert critic.resolved_weight > 0.0 + + +class TestCriticEdgeCases: + """Tests for edge cases in critic evaluation.""" + + def test_binary_critic_with_complex_types(self) -> None: + """Test BinaryCritic with dicts and lists.""" + critic = BinaryCritic(critic_field="config", weight=1.0) + + # Dict comparison + result = critic.evaluate({"a": 1, "b": 2}, {"a": 1, "b": 2}) + assert result["match"] is True + + # List comparison + result = critic.evaluate([1, 2, 3], [1, 2, 3]) + assert result["match"] is True + + # Nested structures + result = critic.evaluate({"list": [1, 2]}, {"list": [1, 2]}) + assert result["match"] is True + + def test_numeric_critic_with_string_numbers(self) -> None: + """Test NumericCritic casts string numbers to float.""" + critic = NumericCritic(critic_field="value", weight=1.0, value_range=(0.0, 100.0)) + result = critic.evaluate("50.0", "50.0") + + assert result["match"] is True + assert result["score"] == 1.0 + + def test_similarity_critic_case_insensitive(self) -> None: + """Test that SimilarityCritic handles case differences.""" + critic = SimilarityCritic(critic_field="query", weight=1.0) + result = critic.evaluate("Hello World", "hello world") + + # Should still have high similarity (lowercase conversion happens in TF-IDF) + assert result["score"] > 0.9 + assert result["match"] is True + + def test_similarity_critic_punctuation_differences(self) -> None: + """Test SimilarityCritic with punctuation variations.""" + critic = SimilarityCritic( + critic_field="query", weight=1.0, similarity_threshold=0.8 + ) + result = critic.evaluate("search for cats!", "search for cats") + + # Should have very high similarity despite punctuation + assert result["score"] >= 0.8 + assert result["match"] is True + + def test_numeric_critic_with_negative_ranges(self) -> None: + """Test NumericCritic with negative value ranges.""" + critic = NumericCritic( + critic_field="temperature", weight=1.0, value_range=(-50.0, 50.0) + ) + result = critic.evaluate(-10.0, -10.0) + + assert result["match"] is True + assert result["score"] == 1.0 + + # Test scoring across negative range + result = critic.evaluate(-50.0, 50.0) + assert result["score"] == 0.0 # Maximum difference + + def test_numeric_critic_floating_point_precision(self) -> None: + """Test NumericCritic handles floating point precision correctly.""" + critic = NumericCritic(critic_field="value", weight=1.0, value_range=(0.0, 1.0)) + result = critic.evaluate(0.333333, 0.333334) + + # Very close values should have very high score + assert result["score"] > 0.999 + assert result["match"] is True diff --git a/libs/tests/arcade_evals/test_loaders.py b/libs/tests/arcade_evals/test_loaders.py new file mode 100644 index 000000000..370f553a8 --- /dev/null +++ b/libs/tests/arcade_evals/test_loaders.py @@ -0,0 +1,781 @@ +"""Tests for MCP server loaders (official MCP SDK wrappers).""" + +import importlib.util +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + +# Import the loaders module directly by file path to avoid arcade_core dependency +_LOADERS_PATH = Path(__file__).parent.parent.parent / "arcade-evals" / "arcade_evals" / "loaders.py" +spec = importlib.util.spec_from_file_location("loaders", _LOADERS_PATH) +loaders = importlib.util.module_from_spec(spec) +sys.modules["arcade_evals.loaders"] = loaders +spec.loader.exec_module(loaders) + + +class TestLoadFromStdio: + """Tests for load_from_stdio function.""" + + def setup_method(self): + """Clear cache before each test.""" + loaders.clear_tools_cache() + + def teardown_method(self): + """Clear cache after each test.""" + loaders.clear_tools_cache() + + @pytest.mark.asyncio + async def test_empty_command_returns_empty_list(self): + """Empty command should return empty list without importing MCP.""" + result = await loaders.load_from_stdio_async([]) + assert result == [] + + @pytest.mark.asyncio + async def test_env_vars_are_merged_into_stdio_server_parameters(self): + """Env vars should be merged with current env and passed to StdioServerParameters.""" + mock_tool = MagicMock() + mock_tool.name = "t" + mock_tool.description = "d" + mock_tool.inputSchema = {"type": "object", "properties": {}} + + mock_list_result = MagicMock() + mock_list_result.tools = [mock_tool] + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=mock_list_result) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_stdio_client = MagicMock() + mock_stdio_client.return_value.__aenter__ = AsyncMock(return_value=("read", "write")) + mock_stdio_client.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_sse_client = MagicMock() + mock_stdio_params_cls = MagicMock() + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + mock_stdio_params_cls, + mock_stdio_client, + mock_sse_client, + MagicMock(), # streamablehttp_client + ) + + await loaders.load_from_stdio_async(["echo"], env={"TEST_VAR": "test_value"}) + + # Ensure env merged and passed into server params + _, call_kwargs = mock_stdio_params_cls.call_args + assert "env" in call_kwargs + assert call_kwargs["env"]["TEST_VAR"] == "test_value" + + +class TestLoadFromHttp: + """Tests for load_from_http function.""" + + def setup_method(self): + """Clear cache before each test.""" + loaders.clear_tools_cache() + + def teardown_method(self): + """Clear cache after each test.""" + loaders.clear_tools_cache() + + @pytest.mark.asyncio + async def test_url_gets_mcp_appended(self): + """URL without /mcp should get it appended before calling sse_client.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_sse_client = MagicMock() + mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=("read", "write")) + mock_sse_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + mock_sse_client, + MagicMock(), # streamablehttp_client + ) + + await loaders.load_mcp_remote_async("http://localhost:8000", use_sse=True) + called_url = mock_sse_client.call_args[0][0] + assert called_url.endswith("/mcp") + + @pytest.mark.asyncio + async def test_url_with_mcp_not_duplicated(self): + """URL with /mcp should not get duplicated.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_sse_client = MagicMock() + mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=("read", "write")) + mock_sse_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + mock_sse_client, + MagicMock(), # streamablehttp_client + ) + + await loaders.load_mcp_remote_async("http://localhost:8000/mcp", use_sse=True) + called_url = mock_sse_client.call_args[0][0] + assert "/mcp/mcp" not in called_url + + @pytest.mark.asyncio + async def test_headers_are_passed(self): + """Custom headers should be passed to sse_client.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_sse_client = MagicMock() + mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=("read", "write")) + mock_sse_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + mock_sse_client, + MagicMock(), # streamablehttp_client + ) + + await loaders.load_mcp_remote_async( + "http://localhost:8000", + headers={"Authorization": "Bearer token123"}, + use_sse=True, + ) + _, call_kwargs = mock_sse_client.call_args + assert call_kwargs["headers"]["Authorization"] == "Bearer token123" + + @pytest.mark.asyncio + async def test_returns_tools_from_response(self): + """Should convert SDK Tool objects into dicts.""" + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "Test tool 1" + mock_tool1.inputSchema = {"type": "object", "properties": {}} + + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = None + mock_tool2.inputSchema = {"type": "object", "properties": {}} + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[mock_tool1, mock_tool2])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_sse_client = MagicMock() + mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=("read", "write")) + mock_sse_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + mock_sse_client, + MagicMock(), # streamablehttp_client + ) + + result = await loaders.load_mcp_remote_async("http://localhost:8000", use_sse=True) + assert result == [ + { + "name": "tool1", + "description": "Test tool 1", + "inputSchema": {"type": "object", "properties": {}}, + }, + { + "name": "tool2", + "description": "", + "inputSchema": {"type": "object", "properties": {}}, + }, + ] + + +class TestLoadArcadeMcpGateway: + """Tests for load_arcade_mcp_gateway function.""" + + def setup_method(self): + """Clear cache before each test.""" + loaders.clear_tools_cache() + + def teardown_method(self): + """Clear cache after each test.""" + loaders.clear_tools_cache() + + @pytest.mark.asyncio + async def test_builds_correct_url_and_headers_with_slug(self): + """Should build correct Arcade MCP URL and pass auth headers.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + # Arcade gateway uses streamable-http (returns 3 values) + mock_streamable_client = MagicMock() + mock_streamable_client.return_value.__aenter__ = AsyncMock( + return_value=("read", "write", "session_id") + ) + mock_streamable_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + MagicMock(), # sse_client + mock_streamable_client, + ) + + await loaders.load_arcade_mcp_gateway_async( + "my-gateway", + arcade_api_key="key", + arcade_user_id="user", + ) + + called_url = mock_streamable_client.call_args[0][0] + called_headers = mock_streamable_client.call_args[1]["headers"] + assert called_url == "https://api.arcade.dev/mcp/my-gateway" + # Code adds "Bearer " prefix to key + assert called_headers["Authorization"] == "Bearer key" + assert called_headers["Arcade-User-Id"] == "user" + + @pytest.mark.asyncio + async def test_builds_correct_url_without_slug(self): + """Should build correct Arcade MCP URL without gateway slug.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_streamable_client = MagicMock() + mock_streamable_client.return_value.__aenter__ = AsyncMock( + return_value=("read", "write", "session_id") + ) + mock_streamable_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + MagicMock(), # sse_client + mock_streamable_client, + ) + + await loaders.load_arcade_mcp_gateway_async(arcade_api_key="key") + + called_url = mock_streamable_client.call_args[0][0] + assert called_url == "https://api.arcade.dev/mcp" + + @pytest.mark.asyncio + async def test_custom_base_url(self): + """Should use custom base URL when provided.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_streamable_client = MagicMock() + mock_streamable_client.return_value.__aenter__ = AsyncMock( + return_value=("read", "write", "session_id") + ) + mock_streamable_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + MagicMock(), # sse_client + mock_streamable_client, + ) + + await loaders.load_arcade_mcp_gateway_async( + "my-gateway", + base_url="https://staging.arcade.dev", + ) + + called_url = mock_streamable_client.call_args[0][0] + assert called_url == "https://staging.arcade.dev/mcp/my-gateway" + + +class TestLoadStdioArcade: + """Tests for load_stdio_arcade function.""" + + @pytest.mark.asyncio + async def test_passes_env_vars_to_stdio(self): + """Should pass Arcade env vars to stdio loader.""" + with patch.object(loaders, "load_from_stdio_async", new_callable=AsyncMock) as mock_stdio: + mock_stdio.return_value = [] + + await loaders.load_stdio_arcade_async( + ["python", "server.py"], + arcade_api_key="test_key", + arcade_user_id="test_user", + ) + + call_kwargs = mock_stdio.call_args[1] + assert call_kwargs["env"]["ARCADE_API_KEY"] == "test_key" + assert call_kwargs["env"]["ARCADE_USER_ID"] == "test_user" + + @pytest.mark.asyncio + async def test_includes_tool_secrets(self): + """Should include tool secrets in environment.""" + with patch.object(loaders, "load_from_stdio_async", new_callable=AsyncMock) as mock_stdio: + mock_stdio.return_value = [] + + await loaders.load_stdio_arcade_async( + ["python", "server.py"], + tool_secrets={"GITHUB_TOKEN": "gh_token", "SLACK_TOKEN": "slack_token"}, + ) + + call_kwargs = mock_stdio.call_args[1] + assert call_kwargs["env"]["GITHUB_TOKEN"] == "gh_token" + assert call_kwargs["env"]["SLACK_TOKEN"] == "slack_token" + + +class TestLazyImport: + """Tests for lazy MCP import behavior.""" + + def test_require_mcp_error_message(self): + """Should raise helpful ImportError when MCP SDK is not installed.""" + # If MCP is installed in the environment, this test isn't meaningful. + # Force an import failure by masking the module. + with patch.dict(sys.modules, {"mcp": None}): + with pytest.raises(ImportError) as exc: + loaders._require_mcp() + assert "pip install" in str(exc.value) + + @pytest.mark.asyncio + async def test_http_loader_raises_import_error_without_mcp(self): + """Test that HTTP loader raises ImportError when MCP SDK missing.""" + with patch.dict(sys.modules, {"mcp": None}): + with pytest.raises(ImportError, match="pip install"): + await loaders.load_mcp_remote_async("http://localhost:8000") + + @pytest.mark.asyncio + async def test_stdio_loader_raises_import_error_without_mcp(self): + """Test that stdio loader raises ImportError when MCP SDK missing.""" + with patch.dict(sys.modules, {"mcp": None}): + with pytest.raises(ImportError, match="pip install"): + await loaders.load_from_stdio_async(["python", "server.py"]) + + @pytest.mark.asyncio + async def test_arcade_gateway_loader_raises_import_error_without_mcp(self): + """Test that Arcade gateway loader raises ImportError when MCP SDK missing.""" + with patch.dict(sys.modules, {"mcp": None}): + with pytest.raises(ImportError, match="pip install"): + await loaders.load_arcade_mcp_gateway_async("my-gateway") + + +class TestEnsureMcpPath: + """Tests for _ensure_mcp_path utility function.""" + + def test_appends_mcp_to_bare_url(self): + """Should append /mcp to URL without path.""" + result = loaders._ensure_mcp_path("http://localhost:8000") + assert result == "http://localhost:8000/mcp" + + def test_appends_mcp_to_url_with_path(self): + """Should append /mcp to URL with existing path.""" + result = loaders._ensure_mcp_path("http://localhost:8000/api") + assert result == "http://localhost:8000/api/mcp" + + def test_does_not_duplicate_mcp(self): + """Should not duplicate /mcp if already present.""" + result = loaders._ensure_mcp_path("http://localhost:8000/mcp") + assert result == "http://localhost:8000/mcp" + + def test_handles_mcp_in_path(self): + """Should not add /mcp if 'mcp' is anywhere in path segments.""" + result = loaders._ensure_mcp_path("http://localhost:8000/mcp/my-slug") + assert result == "http://localhost:8000/mcp/my-slug" + + def test_preserves_query_string(self): + """Should preserve query string in URL.""" + result = loaders._ensure_mcp_path("http://localhost:8000?foo=bar") + assert result == "http://localhost:8000/mcp?foo=bar" + + def test_preserves_fragment(self): + """Should preserve fragment in URL.""" + result = loaders._ensure_mcp_path("http://localhost:8000#section") + assert result == "http://localhost:8000/mcp#section" + + +class TestBuildArcadeMcpUrl: + """Tests for _build_arcade_mcp_url utility function.""" + + def test_builds_url_with_slug(self): + """Should build correct URL with gateway slug.""" + result = loaders._build_arcade_mcp_url("my-gateway", "https://api.arcade.dev") + assert result == "https://api.arcade.dev/mcp/my-gateway" + + def test_builds_url_without_slug(self): + """Should build correct URL without gateway slug.""" + result = loaders._build_arcade_mcp_url(None, "https://api.arcade.dev") + assert result == "https://api.arcade.dev/mcp" + + def test_strips_trailing_slash(self): + """Should strip trailing slash from base URL.""" + result = loaders._build_arcade_mcp_url("my-gateway", "https://api.arcade.dev/") + assert result == "https://api.arcade.dev/mcp/my-gateway" + + +class TestToolToDict: + """Tests for _tool_to_dict utility function.""" + + def test_converts_tool_to_dict(self): + """Should convert MCP Tool object to dictionary.""" + mock_tool = MagicMock() + mock_tool.name = "my_tool" + mock_tool.description = "A description" + mock_tool.inputSchema = {"type": "object", "properties": {"x": {"type": "string"}}} + + result = loaders._tool_to_dict(mock_tool) + assert result == { + "name": "my_tool", + "description": "A description", + "inputSchema": {"type": "object", "properties": {"x": {"type": "string"}}}, + } + + def test_handles_none_description(self): + """Should handle None description.""" + mock_tool = MagicMock() + mock_tool.name = "my_tool" + mock_tool.description = None + mock_tool.inputSchema = {} + + result = loaders._tool_to_dict(mock_tool) + assert result["description"] == "" + + +class TestToolsCache: + """Tests for tools caching functionality.""" + + def setup_method(self): + """Clear cache before each test.""" + loaders.clear_tools_cache() + + def teardown_method(self): + """Clear cache after each test.""" + loaders.clear_tools_cache() + + def test_clear_tools_cache(self): + """Should clear the tools cache and locks.""" + # Add something to cache directly + loaders._tools_cache["test_key"] = [{"name": "tool1"}] + loaders._cache_locks["test_key"] = MagicMock() + assert len(loaders._tools_cache) == 1 + assert len(loaders._cache_locks) == 1 + + loaders.clear_tools_cache() + assert len(loaders._tools_cache) == 0 + assert len(loaders._cache_locks) == 0 + + def test_make_cache_key_different_urls(self): + """Should create different keys for different URLs.""" + key1 = loaders._make_cache_key("http://localhost:8000", None) + key2 = loaders._make_cache_key("http://localhost:9000", None) + assert key1 != key2 + + def test_make_cache_key_different_headers(self): + """Should create different keys for different headers.""" + key1 = loaders._make_cache_key("http://localhost:8000", {"Auth": "token1"}) + key2 = loaders._make_cache_key("http://localhost:8000", {"Auth": "token2"}) + assert key1 != key2 + + def test_make_cache_key_same_inputs(self): + """Should create same key for same inputs.""" + key1 = loaders._make_cache_key("http://localhost:8000", {"Auth": "token"}) + key2 = loaders._make_cache_key("http://localhost:8000", {"Auth": "token"}) + assert key1 == key2 + + @pytest.mark.asyncio + async def test_get_cache_lock_creates_lock(self): + """Should create a lock for a new key.""" + lock = await loaders._get_cache_lock("new_key") + assert isinstance(lock, type(loaders.asyncio.Lock())) + assert "new_key" in loaders._cache_locks + + @pytest.mark.asyncio + async def test_get_cache_lock_returns_same_lock(self): + """Should return same lock for same key.""" + lock1 = await loaders._get_cache_lock("same_key") + lock2 = await loaders._get_cache_lock("same_key") + assert lock1 is lock2 + + @pytest.mark.asyncio + async def test_acquire_lock_with_timeout_succeeds(self): + """Should acquire lock successfully when available.""" + lock = loaders.asyncio.Lock() + acquired = await loaders._acquire_lock_with_timeout(lock, timeout=1.0) + assert acquired is True + assert lock.locked() + lock.release() + + @pytest.mark.asyncio + async def test_acquire_lock_with_timeout_fails(self): + """Should return False when lock acquisition times out.""" + lock = loaders.asyncio.Lock() + await lock.acquire() # Hold the lock + + # Try to acquire with short timeout - should fail + acquired = await loaders._acquire_lock_with_timeout(lock, timeout=0.1) + assert acquired is False + + lock.release() + + @pytest.mark.asyncio + async def test_http_loader_caches_results(self): + """Should cache results and return cached on second call.""" + mock_tool = MagicMock() + mock_tool.name = "tool1" + mock_tool.description = "Test" + mock_tool.inputSchema = {} + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[mock_tool])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_sse_client = MagicMock() + mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=("read", "write")) + mock_sse_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + mock_sse_client, + MagicMock(), # streamablehttp_client + ) + + # First call - should connect + result1 = await loaders.load_mcp_remote_async("http://localhost:8000", use_sse=True) + assert len(result1) == 1 + assert mock_sse_client.call_count == 1 + + # Second call - should use cache + result2 = await loaders.load_mcp_remote_async("http://localhost:8000", use_sse=True) + assert len(result2) == 1 + # sse_client should NOT be called again + assert mock_sse_client.call_count == 1 + + @pytest.mark.asyncio + async def test_http_loader_different_urls_not_cached(self): + """Should not use cache for different URLs.""" + mock_tool = MagicMock() + mock_tool.name = "tool1" + mock_tool.description = "Test" + mock_tool.inputSchema = {} + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[mock_tool])) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_sse_client = MagicMock() + mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=("read", "write")) + mock_sse_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + mock_sse_client, + MagicMock(), # streamablehttp_client + ) + + # First URL + await loaders.load_mcp_remote_async("http://localhost:8000", use_sse=True) + assert mock_sse_client.call_count == 1 + + # Different URL - should connect again + await loaders.load_mcp_remote_async("http://localhost:9000", use_sse=True) + assert mock_sse_client.call_count == 2 + + @pytest.mark.asyncio + async def test_http_loader_lock_timeout_raises_error(self): + """Should raise TimeoutError when lock acquisition times out.""" + # Create a lock and hold it + loaders._cache_locks["test_key"] = loaders.asyncio.Lock() + lock = loaders._cache_locks["test_key"] + await lock.acquire() + + try: + # Try to load with a key that will wait for the held lock + with ( + patch.object(loaders, "_make_cache_key", return_value="test_key"), + patch.object(loaders, "LOCK_TIMEOUT_SECONDS", 0.1), + ): + with pytest.raises(TimeoutError, match="Timeout waiting for lock"): + await loaders.load_mcp_remote_async("http://localhost:8000") + finally: + lock.release() + loaders.clear_tools_cache() + + @pytest.mark.asyncio + async def test_stdio_loader_lock_timeout_raises_error(self): + """Should raise TimeoutError when stdio lock acquisition times out.""" + # Create a specific cache key and hold its lock + cache_key = "stdio|python server.py|[]" + loaders._cache_locks[cache_key] = loaders.asyncio.Lock() + lock = loaders._cache_locks[cache_key] + await lock.acquire() + + try: + with patch.object(loaders, "LOCK_TIMEOUT_SECONDS", 0.1): + with pytest.raises(TimeoutError, match="Timeout waiting for lock on stdio"): + await loaders.load_from_stdio_async(["python", "server.py"]) + finally: + lock.release() + loaders.clear_tools_cache() + + @pytest.mark.asyncio + async def test_lock_released_after_connection_error(self): + """Should release lock even when MCP connection fails.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock(side_effect=ConnectionError("Connection failed")) + + mock_client_session_cls = MagicMock() + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_sse_client = MagicMock() + mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=("read", "write")) + mock_sse_client.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object(loaders, "_require_mcp") as mock_require: + mock_require.return_value = ( + mock_client_session_cls, + MagicMock(), + MagicMock(), + mock_sse_client, + MagicMock(), + ) + + # First call should fail + with pytest.raises(ConnectionError): + await loaders.load_mcp_remote_async("http://localhost:8000", use_sse=True) + + # Lock should be released - second call should not timeout + cache_key = loaders._make_cache_key("http://localhost:8000/mcp", None) + lock = loaders._cache_locks.get(cache_key) + if lock: + assert not lock.locked(), "Lock should be released after error" + + +class TestMCPLoggingFilter: + """Tests for MCP SDK logging filter.""" + + def test_filter_suppresses_session_termination_202(self): + """Should suppress 'Session termination failed: 202' messages.""" + import logging + + log_filter = loaders.MCPSessionFilter() + + # Create a mock log record with the misleading message + record = logging.LogRecord( + name="mcp.client.session", + level=logging.WARNING, + pathname="", + lineno=0, + msg="Session termination failed: 202", + args=(), + exc_info=None, + ) + + # Should be filtered out (return False) + assert log_filter.filter(record) is False + + def test_filter_allows_other_messages(self): + """Should allow other log messages through.""" + import logging + + log_filter = loaders.MCPSessionFilter() + + # Create a log record with a normal message + record = logging.LogRecord( + name="mcp.client", + level=logging.INFO, + pathname="", + lineno=0, + msg="Connected to MCP server", + args=(), + exc_info=None, + ) + + # Should pass through (return True) + assert log_filter.filter(record) is True + + def test_filter_allows_real_errors(self): + """Should allow real error messages through.""" + import logging + + log_filter = loaders.MCPSessionFilter() + + # Create a log record with an actual error + record = logging.LogRecord( + name="mcp.client", + level=logging.ERROR, + pathname="", + lineno=0, + msg="Connection failed: Timeout", + args=(), + exc_info=None, + ) + + # Should pass through (return True) + assert log_filter.filter(record) is True diff --git a/libs/tests/arcade_evals/test_schema_converters.py b/libs/tests/arcade_evals/test_schema_converters.py new file mode 100644 index 000000000..b6a578643 --- /dev/null +++ b/libs/tests/arcade_evals/test_schema_converters.py @@ -0,0 +1,1028 @@ +"""Tests for MCP tool schema converters (OpenAI and Anthropic formats).""" + +import pytest +from arcade_evals._evalsuite._anthropic_schema import ( + convert_mcp_to_anthropic_tool, + convert_mcp_tools_to_anthropic, +) +from arcade_evals._evalsuite._openai_schema import ( + SchemaConversionError, + convert_to_strict_mode_schema, +) +from arcade_evals._evalsuite._tool_registry import EvalSuiteToolRegistry + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + + +class TestOpenAISchemaConversion: + """Tests for OpenAI strict mode schema conversion.""" + + def test_basic_schema_conversion(self): + """Test basic schema conversion to OpenAI strict mode.""" + input_schema = { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + } + + result = convert_to_strict_mode_schema(input_schema) + + assert result["type"] == "object" + assert result["additionalProperties"] is False + assert "query" in result["properties"] + assert result["required"] == ["query"] + + def test_optional_params_get_null_union(self): + """Test that optional parameters get null union type.""" + input_schema = { + "type": "object", + "properties": { + "required_param": {"type": "string"}, + "optional_param": {"type": "integer"}, + }, + "required": ["required_param"], + } + + result = convert_to_strict_mode_schema(input_schema) + + # Required param should have single type + assert result["properties"]["required_param"]["type"] == "string" + + # Optional param should have null union + assert result["properties"]["optional_param"]["type"] == ["integer", "null"] + + # Both should be in required (OpenAI strict mode requirement) + assert set(result["required"]) == {"required_param", "optional_param"} + + def test_unsupported_keywords_stripped(self): + """Test that unsupported JSON Schema keywords are stripped.""" + input_schema = { + "type": "object", + "properties": { + "count": { + "type": "integer", + "minimum": 0, + "maximum": 100, + "default": 10, + }, + "name": { + "type": "string", + "minLength": 1, + "maxLength": 50, + "pattern": "^[a-z]+$", + "format": "hostname", + }, + }, + "required": ["count", "name"], + } + + result = convert_to_strict_mode_schema(input_schema) + + # These keywords should be stripped + count_prop = result["properties"]["count"] + assert "minimum" not in count_prop + assert "maximum" not in count_prop + assert "default" not in count_prop + + name_prop = result["properties"]["name"] + assert "minLength" not in name_prop + assert "maxLength" not in name_prop + assert "pattern" not in name_prop + assert "format" not in name_prop + + def test_enum_values_converted_to_strings(self): + """Test that enum values are converted to strings.""" + input_schema = { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": [1, 2, "three"], + }, + }, + "required": ["status"], + } + + result = convert_to_strict_mode_schema(input_schema) + + assert result["properties"]["status"]["enum"] == ["1", "2", "three"] + + def test_integer_enum_type_changed_to_string(self): + """Test that integer enums have their type changed to string. + + OpenAI strict mode validates enum values against the declared type. + When enum values are converted to strings, the type must also change. + + Example error without fix: + "enum value 0 does not validate against {'type': ['integer', 'null']}" + """ + input_schema = { + "type": "object", + "properties": { + "priority": { + "type": "integer", + "enum": [0, 1, 2, 3, 4], + "description": "Priority: 0=none, 1=urgent, 2=high, 3=medium, 4=low", + }, + }, + "required": ["priority"], + } + + result = convert_to_strict_mode_schema(input_schema) + + # Enum values should be strings + assert result["properties"]["priority"]["enum"] == ["0", "1", "2", "3", "4"] + # Type should be changed to string to match + assert result["properties"]["priority"]["type"] == "string" + + def test_optional_integer_enum_type_changed_to_string_null_union(self): + """Test that optional integer enums get type ["string", "null"]. + + When an integer enum is optional: + 1. Enum values are converted to strings + 2. Type changes from "integer" to ["string", "null"] + + This fixes: "enum value 0 does not validate against {'type': ['integer', 'null']}" + """ + input_schema = { + "type": "object", + "properties": { + "priority": { + "type": "integer", + "enum": [0, 1, 2, 3, 4], + }, + }, + "required": [], # priority is optional + } + + result = convert_to_strict_mode_schema(input_schema) + + # Enum values should be strings + assert result["properties"]["priority"]["enum"] == ["0", "1", "2", "3", "4"] + # Type should be ["string", "null"] for optional param + assert result["properties"]["priority"]["type"] == ["string", "null"] + + def test_string_enum_type_unchanged(self): + """Test that string enums keep their type as string.""" + input_schema = { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"], + }, + }, + "required": ["status"], + } + + result = convert_to_strict_mode_schema(input_schema) + + # Enum values unchanged + assert result["properties"]["status"]["enum"] == ["active", "inactive", "pending"] + # Type remains string + assert result["properties"]["status"]["type"] == "string" + + def test_boolean_enum_type_changed_to_string(self): + """Test that boolean enums have their type changed to string.""" + input_schema = { + "type": "object", + "properties": { + "flag": { + "type": "boolean", + "enum": [True, False], + }, + }, + "required": ["flag"], + } + + result = convert_to_strict_mode_schema(input_schema) + + # Boolean values converted to strings + assert result["properties"]["flag"]["enum"] == ["True", "False"] + # Type changed to string + assert result["properties"]["flag"]["type"] == "string" + + def test_enum_with_list_type_no_null(self): + """Test that enums with list type but no null are converted to single string type.""" + input_schema = { + "type": "object", + "properties": { + "priority": { + "type": ["integer"], + "enum": [1, 2, 3], + }, + }, + "required": ["priority"], + } + + result = convert_to_strict_mode_schema(input_schema) + + assert result["properties"]["priority"]["enum"] == ["1", "2", "3"] + # Should be "string", not ["string"] + assert result["properties"]["priority"]["type"] == "string" + + def test_nested_object_enum_type_conversion(self): + """Test that enum type conversion works in nested objects.""" + input_schema = { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "level": { + "type": "integer", + "enum": [1, 2, 3], + }, + }, + "required": ["level"], + }, + }, + "required": ["config"], + } + + result = convert_to_strict_mode_schema(input_schema) + + nested = result["properties"]["config"]["properties"]["level"] + assert nested["enum"] == ["1", "2", "3"] + assert nested["type"] == "string" + + def test_nested_object_gets_strict_mode(self): + """Test that nested objects also get strict mode treatment.""" + input_schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + }, + }, + "required": ["user"], + } + + result = convert_to_strict_mode_schema(input_schema) + + nested = result["properties"]["user"] + assert nested["additionalProperties"] is False + # Both should be in required for nested object too + assert set(nested["required"]) == {"name", "age"} + # age is optional so should have null union + assert nested["properties"]["age"]["type"] == ["integer", "null"] + + def test_array_items_processed(self): + """Test that array items schema is processed.""" + input_schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "integer", "minimum": 0}, + }, + "required": ["id"], + }, + }, + }, + "required": ["items"], + } + + result = convert_to_strict_mode_schema(input_schema) + + array_items = result["properties"]["items"]["items"] + assert array_items["additionalProperties"] is False + # minimum should be stripped from nested object property + assert "minimum" not in array_items["properties"]["id"] + + def test_empty_schema(self): + """Test conversion of empty schema.""" + input_schema = {"type": "object", "properties": {}} + + result = convert_to_strict_mode_schema(input_schema) + + assert result["type"] == "object" + assert result["properties"] == {} + assert result["additionalProperties"] is False + assert result["required"] == [] + + def test_max_depth_protection(self): + """Test that deeply nested schemas raise an error.""" + # Create a deeply nested schema that exceeds max depth + schema: dict = {"type": "object", "properties": {}} + current = schema + for i in range(60): # Exceeds _MAX_SCHEMA_DEPTH of 50 + current["properties"] = {"nested": {"type": "object", "properties": {}}} + current["required"] = ["nested"] + current = current["properties"]["nested"] + + with pytest.raises(SchemaConversionError, match="maximum depth"): + convert_to_strict_mode_schema(schema) + + +class TestAnthropicSchemaConversion: + """Tests for Anthropic schema conversion.""" + + def test_basic_conversion(self): + """Test basic MCP to Anthropic tool conversion.""" + mcp_tool = { + "name": "search_files", + "description": "Search for files", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + }, + } + + result = convert_mcp_to_anthropic_tool(mcp_tool) + + assert result["name"] == "search_files" + assert result["description"] == "Search for files" + assert "input_schema" in result + assert "inputSchema" not in result + # Schema should be unchanged + assert result["input_schema"]["properties"]["query"]["type"] == "string" + + def test_schema_preserved_as_is(self): + """Test that the schema is preserved without modifications.""" + mcp_tool = { + "name": "test", + "description": "Test", + "inputSchema": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "minimum": 0, + "maximum": 100, + "default": 10, + }, + }, + "required": ["count"], + }, + } + + result = convert_mcp_to_anthropic_tool(mcp_tool) + + # Unlike OpenAI, these keywords should be preserved + schema = result["input_schema"] + assert schema["properties"]["count"]["minimum"] == 0 + assert schema["properties"]["count"]["maximum"] == 100 + assert schema["properties"]["count"]["default"] == 10 + + def test_tool_name_dots_normalized_to_underscores(self): + """Test that dots in tool names are converted to underscores. + + Anthropic tool names must match pattern: ^[a-zA-Z0-9_-]{1,64}$ + Dots are not allowed, so they must be converted. + """ + mcp_tool = { + "name": "Google.Search", + "description": "Search Google", + "inputSchema": {"type": "object", "properties": {}}, + } + + result = convert_mcp_to_anthropic_tool(mcp_tool) + + assert result["name"] == "Google_Search" + + def test_tool_name_hyphens_preserved(self): + """Test that hyphens in tool names are preserved (they're valid).""" + mcp_tool = { + "name": "search-files", + "description": "Search files", + "inputSchema": {"type": "object", "properties": {}}, + } + + result = convert_mcp_to_anthropic_tool(mcp_tool) + + assert result["name"] == "search-files" + + def test_tool_name_multiple_dots(self): + """Test that multiple dots are all converted to underscores.""" + mcp_tool = { + "name": "Google.Gmail.Send.Email", + "description": "Send email", + "inputSchema": {"type": "object", "properties": {}}, + } + + result = convert_mcp_to_anthropic_tool(mcp_tool) + + assert result["name"] == "Google_Gmail_Send_Email" + + def test_missing_description_defaults_to_empty(self): + """Test that missing description defaults to empty string.""" + mcp_tool = { + "name": "test", + "inputSchema": {"type": "object", "properties": {}}, + } + + result = convert_mcp_to_anthropic_tool(mcp_tool) + + assert result["description"] == "" + + def test_missing_schema_defaults_to_empty_object(self): + """Test that missing inputSchema defaults to empty object schema.""" + mcp_tool = {"name": "test"} + + result = convert_mcp_to_anthropic_tool(mcp_tool) + + assert result["input_schema"] == {"type": "object", "properties": {}} + + def test_convert_multiple_tools(self): + """Test converting a list of MCP tools.""" + mcp_tools = [ + {"name": "tool1", "description": "First tool"}, + {"name": "tool2", "description": "Second tool"}, + ] + + result = convert_mcp_tools_to_anthropic(mcp_tools) + + assert len(result) == 2 + assert result[0]["name"] == "tool1" + assert result[1]["name"] == "tool2" + + +class TestToolRegistryOpenAIFormat: + """Tests for EvalSuiteToolRegistry OpenAI format output.""" + + def test_list_tools_openai_format(self): + """Test listing tools in OpenAI format.""" + registry = EvalSuiteToolRegistry(strict_mode=True) + registry.add_tool({ + "name": "search", + "description": "Search function", + "inputSchema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }) + + tools = registry.list_tools_for_model("openai") + + assert len(tools) == 1 + tool = tools[0] + assert tool["type"] == "function" + assert tool["function"]["name"] == "search" + assert tool["function"]["strict"] is True + assert tool["function"]["parameters"]["additionalProperties"] is False + + def test_list_tools_openai_without_strict_mode(self): + """Test OpenAI format without strict mode.""" + registry = EvalSuiteToolRegistry(strict_mode=False) + registry.add_tool({ + "name": "search", + "description": "Search", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "minLength": 1}, + }, + "required": ["query"], + }, + }) + + tools = registry.list_tools_for_model("openai") + + tool = tools[0] + # No strict flag when strict_mode is False + assert "strict" not in tool["function"] + # Schema keywords should be preserved when strict_mode is False + assert tool["function"]["parameters"]["properties"]["query"]["minLength"] == 1 + + +class TestToolRegistryAnthropicFormat: + """Tests for EvalSuiteToolRegistry Anthropic format output.""" + + def test_list_tools_anthropic_format(self): + """Test listing tools in Anthropic format.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({ + "name": "search", + "description": "Search function", + "inputSchema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }) + + tools = registry.list_tools_for_model("anthropic") + + assert len(tools) == 1 + tool = tools[0] + # Anthropic format - flat structure + assert "type" not in tool + assert "function" not in tool + assert tool["name"] == "search" + assert tool["description"] == "Search function" + assert "input_schema" in tool + + def test_anthropic_format_preserves_schema(self): + """Test that Anthropic format preserves JSON Schema keywords.""" + registry = EvalSuiteToolRegistry(strict_mode=True) # strict_mode shouldn't affect Anthropic + registry.add_tool({ + "name": "test", + "description": "Test", + "inputSchema": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "minimum": 0, + "maximum": 100, + }, + }, + "required": ["count"], + }, + }) + + tools = registry.list_tools_for_model("anthropic") + + # Schema should be preserved as-is for Anthropic + schema = tools[0]["input_schema"] + assert schema["properties"]["count"]["minimum"] == 0 + assert schema["properties"]["count"]["maximum"] == 100 + + def test_anthropic_format_no_null_union(self): + """Test that Anthropic format doesn't add null union types.""" + registry = EvalSuiteToolRegistry(strict_mode=True) + registry.add_tool({ + "name": "test", + "description": "Test", + "inputSchema": { + "type": "object", + "properties": { + "required_param": {"type": "string"}, + "optional_param": {"type": "integer"}, + }, + "required": ["required_param"], # optional_param is optional + }, + }) + + tools = registry.list_tools_for_model("anthropic") + + # Optional param should NOT have null union for Anthropic + optional_type = tools[0]["input_schema"]["properties"]["optional_param"]["type"] + assert optional_type == "integer" + assert not isinstance(optional_type, list) + + def test_anthropic_format_normalizes_tool_names(self): + """Test that Anthropic format normalizes tool names (dots to underscores).""" + registry = EvalSuiteToolRegistry() + registry.add_tool({ + "name": "Google.Gmail.Send", + "description": "Send email via Gmail", + "inputSchema": {"type": "object", "properties": {}}, + }) + + tools = registry.list_tools_for_model("anthropic") + + # Dots should be converted to underscores + assert tools[0]["name"] == "Google_Gmail_Send" + + +class TestToolRegistryOpenAINameNormalization: + """Tests for OpenAI format tool name normalization.""" + + def test_openai_format_normalizes_tool_names(self): + """Test that OpenAI format normalizes tool names (dots to underscores). + + OpenAI function names don't allow dots, so they must be converted. + """ + registry = EvalSuiteToolRegistry(strict_mode=True) + registry.add_tool({ + "name": "Google.Search", + "description": "Search Google", + "inputSchema": {"type": "object", "properties": {}}, + }) + + tools = registry.list_tools_for_model("openai") + + # Dots should be converted to underscores + assert tools[0]["function"]["name"] == "Google_Search" + + def test_openai_format_normalizes_multiple_dots(self): + """Test that multiple dots are all converted to underscores for OpenAI.""" + registry = EvalSuiteToolRegistry(strict_mode=True) + registry.add_tool({ + "name": "Google.Gmail.Send.Email", + "description": "Send email", + "inputSchema": {"type": "object", "properties": {}}, + }) + + tools = registry.list_tools_for_model("openai") + + assert tools[0]["function"]["name"] == "Google_Gmail_Send_Email" + + def test_openai_format_preserves_underscores(self): + """Test that underscores in tool names are preserved for OpenAI.""" + registry = EvalSuiteToolRegistry(strict_mode=True) + registry.add_tool({ + "name": "search_files", + "description": "Search files", + "inputSchema": {"type": "object", "properties": {}}, + }) + + tools = registry.list_tools_for_model("openai") + + assert tools[0]["function"]["name"] == "search_files" + + +class TestToolRegistryFormatComparison: + """Tests comparing OpenAI and Anthropic format outputs.""" + + def test_same_tool_different_formats(self): + """Test that the same tool produces correct different formats.""" + registry = EvalSuiteToolRegistry(strict_mode=True) + registry.add_tool({ + "name": "search", + "description": "Search for items", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "limit": {"type": "integer", "default": 10}, + }, + "required": ["query"], + }, + }) + + openai_tools = registry.list_tools_for_model("openai") + anthropic_tools = registry.list_tools_for_model("anthropic") + + # OpenAI format + openai_tool = openai_tools[0] + assert openai_tool["type"] == "function" + assert openai_tool["function"]["strict"] is True + openai_params = openai_tool["function"]["parameters"] + assert openai_params["additionalProperties"] is False + # limit should have null union in OpenAI + assert openai_params["properties"]["limit"]["type"] == ["integer", "null"] + # default should be stripped in OpenAI + assert "default" not in openai_params["properties"]["limit"] + + # Anthropic format + anthropic_tool = anthropic_tools[0] + assert "type" not in anthropic_tool + assert "function" not in anthropic_tool + anthropic_schema = anthropic_tool["input_schema"] + # limit should have simple type in Anthropic + assert anthropic_schema["properties"]["limit"]["type"] == "integer" + # default should be preserved in Anthropic + assert anthropic_schema["properties"]["limit"]["default"] == 10 + + def test_invalid_format_raises(self): + """Test that invalid format raises ValueError.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "test"}) + + with pytest.raises(ValueError, match="not supported"): + registry.list_tools_for_model("invalid") # type: ignore + + +class TestToolRegistryMultipleTools: + """Tests for registry with multiple tools.""" + + def test_multiple_tools_both_formats(self): + """Test multiple tools converted to both formats.""" + registry = EvalSuiteToolRegistry() + registry.add_tools([ + {"name": "tool1", "description": "First"}, + {"name": "tool2", "description": "Second"}, + {"name": "tool3", "description": "Third"}, + ]) + + openai_tools = registry.list_tools_for_model("openai") + anthropic_tools = registry.list_tools_for_model("anthropic") + + assert len(openai_tools) == 3 + assert len(anthropic_tools) == 3 + + # Verify names are preserved + openai_names = {t["function"]["name"] for t in openai_tools} + anthropic_names = {t["name"] for t in anthropic_tools} + assert openai_names == {"tool1", "tool2", "tool3"} + assert anthropic_names == {"tool1", "tool2", "tool3"} + + +class TestToolNameResolution: + """Tests for tool name resolution (handling Anthropic normalized names).""" + + def test_resolve_original_name(self): + """Test that original names resolve correctly.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Search"}) + + assert registry.resolve_tool_name("Google.Search") == "Google.Search" + + def test_resolve_normalized_name(self): + """Test that normalized names (underscores) resolve to original.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Search"}) + + # Anthropic returns "Google_Search" but tool is stored as "Google.Search" + assert registry.resolve_tool_name("Google_Search") == "Google.Search" + + def test_resolve_unknown_name_returns_none(self): + """Test that unknown names return None.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Search"}) + + assert registry.resolve_tool_name("Unknown.Tool") is None + assert registry.resolve_tool_name("Unknown_Tool") is None + + def test_has_tool_with_normalized_name(self): + """Test has_tool works with normalized names.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Slack.Post"}) + + assert registry.has_tool("Slack.Post") is True + assert registry.has_tool("Slack_Post") is True # Normalized + assert registry.has_tool("Unknown") is False + + def test_get_tool_schema_with_normalized_name(self): + """Test get_tool_schema works with normalized names.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({ + "name": "Email.Send", + "description": "Send email", + "inputSchema": {"type": "object", "properties": {"to": {"type": "string"}}}, + }) + + # Original name + schema = registry.get_tool_schema("Email.Send") + assert schema is not None + assert schema["name"] == "Email.Send" + + # Normalized name + schema = registry.get_tool_schema("Email_Send") + assert schema is not None + assert schema["name"] == "Email.Send" + + def test_normalize_args_with_normalized_tool_name(self): + """Test normalize_args works when called with normalized name.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({ + "name": "Calendar.Create", + "inputSchema": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "duration": {"type": "integer", "default": 30}, + }, + }, + }) + + # Call normalize_args with the Anthropic-returned name + result = registry.normalize_args("Calendar_Create", {"title": "Meeting"}) + + # Should apply defaults even though lookup was by normalized name + assert result["title"] == "Meeting" + assert result["duration"] == 30 + + def test_normalize_args_replaces_null_with_default(self): + """Test normalize_args replaces null (None) values with defaults. + + OpenAI strict mode sends null for optional parameters that weren't provided. + This test verifies that null values are replaced with schema defaults. + """ + registry = EvalSuiteToolRegistry() + registry.add_tool({ + "name": "Search", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer", "default": 10}, + "offset": {"type": "integer", "default": 0}, + }, + }, + }) + + # OpenAI strict mode might send null for optional params + result = registry.normalize_args("Search", {"query": "test", "limit": None, "offset": None}) + + # Null values should be replaced with defaults + assert result["query"] == "test" + assert result["limit"] == 10 + assert result["offset"] == 0 + + def test_normalize_args_preserves_explicit_values(self): + """Test normalize_args preserves explicitly set values (non-null).""" + registry = EvalSuiteToolRegistry() + registry.add_tool({ + "name": "Search", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer", "default": 10}, + }, + }, + }) + + # Explicit value should be preserved + result = registry.normalize_args("Search", {"query": "test", "limit": 50}) + + assert result["query"] == "test" + assert result["limit"] == 50 # Not replaced with default + + def test_multiple_dots_in_name(self): + """Test tools with multiple dots in name.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Gmail.Send"}) + + # Should normalize all dots + assert registry.resolve_tool_name("Google_Gmail_Send") == "Google.Gmail.Send" + assert registry.has_tool("Google_Gmail_Send") is True + + def test_no_dot_in_name_no_mapping(self): + """Test that tools without dots don't create unnecessary mappings.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "simple_tool"}) + + # Direct lookup works + assert registry.resolve_tool_name("simple_tool") == "simple_tool" + # No false positives + assert registry.resolve_tool_name("simple.tool") is None + + def test_mixed_tools_resolution(self): + """Test registry with mix of dotted and non-dotted names.""" + registry = EvalSuiteToolRegistry() + registry.add_tools([ + {"name": "Google.Search"}, + {"name": "simple_search"}, + {"name": "Slack.Channel.Create"}, + ]) + + # All originals resolve + assert registry.resolve_tool_name("Google.Search") == "Google.Search" + assert registry.resolve_tool_name("simple_search") == "simple_search" + assert registry.resolve_tool_name("Slack.Channel.Create") == "Slack.Channel.Create" + + # Normalized versions resolve to originals + assert registry.resolve_tool_name("Google_Search") == "Google.Search" + assert registry.resolve_tool_name("Slack_Channel_Create") == "Slack.Channel.Create" + + +class TestProcessToolCall: + """Tests for EvalSuiteToolRegistry.process_tool_call combined method.""" + + def test_process_tool_call_resolves_and_normalizes(self): + """Test that process_tool_call resolves name and applies defaults.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({ + "name": "Google.Search", + "description": "Search", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer", "default": 10}, + }, + }, + }) + + # Anthropic-style name with missing default arg + resolved_name, args = registry.process_tool_call("Google_Search", {"query": "test"}) + + assert resolved_name == "Google.Search" + assert args == {"query": "test", "limit": 10} + + def test_process_tool_call_unknown_tool(self): + """Test that unknown tools keep original name.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "KnownTool"}) + + resolved_name, args = registry.process_tool_call("UnknownTool", {"arg": "value"}) + + assert resolved_name == "UnknownTool" + assert args == {"arg": "value"} + + def test_process_tool_call_no_defaults_needed(self): + """Test when all args provided.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({ + "name": "Tool", + "inputSchema": { + "type": "object", + "properties": {"a": {"type": "string", "default": "x"}, "b": {"type": "string"}}, + }, + }) + + resolved_name, args = registry.process_tool_call("Tool", {"a": "provided", "b": "also"}) + + assert resolved_name == "Tool" + assert args == {"a": "provided", "b": "also"} + + +class TestToolRegistryErrors: + """Tests for EvalSuiteToolRegistry error handling.""" + + def test_duplicate_tool_registration_raises_error(self): + """Test that registering the same tool twice raises ValueError.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Search", "description": "Search"}) + + with pytest.raises(ValueError) as exc_info: + registry.add_tool({"name": "Google.Search", "description": "Search again"}) + + assert "already registered" in str(exc_info.value) + assert "Google.Search" in str(exc_info.value) + + def test_tool_without_name_raises_error(self): + """Test that registering a tool without name raises ValueError.""" + registry = EvalSuiteToolRegistry() + + with pytest.raises(ValueError) as exc_info: + registry.add_tool({"description": "No name tool"}) + + assert "name" in str(exc_info.value).lower() + + def test_empty_registry_tool_count(self): + """Test that empty registry has zero tools.""" + registry = EvalSuiteToolRegistry() + assert registry.tool_count() == 0 + assert registry.tool_names() == [] + + def test_empty_registry_list_tools(self): + """Test that empty registry returns empty list for both formats.""" + registry = EvalSuiteToolRegistry() + assert registry.list_tools_for_model("openai") == [] + assert registry.list_tools_for_model("anthropic") == [] + + def test_invalid_format_raises_error(self): + """Test that invalid tool format raises ValueError.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "test"}) + + with pytest.raises(ValueError) as exc_info: + registry.list_tools_for_model("invalid_format") # type: ignore + + assert "not supported" in str(exc_info.value) + + +class TestToolNameCollisions: + """Tests for handling tool name collisions during normalization.""" + + def test_different_original_names_same_normalized(self): + """Test that tools with different original names but same normalized name are both registered. + + This is the expected behavior: `Google.Search` and `Google_Search` are treated as + different tools because the registry uses original names as keys. + The normalized name mapping only helps with lookup (for Anthropic format). + """ + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Search", "description": "Dot version"}) + registry.add_tool({"name": "Google_Search", "description": "Underscore version"}) + + # Both tools should be registered + assert registry.tool_count() == 2 + assert "Google.Search" in registry.tool_names() + assert "Google_Search" in registry.tool_names() + + def test_normalized_name_resolution_prefers_underscore_version(self): + """Test that when both Google.Search and Google_Search exist, + resolving 'Google_Search' returns the explicit underscore version. + """ + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Search", "description": "Dot version"}) + registry.add_tool({"name": "Google_Search", "description": "Underscore version"}) + + # "Google_Search" should resolve to itself (explicit match) + resolved = registry.resolve_tool_name("Google_Search") + assert resolved == "Google_Search" + + # "Google.Search" should resolve to itself (exact match) + resolved = registry.resolve_tool_name("Google.Search") + assert resolved == "Google.Search" + + def test_normalized_lookup_when_only_dot_version_exists(self): + """Test that normalized name lookup works when only dot version exists.""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Search", "description": "Dot version"}) + + # "Google_Search" should resolve to "Google.Search" + resolved = registry.resolve_tool_name("Google_Search") + assert resolved == "Google.Search" + + def test_anthropic_format_normalizes_names(self): + """Test that Anthropic format output uses normalized names (underscores).""" + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "Google.Search", "description": "Search"}) + + tools = registry.list_tools_for_model("anthropic") + + # Anthropic format should have normalized name + assert tools[0]["name"] == "Google_Search" diff --git a/libs/tests/arcade_evals/test_tracks.py b/libs/tests/arcade_evals/test_tracks.py new file mode 100644 index 000000000..2d95b5d0d --- /dev/null +++ b/libs/tests/arcade_evals/test_tracks.py @@ -0,0 +1,126 @@ +"""Tests for track management in comparative evaluations.""" + +import pytest +from arcade_evals._evalsuite._tool_registry import EvalSuiteToolRegistry +from arcade_evals._evalsuite._tracks import TrackManager + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + + +class TestTrackManager: + """Tests for TrackManager class.""" + + def test_create_track(self) -> None: + """Test creating a new track.""" + manager = TrackManager() + registry = EvalSuiteToolRegistry() + + track_name = manager.create_track("Test Track", registry) + + assert track_name == "Test Track" + assert manager.has_track("Test Track") + assert manager.track_count() == 1 + + def test_create_duplicate_track_raises(self) -> None: + """Test that creating a duplicate track raises ValueError.""" + manager = TrackManager() + registry1 = EvalSuiteToolRegistry() + registry2 = EvalSuiteToolRegistry() + + manager.create_track("Track1", registry1) + + with pytest.raises(ValueError, match="already exists"): + manager.create_track("Track1", registry2) + + def test_get_registry(self) -> None: + """Test getting a registry by track name.""" + manager = TrackManager() + registry = EvalSuiteToolRegistry() + registry.add_tool({"name": "TestTool", "description": "Test"}) + + manager.create_track("MyTrack", registry) + retrieved = manager.get_registry("MyTrack") + + assert retrieved is registry + assert retrieved.has_tool("TestTool") + + def test_get_registry_nonexistent(self) -> None: + """Test getting a nonexistent registry returns None.""" + manager = TrackManager() + + result = manager.get_registry("NonexistentTrack") + + assert result is None + + def test_get_track_names(self) -> None: + """Test getting all track names.""" + manager = TrackManager() + manager.create_track("Track1", EvalSuiteToolRegistry()) + manager.create_track("Track2", EvalSuiteToolRegistry()) + manager.create_track("Track3", EvalSuiteToolRegistry()) + + names = manager.get_track_names() + + assert names == ["Track1", "Track2", "Track3"] + + def test_get_track_names_empty(self) -> None: + """Test getting track names when empty.""" + manager = TrackManager() + + names = manager.get_track_names() + + assert names == [] + + def test_has_track(self) -> None: + """Test checking if track exists.""" + manager = TrackManager() + manager.create_track("Exists", EvalSuiteToolRegistry()) + + assert manager.has_track("Exists") is True + assert manager.has_track("DoesNotExist") is False + + def test_track_count(self) -> None: + """Test counting tracks.""" + manager = TrackManager() + + assert manager.track_count() == 0 + + manager.create_track("Track1", EvalSuiteToolRegistry()) + assert manager.track_count() == 1 + + manager.create_track("Track2", EvalSuiteToolRegistry()) + assert manager.track_count() == 2 + + def test_get_all_registries(self) -> None: + """Test getting all registries.""" + manager = TrackManager() + reg1 = EvalSuiteToolRegistry() + reg2 = EvalSuiteToolRegistry() + + manager.create_track("Track1", reg1) + manager.create_track("Track2", reg2) + + all_regs = manager.get_all_registries() + + assert len(all_regs) == 2 + assert all_regs["Track1"] is reg1 + assert all_regs["Track2"] is reg2 + + def test_registries_are_isolated(self) -> None: + """Test that each track has its own isolated registry.""" + manager = TrackManager() + reg1 = EvalSuiteToolRegistry() + reg2 = EvalSuiteToolRegistry() + + reg1.add_tool({"name": "Tool1", "description": "Tool for track 1"}) + reg2.add_tool({"name": "Tool2", "description": "Tool for track 2"}) + + manager.create_track("Track1", reg1) + manager.create_track("Track2", reg2) + + # Each registry only has its own tool + assert manager.get_registry("Track1").has_tool("Tool1") + assert not manager.get_registry("Track1").has_tool("Tool2") + assert manager.get_registry("Track2").has_tool("Tool2") + assert not manager.get_registry("Track2").has_tool("Tool1") diff --git a/libs/tests/arcade_evals/test_types.py b/libs/tests/arcade_evals/test_types.py new file mode 100644 index 000000000..cde8054bb --- /dev/null +++ b/libs/tests/arcade_evals/test_types.py @@ -0,0 +1,391 @@ +"""Tests for shared types in _types.py module.""" + +import pytest +from arcade_evals._evalsuite._types import ( + AnyExpectedToolCall, + ComparativeCase, + EvalRubric, + ExpectedMCPToolCall, + ExpectedToolCall, + NamedExpectedToolCall, + TrackConfig, +) + +# Mark all tests in this module as requiring evals dependencies +pytestmark = pytest.mark.evals + + +class TestExpectedToolCall: + """Tests for ExpectedToolCall dataclass.""" + + def test_create_with_func_and_args(self) -> None: + """Test creating ExpectedToolCall with function and args.""" + + def my_tool(x: int, y: int) -> int: + return x + y + + tc = ExpectedToolCall(func=my_tool, args={"x": 1, "y": 2}) + + assert tc.func is my_tool + assert tc.args == {"x": 1, "y": 2} + + def test_create_with_empty_args(self) -> None: + """Test creating ExpectedToolCall with default empty args.""" + + def my_tool() -> None: + pass + + tc = ExpectedToolCall(func=my_tool) + + assert tc.func is my_tool + assert tc.args == {} + + def test_create_positional_args(self) -> None: + """Test creating ExpectedToolCall with positional args.""" + + def my_tool(x: int) -> int: + return x + + tc = ExpectedToolCall(my_tool, {"x": 5}) + + assert tc.func is my_tool + assert tc.args == {"x": 5} + + +class TestExpectedMCPToolCall: + """Tests for ExpectedMCPToolCall dataclass.""" + + def test_create_with_name_and_args(self) -> None: + """Test creating ExpectedMCPToolCall with name and args.""" + tc = ExpectedMCPToolCall(tool_name="Calculator_Add", args={"a": 5, "b": 3}) + + assert tc.tool_name == "Calculator_Add" + assert tc.args == {"a": 5, "b": 3} + + def test_create_with_empty_args(self) -> None: + """Test creating ExpectedMCPToolCall with default empty args.""" + tc = ExpectedMCPToolCall(tool_name="GetTime") + + assert tc.tool_name == "GetTime" + assert tc.args == {} + + def test_create_positional_args(self) -> None: + """Test creating ExpectedMCPToolCall with positional args.""" + tc = ExpectedMCPToolCall("Weather_Get", {"city": "NYC"}) + + assert tc.tool_name == "Weather_Get" + assert tc.args == {"city": "NYC"} + + +class TestNamedExpectedToolCall: + """Tests for NamedExpectedToolCall dataclass.""" + + def test_create(self) -> None: + """Test creating NamedExpectedToolCall.""" + tc = NamedExpectedToolCall(name="MyTool", args={"param": "value"}) + + assert tc.name == "MyTool" + assert tc.args == {"param": "value"} + + def test_create_empty_args(self) -> None: + """Test creating NamedExpectedToolCall with empty args.""" + tc = NamedExpectedToolCall(name="SimpleTool", args={}) + + assert tc.name == "SimpleTool" + assert tc.args == {} + + +class TestAnyExpectedToolCallTypeAlias: + """Tests for AnyExpectedToolCall type alias.""" + + def test_type_alias_accepts_expected_tool_call(self) -> None: + """Test that ExpectedToolCall is valid for AnyExpectedToolCall.""" + + def my_func() -> None: + pass + + tc: AnyExpectedToolCall = ExpectedToolCall(func=my_func) + assert isinstance(tc, ExpectedToolCall) + + def test_type_alias_accepts_expected_mcp_tool_call(self) -> None: + """Test that ExpectedMCPToolCall is valid for AnyExpectedToolCall.""" + tc: AnyExpectedToolCall = ExpectedMCPToolCall(tool_name="Test") + assert isinstance(tc, ExpectedMCPToolCall) + + +class TestEvalRubric: + """Tests for EvalRubric dataclass.""" + + def test_default_values(self) -> None: + """Test EvalRubric has correct default values.""" + rubric = EvalRubric() + + assert rubric.fail_threshold == 0.8 + assert rubric.warn_threshold == 0.9 + assert rubric.fail_on_tool_selection is True + assert rubric.fail_on_tool_call_quantity is True + assert rubric.tool_selection_weight == 1.0 + + def test_custom_values(self) -> None: + """Test EvalRubric with custom values.""" + rubric = EvalRubric( + fail_threshold=0.7, + warn_threshold=0.85, + fail_on_tool_selection=False, + fail_on_tool_call_quantity=False, + tool_selection_weight=0.5, + ) + + assert rubric.fail_threshold == 0.7 + assert rubric.warn_threshold == 0.85 + assert rubric.fail_on_tool_selection is False + assert rubric.fail_on_tool_call_quantity is False + assert rubric.tool_selection_weight == 0.5 + + def test_str_representation(self) -> None: + """Test EvalRubric __str__ method.""" + rubric = EvalRubric() + + result = str(rubric) + + assert "EvalRubric(" in result + assert "fail_threshold=0.8" in result + assert "warn_threshold=0.9" in result + assert "fail_on_tool_selection=True" in result + assert "fail_on_tool_call_quantity=True" in result + assert "tool_selection_weight=1.0" in result + + def test_repr_representation(self) -> None: + """Test EvalRubric __repr__ method returns same as __str__.""" + rubric = EvalRubric(fail_threshold=0.75) + + assert repr(rubric) == str(rubric) + + def test_str_with_custom_values(self) -> None: + """Test __str__ reflects custom values.""" + rubric = EvalRubric( + fail_threshold=0.5, + warn_threshold=0.6, + fail_on_tool_selection=False, + fail_on_tool_call_quantity=False, + tool_selection_weight=2.0, + ) + + result = str(rubric) + + assert "fail_threshold=0.5" in result + assert "warn_threshold=0.6" in result + assert "fail_on_tool_selection=False" in result + assert "fail_on_tool_call_quantity=False" in result + assert "tool_selection_weight=2.0" in result + + +class TestTrackConfigFromTypes: + """Tests for TrackConfig dataclass from _types module.""" + + def test_create_with_expected_tool_calls(self) -> None: + """Test creating TrackConfig with expected tool calls.""" + + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ + ExpectedMCPToolCall("Tool1", {"arg": "val"}) + ] + config = TrackConfig(expected_tool_calls=expected) + + assert config.expected_tool_calls == expected + assert config.critics == [] + + def test_create_with_critics(self) -> None: + """Test creating TrackConfig with critics.""" + from arcade_evals.critic import Critic, NoneCritic + + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ExpectedMCPToolCall("Tool1")] + critics: list[Critic] = [NoneCritic(critic_field="field1")] + config = TrackConfig(expected_tool_calls=expected, critics=critics) + + assert config.expected_tool_calls == expected + assert config.critics == critics + + def test_mixed_expected_tool_calls(self) -> None: + """Test TrackConfig with mixed ExpectedToolCall and ExpectedMCPToolCall.""" + + def my_func() -> None: + pass + + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ + ExpectedToolCall(func=my_func, args={"x": 1}), + ExpectedMCPToolCall(tool_name="MCPTool", args={"y": 2}), + ] + config = TrackConfig(expected_tool_calls=expected) + + assert len(config.expected_tool_calls) == 2 + assert isinstance(config.expected_tool_calls[0], ExpectedToolCall) + assert isinstance(config.expected_tool_calls[1], ExpectedMCPToolCall) + + +class TestComparativeCaseFromTypes: + """Tests for ComparativeCase dataclass from _types module.""" + + def test_default_values(self) -> None: + """Test ComparativeCase default values.""" + case = ComparativeCase( + name="test", + user_message="Hello", + ) + + assert case.name == "test" + assert case.user_message == "Hello" + assert case.system_message == "" + assert case.additional_messages == [] + assert case.rubric is None + assert case.track_configs == {} + + def test_with_rubric(self) -> None: + """Test ComparativeCase with custom rubric.""" + rubric = EvalRubric(fail_threshold=0.9) + case = ComparativeCase( + name="test", + user_message="Hello", + rubric=rubric, + ) + + assert case.rubric is rubric + + def test_add_track_config(self) -> None: + """Test adding track configuration.""" + case = ComparativeCase(name="test", user_message="Hello") + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ + ExpectedMCPToolCall("Tool1", {"arg": "val"}) + ] + + case.add_track_config("Track1", expected) + + assert "Track1" in case.track_configs + assert case.track_configs["Track1"].expected_tool_calls == expected + + def test_add_track_config_with_critics(self) -> None: + """Test adding track config with critics.""" + from arcade_evals.critic import Critic, NoneCritic + + case = ComparativeCase(name="test", user_message="Hello") + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ExpectedMCPToolCall("Tool1")] + critics: list[Critic] = [NoneCritic(critic_field="field")] + + case.add_track_config("Track1", expected, critics=critics) + + assert case.track_configs["Track1"].critics == critics + + def test_add_duplicate_track_raises(self) -> None: + """Test adding duplicate track config raises ValueError.""" + case = ComparativeCase(name="test", user_message="Hello") + expected: list[ExpectedToolCall | ExpectedMCPToolCall] = [ExpectedMCPToolCall("Tool1")] + + case.add_track_config("Track1", expected) + + with pytest.raises(ValueError, match="already configured"): + case.add_track_config("Track1", expected) + + def test_get_configured_tracks(self) -> None: + """Test getting list of configured tracks.""" + case = ComparativeCase(name="test", user_message="Hello") + + assert case.get_configured_tracks() == [] + + track1_calls: list[ExpectedToolCall | ExpectedMCPToolCall] = [ExpectedMCPToolCall("T1")] + track2_calls: list[ExpectedToolCall | ExpectedMCPToolCall] = [ExpectedMCPToolCall("T2")] + case.add_track_config("Track1", track1_calls) + case.add_track_config("Track2", track2_calls) + + tracks = case.get_configured_tracks() + + assert "Track1" in tracks + assert "Track2" in tracks + assert len(tracks) == 2 + + +class TestEvalSuiteCreateEvalCase: + """Tests for EvalSuite._create_eval_case factory method.""" + + def test_create_eval_case_basic(self) -> None: + """Test creating EvalCase via factory method.""" + from arcade_evals import EvalSuite + from arcade_evals._evalsuite._types import EvalRubric, NamedExpectedToolCall + + suite = EvalSuite(name="Test", system_message="System") + + case = suite._create_eval_case( + name="test_case", + system_message="Custom system", + user_message="Hello", + expected_tool_calls=[NamedExpectedToolCall(name="Tool1", args={"x": 1})], + rubric=EvalRubric(), + critics=[], + additional_messages=[], + ) + + assert case.name == "test_case" + assert case.system_message == "Custom system" + assert case.user_message == "Hello" + assert len(case.expected_tool_calls) == 1 + assert case.expected_tool_calls[0].name == "Tool1" + + def test_create_eval_case_with_critics(self) -> None: + """Test creating EvalCase with critics.""" + from arcade_evals import EvalSuite + from arcade_evals._evalsuite._types import EvalRubric, NamedExpectedToolCall + from arcade_evals.critic import Critic, SimilarityCritic + + suite = EvalSuite(name="Test", system_message="System") + critics: list[Critic] = [SimilarityCritic(critic_field="query", weight=1.0)] + + case = suite._create_eval_case( + name="test_case", + system_message="System", + user_message="Query", + expected_tool_calls=[NamedExpectedToolCall(name="Search", args={"query": "test"})], + rubric=EvalRubric(), + critics=critics, + additional_messages=[], + ) + + assert case.critics == critics + + def test_create_eval_case_with_additional_messages(self) -> None: + """Test creating EvalCase with additional messages.""" + from arcade_evals import EvalSuite + from arcade_evals._evalsuite._types import EvalRubric + + suite = EvalSuite(name="Test", system_message="System") + additional = [{"role": "assistant", "content": "Previous response"}] + + case = suite._create_eval_case( + name="test_case", + system_message="System", + user_message="Follow-up", + expected_tool_calls=[], + rubric=EvalRubric(), + critics=[], + additional_messages=additional, + ) + + assert case.additional_messages == additional + + def test_create_eval_case_with_custom_rubric(self) -> None: + """Test creating EvalCase with custom rubric.""" + from arcade_evals import EvalSuite + from arcade_evals._evalsuite._types import EvalRubric + + suite = EvalSuite(name="Test", system_message="System") + rubric = EvalRubric(fail_threshold=0.95, warn_threshold=0.98) + + case = suite._create_eval_case( + name="test_case", + system_message="System", + user_message="Test", + expected_tool_calls=[], + rubric=rubric, + critics=[], + additional_messages=[], + ) + + assert case.rubric.fail_threshold == 0.95 + assert case.rubric.warn_threshold == 0.98 diff --git a/libs/tests/arcade_mcp_server/test_context.py b/libs/tests/arcade_mcp_server/test_context.py index 39a4c7345..550e4b239 100644 --- a/libs/tests/arcade_mcp_server/test_context.py +++ b/libs/tests/arcade_mcp_server/test_context.py @@ -8,7 +8,6 @@ from arcade_mcp_server.context import get_current_model_context as get_current_context from arcade_mcp_server.context import set_current_model_context as set_current_context from arcade_mcp_server.types import ( - MCPTool, ModelHint, ModelPreferences, ) diff --git a/libs/tests/arcade_mcp_server/test_mcp_app.py b/libs/tests/arcade_mcp_server/test_mcp_app.py index 655ed78dd..97f61f67f 100644 --- a/libs/tests/arcade_mcp_server/test_mcp_app.py +++ b/libs/tests/arcade_mcp_server/test_mcp_app.py @@ -3,7 +3,7 @@ import subprocess import sys from typing import Annotated -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from arcade_core.catalog import MaterializedTool diff --git a/libs/tests/arcade_mcp_server/test_openapi_docs.py b/libs/tests/arcade_mcp_server/test_openapi_docs.py index eb5046129..cc5700727 100644 --- a/libs/tests/arcade_mcp_server/test_openapi_docs.py +++ b/libs/tests/arcade_mcp_server/test_openapi_docs.py @@ -1,6 +1,5 @@ """Test that MCP routes appear in OpenAPI documentation.""" -import pytest from arcade_core import ToolCatalog from arcade_core.toolkit import Toolkit from arcade_mcp_server.settings import MCPSettings @@ -73,7 +72,6 @@ def test_mcp_routes_in_openapi(monkeypatch): # Verify the actual proxy is mounted (not routes) # The OpenAPI docs should exist but not interfere with the mount - import inspect mounts = [route for route in app.routes if hasattr(route, "app") and hasattr(route, "path")] mcp_mounts = [m for m in mounts if m.path == "/mcp"] diff --git a/libs/tests/arcade_mcp_server/test_settings.py b/libs/tests/arcade_mcp_server/test_settings.py index d47dc4912..cd74b7b80 100644 --- a/libs/tests/arcade_mcp_server/test_settings.py +++ b/libs/tests/arcade_mcp_server/test_settings.py @@ -1,6 +1,5 @@ """Tests for MCP Settings.""" -import pytest from arcade_mcp_server.settings import MCPSettings, ServerSettings diff --git a/libs/tests/arcade_mcp_server/transports/test_http_session_manager.py b/libs/tests/arcade_mcp_server/transports/test_http_session_manager.py index 61f5bc812..82cd859df 100644 --- a/libs/tests/arcade_mcp_server/transports/test_http_session_manager.py +++ b/libs/tests/arcade_mcp_server/transports/test_http_session_manager.py @@ -1,12 +1,10 @@ -from http import HTTPStatus -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from arcade_mcp_server.transports.http_session_manager import ( MCP_SESSION_ID_HEADER, HTTPSessionManager, ) -from arcade_mcp_server.transports.http_streamable import HTTPStreamableTransport class TestHTTPSessionManager: diff --git a/libs/tests/arcade_mcp_server/transports/test_http_streamable.py b/libs/tests/arcade_mcp_server/transports/test_http_streamable.py index 37d8ca8a7..4d09f4a11 100644 --- a/libs/tests/arcade_mcp_server/transports/test_http_streamable.py +++ b/libs/tests/arcade_mcp_server/transports/test_http_streamable.py @@ -1,4 +1,3 @@ -import json from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/libs/tests/cli/deploy/test_deploy.py b/libs/tests/cli/deploy/test_deploy.py index 0a6768bc6..ff7bf3638 100644 --- a/libs/tests/cli/deploy/test_deploy.py +++ b/libs/tests/cli/deploy/test_deploy.py @@ -2,7 +2,6 @@ import io import subprocess import tarfile -import time from pathlib import Path import pytest diff --git a/libs/tests/cli/test_capture_formatters.py b/libs/tests/cli/test_capture_formatters.py new file mode 100644 index 000000000..832ab8e5a --- /dev/null +++ b/libs/tests/cli/test_capture_formatters.py @@ -0,0 +1,926 @@ +"""Tests for capture mode formatters.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +import pytest +from arcade_cli.formatters import ( + CAPTURE_FORMATTERS, + CaptureHtmlFormatter, + CaptureJsonFormatter, + CaptureMarkdownFormatter, + CaptureTextFormatter, + get_capture_formatter, +) + +if TYPE_CHECKING: + from arcade_evals import CaptureResult + + +def _create_mock_capture_result( + suite_name: str = "TestSuite", + model: str = "gpt-4o", + provider: str = "openai", + cases: list[dict] | None = None, +) -> CaptureResult: + """Create a mock CaptureResult for testing.""" + if cases is None: + cases = [ + { + "case_name": "test_case_1", + "user_message": "What's the weather?", + "tool_calls": [ + {"name": "GetWeather", "args": {"city": "NYC", "units": "celsius"}}, + ], + "system_message": "You are helpful", + "additional_messages": [{"role": "user", "content": "Hi"}], + } + ] + + # Create mock capture result + capture = MagicMock() + capture.suite_name = suite_name + capture.model = model + capture.provider = provider + + # Create mock captured cases + captured_cases = [] + for case_data in cases: + case = MagicMock() + case.case_name = case_data["case_name"] + case.user_message = case_data["user_message"] + case.system_message = case_data.get("system_message") + case.additional_messages = case_data.get("additional_messages", []) + # Explicitly set track_name to None unless specified (avoids MagicMock) + case.track_name = case_data.get("track_name") + + # Create mock tool calls + tool_calls = [] + for tc_data in case_data.get("tool_calls", []): + tc = MagicMock() + tc.name = tc_data["name"] + tc.args = tc_data.get("args", {}) + tool_calls.append(tc) + case.tool_calls = tool_calls + + captured_cases.append(case) + + capture.captured_cases = captured_cases + + # Mock to_dict method + def to_dict(include_context: bool = False) -> dict: + result = { + "suite_name": capture.suite_name, + "model": capture.model, + "provider": capture.provider, + "captured_cases": [], + } + for case in captured_cases: + case_dict = { + "case_name": case.case_name, + "user_message": case.user_message, + "tool_calls": [{"name": tc.name, "args": tc.args} for tc in case.tool_calls], + } + if include_context: + case_dict["system_message"] = case.system_message + case_dict["additional_messages"] = case.additional_messages + result["captured_cases"].append(case_dict) + return result + + capture.to_dict = to_dict + + return capture + + +class TestGetCaptureFormatter: + """Tests for get_capture_formatter function.""" + + def test_get_json_formatter(self) -> None: + """Test getting JSON formatter.""" + formatter = get_capture_formatter("json") + assert isinstance(formatter, CaptureJsonFormatter) + + def test_get_txt_formatter(self) -> None: + """Test getting text formatter.""" + formatter = get_capture_formatter("txt") + assert isinstance(formatter, CaptureTextFormatter) + + def test_get_md_formatter(self) -> None: + """Test getting markdown formatter.""" + formatter = get_capture_formatter("md") + assert isinstance(formatter, CaptureMarkdownFormatter) + + def test_get_html_formatter(self) -> None: + """Test getting HTML formatter.""" + formatter = get_capture_formatter("html") + assert isinstance(formatter, CaptureHtmlFormatter) + + def test_case_insensitive(self) -> None: + """Test that format names are case insensitive.""" + assert isinstance(get_capture_formatter("JSON"), CaptureJsonFormatter) + assert isinstance(get_capture_formatter("TXT"), CaptureTextFormatter) + assert isinstance(get_capture_formatter("MD"), CaptureMarkdownFormatter) + assert isinstance(get_capture_formatter("HTML"), CaptureHtmlFormatter) + + def test_unsupported_format_raises(self) -> None: + """Test that unsupported formats raise ValueError.""" + with pytest.raises(ValueError, match="Unsupported capture format 'xlsx'"): + get_capture_formatter("xlsx") + + def test_close_match_suggestion(self) -> None: + """Test that close matches are suggested.""" + with pytest.raises(ValueError, match="Did you mean 'json'"): + get_capture_formatter("jsn") + + +class TestCaptureJsonFormatter: + """Tests for CaptureJsonFormatter.""" + + def test_file_extension(self) -> None: + """Test file extension is json.""" + formatter = CaptureJsonFormatter() + assert formatter.file_extension == "json" + + def test_format_basic(self) -> None: + """Test basic JSON formatting.""" + formatter = CaptureJsonFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + parsed = json.loads(output) + + assert "captures" in parsed + assert len(parsed["captures"]) == 1 + assert parsed["captures"][0]["suite_name"] == "TestSuite" + assert parsed["captures"][0]["model"] == "gpt-4o" + + def test_format_includes_tool_calls(self) -> None: + """Test that tool calls are included.""" + formatter = CaptureJsonFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + parsed = json.loads(output) + + case = parsed["captures"][0]["captured_cases"][0] + assert len(case["tool_calls"]) == 1 + assert case["tool_calls"][0]["name"] == "GetWeather" + assert case["tool_calls"][0]["args"]["city"] == "NYC" + + def test_format_with_context(self) -> None: + """Test formatting with context included.""" + formatter = CaptureJsonFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture], include_context=True) + parsed = json.loads(output) + + case = parsed["captures"][0]["captured_cases"][0] + assert "system_message" in case + assert case["system_message"] == "You are helpful" + + def test_format_without_context(self) -> None: + """Test formatting without context (default).""" + formatter = CaptureJsonFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture], include_context=False) + parsed = json.loads(output) + + case = parsed["captures"][0]["captured_cases"][0] + assert "system_message" not in case + + +class TestCaptureTextFormatter: + """Tests for CaptureTextFormatter.""" + + def test_file_extension(self) -> None: + """Test file extension is txt.""" + formatter = CaptureTextFormatter() + assert formatter.file_extension == "txt" + + def test_format_contains_suite_info(self) -> None: + """Test that suite info is in output.""" + formatter = CaptureTextFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "Suite: TestSuite" in output + assert "Model: gpt-4o" in output + assert "Provider: openai" in output + + def test_format_contains_case_info(self) -> None: + """Test that case info is in output.""" + formatter = CaptureTextFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "Case: test_case_1" in output + assert "User Message: What's the weather?" in output + + def test_format_contains_tool_calls(self) -> None: + """Test that tool calls are in output.""" + formatter = CaptureTextFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "GetWeather" in output + assert "city: NYC" in output + + def test_format_contains_summary(self) -> None: + """Test that summary is in output.""" + formatter = CaptureTextFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "Summary: 1 tool calls across 1 cases" in output + + def test_format_with_context(self) -> None: + """Test formatting with context.""" + formatter = CaptureTextFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture], include_context=True) + + assert "System Message: You are helpful" in output + + +class TestCaptureMarkdownFormatter: + """Tests for CaptureMarkdownFormatter.""" + + def test_file_extension(self) -> None: + """Test file extension is md.""" + formatter = CaptureMarkdownFormatter() + assert formatter.file_extension == "md" + + def test_format_has_heading(self) -> None: + """Test that markdown has main heading.""" + formatter = CaptureMarkdownFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "# Capture Results" in output + + def test_format_has_suite_heading(self) -> None: + """Test that suite has heading.""" + formatter = CaptureMarkdownFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "## TestSuite" in output + + def test_format_has_case_heading(self) -> None: + """Test that case has heading.""" + formatter = CaptureMarkdownFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "### Case: test_case_1" in output + + def test_format_has_code_blocks(self) -> None: + """Test that tool args are in code blocks.""" + formatter = CaptureMarkdownFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "```json" in output + assert '"city": "NYC"' in output + assert "```" in output + + def test_format_has_summary(self) -> None: + """Test that summary is present.""" + formatter = CaptureMarkdownFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "## Summary" in output + assert "**Total Cases:** 1" in output + assert "**Total Tool Calls:** 1" in output + + +class TestCaptureHtmlFormatter: + """Tests for CaptureHtmlFormatter.""" + + def test_file_extension(self) -> None: + """Test file extension is html.""" + formatter = CaptureHtmlFormatter() + assert formatter.file_extension == "html" + + def test_format_is_valid_html(self) -> None: + """Test that output is valid HTML structure.""" + formatter = CaptureHtmlFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "" in output + assert "" in output + assert "" in output + assert "" in output + assert "" in output + assert "" in output + + def test_format_contains_styles(self) -> None: + """Test that CSS styles are included.""" + formatter = CaptureHtmlFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "" in output + + def test_format_contains_suite_info(self) -> None: + """Test that suite info is in output.""" + formatter = CaptureHtmlFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "TestSuite" in output + assert "gpt-4o" in output + + def test_format_contains_tool_calls(self) -> None: + """Test that tool calls are in output.""" + formatter = CaptureHtmlFormatter() + capture = _create_mock_capture_result() + + output = formatter.format([capture]) + + assert "GetWeather" in output + # Args should be HTML-escaped + assert ""city"" in output or '"city"' in output + + def test_format_escapes_html(self) -> None: + """Test that HTML special characters are escaped.""" + formatter = CaptureHtmlFormatter() + capture = _create_mock_capture_result( + cases=[ + { + "case_name": "Test ", + "suite_name": "Suite & Test", + "rubric": "Test", + "cases": [{ + "name": "", + "input": "test' OR '1'='1", + "evaluation": MockEvaluation( + passed=False, + score=0.0, + failure_reason="Error: ", + ), + }], + }]] + + formatter = HtmlFormatter() + output = formatter.format(results) + + # Should NOT contain raw script tags or other unescaped HTML + assert "", + "input": "Test bold", + "evaluation": MockEvaluation(passed=True, score=1.0), + } + ] + formatter = HtmlFormatter() + output = formatter.format(make_mock_results(cases=cases)) + + # Should escape < and > in case name + assert "<script>" in output + # Raw script tags should NOT be present (XSS prevention) + assert "