Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ If you are using a virtual environment, or if `sktime-mcp` is not on your `PATH`
}
```

## ⚙️ Configuration

The server can be configured via environment variables:

| Environment Variable | Description | Default |
|----------------------|-------------|---------|
| `SKTIME_MCP_LOG_LEVEL` | Logging verbosity (e.g. `INFO`, `DEBUG`, `WARNING`) | `"WARNING"` |
| `SKTIME_MCP_LOG_PATH` | Optional file path to output logs to in addition to stderr | (None) |
| `SKTIME_MCP_AUTO_FORMAT` | Automatically format time series data on load (`true`/`false`) | `"true"` |
| `SKTIME_MCP_JOB_MAX_AGE_HOURS` | Maximum age in hours before background jobs are cleared | `24` |
| `SKTIME_MCP_JOB_CLEANUP_INTERVAL` | Interval in seconds for periodic job cleanup checks | `3600` |

## 📚 Available Tools

### Discovery & Search
Expand Down
62 changes: 62 additions & 0 deletions src/sktime_mcp/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Configuration module for sktime-mcp.

Centralizes environment variables and provides sensible defaults.
"""

import os


class Settings:
"""Server and runtime configuration settings."""

# -- Runtime & Server Settings --
@property
def log_level(self) -> str:
"""
Logging level.
Env Var: SKTIME_MCP_LOG_LEVEL
Default: "WARNING"
"""
return os.environ.get("SKTIME_MCP_LOG_LEVEL", "WARNING").upper()

@property
def log_path(self) -> str | None:
"""
Optional file path to output logs to in addition to stderr.
Env Var: SKTIME_MCP_LOG_PATH
Default: None
"""
return os.environ.get("SKTIME_MCP_LOG_PATH")

# -- Data Formatting --
@property
def auto_format(self) -> bool:
"""
Whether to automatically format time series data upon load.
Env Var: SKTIME_MCP_AUTO_FORMAT
Default: True
"""
return os.environ.get("SKTIME_MCP_AUTO_FORMAT", "true").lower() == "true"

# -- Job Management --
@property
def job_max_age_hours(self) -> int:
"""
Maximum age in hours before a job is cleaned up.
Env Var: SKTIME_MCP_JOB_MAX_AGE_HOURS
Default: 24
"""
return int(os.environ.get("SKTIME_MCP_JOB_MAX_AGE_HOURS", "24"))

@property
def job_cleanup_interval_secs(self) -> int:
"""
Interval in seconds for periodic job cleanup.
Env Var: SKTIME_MCP_JOB_CLEANUP_INTERVAL
Default: 3600
"""
return int(os.environ.get("SKTIME_MCP_JOB_CLEANUP_INTERVAL", "3600"))


settings = Settings()
7 changes: 3 additions & 4 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import asyncio
import inspect
import logging
import os
import uuid
from typing import Any

Expand Down Expand Up @@ -54,9 +53,9 @@ def __init__(self):
self._handle_manager = get_handle_manager()
self._job_manager = get_job_manager()
self._data_handles = {} # Store data handles
self._auto_format_enabled = (
os.environ.get("SKTIME_MCP_AUTO_FORMAT", "true").lower() == "true"
)
from sktime_mcp.config import settings

self._auto_format_enabled = settings.auto_format

def instantiate(
self,
Expand Down
34 changes: 21 additions & 13 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
import asyncio
import json
import logging
import os
import sys
from io import TextIOWrapper
from typing import Any

import anyio

try:
import numpy as np

Expand All @@ -30,6 +33,7 @@
from mcp.types import TextContent, Tool

from sktime_mcp.composition.validator import get_composition_validator
from sktime_mcp.config import settings
from sktime_mcp.tools.codegen import export_code_tool
from sktime_mcp.tools.data_tools import (
load_data_source_async_tool,
Expand Down Expand Up @@ -62,19 +66,15 @@
)
from sktime_mcp.tools.save_model import save_model_tool

# ---------------------------------------------------------------------------
# Server configuration via environment variables
# ---------------------------------------------------------------------------
JOB_MAX_AGE_HOURS = int(os.environ.get("SKTIME_MCP_JOB_MAX_AGE_HOURS", "24"))
JOB_CLEANUP_INTERVAL_SECS = int(os.environ.get("SKTIME_MCP_JOB_CLEANUP_INTERVAL", "3600"))

# Configure logging to stderr with detailed format
_LOG_LEVEL = os.environ.get("SKTIME_MCP_LOG_LEVEL", "WARNING").upper()
_handlers: list[logging.Handler] = [logging.StreamHandler(sys.stderr)]
if settings.log_path:
_handlers.append(logging.FileHandler(settings.log_path))

logging.basicConfig(
level=getattr(logging, _LOG_LEVEL, logging.WARNING),
level=getattr(logging, settings.log_level, logging.WARNING),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
handlers=_handlers,
)
logger = logging.getLogger(__name__)
# Create MCP server instance
Expand Down Expand Up @@ -810,10 +810,10 @@ async def _periodic_job_cleanup():
from sktime_mcp.runtime.jobs import get_job_manager

while True:
await asyncio.sleep(JOB_CLEANUP_INTERVAL_SECS)
await asyncio.sleep(settings.job_cleanup_interval_secs)
try:
job_manager = get_job_manager()
removed = job_manager.cleanup_old_jobs(JOB_MAX_AGE_HOURS)
removed = job_manager.cleanup_old_jobs(settings.job_max_age_hours)
if removed:
logger.info(f"Periodic cleanup: removed {removed} old job(s)")
except Exception:
Expand All @@ -822,9 +822,17 @@ async def _periodic_job_cleanup():

async def run_server():
"""Run the MCP server."""
# Stdio safety: redirect stdout to stderr to protect MCP JSON-RPC
# streams from being corrupted by stray prints in third-party libraries.
original_stdout = sys.stdout
sys.stdout = sys.stderr

# Explicitly wrap the original stdout buffer for the MCP server output
mcp_stdout = anyio.wrap_file(TextIOWrapper(original_stdout.buffer, encoding="utf-8"))

asyncio.create_task(_periodic_job_cleanup())

async with stdio_server() as (read_stream, write_stream):
async with stdio_server(stdout=mcp_stdout) as (read_stream, write_stream):
await server.run(read_stream, write_stream, server.create_initialization_options())


Expand Down
1 change: 0 additions & 1 deletion tests/test_data_target_column_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,3 @@ def test_load_data_source_returns_error_for_missing_explicit_target_column():
assert result["error_type"] == "ValueError"
assert "Target column 'sales' not found" in result["error"]
assert "value" in result["error"]

Loading