Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,43 @@ def __init__(self):
os.environ.get("SKTIME_MCP_AUTO_FORMAT", "true").lower() == "true"
)

@staticmethod
def _get_index_type(index: Any) -> str:
"""Return a compact, agent-friendly label for the time index."""
if isinstance(index, pd.DatetimeIndex):
return "datetime"
if isinstance(index, pd.PeriodIndex):
return "period"
if isinstance(index, pd.RangeIndex):
return "range"
if pd.api.types.is_integer_dtype(index.dtype):
return "integer"
return type(index).__name__.lower()

def _build_data_metadata(
self,
base_metadata: dict[str, Any],
data: pd.DataFrame,
y: pd.Series,
X: pd.DataFrame | None,
) -> dict[str, Any]:
"""Augment adapter metadata with agent-facing shape semantics."""
metadata = base_metadata.copy()
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
metadata["target_scitype"] = "Series"
metadata["target_variates"] = "multivariate" if getattr(y, "ndim", 1) > 1 else "univariate"
metadata["has_exog"] = X is not None
metadata["exog_variates"] = (
"none" if X is None else ("multivariate" if len(X.columns) > 1 else "univariate")
)
metadata["index_type"] = self._get_index_type(y.index)
metadata["n_target_columns"] = 1
metadata["n_exog_columns"] = 0 if X is None else len(X.columns)
return metadata

def instantiate(
self,
estimator_name: str,
Expand Down Expand Up @@ -571,12 +608,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]:
y, X = adapter.to_sktime_format(data)

# Update metadata to reflect the target and used columns
metadata = adapter.get_metadata().copy()
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
# Inject column dtypes so LLMs can distinguish time index vs target
metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
metadata = self._build_data_metadata(adapter.get_metadata(), data, y, X)
# Generate handle
data_handle = f"data_{uuid.uuid4().hex[:8]}"

Expand Down Expand Up @@ -609,8 +641,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]:
except Exception as e:
logger.warning(f"Auto-formatting failed: {e}")
# Continue with unformatted data if formatting fails
_final_meta = adapter.get_metadata().copy()
_final_meta["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
_final_meta = self._build_data_metadata(adapter.get_metadata(), data, y, X)
return {
"success": True,
"data_handle": data_handle,
Expand Down Expand Up @@ -694,12 +725,7 @@ async def load_data_source_async(

y, X = adapter.to_sktime_format(data)

metadata = adapter.get_metadata().copy()
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
# Inject column dtypes so LLMs can distinguish time index vs target
metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
metadata = self._build_data_metadata(adapter.get_metadata(), data, y, X)
data_handle = f"data_{uuid.uuid4().hex[:8]}"

self._data_handles[data_handle] = {
Expand Down
56 changes: 56 additions & 0 deletions tests/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,59 @@ def test_load_and_predict_with_data_handle(self):

cleanup = executor.release_data_handle(result["data_handle"])
assert cleanup["success"]

def test_load_data_source_exposes_agent_friendly_metadata(self):
executor = get_executor()

config = {
"type": "pandas",
"data": {
"date": pd.date_range(start="2020-01-01", periods=12, freq="D"),
"sales": [100 + i for i in range(12)],
"promo": [0, 1] * 6,
"temp": [20 + i for i in range(12)],
},
"time_column": "date",
"target_column": "sales",
"exog_columns": ["promo", "temp"],
}

result = executor.load_data_source(config)
assert result["success"]

metadata = result["metadata"]
assert metadata["target_scitype"] == "Series"
assert metadata["target_variates"] == "univariate"
assert metadata["has_exog"] is True
assert metadata["exog_variates"] == "multivariate"
assert metadata["index_type"] == "datetime"
assert metadata["n_target_columns"] == 1
assert metadata["n_exog_columns"] == 2

cleanup = executor.release_data_handle(result["data_handle"])
assert cleanup["success"]

def test_load_data_source_range_index_metadata(self):
executor = get_executor()

config = {
"type": "pandas",
"data": {
"y": [1, 2, 3, 4, 5],
},
}

result = executor.load_data_source(config)
assert result["success"]

metadata = result["metadata"]
assert metadata["target_scitype"] == "Series"
assert metadata["target_variates"] == "univariate"
assert metadata["has_exog"] is False
assert metadata["exog_variates"] == "none"
assert metadata["index_type"] == "range"
assert metadata["n_target_columns"] == 1
assert metadata["n_exog_columns"] == 0

cleanup = executor.release_data_handle(result["data_handle"])
assert cleanup["success"]
Loading