diff --git a/README.md b/README.md index 728ffab4..2778db47 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,18 @@ If you are using a virtual environment, or if `sktime-mcp` is not on your `PATH` } ``` +## ⚙️ Configuration + +The server can be configured via environment variables: + +| Environment Variable | Description | Default | +|----------------------|-------------|---------| +| `SKTIME_MCP_LOG_LEVEL` | Logging verbosity (e.g. `INFO`, `DEBUG`, `WARNING`) | `"WARNING"` | +| `SKTIME_MCP_LOG_PATH` | Optional file path to output logs to in addition to stderr | (None) | +| `SKTIME_MCP_AUTO_FORMAT` | Automatically format time series data on load (`true`/`false`) | `"true"` | +| `SKTIME_MCP_JOB_MAX_AGE_HOURS` | Maximum age in hours before background jobs are cleared | `24` | +| `SKTIME_MCP_JOB_CLEANUP_INTERVAL` | Interval in seconds for periodic job cleanup checks | `3600` | + ## 📚 Available Tools ### Discovery & Search diff --git a/src/sktime_mcp/config.py b/src/sktime_mcp/config.py new file mode 100644 index 00000000..7631a172 --- /dev/null +++ b/src/sktime_mcp/config.py @@ -0,0 +1,62 @@ +""" +Configuration module for sktime-mcp. + +Centralizes environment variables and provides sensible defaults. +""" + +import os + + +class Settings: + """Server and runtime configuration settings.""" + + # -- Runtime & Server Settings -- + @property + def log_level(self) -> str: + """ + Logging level. + Env Var: SKTIME_MCP_LOG_LEVEL + Default: "WARNING" + """ + return os.environ.get("SKTIME_MCP_LOG_LEVEL", "WARNING").upper() + + @property + def log_path(self) -> str | None: + """ + Optional file path to output logs to in addition to stderr. + Env Var: SKTIME_MCP_LOG_PATH + Default: None + """ + return os.environ.get("SKTIME_MCP_LOG_PATH") + + # -- Data Formatting -- + @property + def auto_format(self) -> bool: + """ + Whether to automatically format time series data upon load. + Env Var: SKTIME_MCP_AUTO_FORMAT + Default: True + """ + return os.environ.get("SKTIME_MCP_AUTO_FORMAT", "true").lower() == "true" + + # -- Job Management -- + @property + def job_max_age_hours(self) -> int: + """ + Maximum age in hours before a job is cleaned up. + Env Var: SKTIME_MCP_JOB_MAX_AGE_HOURS + Default: 24 + """ + return int(os.environ.get("SKTIME_MCP_JOB_MAX_AGE_HOURS", "24")) + + @property + def job_cleanup_interval_secs(self) -> int: + """ + Interval in seconds for periodic job cleanup. + Env Var: SKTIME_MCP_JOB_CLEANUP_INTERVAL + Default: 3600 + """ + return int(os.environ.get("SKTIME_MCP_JOB_CLEANUP_INTERVAL", "3600")) + + +settings = Settings() diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 0dc2673f..9c8b09fc 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -8,7 +8,6 @@ import asyncio import inspect import logging -import os import uuid from typing import Any @@ -54,9 +53,9 @@ def __init__(self): self._handle_manager = get_handle_manager() self._job_manager = get_job_manager() self._data_handles = {} # Store data handles - self._auto_format_enabled = ( - os.environ.get("SKTIME_MCP_AUTO_FORMAT", "true").lower() == "true" - ) + from sktime_mcp.config import settings + + self._auto_format_enabled = settings.auto_format def instantiate( self, diff --git a/src/sktime_mcp/server.py b/src/sktime_mcp/server.py index 8c8a9471..790cbcd3 100644 --- a/src/sktime_mcp/server.py +++ b/src/sktime_mcp/server.py @@ -8,9 +8,12 @@ import asyncio import json import logging -import os +import sys +from io import TextIOWrapper from typing import Any +import anyio + try: import numpy as np @@ -30,6 +33,7 @@ from mcp.types import TextContent, Tool from sktime_mcp.composition.validator import get_composition_validator +from sktime_mcp.config import settings from sktime_mcp.tools.codegen import export_code_tool from sktime_mcp.tools.data_tools import ( load_data_source_async_tool, @@ -62,19 +66,15 @@ ) from sktime_mcp.tools.save_model import save_model_tool -# --------------------------------------------------------------------------- -# Server configuration via environment variables -# --------------------------------------------------------------------------- -JOB_MAX_AGE_HOURS = int(os.environ.get("SKTIME_MCP_JOB_MAX_AGE_HOURS", "24")) -JOB_CLEANUP_INTERVAL_SECS = int(os.environ.get("SKTIME_MCP_JOB_CLEANUP_INTERVAL", "3600")) - # Configure logging to stderr with detailed format -_LOG_LEVEL = os.environ.get("SKTIME_MCP_LOG_LEVEL", "WARNING").upper() +_handlers: list[logging.Handler] = [logging.StreamHandler(sys.stderr)] +if settings.log_path: + _handlers.append(logging.FileHandler(settings.log_path)) logging.basicConfig( - level=getattr(logging, _LOG_LEVEL, logging.WARNING), + level=getattr(logging, settings.log_level, logging.WARNING), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()], + handlers=_handlers, ) logger = logging.getLogger(__name__) # Create MCP server instance @@ -810,10 +810,10 @@ async def _periodic_job_cleanup(): from sktime_mcp.runtime.jobs import get_job_manager while True: - await asyncio.sleep(JOB_CLEANUP_INTERVAL_SECS) + await asyncio.sleep(settings.job_cleanup_interval_secs) try: job_manager = get_job_manager() - removed = job_manager.cleanup_old_jobs(JOB_MAX_AGE_HOURS) + removed = job_manager.cleanup_old_jobs(settings.job_max_age_hours) if removed: logger.info(f"Periodic cleanup: removed {removed} old job(s)") except Exception: @@ -822,9 +822,17 @@ async def _periodic_job_cleanup(): async def run_server(): """Run the MCP server.""" + # Stdio safety: redirect stdout to stderr to protect MCP JSON-RPC + # streams from being corrupted by stray prints in third-party libraries. + original_stdout = sys.stdout + sys.stdout = sys.stderr + + # Explicitly wrap the original stdout buffer for the MCP server output + mcp_stdout = anyio.wrap_file(TextIOWrapper(original_stdout.buffer, encoding="utf-8")) + asyncio.create_task(_periodic_job_cleanup()) - async with stdio_server() as (read_stream, write_stream): + async with stdio_server(stdout=mcp_stdout) as (read_stream, write_stream): await server.run(read_stream, write_stream, server.create_initialization_options()) diff --git a/tests/test_data_target_column_validation.py b/tests/test_data_target_column_validation.py index ead2be77..2772df8a 100644 --- a/tests/test_data_target_column_validation.py +++ b/tests/test_data_target_column_validation.py @@ -50,4 +50,3 @@ def test_load_data_source_returns_error_for_missing_explicit_target_column(): assert result["error_type"] == "ValueError" assert "Target column 'sales' not found" in result["error"] assert "value" in result["error"] -