diff --git a/README.md b/README.md index a4541fb6..7738cea0 100644 --- a/README.md +++ b/README.md @@ -584,38 +584,76 @@ Load custom files, SQL database queries, URLs, or inline JSON into the server as } ``` -#### 18. `save_data` -Persist an in-memory `data_handle` (such as predictions or transformed series) back to disk. +#### 18. `inspect_data` +Inspect a loaded data handle and return rich metadata (mtype, scitype, shape, columns, dtypes, frequency, cutoff, missing counts, head preview, summary statistics). * **Arguments:** - * `data_handle` (`str`, required): In-memory data handle to save. - * `path` (`str`, required): Local filesystem path where the file will be saved. - * `format` (`str`, optional): Output format (inferred from path extension if omitted: `"csv"`, `"parquet"`, `"json"`). + * `data_handle` (`str`, required): Handle to inspect. * **Returns:** ```json { "success": true, - "saved_path": "/home/user/forecasts.csv", - "message": "Data saved successfully." + "data_handle": "data_abc123", + "mtype": "pd.Series", + "scitype": "Series", + "shape": [60], + "cutoff": "2024-12-01", + "n_missing": 0, + "head": {}, + "summary_stats": {} + } + ``` + +#### 19. `split_data` +Split a time series handle into temporal train/test sets, returning two new handles. +* **Arguments:** + * `data_handle` (`str`, required): Handle to split. + * `test_size` (`float`, optional): Fraction in (0, 1) to hold out. Mutually exclusive with `fh`. + * `fh` (`int | list[int]`, optional): Forecast horizon — integer steps or list of relative indices (uses `max(fh)` steps). +* **Returns:** + ```json + { + "success": true, + "train_handle": "data_train123", + "test_handle": "data_test456", + "cutoff": "2024-06-01", + "train_size": 48, + "n_test": 12 } ``` -#### 19. `format_time_series` -Clean, fill missing values, deduplicate, and standardize loaded time series data. +#### 20. `transform_data` +Transform a data handle — format (auto-fix frequency/dupes/NaN) or convert mtype. * **Arguments:** - * `data_handle` (`str`, required): Target data handle. - * `auto_infer_freq` (`bool`, optional, default=`true`): Re-infer time delta frequency. - * `fill_missing` (`bool`, optional, default=`true`): Interpolate missing values using forward/backward fills. - * `remove_duplicates` (`bool`, optional, default=`true`): Deduplicate timestamps. + * `data_handle` (`str`, required): Handle to transform. + * `action` (`str`, optional, default=`"format"`): `"format"` or `"convert"`. + * `auto_infer_freq`, `fill_missing`, `remove_duplicates` (`bool`, optional): Format-mode options. + * `to_mtype` (`str`, optional): Required when `action="convert"` (e.g. `"pd.DataFrame"`). * **Returns:** ```json { "success": true, "data_handle": "data_abc123", - "changes_applied": ["inferred frequency: M", "filled 3 missing values"] + "changes_applied": ["Inferred and set frequency to 'MS'"] + } + ``` + +#### 21. `save_data` +Persist an in-memory `data_handle` (target series and exogenous features) to disk. +* **Arguments:** + * `data_handle` (`str`, required): In-memory data handle to save. + * `path` (`str`, required): Local filesystem path where the file will be saved. + * `format` (`str`, optional, default=`"csv"`): Output format — `"csv"`, `"parquet"`, or `"json"`. +* **Returns:** + ```json + { + "success": true, + "saved_path": "/home/user/forecasts.csv", + "format": "csv", + "rows": 60 } ``` -#### 20. `release_data_handle` +#### 22. `release_data_handle` Free a data handle and its contents from server memory. * **Arguments:** * `data_handle` (`str`, required): Handle ID to release. @@ -633,7 +671,7 @@ Free a data handle and its contents from server memory. These tools manage the serialization of estimator instances and generation of production-ready source code. -#### 21. `save_model` +#### 23. `save_model` Serialize an estimator blueprint or fitted model handle to disk using sktime-MLflow integration. * **Arguments:** * `estimator_handle` (`str`, required): Estimator or pipeline handle to save. @@ -649,7 +687,7 @@ Serialize an estimator blueprint or fitted model handle to disk using sktime-MLf } ``` -#### 22. `load_model` +#### 24. `load_model` Reload a serialized blueprint or fitted model back into an active `estimator_handle`. * **Arguments:** * `path` (`str`, required): Filesystem path to the model directory. @@ -662,7 +700,7 @@ Reload a serialized blueprint or fitted model back into an active `estimator_han } ``` -#### 23. `export_code` +#### 25. `export_code` Generate standalone, executable Python code to reproduce an estimator's structure and execution. * **Arguments:** * `handle` (`str`, required): Handle ID of the estimator/pipeline. diff --git a/src/sktime_mcp/server.py b/src/sktime_mcp/server.py index 8951258f..7d0e64b8 100644 --- a/src/sktime_mcp/server.py +++ b/src/sktime_mcp/server.py @@ -51,7 +51,7 @@ fit_predict_async_tool, fit_predict_tool, ) -from sktime_mcp.tools.format_tools import format_time_series_tool +from sktime_mcp.tools.inspect_data import inspect_data_tool from sktime_mcp.tools.instantiate import ( instantiate_estimator_tool, instantiate_pipeline_tool, @@ -69,7 +69,10 @@ get_available_tags, list_estimators_tool, ) +from sktime_mcp.tools.save_data import save_data_tool from sktime_mcp.tools.save_model import save_model_tool +from sktime_mcp.tools.split_data import split_data_tool +from sktime_mcp.tools.transform_data import transform_data_tool # --------------------------------------------------------------------------- @@ -519,34 +522,152 @@ async def list_tools() -> list[Tool]: }, ), Tool( - name="format_time_series", - description="Automatically format time series data (frequency, duplicates, missing values)", + name="inspect_data", + description=( + "Inspect a loaded data handle and return rich metadata for understanding " + "the series before modelling. Returns mtype, scitype, shape, column names, " + "dtypes, index level names, inferred frequency, cutoff (last training " + "timestamp), total missing-value count, a 5-row head preview, and " + "per-column summary statistics. Works on handles from load_data_source, " + "split_data, or transform_data. Does not modify the data." + ), + inputSchema={ + "type": "object", + "properties": { + "data_handle": { + "type": "string", + "description": ( + "Data handle ID to inspect (from load_data_source, split_data, " + "or transform_data)." + ), + }, + }, + "required": ["data_handle"], + }, + ), + Tool( + name="split_data", + description=( + "Split a time series data handle into temporal train and test sets, " + "registering both halves as new data handles. Provide exactly one of " + "test_size (fraction in (0, 1)) or fh (forecast horizon). fh may be an " + "integer (hold out that many final steps) or a list of relative horizon " + "indices (hold out max(fh) final steps). Returns train_handle, " + "test_handle, cutoff timestamp, train_size, and n_test." + ), + inputSchema={ + "type": "object", + "properties": { + "data_handle": { + "type": "string", + "description": "Data handle ID to split (from load_data_source or transform_data).", + }, + "test_size": { + "type": "number", + "description": ( + "Fraction of observations to hold out for the test set, " + "exclusive range (0.0, 1.0). Mutually exclusive with fh." + ), + }, + "fh": { + "description": ( + "Forecast horizon for the test window. Integer: hold out that " + "many final time steps. List of ints: hold out max(fh) final " + "steps (e.g. fh=[1,5,10] reserves 10 steps). " + "Mutually exclusive with test_size." + ), + }, + }, + "required": ["data_handle"], + }, + ), + Tool( + name="transform_data", + description=( + "Transform a loaded data handle and return a new handle. " + "action='format' (default): auto-fix common time series issues — " + "infer/set frequency, remove duplicate timestamps, fill index gaps, " + "and forward/backward-fill missing values; returns changes_applied. " + "action='convert': convert y to a different sktime mtype via convert_to() " + "(requires to_mtype, e.g. 'pd.DataFrame', 'pd.Series', 'np.ndarray'). " + "Replaces the legacy format_time_series tool." + ), inputSchema={ "type": "object", "properties": { "data_handle": { "type": "string", - "description": "Handle from load_data_source", + "description": "Data handle ID to transform.", + }, + "action": { + "type": "string", + "description": ( + "Transformation to apply: 'format' (default) or 'convert'." + ), + "enum": ["format", "convert"], + "default": "format", }, "auto_infer_freq": { "type": "boolean", - "description": "Automatically infer and set frequency (default: True)", + "description": "(format only) Infer and set DatetimeIndex frequency (default: true).", "default": True, }, "fill_missing": { "type": "boolean", - "description": "Fill missing values with forward/backward fill (default: True)", + "description": "(format only) Forward/backward fill missing values (default: true).", "default": True, }, "remove_duplicates": { "type": "boolean", - "description": "Remove duplicate timestamps (default: True)", + "description": "(format only) Drop duplicate timestamps, keeping first (default: true).", "default": True, }, + "to_mtype": { + "type": "string", + "description": ( + "(convert only, required) Target sktime mtype string, " + "e.g. 'pd.DataFrame', 'pd.Series', 'np.ndarray'." + ), + }, }, "required": ["data_handle"], }, ), + Tool( + name="save_data", + description=( + "Persist the target series (y) and any exogenous features (X) behind a " + "data handle to a local file. Combines y and X into one table. Creates " + "parent directories as needed. Supported formats: csv (default, writes " + "index as first column), parquet, json (records orient, ISO dates)." + ), + inputSchema={ + "type": "object", + "properties": { + "data_handle": { + "type": "string", + "description": ( + "Data handle ID to export (from load_data_source, split_data, " + "or transform_data)." + ), + }, + "path": { + "type": "string", + "description": ( + "Destination file path. Format is controlled by the format " + "argument, not the file extension." + ), + }, + "format": { + "type": "string", + "description": "Output format: csv (default), parquet, or json.", + "enum": ["csv", "parquet", "json"], + "default": "csv", + }, + }, + "required": ["data_handle", "path"], + }, + ), # -- Export / Persistence -------------------------------------------- Tool( name="export_code", @@ -840,12 +961,31 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: elif name == "release_data_handle": result = release_data_handle_tool(arguments["data_handle"]) - elif name == "format_time_series": - result = format_time_series_tool( - arguments["data_handle"], - arguments.get("auto_infer_freq", True), - arguments.get("fill_missing", True), - arguments.get("remove_duplicates", True), + elif name == "inspect_data": + result = inspect_data_tool(arguments["data_handle"]) + + elif name == "split_data": + result = split_data_tool( + data_handle=arguments["data_handle"], + test_size=arguments.get("test_size"), + fh=arguments.get("fh"), + ) + + elif name == "transform_data": + result = transform_data_tool( + data_handle=arguments["data_handle"], + action=arguments.get("action", "format"), + auto_infer_freq=arguments.get("auto_infer_freq", True), + fill_missing=arguments.get("fill_missing", True), + remove_duplicates=arguments.get("remove_duplicates", True), + to_mtype=arguments.get("to_mtype"), + ) + + elif name == "save_data": + result = save_data_tool( + data_handle=arguments["data_handle"], + path=arguments["path"], + format=arguments.get("format", "csv"), ) elif name == "auto_format_on_load": diff --git a/src/sktime_mcp/tools/__init__.py b/src/sktime_mcp/tools/__init__.py index d8d0f0d7..405595c3 100644 --- a/src/sktime_mcp/tools/__init__.py +++ b/src/sktime_mcp/tools/__init__.py @@ -13,6 +13,7 @@ fit_predict_tool, ) from sktime_mcp.tools.format_tools import format_time_series_tool +from sktime_mcp.tools.inspect_data import inspect_data_tool from sktime_mcp.tools.instantiate import ( instantiate_estimator_tool, instantiate_pipeline_tool, @@ -30,7 +31,10 @@ get_available_tags, list_estimators_tool, ) +from sktime_mcp.tools.save_data import save_data_tool from sktime_mcp.tools.save_model import save_model_tool +from sktime_mcp.tools.split_data import split_data_tool +from sktime_mcp.tools.transform_data import transform_data_tool __all__ = [ "list_estimators_tool", @@ -49,6 +53,10 @@ "release_data_handle_tool", "list_available_data_tool", "format_time_series_tool", + "inspect_data_tool", + "split_data_tool", + "transform_data_tool", + "save_data_tool", "export_code_tool", "save_model_tool", "check_job_status_tool", diff --git a/src/sktime_mcp/tools/inspect_data.py b/src/sktime_mcp/tools/inspect_data.py new file mode 100644 index 00000000..7df79ae3 --- /dev/null +++ b/src/sktime_mcp/tools/inspect_data.py @@ -0,0 +1,237 @@ +""" +Data inspection tool for sktime MCP. + +Provides rich metadata about loaded data handles including mtype, +scitype, shape, frequency, cutoff, missing values, and summary stats. +""" + +import logging +from typing import Any + +import pandas as pd + +from sktime_mcp.runtime.executor import get_executor + +logger = logging.getLogger(__name__) + + +def inspect_data_tool(data_handle: str) -> dict[str, Any]: + """Inspect a loaded data handle and return rich metadata. + + Provides comprehensive information about the data behind a handle, + including shape, column types, frequency, cutoff point, missing + value counts, a preview (head), and summary statistics. + + Parameters + ---------- + data_handle : str + The unique handle ID for the loaded data source (from load_data_source). + + Returns + ------- + dict + Dictionary containing detailed metadata and summary statistics: + - "success" : bool + True if the data handle was found and inspected successfully. + - "data_handle" : str + The inspected data handle ID. + - "mtype" : str + The format mtype (e.g. 'pd.Series', 'pd.DataFrame'). + - "scitype" : str + The sktime scientific type (e.g. 'Series', 'Panel'). + - "shape" : list of int + Shape list: [rows, columns] or [rows]. + - "columns" : list of str + Names of all variables (including exogenous features if present). + - "dtypes" : dict + Mapping of column names to string names of their data types. + - "index_names" : list of str + List of names of the index levels. + - "freq" : str or None + Inferred or declared frequency of the time index. + - "cutoff" : str or None + Cutoff timestamp/integer index indicating the end of the history. + - "n_missing" : int + Total count of missing values across the entire dataset. + - "head" : dict + Preview of the first 5 rows of data. + - "summary_stats" : dict + Statistical summary metrics (mean, std, min, max, etc.) per column. + - "error" : str, optional + Error message if "success" is False. + """ + executor = get_executor() + + if data_handle not in executor._data_handles: + return { + "success": False, + "error": f"Data handle '{data_handle}' not found", + "available_handles": list(executor._data_handles.keys()), + } + + data_info = executor._data_handles[data_handle] + y = data_info["y"] + X = data_info.get("X") + + try: + # --- mtype detection --- + mtype = type(y).__name__ + if isinstance(y, pd.DataFrame): + mtype = "pd.DataFrame" + elif isinstance(y, pd.Series): + mtype = "pd.Series" + + # --- scitype detection --- + scitype = _detect_scitype(y) + + # --- shape --- + shape = list(y.shape) + + # --- columns --- + if isinstance(y, pd.DataFrame): + columns = list(y.columns) + elif isinstance(y, pd.Series): + columns = [y.name if y.name else "target"] + else: + columns = [] + + # Add exogenous columns if present + if X is not None and isinstance(X, pd.DataFrame): + columns = columns + [f"X:{c}" for c in X.columns] + + # --- dtypes --- + if isinstance(y, pd.DataFrame): + dtypes = {str(col): str(dtype) for col, dtype in y.dtypes.items()} + elif isinstance(y, pd.Series): + dtypes = {y.name if y.name else "target": str(y.dtype)} + else: + dtypes = {} + + # --- index names --- + if hasattr(y.index, "names"): + index_names = [str(n) if n is not None else "index" for n in y.index.names] + else: + index_names = ["index"] + + # --- frequency --- + freq = None + if hasattr(y.index, "freq") and y.index.freq is not None: + freq = str(y.index.freq) + elif hasattr(y.index, "inferred_freq"): + freq = y.index.inferred_freq + + # --- cutoff --- + cutoff = None + try: + from sktime.datatypes import get_cutoff as sktime_get_cutoff + + cutoff_val = sktime_get_cutoff(y) + cutoff = str(cutoff_val) + except Exception: + # Fallback: use last index value + if len(y) > 0: + cutoff = str(y.index[-1]) + + # --- missing values --- + if isinstance(y, pd.DataFrame): + n_missing = int(y.isna().sum().sum()) + elif isinstance(y, pd.Series): + n_missing = int(y.isna().sum()) + else: + n_missing = 0 + + # --- head (first 5 rows) --- + head_data = _safe_head(y, n=5) + + # --- summary statistics --- + summary_stats = _safe_describe(y) + + return { + "success": True, + "data_handle": data_handle, + "mtype": mtype, + "scitype": scitype, + "shape": shape, + "columns": columns, + "dtypes": dtypes, + "index_names": index_names, + "freq": freq, + "cutoff": cutoff, + "n_missing": n_missing, + "head": head_data, + "summary_stats": summary_stats, + } + + except Exception as e: + logger.exception("Error inspecting data handle") + return { + "success": False, + "error": str(e), + "error_type": type(e).__name__, + } + + +def _detect_scitype(y: Any) -> str: + """Detect the sktime scitype of the data.""" + try: + from sktime.datatypes import scitype as sktime_scitype + + return sktime_scitype(y, candidate_scitypes=["Series", "Panel", "Hierarchical"]) + except Exception: + pass + + # Fallback heuristic + if isinstance(y, pd.Series): + return "Series" + if isinstance(y, pd.DataFrame): + if isinstance(y.index, pd.MultiIndex): + if y.index.nlevels >= 3: + return "Hierarchical" + return "Panel" + return "Series" + return "Unknown" + + +def _safe_head(y: Any, n: int = 5) -> dict: + """Return the first n rows as a JSON-safe dict.""" + try: + if isinstance(y, pd.Series): + head = y.head(n) + return {str(k): _safe_value(v) for k, v in head.items()} + if isinstance(y, pd.DataFrame): + head = y.head(n) + return { + str(idx): {str(col): _safe_value(val) for col, val in row.items()} + for idx, row in head.iterrows() + } + except Exception: + pass + return {} + + +def _safe_describe(y: Any) -> dict: + """Return summary statistics as a JSON-safe dict.""" + try: + if isinstance(y, (pd.Series, pd.DataFrame)): + desc = y.describe() + if isinstance(desc, pd.Series): + return {str(k): _safe_value(v) for k, v in desc.items()} + if isinstance(desc, pd.DataFrame): + return { + str(col): {str(stat): _safe_value(val) for stat, val in desc[col].items()} + for col in desc.columns + } + except Exception: + pass + return {} + + +def _safe_value(val: Any) -> Any: + """Convert a value to a JSON-safe type.""" + if isinstance(val, float) and (pd.isna(val) or val != val): + return None + if hasattr(val, "item"): + return val.item() + if isinstance(val, pd.Timestamp): + return val.isoformat() + return val diff --git a/src/sktime_mcp/tools/save_data.py b/src/sktime_mcp/tools/save_data.py new file mode 100644 index 00000000..0fa49d43 --- /dev/null +++ b/src/sktime_mcp/tools/save_data.py @@ -0,0 +1,122 @@ +""" +Data persistence tool for sktime MCP. + +Saves the data behind a handle to a local file in CSV, Parquet, or JSON format. +""" + +import logging +from pathlib import Path +from typing import Any + +import pandas as pd + +from sktime_mcp.runtime.executor import get_executor + +logger = logging.getLogger(__name__) + +# Supported output formats and their pandas writer methods +_FORMAT_WRITERS = { + "csv": "to_csv", + "parquet": "to_parquet", + "json": "to_json", +} + + +def save_data_tool( + data_handle: str, + path: str, + format: str = "csv", +) -> dict[str, Any]: + """Persist the data behind a handle to a local file. + + Supports CSV, Parquet, and JSON output formats. The target + directory is created automatically if it does not exist. + + Parameters + ---------- + data_handle : str + Handle ID of the data to save (from load_data_source, split_data, etc.). + path : str + Destination file path (e.g. "/tmp/forecast_output.csv"). + Note that the file extension is not used to infer the format; + use the `format` argument instead. + format : str, default="csv" + Output format. Must be one of: "csv", "parquet", or "json". + + Returns + ------- + dict + Dictionary containing success status and path information: + - "success" : bool + True if the data was written successfully, False otherwise. + - "saved_path" : str + Absolute path to the written file. + - "format" : str + The format used to write the file. + - "rows" : int + Number of rows written to the file. + - "error" : str, optional + Error message if "success" is False. + """ + executor = get_executor() + + # --- validation -------------------------------------------------------- + if data_handle not in executor._data_handles: + return { + "success": False, + "error": f"Data handle '{data_handle}' not found", + "available_handles": list(executor._data_handles.keys()), + } + + fmt = format.lower() + if fmt not in _FORMAT_WRITERS: + return { + "success": False, + "error": f"Unsupported format '{format}'. Choose from: {list(_FORMAT_WRITERS.keys())}", + } + + data_info = executor._data_handles[data_handle] + y = data_info["y"] + X = data_info.get("X") + + try: + # Combine y and X into a single DataFrame for export + if isinstance(y, pd.Series): + df = y.to_frame(name=y.name if y.name else "target") + elif isinstance(y, pd.DataFrame): + df = y.copy() + else: + # Best-effort: wrap in a DataFrame + df = pd.DataFrame(y, columns=["target"]) + + if X is not None and isinstance(X, pd.DataFrame): + df = pd.concat([df, X], axis=1) + + # Ensure target directory exists + abs_path = Path(path).resolve() + abs_path.parent.mkdir(parents=True, exist_ok=True) + + # Write + writer = getattr(df, _FORMAT_WRITERS[fmt]) + if fmt == "json": + writer(str(abs_path), orient="records", date_format="iso", indent=2) + elif fmt == "parquet": + writer(str(abs_path)) + else: + # CSV — include the index as a time column + writer(str(abs_path)) + + return { + "success": True, + "saved_path": str(abs_path), + "format": fmt, + "rows": len(df), + } + + except Exception as e: + logger.exception("Error saving data") + return { + "success": False, + "error": str(e), + "error_type": type(e).__name__, + } diff --git a/src/sktime_mcp/tools/split_data.py b/src/sktime_mcp/tools/split_data.py new file mode 100644 index 00000000..f948b794 --- /dev/null +++ b/src/sktime_mcp/tools/split_data.py @@ -0,0 +1,213 @@ +""" +Data splitting tool for sktime MCP. + +Provides temporal train/test splitting for time series data, +registering both halves as new data handles. +""" + +import logging +import uuid +from typing import Any + +from sktime_mcp.runtime.executor import get_executor + +logger = logging.getLogger(__name__) + + +def split_data_tool( + data_handle: str, + test_size: float | None = None, + fh: list[int] | int | None = None, +) -> dict[str, Any]: + """Split a time series data handle into train and test sets. + + Uses sktime's temporal_train_test_split() when available, falling + back to a pandas-based implementation. Exactly one of `test_size` + or `fh` must be provided. + + Parameters + ---------- + data_handle : str + Handle ID of the loaded data to split (from load_data_source). + test_size : float or None, default=None + Fraction of the data to hold out for testing (0.0–1.0). + Mutually exclusive with `fh`. + fh : int, list of int, or None, default=None + Forecast horizon — the number of final time steps to reserve + as the test set. Can be a single int or a list of relative + step indices. Mutually exclusive with `test_size`. + + Returns + ------- + dict + Dictionary containing the split train/test handles and metadata: + - "success" : bool + True if the split completed successfully, False otherwise. + - "train_handle" : str + The new unique data handle ID representing the training set. + - "test_handle" : str + The new unique data handle ID representing the test set. + - "cutoff" : str + The cutoff timestamp indicating the last training timestamp. + - "train_size" : int + Number of observations in the training set. + - "n_test" : int + Number of observations in the test set. + - "error" : str, optional + Error message if "success" is False. + """ + executor = get_executor() + + # --- validation -------------------------------------------------------- + if data_handle not in executor._data_handles: + return { + "success": False, + "error": f"Data handle '{data_handle}' not found", + "available_handles": list(executor._data_handles.keys()), + } + + if test_size is not None and fh is not None: + return { + "success": False, + "error": "Provide exactly one of 'test_size' or 'fh', not both.", + } + + if test_size is None and fh is None: + return { + "success": False, + "error": "Provide at least one of 'test_size' or 'fh'.", + } + + if test_size is not None and (test_size <= 0.0 or test_size >= 1.0): + return { + "success": False, + "error": f"test_size must be between 0.0 and 1.0 (exclusive), got {test_size}", + } + + if fh is not None: + if isinstance(fh, int): + if fh < 1: + return { + "success": False, + "error": f"fh must be a positive integer, got {fh}", + } + elif isinstance(fh, list): + if not fh or not all(isinstance(step, int) and step > 0 for step in fh): + return { + "success": False, + "error": "fh must be a non-empty list of positive integers.", + } + else: + return { + "success": False, + "error": f"fh must be an integer or list of integers, got {type(fh).__name__}", + } + + data_info = executor._data_handles[data_handle] + y = data_info["y"] + X = data_info.get("X") + + try: + # --- determine split point ---------------------------------------- + n = len(y) + + if test_size is not None: + n_test = max(1, int(n * test_size)) + elif isinstance(fh, int): + n_test = fh + else: + # fh as list of relative horizon indices — reserve max(fh) final steps + n_test = max(fh) + + if n_test >= n: + return { + "success": False, + "error": ( + f"Test set would be {n_test} samples but the series only has {n} observations." + ), + } + + split_idx = n - n_test + + # --- try sktime first --------------------------------------------- + try: + from sktime.split import temporal_train_test_split + + y_train, y_test = temporal_train_test_split(y, test_size=n_test / n) + if X is not None: + X_train, X_test = temporal_train_test_split(X, test_size=n_test / n) + else: + X_train, X_test = None, None + except Exception: + # Fallback: plain pandas slicing + y_train = y.iloc[:split_idx] + y_test = y.iloc[split_idx:] + if X is not None: + X_train = X.iloc[:split_idx] + X_test = X.iloc[split_idx:] + else: + X_train, X_test = None, None + + # --- cutoff ------------------------------------------------------- + cutoff = str(y_train.index[-1]) + + # --- register handles -------------------------------------------- + train_handle = f"data_{uuid.uuid4().hex[:8]}" + test_handle = f"data_{uuid.uuid4().hex[:8]}" + + base_meta = data_info.get("metadata", {}).copy() + + train_meta = { + **base_meta, + "split": "train", + "rows": len(y_train), + "start_date": str(y_train.index[0]), + "end_date": str(y_train.index[-1]), + "parent_handle": data_handle, + } + test_meta = { + **base_meta, + "split": "test", + "rows": len(y_test), + "start_date": str(y_test.index[0]), + "end_date": str(y_test.index[-1]), + "parent_handle": data_handle, + } + + executor._register_data_handle( + train_handle, + { + "y": y_train, + "X": X_train, + "metadata": train_meta, + "validation": data_info.get("validation", {}), + "config": data_info.get("config", {}), + }, + ) + executor._register_data_handle( + test_handle, + { + "y": y_test, + "X": X_test, + "metadata": test_meta, + "validation": data_info.get("validation", {}), + "config": data_info.get("config", {}), + }, + ) + + return { + "success": True, + "train_handle": train_handle, + "test_handle": test_handle, + "cutoff": cutoff, + "train_size": len(y_train), + "n_test": len(y_test), + } + + except Exception as e: + logger.exception("Error splitting data") + return { + "success": False, + "error": str(e), + "error_type": type(e).__name__, + } diff --git a/src/sktime_mcp/tools/transform_data.py b/src/sktime_mcp/tools/transform_data.py new file mode 100644 index 00000000..36fb9645 --- /dev/null +++ b/src/sktime_mcp/tools/transform_data.py @@ -0,0 +1,217 @@ +""" +Data transformation tool for sktime MCP. + +Provides two actions: + - "format": auto-fix frequency, duplicates, missing values (replaces format_time_series). + - "convert": convert data between sktime mtypes using convert_to(). +""" + +import logging +import uuid +from typing import Any + +import pandas as pd + +from sktime_mcp.runtime.executor import get_executor + +logger = logging.getLogger(__name__) + + +def transform_data_tool( + data_handle: str, + action: str = "format", + auto_infer_freq: bool = True, + fill_missing: bool = True, + remove_duplicates: bool = True, + to_mtype: str | None = None, +) -> dict[str, Any]: + """Transform a data handle — either format it or convert its mtype. + + Supports two modes controlled by the `action` argument. + + Parameters + ---------- + data_handle : str + Handle ID of the loaded data to transform (from load_data_source). + action : str, default="format" + The transformation action to perform. Must be one of: + - "format" : Auto-fix common time series issues like inferring frequency, + removing duplicate timestamps, and filling missing values. + - "convert" : Convert the data to a different sktime machine type (mtype). + auto_infer_freq : bool, default=True + (Format mode only) Infer and set frequency. + fill_missing : bool, default=True + (Format mode only) Forward/backward fill missing values. + remove_duplicates : bool, default=True + (Format mode only) Remove duplicate timestamps. + to_mtype : str or None, default=None + (Convert mode only) Target machine type string, e.g. "pd.DataFrame", + "pd.Series", "np.ndarray". + + Returns + ------- + dict + Dictionary containing the new data handle and a list of applied changes: + - "success" : bool + True if the transformation succeeded, False otherwise. + - "data_handle" : str + The new unique data handle ID representing the transformed data. + - "changes_applied" : list of str + A list of human-readable changes that were applied to the data. + - "metadata" : dict, optional + Updated metadata for the new handle. + - "error" : str, optional + Error message if "success" is False. + """ + if action not in ("format", "convert"): + return { + "success": False, + "error": f"Unknown action '{action}'. Must be 'format' or 'convert'.", + } + + if action == "convert" and not to_mtype: + return { + "success": False, + "error": "The 'to_mtype' argument is required when action='convert'.", + } + + executor = get_executor() + + if data_handle not in executor._data_handles: + return { + "success": False, + "error": f"Data handle '{data_handle}' not found", + "available_handles": list(executor._data_handles.keys()), + } + + try: + if action == "format": + return _action_format( + executor, + data_handle, + auto_infer_freq=auto_infer_freq, + fill_missing=fill_missing, + remove_duplicates=remove_duplicates, + ) + else: + return _action_convert(executor, data_handle, to_mtype) + except Exception as e: + logger.exception("Error transforming data") + return { + "success": False, + "error": str(e), + "error_type": type(e).__name__, + } + + +# --------------------------------------------------------------------------- +# Action: format +# --------------------------------------------------------------------------- + + +def _action_format( + executor: Any, + data_handle: str, + *, + auto_infer_freq: bool, + fill_missing: bool, + remove_duplicates: bool, +) -> dict[str, Any]: + """Delegate to the executor's existing format logic and wrap the result.""" + result = executor.format_data_handle( + data_handle, + auto_infer_freq=auto_infer_freq, + fill_missing=fill_missing, + remove_duplicates=remove_duplicates, + ) + + if not result.get("success"): + return result + + # Build a human-readable list of changes + changes_applied: list[str] = [] + changes = result.get("changes_made", {}) + + if changes.get("duplicates_removed", 0) > 0: + changes_applied.append(f"Removed {changes['duplicates_removed']} duplicate timestamps") + if changes.get("frequency_set"): + freq = changes.get("frequency", "?") + changes_applied.append(f"Inferred and set frequency to '{freq}'") + if changes.get("gaps_filled", 0) > 0: + changes_applied.append(f"Filled {changes['gaps_filled']} gaps in the time index") + if changes.get("missing_filled", 0) > 0: + changes_applied.append( + f"Filled {changes['missing_filled']} missing values (forward/backward fill)" + ) + if not changes_applied: + changes_applied.append("No changes needed — data was already clean") + + return { + "success": True, + "data_handle": result["data_handle"], + "changes_applied": changes_applied, + "metadata": result.get("metadata", {}), + } + + +# --------------------------------------------------------------------------- +# Action: convert +# --------------------------------------------------------------------------- + + +def _action_convert( + executor: Any, + data_handle: str, + to_mtype: str, +) -> dict[str, Any]: + """Convert the data to a different sktime mtype.""" + data_info = executor._data_handles[data_handle] + y = data_info["y"] + + try: + from sktime.datatypes import convert_to + except ImportError: + return { + "success": False, + "error": "sktime.datatypes.convert_to is not available in this environment.", + } + + original_mtype = type(y).__name__ + converted = convert_to(y, to_type=to_mtype) + + # Register as new handle + new_handle = f"data_{uuid.uuid4().hex[:8]}" + base_meta = data_info.get("metadata", {}).copy() + base_meta["mtype"] = to_mtype + base_meta["converted_from"] = original_mtype + base_meta["parent_handle"] = data_handle + + # Determine y and X for the new handle + if isinstance(converted, pd.DataFrame): + new_y = converted + new_X = None + elif isinstance(converted, pd.Series): + new_y = converted + new_X = data_info.get("X") + else: + # numpy or other — wrap as Series for consistency + new_y = converted + new_X = None + + executor._register_data_handle( + new_handle, + { + "y": new_y, + "X": new_X, + "metadata": base_meta, + "validation": data_info.get("validation", {}), + "config": data_info.get("config", {}), + }, + ) + + return { + "success": True, + "data_handle": new_handle, + "changes_applied": [f"Converted from '{original_mtype}' to '{to_mtype}'"], + "metadata": base_meta, + } diff --git a/tests/test_data_management.py b/tests/test_data_management.py new file mode 100644 index 00000000..83c0ed67 --- /dev/null +++ b/tests/test_data_management.py @@ -0,0 +1,276 @@ +"""Tests for the new Data Management tools: inspect_data, split_data, transform_data, save_data.""" + +import os +import tempfile +from pathlib import Path + +import pandas as pd + +from sktime_mcp.runtime.executor import Executor +from sktime_mcp.tools.inspect_data import inspect_data_tool +from sktime_mcp.tools.save_data import save_data_tool +from sktime_mcp.tools.split_data import split_data_tool +from sktime_mcp.tools.transform_data import transform_data_tool + + +def _make_executor_with_data(): + """Create an Executor and load the airline demo dataset into a handle.""" + executor = Executor() + # Load a small demo dataset (airline) + result = executor.load_data_source( + { + "type": "pandas", + "data": { + "date": pd.date_range("2020-01", periods=60, freq="MS") + .strftime("%Y-%m-%d") + .tolist(), + "value": list(range(60)), + }, + "time_column": "date", + "target_column": "value", + } + ) + assert result["success"], f"Setup failed: {result}" + return executor, result["data_handle"] + + +# ─────────────────────────────────────────────────────────────────────────── +# inspect_data +# ─────────────────────────────────────────────────────────────────────────── + + +class TestInspectData: + def test_inspect_returns_all_fields(self): + executor, handle = _make_executor_with_data() + # Patch singleton so the tool sees our executor + import sktime_mcp.tools.inspect_data as mod + + _orig = mod.get_executor + mod.get_executor = lambda: executor + try: + result = inspect_data_tool(handle) + finally: + mod.get_executor = _orig + + assert result["success"] + for key in [ + "mtype", + "scitype", + "shape", + "columns", + "dtypes", + "index_names", + "freq", + "cutoff", + "n_missing", + "head", + "summary_stats", + ]: + assert key in result, f"Missing key: {key}" + + assert result["shape"][0] == 60 + assert result["n_missing"] == 0 + + def test_inspect_unknown_handle(self): + result = inspect_data_tool("data_nonexistent") + assert not result["success"] + assert "not found" in result["error"] + + +# ─────────────────────────────────────────────────────────────────────────── +# split_data +# ─────────────────────────────────────────────────────────────────────────── + + +class TestSplitData: + def _patch(self, executor): + import sktime_mcp.tools.split_data as mod + + self._orig = mod.get_executor + mod.get_executor = lambda: executor + + def _unpatch(self): + import sktime_mcp.tools.split_data as mod + + mod.get_executor = self._orig + + def test_split_by_test_size(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + result = split_data_tool(handle, test_size=0.2) + finally: + self._unpatch() + + assert result["success"] + assert "train_handle" in result + assert "test_handle" in result + assert "cutoff" in result + assert result["train_size"] + result["n_test"] == 60 + + def test_split_by_fh(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + result = split_data_tool(handle, fh=12) + finally: + self._unpatch() + + assert result["success"] + assert result["n_test"] == 12 + assert result["train_size"] == 48 + + def test_split_by_fh_list_uses_max(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + result = split_data_tool(handle, fh=[1, 5, 10]) + finally: + self._unpatch() + + assert result["success"] + assert result["n_test"] == 10 + assert result["train_size"] == 50 + + def test_split_requires_one_arg(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + # Neither + r1 = split_data_tool(handle) + assert not r1["success"] + # Both + r2 = split_data_tool(handle, test_size=0.2, fh=12) + assert not r2["success"] + finally: + self._unpatch() + + def test_split_invalid_test_size(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + result = split_data_tool(handle, test_size=1.5) + finally: + self._unpatch() + assert not result["success"] + + def test_split_unknown_handle(self): + result = split_data_tool("data_nonexistent", test_size=0.2) + assert not result["success"] + + +# ─────────────────────────────────────────────────────────────────────────── +# transform_data +# ─────────────────────────────────────────────────────────────────────────── + + +class TestTransformData: + def _patch(self, executor): + import sktime_mcp.tools.transform_data as mod + + self._orig = mod.get_executor + mod.get_executor = lambda: executor + + def _unpatch(self): + import sktime_mcp.tools.transform_data as mod + + mod.get_executor = self._orig + + def test_format_action(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + result = transform_data_tool(handle, action="format") + finally: + self._unpatch() + + assert result["success"] + assert "data_handle" in result + assert "changes_applied" in result + assert isinstance(result["changes_applied"], list) + + def test_convert_requires_to_mtype(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + result = transform_data_tool(handle, action="convert") + finally: + self._unpatch() + assert not result["success"] + assert "to_mtype" in result["error"] + + def test_invalid_action(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + result = transform_data_tool(handle, action="bogus") + finally: + self._unpatch() + assert not result["success"] + assert "Unknown action" in result["error"] + + def test_unknown_handle(self): + result = transform_data_tool("data_nonexistent", action="format") + assert not result["success"] + + +# ─────────────────────────────────────────────────────────────────────────── +# save_data +# ─────────────────────────────────────────────────────────────────────────── + + +class TestSaveData: + def _patch(self, executor): + import sktime_mcp.tools.save_data as mod + + self._orig = mod.get_executor + mod.get_executor = lambda: executor + + def _unpatch(self): + import sktime_mcp.tools.save_data as mod + + mod.get_executor = self._orig + + def test_save_csv(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f: + path = f.name + result = save_data_tool(handle, path=path, format="csv") + finally: + self._unpatch() + + assert result["success"] + assert result["format"] == "csv" + assert result["rows"] == 60 + assert Path(result["saved_path"]).exists() + Path(path).unlink() + + def test_save_json(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = f.name + result = save_data_tool(handle, path=path, format="json") + finally: + self._unpatch() + + assert result["success"] + assert result["format"] == "json" + Path(path).unlink() + + def test_save_unsupported_format(self): + executor, handle = _make_executor_with_data() + self._patch(executor) + try: + result = save_data_tool(handle, path="/tmp/out.xlsx", format="xlsx") + finally: + self._unpatch() + assert not result["success"] + assert "Unsupported format" in result["error"] + + def test_save_unknown_handle(self): + result = save_data_tool("data_nonexistent", path="/tmp/out.csv") + assert not result["success"]