Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 10 additions & 23 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
)
from sktime_mcp.tools.codegen import export_code_tool
from sktime_mcp.tools.data_tools import (
load_data_source_async_tool,
load_data_source_tool,
release_data_handle_tool,
)
Expand Down Expand Up @@ -459,6 +458,7 @@ async def list_tools() -> list[Tool]:
name="load_data_source",
description=(
"Load data from various sources into a data handle for forecasting. "
"Can run synchronously (blocking) or asynchronously in the background. "
"Supported source types: "
"'pandas' - from a dict or inline data (keys: data, time_column, target_column). "
"'file' - from CSV, Excel (.xlsx), or Parquet (keys: path, time_column, target_column). "
Expand All @@ -481,24 +481,14 @@ async def list_tools() -> list[Tool]:
"(pandas, sql, file, url)."
),
},
},
"required": ["config"],
},
),
Tool(
name="load_data_source_async",
description=(
"Load data from any source in the background "
"(non-blocking). Returns a job_id to track "
"progress. The data_handle is available in "
"the job result when completed."
),
inputSchema={
"type": "object",
"properties": {
"config": {
"type": "object",
"description": "Data source configuration. Same format as load_data_source.",
"run_async": {
"type": "boolean",
"description": (
"If True, loads data in the background (non-blocking) and "
"returns a job_id. If False (default), blocks and returns the "
"data_handle directly."
),
"default": False,
},
},
"required": ["config"],
Expand Down Expand Up @@ -823,10 +813,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
elif name == "list_available_data":
result = list_available_data_tool(arguments.get("is_demo"))
elif name == "load_data_source":
result = load_data_source_tool(arguments["config"])

elif name == "load_data_source_async":
result = load_data_source_async_tool(arguments["config"])
result = load_data_source_tool(arguments["config"], arguments.get("run_async", False))

elif name == "list_data_sources":
# Deprecated — info is now in load_data_source description
Expand Down
2 changes: 0 additions & 2 deletions src/sktime_mcp/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from sktime_mcp.tools.codegen import export_code_tool
from sktime_mcp.tools.data_tools import (
load_data_source_async_tool,
load_data_source_tool,
release_data_handle_tool,
)
Expand Down Expand Up @@ -45,7 +44,6 @@
"fit_predict_async_tool",
"evaluate_estimator_tool",
"load_data_source_tool",
"load_data_source_async_tool",
"release_data_handle_tool",
"list_available_data_tool",
"format_time_series_tool",
Expand Down
268 changes: 133 additions & 135 deletions src/sktime_mcp/tools/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,63 +12,128 @@
logger = logging.getLogger(__name__)


def load_data_source_tool(config: dict[str, Any]) -> dict[str, Any]:
"""
Load data from any source (pandas, SQL, file, etc.).

Args:
config: Data source configuration
{
"type": "pandas" | "sql" | "file" | "url",
... (type-specific configuration)
}

Returns:
Dictionary with:
- success: bool
- data_handle: str (handle ID for the loaded data)
- metadata: dict (information about the data)
- validation: dict (validation results)

Examples:
# Pandas DataFrame
>>> load_data_source_tool({
... "type": "pandas",
... "data": {"date": [...], "value": [...]},
... "time_column": "date",
... "target_column": "value"
... })

# SQL Database
>>> load_data_source_tool({
... "type": "sql",
... "connection_string": "postgresql://user:pass@host:5432/db",
... "query": "SELECT date, value FROM sales",
... "time_column": "date",
... "target_column": "value"
... })

# CSV File
>>> load_data_source_tool({
... "type": "file",
... "path": "/path/to/data.csv",
... "time_column": "date",
... "target_column": "value"
... })
def load_data_source_tool(
config: dict[str, Any],
run_async: bool = False,
) -> dict[str, Any]:
"""Load data from any source (pandas, SQL, file, etc.).

Can run synchronously (blocking) or asynchronously in the background.

Parameters
----------
config : dict
Data source configuration dictionary. Must contain:
- "type" : str
Source type: "pandas", "sql", "file", or "url".
- Additional type-specific configuration keys (e.g. "data", "path",
"time_column", "target_column").
run_async : bool, default=False
If True, schedules the loading as a background job and
returns a job_id immediately. If False, blocks until loaded
and returns the data_handle directly.

Returns
-------
dict
Dictionary containing load results and metadata.
If run_async is False, contains:
- "success" : bool
True if the data was loaded successfully.
- "data_handle" : str
The unique handle ID for the loaded data.
- "metadata" : dict
Rich metadata including row count, columns, and data type information.
- "validation" : dict
Results of indexing and format validation checks.

If run_async is True, contains:
- "success" : bool
True if the background job was scheduled successfully.
- "job_id" : str
Unique job ID to monitor progress via check_job_status.
- "message" : str
A user-friendly status message.
- "source_type" : str
The type of the source requested to load.

Examples
--------
# Synchronous Pandas DataFrame Loading
>>> load_data_source_tool({
... "type": "pandas",
... "data": {"date": [...], "value": [...]},
... "time_column": "date",
... "target_column": "value"
... })

# Asynchronous CSV File Loading
>>> load_data_source_tool({
... "type": "file",
... "path": "/path/to/data.csv",
... "time_column": "date",
... "target_column": "value"
... }, run_async=True)
"""
executor = get_executor()
return executor.load_data_source(config)
if run_async:
import asyncio

from sktime_mcp.runtime.jobs import get_job_manager

def list_data_sources_tool() -> dict[str, Any]:
"""
List all available data source types.
executor = get_executor()
job_manager = get_job_manager()

source_type = config.get("type", "unknown")

# create a background job for data loading
job_id = job_manager.create_job(
job_type="data_loading",
estimator_handle="",
dataset_name=source_type,
total_steps=3, # load, validate, format
)

coro = executor.load_data_source_async(config, job_id)

Returns:
Dictionary with:
- success: bool
- sources: list of available source types
- descriptions: dict with descriptions for each source type
# Schedule the async coroutine on the event loop
try:
loop = asyncio.get_running_loop()
loop.create_task(coro)
except RuntimeError:
# No running event loop (e.g. sync test or CLI environment)
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(coro)
finally:
loop.close()

return {
"success": True,
"job_id": job_id,
"message": (
f"Data loading job started for source type '{source_type}'. "
f"Use check_job_status('{job_id}') to monitor progress."
),
"source_type": source_type,
}
else:
executor = get_executor()
return executor.load_data_source(config)


def list_data_sources_tool() -> dict[str, Any]:
"""List all available data source types.

Returns
-------
dict
Dictionary containing available data sources:
- "success" : bool
True if the list was retrieved successfully.
- "sources" : list of str
List of supported source type names.
- "descriptions" : dict
A mapping of source type names to their class and descriptions.
"""
from sktime_mcp.data import DataSourceRegistry

Expand All @@ -91,88 +156,21 @@ def list_data_sources_tool() -> dict[str, Any]:


def release_data_handle_tool(data_handle: str) -> dict[str, Any]:
"""
Release a data handle and free memory.

Args:
data_handle: Data handle to release

Returns:
Dictionary with success status
"""Release a data handle and free memory.

Parameters
----------
data_handle : str
Data handle to release.

Returns
-------
dict
Dictionary containing success status:
- "success" : bool
True if the handle was successfully released, False otherwise.
- "message" : str, optional
Detailed status or error message.
"""
executor = get_executor()
return executor.release_data_handle(data_handle)


def load_data_source_async_tool(
config: dict[str, Any],
) -> dict[str, Any]:
"""
Load data from any source in the background (non-blocking).

Schedules the data loading as a background job and returns
immediately with a job_id. Use check_job_status to monitor
progress and retrieve the data_handle when done.

Args:
config: Data source configuration (same as load_data_source)

Returns:
Dictionary with:
- success: bool
- job_id: Job ID for tracking progress
- message: Information about the job

Example:
>>> load_data_source_async_tool({
... "type": "file",
... "path": "/path/to/large_data.csv",
... "time_column": "date",
... "target_column": "value"
... })
{
"success": True,
"job_id": "abc-123-def-456",
"message": "Data loading job started..."
}
"""
import asyncio

from sktime_mcp.runtime.jobs import get_job_manager

executor = get_executor()
job_manager = get_job_manager()

source_type = config.get("type", "unknown")

# create a background job for data loading
job_id = job_manager.create_job(
job_type="data_loading",
estimator_handle="",
dataset_name=source_type,
total_steps=3, # load, validate, format
)

coro = executor.load_data_source_async(config, job_id)

# Schedule the async coroutine on the event loop
try:
loop = asyncio.get_running_loop()
loop.create_task(coro)
except RuntimeError:
# No running event loop (e.g. sync test or CLI environment)
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(coro)
finally:
loop.close()

return {
"success": True,
"job_id": job_id,
"message": (
f"Data loading job started for source type '{source_type}'. "
f"Use check_job_status('{job_id}') to monitor progress."
),
"source_type": source_type,
}
4 changes: 2 additions & 2 deletions tests/test_async_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from sktime_mcp.runtime.executor import get_executor
from sktime_mcp.runtime.jobs import JobStatus, get_job_manager
from sktime_mcp.tools.data_tools import load_data_source_async_tool
from sktime_mcp.tools.data_tools import load_data_source_tool


class TestAsyncDataLoadingTool(unittest.TestCase):
Expand All @@ -27,7 +27,7 @@ def test_returns_job_id(self):
"time_column": "date",
"target_column": "value",
}
result = load_data_source_async_tool(config)
result = load_data_source_tool(config, run_async=True)
self.assertTrue(result["success"])
self.assertIn("job_id", result)
self.assertEqual(result["source_type"], "pandas")
Expand Down
Loading