diff --git a/docs/source/dev-guide.md b/docs/source/dev-guide.md index aab9d426..d76fae9b 100644 --- a/docs/source/dev-guide.md +++ b/docs/source/dev-guide.md @@ -6,7 +6,7 @@ This guide explains how the project is structured, how to develop new features, - Python 3.10+ - pip -- Optional: `mkdocs` if you want to build the documentation site +- Optional: Sphinx if you want to build the documentation site ## Setup @@ -114,8 +114,9 @@ Update `DEMO_DATASETS` in `src/sktime_mcp/runtime/executor.py` with a new loader If you want to build the docs site locally: ```bash -pip install mkdocs -mkdocs serve +pip install -r docs/requirements.txt +cd docs +make html ``` -The config lives in `mkdocs.yml` and the docs content is under `docs/`. +The config lives in `docs/source/conf.py` and the docs content is under `docs/source/`. diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 0dc2673f..2a0c79a7 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -131,12 +131,16 @@ def fit( return {"success": False, "error": f"Handle not found: {handle_id}"} try: - if fh is not None: - instance.fit(y, X=X, fh=fh) - elif X is not None: - instance.fit(y, X=X) + is_classifier = getattr(instance, "_estimator_type", None) == "classifier" + if is_classifier: + instance.fit(X, y) else: - instance.fit(y) + if fh is not None: + instance.fit(y, X=X, fh=fh) + elif X is not None: + instance.fit(y, X=X) + else: + instance.fit(y) self._handle_manager.mark_fitted(handle_id) return {"success": True, "handle": handle_id, "fitted": True} @@ -159,20 +163,27 @@ def predict( return {"success": False, "error": "Estimator not fitted"} try: - if fh is None: - fh = list(range(1, 13)) + is_classifier = getattr(instance, "_estimator_type", None) == "classifier" + if is_classifier: + predictions = instance.predict(X) + else: + if fh is None: + fh = list(range(1, 13)) - predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh) + predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh) if isinstance(predictions, pd.Series): # Convert index to string to avoid JSON serialization issues with Period/DatetimeIndex predictions_copy = predictions.copy() predictions_copy.index = predictions_copy.index.astype(str) - result = predictions_copy.to_dict() + # Convert values to native Python types for JSON serialization + result = {k: float(v) if hasattr(v, "item") else v for k, v in predictions_copy.to_dict().items()} elif isinstance(predictions, pd.DataFrame): predictions_copy = predictions.copy() predictions_copy.index = predictions_copy.index.astype(str) - result = predictions_copy.to_dict(orient="list") + # Convert lists of values to native Python types + raw_dict = predictions_copy.to_dict(orient="list") + result = {k: [float(v_i) if hasattr(v_i, "item") else v_i for v_i in v_list] for k, v_list in raw_dict.items()} else: result = predictions.tolist() if hasattr(predictions, "tolist") else predictions @@ -577,6 +588,8 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]: metadata["exog_columns"] = list(X.columns) # Inject column dtypes so LLMs can distinguish time index vs target metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()} + # Expose explicit shape metadata for agent reasoning (#416) + metadata["shape"] = list(data.shape) if hasattr(data, "shape") else len(data) # Generate handle data_handle = f"data_{uuid.uuid4().hex[:8]}" @@ -611,6 +624,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]: # Continue with unformatted data if formatting fails _final_meta = adapter.get_metadata().copy() _final_meta["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()} + _final_meta["shape"] = list(data.shape) if hasattr(data, "shape") else len(data) return { "success": True, "data_handle": data_handle, @@ -788,11 +802,16 @@ def format_data_handle( # 3. Infer and set frequency if auto_infer_freq: - freq = y.index.freq + freq = None + if not isinstance(y.index, (pd.DatetimeIndex, pd.PeriodIndex)): + changes_made["frequency_set"] = True + changes_made["frequency"] = "Integer" + else: + freq = getattr(y.index, "freq", None) - if freq is None: - # Try to infer - freq = pd.infer_freq(y.index) + if freq is None: + # Try to infer + freq = pd.infer_freq(y.index) if freq is None: # Manual inference @@ -853,7 +872,7 @@ def format_data_handle( "metadata": { **data_info["metadata"], "formatted": True, - "frequency": str(y.index.freq) if y.index.freq else changes_made.get("frequency"), + "frequency": str(y.index.freq) if getattr(y.index, "freq", None) else changes_made.get("frequency"), "rows": len(y), "start_date": str(y.index.min()), "end_date": str(y.index.max()), diff --git a/src/sktime_mcp/server.py b/src/sktime_mcp/server.py index edee569a..febd185c 100644 --- a/src/sktime_mcp/server.py +++ b/src/sktime_mcp/server.py @@ -37,7 +37,7 @@ release_data_handle_tool, ) from sktime_mcp.tools.describe_estimator import describe_estimator_tool -from sktime_mcp.tools.evaluate import evaluate_estimator_tool +from sktime_mcp.tools.evaluate import evaluate_estimator_tool, diagnose_residuals_tool from sktime_mcp.tools.fit_predict import ( fit_predict_async_tool, fit_predict_tool, @@ -160,6 +160,7 @@ async def list_tools() -> list[Tool]: "properties": { "task": { "type": "string", + "enum": ["forecasting", "classification", "regression", "transformation", "clustering"], "description": ( "Task type filter: forecasting, classification, " "regression, transformation, clustering" @@ -368,6 +369,8 @@ async def list_tools() -> list[Tool]: }, "cv_folds": { "type": "integer", + "minimum": 2, + "maximum": 50, "description": "Number of cross-validation folds (default: 3)", "default": 3, }, @@ -375,6 +378,51 @@ async def list_tools() -> list[Tool]: "required": ["estimator_handle", "dataset"], }, ), + Tool( + name="diagnose_residuals_tool", + description="Diagnose model failures by calculating statistical metrics (MAE, RMSE, Mean Bias) on residuals.", + inputSchema={ + "type": "object", + "properties": { + "predictions": { + "type": ["object", "array"], + "description": "Forecasted values (dict or list)", + }, + "actuals": { + "type": ["object", "array"], + "description": "Actual observed values (dict or list)", + }, + }, + "required": ["predictions", "actuals"], + }, + ), + # -- Batch Execution ------------------------------------------------- + Tool( + name="run_tools_batch", + description=( + "Execute multiple read-only tools in a single batch to reduce latency. " + "Supported tools: list_estimators, describe_estimator, " + "get_available_tags, list_available_data." + ), + inputSchema={ + "type": "object", + "properties": { + "operations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "tool": {"type": "string"}, + "arguments": {"type": "object"} + }, + "required": ["tool", "arguments"] + }, + "description": "List of operations to execute in batch." + } + }, + "required": ["operations"] + } + ), # -- Data ------------------------------------------------------------ Tool( name="list_available_data", @@ -695,12 +743,66 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: arguments.get("cv_folds", 3), ) + elif name == "diagnose_residuals_tool": + result = diagnose_residuals_tool( + predictions=arguments["predictions"], + actuals=arguments["actuals"], + ) + elif name == "validate_pipeline": validator = get_composition_validator() validation = validator.validate_pipeline(arguments["components"]) result = validation.to_dict() result["success"] = result["valid"] + elif name == "run_tools_batch": + batch_results = [] + allowed_tools = {"list_estimators", "describe_estimator", "get_available_tags", "list_available_data"} + + for i, op in enumerate(arguments.get("operations", [])): + tool_name = op.get("tool") + tool_args = op.get("arguments", {}) + + if tool_name not in allowed_tools: + batch_results.append({ + "index": i, + "tool": tool_name, + "success": False, + "error": f"Tool '{tool_name}' is not supported in batch mode. Only read-only tools are allowed." + }) + continue + + try: + if tool_name == "list_estimators": + res = list_estimators_tool( + task=tool_args.get("task"), + tags=tool_args.get("tags"), + query=tool_args.get("query"), + limit=tool_args.get("limit", 50), + offset=tool_args.get("offset", 0), + ) + elif tool_name == "describe_estimator": + res = describe_estimator_tool(tool_args.get("estimator", "")) + elif tool_name == "get_available_tags": + res = get_available_tags() + elif tool_name == "list_available_data": + res = list_available_data_tool(tool_args.get("is_demo")) + + batch_results.append({ + "index": i, + "tool": tool_name, + "success": True, + "result": res + }) + except Exception as e: + batch_results.append({ + "index": i, + "tool": tool_name, + "success": False, + "error": str(e) + }) + result = {"batch_results": batch_results} + # -- Data ------------------------------------------------------------ elif name == "list_available_data": result = list_available_data_tool(arguments.get("is_demo")) @@ -831,7 +933,10 @@ async def run_server(): def main(): """Main entry point.""" logger.info("Starting sktime-mcp server...") - asyncio.run(run_server()) + try: + asyncio.run(run_server()) + except KeyboardInterrupt: + logger.info("Server stopped by user (Ctrl+C)") if __name__ == "__main__": diff --git a/src/sktime_mcp/tools/describe_estimator.py b/src/sktime_mcp/tools/describe_estimator.py index fbbcf36c..45140a58 100644 --- a/src/sktime_mcp/tools/describe_estimator.py +++ b/src/sktime_mcp/tools/describe_estimator.py @@ -42,6 +42,12 @@ def describe_estimator_tool(estimator: str) -> dict[str, Any]: registry = get_registry() tag_resolver = get_tag_resolver() + if not isinstance(estimator, str): + return { + "success": False, + "error": f"'estimator' must be a string, got {type(estimator).__name__}.", + } + node = registry.get_estimator_by_name(estimator) if node is None: # Try case-insensitive search @@ -69,7 +75,7 @@ def describe_estimator_tool(estimator: str) -> dict[str, Any]: "hyperparameters": node.hyperparameters, "tags": node.tags, "tag_explanations": tag_explanations, - "docstring": doc[:500], + "docstring": doc, } @@ -84,6 +90,12 @@ def search_estimators_tool(query: str, limit: int = 20) -> dict[str, Any]: Returns: Dictionary with matching estimators """ + if not isinstance(query, str): + return { + "success": False, + "error": f"'query' must be a string, got {type(query).__name__}.", + } + if limit < 1: return { "success": False, diff --git a/src/sktime_mcp/tools/evaluate.py b/src/sktime_mcp/tools/evaluate.py index e87b41ea..a8555ac3 100644 --- a/src/sktime_mcp/tools/evaluate.py +++ b/src/sktime_mcp/tools/evaluate.py @@ -7,6 +7,7 @@ import logging from typing import Any +import numpy as np from sktime.forecasting.model_evaluation import evaluate try: @@ -74,3 +75,62 @@ def evaluate_estimator_tool( except Exception as e: logger.exception("Error during evaluate") return {"success": False, "error": str(e)} + + +def diagnose_residuals_tool( + predictions: dict[str, Any] | list[float], + actuals: dict[str, Any] | list[float], +) -> dict[str, Any]: + """ + Diagnose residuals by comparing predictions and actuals. + + Args: + predictions: Forecasted values. + actuals: Actual observed values. + + Returns: + Dictionary with statistical metrics (MAE, RMSE, Bias). + """ + try: + def extract_values(data): + if isinstance(data, dict): + # Handle nested dicts or flat dicts + try: + return np.array(list(data.values()), dtype=float) + except ValueError: + return np.array([float(v) for v in data.values() if isinstance(v, (int, float))]) + return np.array(data, dtype=float) + + y_pred = extract_values(predictions) + y_true = extract_values(actuals) + + if len(y_pred) != len(y_true): + return { + "success": False, + "error": f"Length mismatch: predictions ({len(y_pred)}) vs actuals ({len(y_true)})", + } + + residuals = y_true - y_pred + mae = float(np.mean(np.abs(residuals))) + mse = float(np.mean(residuals ** 2)) + rmse = float(np.sqrt(mse)) + bias = float(np.mean(residuals)) + + return { + "success": True, + "metrics": { + "MAE": mae, + "RMSE": rmse, + "Mean_Bias": bias, + }, + "residuals": [float(r) for r in residuals], + "diagnosis": ( + f"The model has a mean bias of {bias:.4f}. " + f"A positive bias means the model under-predicts on average, " + f"while a negative bias means it over-predicts. " + f"Average absolute error (MAE) is {mae:.4f}." + ) + } + except Exception as e: + logger.exception("Error during diagnose_residuals") + return {"success": False, "error": str(e)} diff --git a/src/sktime_mcp/tools/job_tools.py b/src/sktime_mcp/tools/job_tools.py index 915c3781..05966812 100644 --- a/src/sktime_mcp/tools/job_tools.py +++ b/src/sktime_mcp/tools/job_tools.py @@ -134,6 +134,12 @@ def cleanup_old_jobs_tool(max_age_hours: int = 24) -> dict[str, Any]: Returns: Dictionary with number of jobs removed """ + if max_age_hours <= 0: + return { + "success": False, + "error": "max_age_hours must be a positive number.", + } + job_manager = get_job_manager() count = job_manager.cleanup_old_jobs(max_age_hours)