diff --git a/src/sktime_mcp/server.py b/src/sktime_mcp/server.py index edee569a..1d64d88d 100644 --- a/src/sktime_mcp/server.py +++ b/src/sktime_mcp/server.py @@ -32,16 +32,12 @@ from sktime_mcp.composition.validator import get_composition_validator from sktime_mcp.tools.codegen import export_code_tool from sktime_mcp.tools.data_tools import ( - load_data_source_async_tool, load_data_source_tool, release_data_handle_tool, ) from sktime_mcp.tools.describe_estimator import describe_estimator_tool from sktime_mcp.tools.evaluate import evaluate_estimator_tool -from sktime_mcp.tools.fit_predict import ( - fit_predict_async_tool, - fit_predict_tool, -) +from sktime_mcp.tools.fit_predict import fit_predict_tool from sktime_mcp.tools.format_tools import format_time_series_tool from sktime_mcp.tools.instantiate import ( instantiate_estimator_tool, @@ -317,36 +313,10 @@ async def list_tools() -> list[Tool]: "description": "Forecast horizon (default: 12)", "default": 12, }, - }, - "required": ["estimator_handle"], - }, - ), - Tool( - name="fit_predict_async", - description=( - "Fit an estimator and generate predictions in the background. " - "Provide exactly ONE of 'dataset' (built-in demo name) " - "or 'data_handle' (from load_data_source)." - ), - inputSchema={ - "type": "object", - "properties": { - "estimator_handle": { - "type": "string", - "description": "Handle from instantiate_estimator", - }, - "dataset": { - "type": "string", - "description": "Demo dataset name: airline, sunspots, lynx, etc.", - }, - "data_handle": { - "type": "string", - "description": "Data handle from load_data_source (e.g. 'data_abc123')", - }, - "horizon": { - "type": "integer", - "description": "Forecast horizon (default: 12)", - "default": 12, + "background": { + "type": "boolean", + "description": "Run in the background as a job (default: false)", + "default": False, }, }, "required": ["estimator_handle"], @@ -423,24 +393,10 @@ async def list_tools() -> list[Tool]: "(pandas, sql, file, url)." ), }, - }, - "required": ["config"], - }, - ), - Tool( - name="load_data_source_async", - description=( - "Load data from any source in the background " - "(non-blocking). Returns a job_id to track " - "progress. The data_handle is available in " - "the job result when completed." - ), - inputSchema={ - "type": "object", - "properties": { - "config": { - "type": "object", - "description": "Data source configuration. Same format as load_data_source.", + "background": { + "type": "boolean", + "description": "Load data in the background (default: false)", + "default": False, }, }, "required": ["config"], @@ -675,17 +631,21 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: elif name == "fit_predict": result = fit_predict_tool( arguments["estimator_handle"], - arguments.get("dataset", ""), + arguments.get("dataset"), arguments.get("horizon", 12), data_handle=arguments.get("data_handle"), + background=arguments.get("background", False), ) elif name == "fit_predict_async": - result = fit_predict_async_tool( - estimator_handle=arguments["estimator_handle"], - dataset=arguments.get("dataset"), + # Deprecated: Route to unified fit_predict + logger.warning("fit_predict_async is deprecated; use fit_predict(background=true)") + result = fit_predict_tool( + arguments["estimator_handle"], + arguments.get("dataset"), + arguments.get("horizon", 12), data_handle=arguments.get("data_handle"), - horizon=arguments.get("horizon", 12), + background=True, ) elif name == "evaluate_estimator": @@ -705,10 +665,18 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: elif name == "list_available_data": result = list_available_data_tool(arguments.get("is_demo")) elif name == "load_data_source": - result = load_data_source_tool(arguments["config"]) + result = load_data_source_tool( + arguments["config"], + background=arguments.get("background", False), + ) elif name == "load_data_source_async": - result = load_data_source_async_tool(arguments["config"]) + # Deprecated: Route to unified load_data_source + logger.warning("load_data_source_async is deprecated; use load_data_source(background=true)") + result = load_data_source_tool( + arguments["config"], + background=True, + ) elif name == "list_data_sources": # Deprecated — info is now in load_data_source description diff --git a/src/sktime_mcp/tools/__init__.py b/src/sktime_mcp/tools/__init__.py index d8d0f0d7..15941854 100644 --- a/src/sktime_mcp/tools/__init__.py +++ b/src/sktime_mcp/tools/__init__.py @@ -2,16 +2,12 @@ from sktime_mcp.tools.codegen import export_code_tool from sktime_mcp.tools.data_tools import ( - load_data_source_async_tool, load_data_source_tool, release_data_handle_tool, ) from sktime_mcp.tools.describe_estimator import describe_estimator_tool from sktime_mcp.tools.evaluate import evaluate_estimator_tool -from sktime_mcp.tools.fit_predict import ( - fit_predict_async_tool, - fit_predict_tool, -) +from sktime_mcp.tools.fit_predict import fit_predict_tool from sktime_mcp.tools.format_tools import format_time_series_tool from sktime_mcp.tools.instantiate import ( instantiate_estimator_tool, @@ -42,10 +38,8 @@ "release_handle_tool", "load_model_tool", "fit_predict_tool", - "fit_predict_async_tool", "evaluate_estimator_tool", "load_data_source_tool", - "load_data_source_async_tool", "release_data_handle_tool", "list_available_data_tool", "format_time_series_tool", diff --git a/src/sktime_mcp/tools/data_tools.py b/src/sktime_mcp/tools/data_tools.py index 794776a5..9bc370ce 100644 --- a/src/sktime_mcp/tools/data_tools.py +++ b/src/sktime_mcp/tools/data_tools.py @@ -12,52 +12,72 @@ logger = logging.getLogger(__name__) -def load_data_source_tool(config: dict[str, Any]) -> dict[str, Any]: +def load_data_source_tool( + config: dict[str, Any], + background: bool = False, +) -> dict[str, Any]: """ - Load data from any source (pandas, SQL, file, etc.). + Load data from any source (Sync or Async). + + Supported source types: pandas, SQL, file, url. Args: config: Data source configuration - { - "type": "pandas" | "sql" | "file" | "url", - ... (type-specific configuration) - } + background: If True, runs the loading in the background and returns a job_id Returns: - Dictionary with: - - success: bool - - data_handle: str (handle ID for the loaded data) - - metadata: dict (information about the data) - - validation: dict (validation results) + If background=False (default): + Dictionary with data_handle and metadata. + If background=True: + Dictionary with success status, job_id, and tracking message. Examples: - # Pandas DataFrame - >>> load_data_source_tool({ - ... "type": "pandas", - ... "data": {"date": [...], "value": [...]}, - ... "time_column": "date", - ... "target_column": "value" - ... }) - - # SQL Database - >>> load_data_source_tool({ - ... "type": "sql", - ... "connection_string": "postgresql://user:pass@host:5432/db", - ... "query": "SELECT date, value FROM sales", - ... "time_column": "date", - ... "target_column": "value" - ... }) - - # CSV File - >>> load_data_source_tool({ - ... "type": "file", - ... "path": "/path/to/data.csv", - ... "time_column": "date", - ... "target_column": "value" - ... }) + # Sync loading (default) + >>> load_data_source_tool({"type": "file", "path": "data.csv"}) + + # Async loading + >>> load_data_source_tool({"type": "file", "path": "large.csv"}, background=True) """ executor = get_executor() - return executor.load_data_source(config) + + if not background: + return executor.load_data_source(config) + + # --- Async Logic --- + import asyncio + + from sktime_mcp.runtime.jobs import get_job_manager + + job_manager = get_job_manager() + source_type = config.get("type", "unknown") + + # create a background job for data loading + job_id = job_manager.create_job( + job_type="data_loading", + estimator_handle="", + dataset_name=source_type, + total_steps=3, # load, validate, format + ) + + # schedule on event loop + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + coro = executor.load_data_source_async(config, job_id) + asyncio.run_coroutine_threadsafe(coro, loop) + + return { + "success": True, + "job_id": job_id, + "message": ( + f"Data loading job started for source type '{source_type}'. " + f"Use check_job_status('{job_id}') to monitor progress." + ), + "source_type": source_type, + } def list_data_sources_tool() -> dict[str, Any]: @@ -104,71 +124,3 @@ def release_data_handle_tool(data_handle: str) -> dict[str, Any]: return executor.release_data_handle(data_handle) -def load_data_source_async_tool( - config: dict[str, Any], -) -> dict[str, Any]: - """ - Load data from any source in the background (non-blocking). - - Schedules the data loading as a background job and returns - immediately with a job_id. Use check_job_status to monitor - progress and retrieve the data_handle when done. - - Args: - config: Data source configuration (same as load_data_source) - - Returns: - Dictionary with: - - success: bool - - job_id: Job ID for tracking progress - - message: Information about the job - - Example: - >>> load_data_source_async_tool({ - ... "type": "file", - ... "path": "/path/to/large_data.csv", - ... "time_column": "date", - ... "target_column": "value" - ... }) - { - "success": True, - "job_id": "abc-123-def-456", - "message": "Data loading job started..." - } - """ - import asyncio - - from sktime_mcp.runtime.jobs import get_job_manager - - executor = get_executor() - job_manager = get_job_manager() - - source_type = config.get("type", "unknown") - - # create a background job for data loading - job_id = job_manager.create_job( - job_type="data_loading", - estimator_handle="", - dataset_name=source_type, - total_steps=3, # load, validate, format - ) - - # schedule on event loop - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - coro = executor.load_data_source_async(config, job_id) - asyncio.run_coroutine_threadsafe(coro, loop) - - return { - "success": True, - "job_id": job_id, - "message": ( - f"Data loading job started for source type '{source_type}'. " - f"Use check_job_status('{job_id}') to monitor progress." - ), - "source_type": source_type, - } diff --git a/src/sktime_mcp/tools/fit_predict.py b/src/sktime_mcp/tools/fit_predict.py index ef8f1b46..92296ee3 100644 --- a/src/sktime_mcp/tools/fit_predict.py +++ b/src/sktime_mcp/tools/fit_predict.py @@ -41,24 +41,26 @@ def _validate_horizon(horizon: int) -> dict[str, Any]: def fit_predict_tool( estimator_handle: str, - dataset: str, + dataset: str | None = None, horizon: int = 12, data_handle: str | None = None, + background: bool = False, ) -> dict[str, Any]: """ - Execute a complete fit-predict workflow. + Execute a complete fit-predict workflow (Sync or Async). Args: estimator_handle: Handle from instantiate_estimator dataset: Name of demo dataset (e.g., "airline", "sunspots") horizon: Forecast horizon (default: 12) data_handle: Optional handle from load_data_source for custom data + background: If True, runs the job in the background and returns a job_id Returns: - Dictionary with: - - success: bool - - predictions: Forecast values - - horizon: Number of steps predicted + If background=False (default): + Dictionary with predictions and success status. + If background=True: + Dictionary with success status, job_id, and tracking message. Example: >>> fit_predict_tool("est_abc123", "airline", horizon=12) @@ -80,7 +82,7 @@ def fit_predict_tool( "error": "Provide either 'dataset' or 'data_handle', not both.", } - if data_handle is None and (not dataset or not str(dataset).strip()): + if not data_handle and (not dataset or not str(dataset).strip()): return { "success": False, "error": ( @@ -88,109 +90,22 @@ def fit_predict_tool( "'data_handle' (from load_data_source) is required." ), } - executor = get_executor() - return executor.fit_predict(estimator_handle, dataset, horizon, data_handle=data_handle) - - -def predict_tool( - estimator_handle: str, - horizon: int = 12, -) -> dict[str, Any]: - """ - Generate predictions from a fitted estimator. - - Args: - estimator_handle: Handle of a fitted estimator - horizon: Forecast horizon - - Returns: - Dictionary with predictions - """ - validation = _validate_horizon(horizon) - if not validation["valid"]: - return { - "success": False, - "error": validation["error"], - } - executor = get_executor() - fh = list(range(1, horizon + 1)) - return executor.predict(estimator_handle, fh=fh) - - -def list_datasets_tool() -> dict[str, Any]: - """ - List available demo datasets. - Returns: - Dictionary with list of dataset names - """ executor = get_executor() - return { - "success": True, - "datasets": executor.list_datasets(), - } - - -def fit_predict_async_tool( - estimator_handle: str, - dataset: str | None = None, - data_handle: str | None = None, - horizon: int = 12, -) -> dict[str, Any]: - """ - Execute a fit-predict workflow in the background (non-blocking). - - Schedules the training as a background job and returns immediately - with a job_id. Use check_job_status to monitor progress. - - Accepts either a demo dataset name or a data handle from - load_data_source -- exactly one must be provided. - Args: - estimator_handle: Handle from instantiate_estimator - dataset: Name of demo dataset (e.g., "airline", "sunspots") - data_handle: Handle from load_data_source (e.g., "data_abc123") - horizon: Forecast horizon (default: 12) - - Returns: - Dictionary with: - - success: bool - - job_id: Job ID for tracking progress - - message: Information about the job - - Example: - >>> fit_predict_async_tool("est_abc123", dataset="airline", horizon=12) - >>> fit_predict_async_tool("est_abc123", data_handle="data_xyz", horizon=5) - """ - validation = _validate_horizon(horizon) - if not validation["valid"]: - return { - "success": False, - "error": validation["error"], - } - if dataset and data_handle: - return { - "success": False, - "error": "Provide either 'dataset' or 'data_handle', not both.", - } - - if not dataset and not data_handle: - return { - "success": False, - "error": ( - "Either 'dataset' (e.g. 'airline') or " - "'data_handle' (from load_data_source) is required." - ), - } + if not background: + return executor.fit_predict( + estimator_handle, str(dataset) if dataset else "", horizon, data_handle=data_handle + ) + # --- Async Logic --- import asyncio from sktime_mcp.runtime.jobs import get_job_manager - executor = get_executor() job_manager = get_job_manager() - # Get estimator info + # Get estimator info for better logging try: handle_info = executor._handle_manager.get_info(estimator_handle) estimator_name = handle_info.estimator_name @@ -237,3 +152,44 @@ def fit_predict_async_tool( "data_source": source_name, "horizon": horizon, } + + +def predict_tool( + estimator_handle: str, + horizon: int = 12, +) -> dict[str, Any]: + """ + Generate predictions from a fitted estimator. + + Args: + estimator_handle: Handle of a fitted estimator + horizon: Forecast horizon + + Returns: + Dictionary with predictions + """ + validation = _validate_horizon(horizon) + if not validation["valid"]: + return { + "success": False, + "error": validation["error"], + } + executor = get_executor() + fh = list(range(1, horizon + 1)) + return executor.predict(estimator_handle, fh=fh) + + +def list_datasets_tool() -> dict[str, Any]: + """ + List available demo datasets. + + Returns: + Dictionary with list of dataset names + """ + executor = get_executor() + return { + "success": True, + "datasets": executor.list_datasets(), + } + + diff --git a/tests/test_async_custom_data.py b/tests/test_async_custom_data.py index b033c55f..dd501628 100644 --- a/tests/test_async_custom_data.py +++ b/tests/test_async_custom_data.py @@ -43,13 +43,14 @@ def _load_custom_data(self): def test_async_with_dataset(self): """Async with a demo dataset should return a job_id.""" - from sktime_mcp.tools.fit_predict import fit_predict_async_tool + from sktime_mcp.tools.fit_predict import fit_predict_tool handle = self._get_estimator_handle() - result = fit_predict_async_tool( + result = fit_predict_tool( estimator_handle=handle, dataset="airline", horizon=3, + background=True, ) assert result["success"], f"Expected success, got: {result}" @@ -58,15 +59,16 @@ def test_async_with_dataset(self): def test_async_with_data_handle(self): """Async with a custom data handle should return a job_id.""" - from sktime_mcp.tools.fit_predict import fit_predict_async_tool + from sktime_mcp.tools.fit_predict import fit_predict_tool handle = self._get_estimator_handle() data_handle = self._load_custom_data() - result = fit_predict_async_tool( + result = fit_predict_tool( estimator_handle=handle, data_handle=data_handle, horizon=5, + background=True, ) assert result["success"], f"Expected success, got: {result}" @@ -75,14 +77,15 @@ def test_async_with_data_handle(self): def test_async_both_provided_error(self): """Providing both dataset and data_handle should fail.""" - from sktime_mcp.tools.fit_predict import fit_predict_async_tool + from sktime_mcp.tools.fit_predict import fit_predict_tool handle = self._get_estimator_handle() - result = fit_predict_async_tool( + result = fit_predict_tool( estimator_handle=handle, dataset="airline", data_handle="data_fake123", horizon=3, + background=True, ) assert not result["success"] @@ -91,12 +94,13 @@ def test_async_both_provided_error(self): def test_async_neither_provided_error(self): """Omitting both dataset and data_handle should fail.""" - from sktime_mcp.tools.fit_predict import fit_predict_async_tool + from sktime_mcp.tools.fit_predict import fit_predict_tool handle = self._get_estimator_handle() - result = fit_predict_async_tool( + result = fit_predict_tool( estimator_handle=handle, horizon=3, + background=True, ) assert not result["success"] @@ -104,13 +108,14 @@ def test_async_neither_provided_error(self): def test_async_invalid_data_handle(self): """An invalid data_handle should fail at the executor level.""" - from sktime_mcp.tools.fit_predict import fit_predict_async_tool + from sktime_mcp.tools.fit_predict import fit_predict_tool handle = self._get_estimator_handle() - result = fit_predict_async_tool( + result = fit_predict_tool( estimator_handle=handle, data_handle="data_nonexistent", horizon=3, + background=True, ) # The tool succeeds in scheduling the job (returns job_id), diff --git a/tests/test_async_data_loading.py b/tests/test_async_data_loading.py index 5b8d9ea6..0b69fd15 100644 --- a/tests/test_async_data_loading.py +++ b/tests/test_async_data_loading.py @@ -13,7 +13,7 @@ from sktime_mcp.runtime.executor import get_executor from sktime_mcp.runtime.jobs import JobStatus, get_job_manager -from sktime_mcp.tools.data_tools import load_data_source_async_tool +from sktime_mcp.tools.data_tools import load_data_source_tool class TestAsyncDataLoadingTool(unittest.TestCase): @@ -27,7 +27,7 @@ def test_returns_job_id(self): "time_column": "date", "target_column": "value", } - result = load_data_source_async_tool(config) + result = load_data_source_tool(config, background=True) self.assertTrue(result["success"]) self.assertIn("job_id", result) self.assertEqual(result["source_type"], "pandas")