diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 0dc2673f..4df24f2c 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -58,6 +58,43 @@ def __init__(self): os.environ.get("SKTIME_MCP_AUTO_FORMAT", "true").lower() == "true" ) + @staticmethod + def _get_index_type(index: Any) -> str: + """Return a compact, agent-friendly label for the time index.""" + if isinstance(index, pd.DatetimeIndex): + return "datetime" + if isinstance(index, pd.PeriodIndex): + return "period" + if isinstance(index, pd.RangeIndex): + return "range" + if pd.api.types.is_integer_dtype(index.dtype): + return "integer" + return type(index).__name__.lower() + + def _build_data_metadata( + self, + base_metadata: dict[str, Any], + data: pd.DataFrame, + y: pd.Series, + X: pd.DataFrame | None, + ) -> dict[str, Any]: + """Augment adapter metadata with agent-facing shape semantics.""" + metadata = base_metadata.copy() + metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"] + if X is not None: + metadata["exog_columns"] = list(X.columns) + metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()} + metadata["target_scitype"] = "Series" + metadata["target_variates"] = "multivariate" if getattr(y, "ndim", 1) > 1 else "univariate" + metadata["has_exog"] = X is not None + metadata["exog_variates"] = ( + "none" if X is None else ("multivariate" if len(X.columns) > 1 else "univariate") + ) + metadata["index_type"] = self._get_index_type(y.index) + metadata["n_target_columns"] = 1 + metadata["n_exog_columns"] = 0 if X is None else len(X.columns) + return metadata + def instantiate( self, estimator_name: str, @@ -571,12 +608,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]: y, X = adapter.to_sktime_format(data) # Update metadata to reflect the target and used columns - metadata = adapter.get_metadata().copy() - metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"] - if X is not None: - metadata["exog_columns"] = list(X.columns) - # Inject column dtypes so LLMs can distinguish time index vs target - metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()} + metadata = self._build_data_metadata(adapter.get_metadata(), data, y, X) # Generate handle data_handle = f"data_{uuid.uuid4().hex[:8]}" @@ -609,8 +641,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]: except Exception as e: logger.warning(f"Auto-formatting failed: {e}") # Continue with unformatted data if formatting fails - _final_meta = adapter.get_metadata().copy() - _final_meta["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()} + _final_meta = self._build_data_metadata(adapter.get_metadata(), data, y, X) return { "success": True, "data_handle": data_handle, @@ -694,12 +725,7 @@ async def load_data_source_async( y, X = adapter.to_sktime_format(data) - metadata = adapter.get_metadata().copy() - metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"] - if X is not None: - metadata["exog_columns"] = list(X.columns) - # Inject column dtypes so LLMs can distinguish time index vs target - metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()} + metadata = self._build_data_metadata(adapter.get_metadata(), data, y, X) data_handle = f"data_{uuid.uuid4().hex[:8]}" self._data_handles[data_handle] = { diff --git a/tests/test_data_sources.py b/tests/test_data_sources.py index 757dd53c..2295e6db 100644 --- a/tests/test_data_sources.py +++ b/tests/test_data_sources.py @@ -103,3 +103,59 @@ def test_load_and_predict_with_data_handle(self): cleanup = executor.release_data_handle(result["data_handle"]) assert cleanup["success"] + + def test_load_data_source_exposes_agent_friendly_metadata(self): + executor = get_executor() + + config = { + "type": "pandas", + "data": { + "date": pd.date_range(start="2020-01-01", periods=12, freq="D"), + "sales": [100 + i for i in range(12)], + "promo": [0, 1] * 6, + "temp": [20 + i for i in range(12)], + }, + "time_column": "date", + "target_column": "sales", + "exog_columns": ["promo", "temp"], + } + + result = executor.load_data_source(config) + assert result["success"] + + metadata = result["metadata"] + assert metadata["target_scitype"] == "Series" + assert metadata["target_variates"] == "univariate" + assert metadata["has_exog"] is True + assert metadata["exog_variates"] == "multivariate" + assert metadata["index_type"] == "datetime" + assert metadata["n_target_columns"] == 1 + assert metadata["n_exog_columns"] == 2 + + cleanup = executor.release_data_handle(result["data_handle"]) + assert cleanup["success"] + + def test_load_data_source_range_index_metadata(self): + executor = get_executor() + + config = { + "type": "pandas", + "data": { + "y": [1, 2, 3, 4, 5], + }, + } + + result = executor.load_data_source(config) + assert result["success"] + + metadata = result["metadata"] + assert metadata["target_scitype"] == "Series" + assert metadata["target_variates"] == "univariate" + assert metadata["has_exog"] is False + assert metadata["exog_variates"] == "none" + assert metadata["index_type"] == "range" + assert metadata["n_target_columns"] == 1 + assert metadata["n_exog_columns"] == 0 + + cleanup = executor.release_data_handle(result["data_handle"]) + assert cleanup["success"]