Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from mcp.types import TextContent, Tool

from sktime_mcp.composition.validator import get_composition_validator
from sktime_mcp.tools.batch import run_tools_batch_tool
from sktime_mcp.tools.codegen import export_code_tool
from sktime_mcp.tools.data_tools import (
load_data_source_async_tool,
Expand Down Expand Up @@ -123,11 +124,11 @@ def sanitize_for_json(obj):
# --- Standard Python containers ---
if isinstance(obj, dict):
return {str(k): sanitize_for_json(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
if isinstance(obj, list | tuple):
return [sanitize_for_json(item) for item in obj]

# --- Already JSON-safe scalars ---
if isinstance(obj, (str, int, float, bool, type(None))):
if isinstance(obj, str | int | float | bool | type(None)):
return obj

# --- Fallback: objects with __dict__ or anything else ---
Expand Down Expand Up @@ -213,6 +214,35 @@ async def list_tools() -> list[Tool]:
),
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="run_tools_batch",
description=(
"Run multiple read-only MCP tool calls in a single request to reduce "
"agent round-trips. Supported tools in MVP: list_estimators, "
"describe_estimator, get_available_tags, list_available_data."
),
inputSchema={
"type": "object",
"properties": {
"operations": {
"type": "array",
"description": (
"Ordered list of tool invocations. Each entry must contain "
"'tool' and optional 'arguments'."
),
"items": {
"type": "object",
"properties": {
"tool": {"type": "string"},
"arguments": {"type": "object"},
},
"required": ["tool"],
},
}
},
"required": ["operations"],
},
),
# -- Instantiation ---------------------------------------------------
Tool(
name="instantiate_estimator",
Expand Down Expand Up @@ -652,6 +682,9 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
elif name == "get_available_tags":
result = get_available_tags()

elif name == "run_tools_batch":
result = run_tools_batch_tool(arguments["operations"])

# -- Instantiation ---------------------------------------------------
elif name == "instantiate_estimator":
result = instantiate_estimator_tool(
Expand Down
2 changes: 2 additions & 0 deletions src/sktime_mcp/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tools module for sktime MCP."""

from sktime_mcp.tools.batch import run_tools_batch_tool
from sktime_mcp.tools.codegen import export_code_tool
from sktime_mcp.tools.data_tools import (
load_data_source_async_tool,
Expand Down Expand Up @@ -54,4 +55,5 @@
"check_job_status_tool",
"list_jobs_tool",
"cancel_job_tool",
"run_tools_batch_tool",
]
119 changes: 119 additions & 0 deletions src/sktime_mcp/tools/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Batch tool execution helpers for sktime MCP.

The MVP intentionally supports only read-only tools so agent workflows can
reduce round-trips without mutating server state.
"""

from collections.abc import Callable
from typing import Any

from sktime_mcp.tools.describe_estimator import describe_estimator_tool
from sktime_mcp.tools.list_available_data import list_available_data_tool
from sktime_mcp.tools.list_estimators import get_available_tags, list_estimators_tool


def run_tools_batch_tool(operations: list[dict[str, Any]]) -> dict[str, Any]:
"""
Execute a list of read-only MCP operations in one call.

Args:
operations: List of operation dicts, each with:
- tool: tool name
- arguments: dict of arguments for that tool (optional)

Returns:
Dict with per-operation results and aggregate success.
"""
if not isinstance(operations, list) or len(operations) == 0:
return {
"success": False,
"error": "operations must be a non-empty list.",
}

dispatch: dict[str, Callable[..., dict[str, Any]]] = {
"list_estimators": list_estimators_tool,
"describe_estimator": describe_estimator_tool,
"get_available_tags": get_available_tags,
"list_available_data": list_available_data_tool,
}

results: list[dict[str, Any]] = []

for index, operation in enumerate(operations):
if not isinstance(operation, dict):
results.append(
{
"index": index,
"success": False,
"error": "Operation must be an object with 'tool' and optional 'arguments'.",
}
)
continue

tool_name = operation.get("tool")
arguments = operation.get("arguments", {})

if not isinstance(tool_name, str) or not tool_name.strip():
results.append(
{
"index": index,
"success": False,
"error": "Operation field 'tool' must be a non-empty string.",
}
)
continue

if not isinstance(arguments, dict):
results.append(
{
"index": index,
"tool": tool_name,
"success": False,
"error": "Operation field 'arguments' must be an object.",
}
)
continue

fn = dispatch.get(tool_name)
if fn is None:
results.append(
{
"index": index,
"tool": tool_name,
"success": False,
"error": (
"Unsupported tool for batch execution. "
"Allowed: list_estimators, describe_estimator, "
"get_available_tags, list_available_data."
),
}
)
continue

try:
item_result = fn(**arguments)
results.append(
{
"index": index,
"tool": tool_name,
"success": bool(item_result.get("success", True)),
"result": item_result,
}
)
except Exception as exc:
results.append(
{
"index": index,
"tool": tool_name,
"success": False,
"error": str(exc),
}
)

return {
"success": all(item.get("success", False) for item in results),
"count": len(results),
"results": results,
}

64 changes: 64 additions & 0 deletions tests/test_batch_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Tests for batch execution tool."""

import sys

sys.path.insert(0, "src")

import sktime_mcp.tools.batch as batch_tools


def test_run_tools_batch_success(monkeypatch):
"""run_tools_batch should execute supported read-only tools."""
monkeypatch.setattr(batch_tools, "get_available_tags", lambda: {"success": True, "tags": []})
monkeypatch.setattr(
batch_tools,
"list_estimators_tool",
lambda **kwargs: {"success": True, "estimators": [], "kwargs": kwargs},
)
monkeypatch.setattr(
batch_tools,
"list_available_data_tool",
lambda *args, **kwargs: {"success": True, "system_demos": ["airline"]},
)

result = batch_tools.run_tools_batch_tool(
[
{"tool": "get_available_tags"},
{"tool": "list_estimators", "arguments": {"task": "forecasting", "limit": 2}},
{"tool": "list_available_data"},
]
)

assert result["success"] is True
assert result["count"] == 3
assert all(item["success"] for item in result["results"])
assert result["results"][1]["result"]["success"] is True


def test_run_tools_batch_rejects_unsupported_tool():
"""Unsupported tools should fail with clear per-operation errors."""
result = batch_tools.run_tools_batch_tool(
[
{"tool": "release_handle", "arguments": {"handle": "est_123"}},
]
)

assert result["success"] is False
assert result["results"][0]["success"] is False
assert "Unsupported tool for batch execution" in result["results"][0]["error"]


def test_run_tools_batch_rejects_malformed_operation():
"""Malformed operations should be reported without crashing."""
result = batch_tools.run_tools_batch_tool(
[
{"tool": "", "arguments": {}},
{"tool": "list_estimators", "arguments": "not-a-dict"},
"not-an-object",
]
)

assert result["success"] is False
assert result["count"] == 3
assert all(item["success"] is False for item in result["results"])