Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 28 additions & 60 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,12 @@
from sktime_mcp.composition.validator import get_composition_validator
from sktime_mcp.tools.codegen import export_code_tool
from sktime_mcp.tools.data_tools import (
load_data_source_async_tool,
load_data_source_tool,
release_data_handle_tool,
)
from sktime_mcp.tools.describe_estimator import describe_estimator_tool
from sktime_mcp.tools.evaluate import evaluate_estimator_tool
from sktime_mcp.tools.fit_predict import (
fit_predict_async_tool,
fit_predict_tool,
)
from sktime_mcp.tools.fit_predict import fit_predict_tool
from sktime_mcp.tools.format_tools import format_time_series_tool
from sktime_mcp.tools.instantiate import (
instantiate_estimator_tool,
Expand Down Expand Up @@ -317,36 +313,10 @@ async def list_tools() -> list[Tool]:
"description": "Forecast horizon (default: 12)",
"default": 12,
},
},
"required": ["estimator_handle"],
},
),
Tool(
name="fit_predict_async",
description=(
"Fit an estimator and generate predictions in the background. "
"Provide exactly ONE of 'dataset' (built-in demo name) "
"or 'data_handle' (from load_data_source)."
),
inputSchema={
"type": "object",
"properties": {
"estimator_handle": {
"type": "string",
"description": "Handle from instantiate_estimator",
},
"dataset": {
"type": "string",
"description": "Demo dataset name: airline, sunspots, lynx, etc.",
},
"data_handle": {
"type": "string",
"description": "Data handle from load_data_source (e.g. 'data_abc123')",
},
"horizon": {
"type": "integer",
"description": "Forecast horizon (default: 12)",
"default": 12,
"background": {
"type": "boolean",
"description": "Run in the background as a job (default: false)",
"default": False,
},
},
"required": ["estimator_handle"],
Expand Down Expand Up @@ -423,24 +393,10 @@ async def list_tools() -> list[Tool]:
"(pandas, sql, file, url)."
),
},
},
"required": ["config"],
},
),
Tool(
name="load_data_source_async",
description=(
"Load data from any source in the background "
"(non-blocking). Returns a job_id to track "
"progress. The data_handle is available in "
"the job result when completed."
),
inputSchema={
"type": "object",
"properties": {
"config": {
"type": "object",
"description": "Data source configuration. Same format as load_data_source.",
"background": {
"type": "boolean",
"description": "Load data in the background (default: false)",
"default": False,
},
},
"required": ["config"],
Expand Down Expand Up @@ -675,17 +631,21 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
elif name == "fit_predict":
result = fit_predict_tool(
arguments["estimator_handle"],
arguments.get("dataset", ""),
arguments.get("dataset"),
arguments.get("horizon", 12),
data_handle=arguments.get("data_handle"),
background=arguments.get("background", False),
)

elif name == "fit_predict_async":
result = fit_predict_async_tool(
estimator_handle=arguments["estimator_handle"],
dataset=arguments.get("dataset"),
# Deprecated: Route to unified fit_predict
logger.warning("fit_predict_async is deprecated; use fit_predict(background=true)")
result = fit_predict_tool(
arguments["estimator_handle"],
arguments.get("dataset"),
arguments.get("horizon", 12),
data_handle=arguments.get("data_handle"),
horizon=arguments.get("horizon", 12),
background=True,
)

elif name == "evaluate_estimator":
Expand All @@ -705,10 +665,18 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
elif name == "list_available_data":
result = list_available_data_tool(arguments.get("is_demo"))
elif name == "load_data_source":
result = load_data_source_tool(arguments["config"])
result = load_data_source_tool(
arguments["config"],
background=arguments.get("background", False),
)

elif name == "load_data_source_async":
result = load_data_source_async_tool(arguments["config"])
# Deprecated: Route to unified load_data_source
logger.warning("load_data_source_async is deprecated; use load_data_source(background=true)")
result = load_data_source_tool(
arguments["config"],
background=True,
)

elif name == "list_data_sources":
# Deprecated — info is now in load_data_source description
Expand Down
8 changes: 1 addition & 7 deletions src/sktime_mcp/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@

from sktime_mcp.tools.codegen import export_code_tool
from sktime_mcp.tools.data_tools import (
load_data_source_async_tool,
load_data_source_tool,
release_data_handle_tool,
)
from sktime_mcp.tools.describe_estimator import describe_estimator_tool
from sktime_mcp.tools.evaluate import evaluate_estimator_tool
from sktime_mcp.tools.fit_predict import (
fit_predict_async_tool,
fit_predict_tool,
)
from sktime_mcp.tools.fit_predict import fit_predict_tool
from sktime_mcp.tools.format_tools import format_time_series_tool
from sktime_mcp.tools.instantiate import (
instantiate_estimator_tool,
Expand Down Expand Up @@ -42,10 +38,8 @@
"release_handle_tool",
"load_model_tool",
"fit_predict_tool",
"fit_predict_async_tool",
"evaluate_estimator_tool",
"load_data_source_tool",
"load_data_source_async_tool",
"release_data_handle_tool",
"list_available_data_tool",
"format_time_series_tool",
Expand Down
160 changes: 56 additions & 104 deletions src/sktime_mcp/tools/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,72 @@
logger = logging.getLogger(__name__)


def load_data_source_tool(config: dict[str, Any]) -> dict[str, Any]:
def load_data_source_tool(
config: dict[str, Any],
background: bool = False,
) -> dict[str, Any]:
"""
Load data from any source (pandas, SQL, file, etc.).
Load data from any source (Sync or Async).

Supported source types: pandas, SQL, file, url.

Args:
config: Data source configuration
{
"type": "pandas" | "sql" | "file" | "url",
... (type-specific configuration)
}
background: If True, runs the loading in the background and returns a job_id

Returns:
Dictionary with:
- success: bool
- data_handle: str (handle ID for the loaded data)
- metadata: dict (information about the data)
- validation: dict (validation results)
If background=False (default):
Dictionary with data_handle and metadata.
If background=True:
Dictionary with success status, job_id, and tracking message.

Examples:
# Pandas DataFrame
>>> load_data_source_tool({
... "type": "pandas",
... "data": {"date": [...], "value": [...]},
... "time_column": "date",
... "target_column": "value"
... })

# SQL Database
>>> load_data_source_tool({
... "type": "sql",
... "connection_string": "postgresql://user:pass@host:5432/db",
... "query": "SELECT date, value FROM sales",
... "time_column": "date",
... "target_column": "value"
... })

# CSV File
>>> load_data_source_tool({
... "type": "file",
... "path": "/path/to/data.csv",
... "time_column": "date",
... "target_column": "value"
... })
# Sync loading (default)
>>> load_data_source_tool({"type": "file", "path": "data.csv"})

# Async loading
>>> load_data_source_tool({"type": "file", "path": "large.csv"}, background=True)
"""
executor = get_executor()
return executor.load_data_source(config)

if not background:
return executor.load_data_source(config)

# --- Async Logic ---
import asyncio

from sktime_mcp.runtime.jobs import get_job_manager

job_manager = get_job_manager()
source_type = config.get("type", "unknown")

# create a background job for data loading
job_id = job_manager.create_job(
job_type="data_loading",
estimator_handle="",
dataset_name=source_type,
total_steps=3, # load, validate, format
)

# schedule on event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

coro = executor.load_data_source_async(config, job_id)
asyncio.run_coroutine_threadsafe(coro, loop)

return {
"success": True,
"job_id": job_id,
"message": (
f"Data loading job started for source type '{source_type}'. "
f"Use check_job_status('{job_id}') to monitor progress."
),
"source_type": source_type,
}


def list_data_sources_tool() -> dict[str, Any]:
Expand Down Expand Up @@ -104,71 +124,3 @@ def release_data_handle_tool(data_handle: str) -> dict[str, Any]:
return executor.release_data_handle(data_handle)


def load_data_source_async_tool(
config: dict[str, Any],
) -> dict[str, Any]:
"""
Load data from any source in the background (non-blocking).

Schedules the data loading as a background job and returns
immediately with a job_id. Use check_job_status to monitor
progress and retrieve the data_handle when done.

Args:
config: Data source configuration (same as load_data_source)

Returns:
Dictionary with:
- success: bool
- job_id: Job ID for tracking progress
- message: Information about the job

Example:
>>> load_data_source_async_tool({
... "type": "file",
... "path": "/path/to/large_data.csv",
... "time_column": "date",
... "target_column": "value"
... })
{
"success": True,
"job_id": "abc-123-def-456",
"message": "Data loading job started..."
}
"""
import asyncio

from sktime_mcp.runtime.jobs import get_job_manager

executor = get_executor()
job_manager = get_job_manager()

source_type = config.get("type", "unknown")

# create a background job for data loading
job_id = job_manager.create_job(
job_type="data_loading",
estimator_handle="",
dataset_name=source_type,
total_steps=3, # load, validate, format
)

# schedule on event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

coro = executor.load_data_source_async(config, job_id)
asyncio.run_coroutine_threadsafe(coro, loop)

return {
"success": True,
"job_id": job_id,
"message": (
f"Data loading job started for source type '{source_type}'. "
f"Use check_job_status('{job_id}') to monitor progress."
),
"source_type": source_type,
}
Loading