Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/source/dev-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This guide explains how the project is structured, how to develop new features,

- Python 3.10+
- pip
- Optional: `mkdocs` if you want to build the documentation site
- Optional: Sphinx if you want to build the documentation site

## Setup

Expand Down Expand Up @@ -114,8 +114,9 @@ Update `DEMO_DATASETS` in `src/sktime_mcp/runtime/executor.py` with a new loader
If you want to build the docs site locally:

```bash
pip install mkdocs
mkdocs serve
pip install -r docs/requirements.txt
cd docs
make html
```

The config lives in `mkdocs.yml` and the docs content is under `docs/`.
The config lives in `docs/source/conf.py` and the docs content is under `docs/source/`.
49 changes: 34 additions & 15 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,16 @@ def fit(
return {"success": False, "error": f"Handle not found: {handle_id}"}

try:
if fh is not None:
instance.fit(y, X=X, fh=fh)
elif X is not None:
instance.fit(y, X=X)
is_classifier = getattr(instance, "_estimator_type", None) == "classifier"
if is_classifier:
instance.fit(X, y)
else:
instance.fit(y)
if fh is not None:
instance.fit(y, X=X, fh=fh)
elif X is not None:
instance.fit(y, X=X)
else:
instance.fit(y)

self._handle_manager.mark_fitted(handle_id)
return {"success": True, "handle": handle_id, "fitted": True}
Expand All @@ -159,20 +163,27 @@ def predict(
return {"success": False, "error": "Estimator not fitted"}

try:
if fh is None:
fh = list(range(1, 13))
is_classifier = getattr(instance, "_estimator_type", None) == "classifier"
if is_classifier:
predictions = instance.predict(X)
else:
if fh is None:
fh = list(range(1, 13))

predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh)
predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh)

if isinstance(predictions, pd.Series):
# Convert index to string to avoid JSON serialization issues with Period/DatetimeIndex
predictions_copy = predictions.copy()
predictions_copy.index = predictions_copy.index.astype(str)
result = predictions_copy.to_dict()
# Convert values to native Python types for JSON serialization
result = {k: float(v) if hasattr(v, "item") else v for k, v in predictions_copy.to_dict().items()}
elif isinstance(predictions, pd.DataFrame):
predictions_copy = predictions.copy()
predictions_copy.index = predictions_copy.index.astype(str)
result = predictions_copy.to_dict(orient="list")
# Convert lists of values to native Python types
raw_dict = predictions_copy.to_dict(orient="list")
result = {k: [float(v_i) if hasattr(v_i, "item") else v_i for v_i in v_list] for k, v_list in raw_dict.items()}
else:
result = predictions.tolist() if hasattr(predictions, "tolist") else predictions

Expand Down Expand Up @@ -577,6 +588,8 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]:
metadata["exog_columns"] = list(X.columns)
# Inject column dtypes so LLMs can distinguish time index vs target
metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
# Expose explicit shape metadata for agent reasoning (#416)
metadata["shape"] = list(data.shape) if hasattr(data, "shape") else len(data)
# Generate handle
data_handle = f"data_{uuid.uuid4().hex[:8]}"

Expand Down Expand Up @@ -611,6 +624,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]:
# Continue with unformatted data if formatting fails
_final_meta = adapter.get_metadata().copy()
_final_meta["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
_final_meta["shape"] = list(data.shape) if hasattr(data, "shape") else len(data)
return {
"success": True,
"data_handle": data_handle,
Expand Down Expand Up @@ -788,11 +802,16 @@ def format_data_handle(

# 3. Infer and set frequency
if auto_infer_freq:
freq = y.index.freq
freq = None
if not isinstance(y.index, (pd.DatetimeIndex, pd.PeriodIndex)):
changes_made["frequency_set"] = True
changes_made["frequency"] = "Integer"
else:
freq = getattr(y.index, "freq", None)

if freq is None:
# Try to infer
freq = pd.infer_freq(y.index)
if freq is None:
# Try to infer
freq = pd.infer_freq(y.index)

if freq is None:
# Manual inference
Expand Down Expand Up @@ -853,7 +872,7 @@ def format_data_handle(
"metadata": {
**data_info["metadata"],
"formatted": True,
"frequency": str(y.index.freq) if y.index.freq else changes_made.get("frequency"),
"frequency": str(y.index.freq) if getattr(y.index, "freq", None) else changes_made.get("frequency"),
"rows": len(y),
"start_date": str(y.index.min()),
"end_date": str(y.index.max()),
Expand Down
109 changes: 107 additions & 2 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
release_data_handle_tool,
)
from sktime_mcp.tools.describe_estimator import describe_estimator_tool
from sktime_mcp.tools.evaluate import evaluate_estimator_tool
from sktime_mcp.tools.evaluate import evaluate_estimator_tool, diagnose_residuals_tool
from sktime_mcp.tools.fit_predict import (
fit_predict_async_tool,
fit_predict_tool,
Expand Down Expand Up @@ -160,6 +160,7 @@ async def list_tools() -> list[Tool]:
"properties": {
"task": {
"type": "string",
"enum": ["forecasting", "classification", "regression", "transformation", "clustering"],
"description": (
"Task type filter: forecasting, classification, "
"regression, transformation, clustering"
Expand Down Expand Up @@ -368,13 +369,60 @@ async def list_tools() -> list[Tool]:
},
"cv_folds": {
"type": "integer",
"minimum": 2,
"maximum": 50,
"description": "Number of cross-validation folds (default: 3)",
"default": 3,
},
},
"required": ["estimator_handle", "dataset"],
},
),
Tool(
name="diagnose_residuals_tool",
description="Diagnose model failures by calculating statistical metrics (MAE, RMSE, Mean Bias) on residuals.",
inputSchema={
"type": "object",
"properties": {
"predictions": {
"type": ["object", "array"],
"description": "Forecasted values (dict or list)",
},
"actuals": {
"type": ["object", "array"],
"description": "Actual observed values (dict or list)",
},
},
"required": ["predictions", "actuals"],
},
),
# -- Batch Execution -------------------------------------------------
Tool(
name="run_tools_batch",
description=(
"Execute multiple read-only tools in a single batch to reduce latency. "
"Supported tools: list_estimators, describe_estimator, "
"get_available_tags, list_available_data."
),
inputSchema={
"type": "object",
"properties": {
"operations": {
"type": "array",
"items": {
"type": "object",
"properties": {
"tool": {"type": "string"},
"arguments": {"type": "object"}
},
"required": ["tool", "arguments"]
},
"description": "List of operations to execute in batch."
}
},
"required": ["operations"]
}
),
# -- Data ------------------------------------------------------------
Tool(
name="list_available_data",
Expand Down Expand Up @@ -695,12 +743,66 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
arguments.get("cv_folds", 3),
)

elif name == "diagnose_residuals_tool":
result = diagnose_residuals_tool(
predictions=arguments["predictions"],
actuals=arguments["actuals"],
)

elif name == "validate_pipeline":
validator = get_composition_validator()
validation = validator.validate_pipeline(arguments["components"])
result = validation.to_dict()
result["success"] = result["valid"]

elif name == "run_tools_batch":
batch_results = []
allowed_tools = {"list_estimators", "describe_estimator", "get_available_tags", "list_available_data"}

for i, op in enumerate(arguments.get("operations", [])):
tool_name = op.get("tool")
tool_args = op.get("arguments", {})

if tool_name not in allowed_tools:
batch_results.append({
"index": i,
"tool": tool_name,
"success": False,
"error": f"Tool '{tool_name}' is not supported in batch mode. Only read-only tools are allowed."
})
continue

try:
if tool_name == "list_estimators":
res = list_estimators_tool(
task=tool_args.get("task"),
tags=tool_args.get("tags"),
query=tool_args.get("query"),
limit=tool_args.get("limit", 50),
offset=tool_args.get("offset", 0),
)
elif tool_name == "describe_estimator":
res = describe_estimator_tool(tool_args.get("estimator", ""))
elif tool_name == "get_available_tags":
res = get_available_tags()
elif tool_name == "list_available_data":
res = list_available_data_tool(tool_args.get("is_demo"))

batch_results.append({
"index": i,
"tool": tool_name,
"success": True,
"result": res
})
except Exception as e:
batch_results.append({
"index": i,
"tool": tool_name,
"success": False,
"error": str(e)
})
result = {"batch_results": batch_results}

# -- Data ------------------------------------------------------------
elif name == "list_available_data":
result = list_available_data_tool(arguments.get("is_demo"))
Expand Down Expand Up @@ -831,7 +933,10 @@ async def run_server():
def main():
"""Main entry point."""
logger.info("Starting sktime-mcp server...")
asyncio.run(run_server())
try:
asyncio.run(run_server())
except KeyboardInterrupt:
logger.info("Server stopped by user (Ctrl+C)")


if __name__ == "__main__":
Expand Down
14 changes: 13 additions & 1 deletion src/sktime_mcp/tools/describe_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def describe_estimator_tool(estimator: str) -> dict[str, Any]:
registry = get_registry()
tag_resolver = get_tag_resolver()

if not isinstance(estimator, str):
return {
"success": False,
"error": f"'estimator' must be a string, got {type(estimator).__name__}.",
}

node = registry.get_estimator_by_name(estimator)
if node is None:
# Try case-insensitive search
Expand Down Expand Up @@ -69,7 +75,7 @@ def describe_estimator_tool(estimator: str) -> dict[str, Any]:
"hyperparameters": node.hyperparameters,
"tags": node.tags,
"tag_explanations": tag_explanations,
"docstring": doc[:500],
"docstring": doc,
}


Expand All @@ -84,6 +90,12 @@ def search_estimators_tool(query: str, limit: int = 20) -> dict[str, Any]:
Returns:
Dictionary with matching estimators
"""
if not isinstance(query, str):
return {
"success": False,
"error": f"'query' must be a string, got {type(query).__name__}.",
}

if limit < 1:
return {
"success": False,
Expand Down
60 changes: 60 additions & 0 deletions src/sktime_mcp/tools/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from typing import Any

import numpy as np
from sktime.forecasting.model_evaluation import evaluate

try:
Expand Down Expand Up @@ -74,3 +75,62 @@ def evaluate_estimator_tool(
except Exception as e:
logger.exception("Error during evaluate")
return {"success": False, "error": str(e)}


def diagnose_residuals_tool(
predictions: dict[str, Any] | list[float],
actuals: dict[str, Any] | list[float],
) -> dict[str, Any]:
"""
Diagnose residuals by comparing predictions and actuals.

Args:
predictions: Forecasted values.
actuals: Actual observed values.

Returns:
Dictionary with statistical metrics (MAE, RMSE, Bias).
"""
try:
def extract_values(data):
if isinstance(data, dict):
# Handle nested dicts or flat dicts
try:
return np.array(list(data.values()), dtype=float)
except ValueError:
return np.array([float(v) for v in data.values() if isinstance(v, (int, float))])
return np.array(data, dtype=float)

y_pred = extract_values(predictions)
y_true = extract_values(actuals)

if len(y_pred) != len(y_true):
return {
"success": False,
"error": f"Length mismatch: predictions ({len(y_pred)}) vs actuals ({len(y_true)})",
}

residuals = y_true - y_pred
mae = float(np.mean(np.abs(residuals)))
mse = float(np.mean(residuals ** 2))
rmse = float(np.sqrt(mse))
bias = float(np.mean(residuals))

return {
"success": True,
"metrics": {
"MAE": mae,
"RMSE": rmse,
"Mean_Bias": bias,
},
"residuals": [float(r) for r in residuals],
"diagnosis": (
f"The model has a mean bias of {bias:.4f}. "
f"A positive bias means the model under-predicts on average, "
f"while a negative bias means it over-predicts. "
f"Average absolute error (MAE) is {mae:.4f}."
)
}
except Exception as e:
logger.exception("Error during diagnose_residuals")
return {"success": False, "error": str(e)}
6 changes: 6 additions & 0 deletions src/sktime_mcp/tools/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def cleanup_old_jobs_tool(max_age_hours: int = 24) -> dict[str, Any]:
Returns:
Dictionary with number of jobs removed
"""
if max_age_hours <= 0:
return {
"success": False,
"error": "max_age_hours must be a positive number.",
}

job_manager = get_job_manager()

count = job_manager.cleanup_old_jobs(max_age_hours)
Expand Down