Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/sktime_mcp/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,15 @@ def to_sktime_format(self, data: pd.DataFrame) -> tuple[pd.Series, Optional[pd.D

# Get exogenous variables if specified
if exog_cols:
valid_exog_cols = [col for col in exog_cols if col in data.columns]
X = data[valid_exog_cols] if valid_exog_cols else None
missing_exog_cols = [col for col in exog_cols if col not in data.columns]
if missing_exog_cols:
available_columns = ", ".join(repr(col) for col in data.columns)
raise ValueError(
f"Exogenous column(s) not found in data: {missing_exog_cols!r}. "
f"Available columns: [{available_columns}]"
)

X = data[exog_cols]
else:
# Use all columns except target as exogenous
other_cols = [col for col in data.columns if col != target_col]
Expand Down
24 changes: 24 additions & 0 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ def _discover_demo_datasets() -> dict:
DEMO_DATASETS = _discover_demo_datasets()


def _merge_adapter_validation_warnings(
validation_report: dict[str, Any],
metadata: dict[str, Any],
) -> dict[str, Any]:
"""Merge warnings added during adapter conversion into validation output."""
metadata_validation = metadata.get("validation")
if not isinstance(metadata_validation, dict):
return validation_report

metadata_warnings = metadata_validation.get("warnings", [])
if not metadata_warnings:
return validation_report

merged = validation_report.copy()
existing_warnings = list(merged.get("warnings", []))
for warning in metadata_warnings:
if warning not in existing_warnings:
existing_warnings.append(warning)
merged["warnings"] = existing_warnings
return merged


class Executor:
"""
Execution runtime for sktime estimators.
Expand Down Expand Up @@ -524,6 +546,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]:

# Update metadata to reflect the target and used columns
metadata = adapter.get_metadata().copy()
validation_report = _merge_adapter_validation_warnings(validation_report, metadata)
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
Expand Down Expand Up @@ -647,6 +670,7 @@ async def load_data_source_async(
y, X = adapter.to_sktime_format(data)

metadata = adapter.get_metadata().copy()
validation_report = _merge_adapter_validation_warnings(validation_report, metadata)
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
Expand Down
82 changes: 82 additions & 0 deletions tests/test_data_exog_column_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tests for explicit exogenous column validation in data adapters."""

import sys
from pathlib import Path

import pytest

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from sktime_mcp.data.adapters.pandas_adapter import PandasAdapter
from sktime_mcp.runtime.executor import Executor


def test_to_sktime_format_rejects_missing_explicit_exog_column():
"""Explicit exog_columns should fail if any requested column is absent."""
adapter = PandasAdapter(
{
"type": "pandas",
"data": {
"date": ["2020-01-01", "2020-01-02", "2020-01-03"],
"value": [1, 2, 3],
"promo": [0, 1, 0],
},
"time_column": "date",
"target_column": "value",
"exog_columns": ["promo", "holiday"],
}
)
data = adapter.load()

with pytest.raises(ValueError, match="Exogenous column\\(s\\) not found"):
adapter.to_sktime_format(data)


def test_load_data_source_returns_error_for_missing_explicit_exog_column():
"""Executor should surface missing exog columns as a structured tool error."""
executor = Executor()

result = executor.load_data_source(
{
"type": "pandas",
"data": {
"date": ["2020-01-01", "2020-01-02", "2020-01-03"],
"value": [1, 2, 3],
"promo": [0, 1, 0],
},
"time_column": "date",
"target_column": "value",
"exog_columns": ["promo", "holiday"],
}
)

assert result["success"] is False
assert result["error_type"] == "ValueError"
assert "Exogenous column(s) not found" in result["error"]
assert "holiday" in result["error"]
assert "promo" in result["error"]


def test_to_sktime_format_keeps_all_valid_explicit_exog_columns():
"""Valid exog_columns should be preserved exactly when all columns exist."""
adapter = PandasAdapter(
{
"type": "pandas",
"data": {
"date": ["2020-01-01", "2020-01-02", "2020-01-03"],
"value": [1, 2, 3],
"promo": [0, 1, 0],
"price": [10, 11, 12],
},
"time_column": "date",
"target_column": "value",
"exog_columns": ["promo", "price"],
}
)
data = adapter.load()

y, X = adapter.to_sktime_format(data)

assert y.name == "value"
assert list(X.columns) == ["promo", "price"]

38 changes: 38 additions & 0 deletions tests/test_data_validation_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Tests for data validation warning propagation."""

import asyncio
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from sktime_mcp.runtime.executor import Executor


def _ambiguous_target_config():
return {
"type": "pandas",
"data": {
"date": ["2020-01-01", "2020-01-02", "2020-01-03"],
"value": [1, 2, 3],
},
}


def test_load_data_source_propagates_default_target_warning():
"""Default target warnings should appear in the top-level validation result."""
result = Executor().load_data_source(_ambiguous_target_config())

assert result["success"] is True
warnings = result["validation"]["warnings"]
assert any("Target column not specified" in warning for warning in warnings)


def test_load_data_source_async_propagates_default_target_warning():
"""Async load should expose the same default-target warning as sync load."""
result = asyncio.run(Executor().load_data_source_async(_ambiguous_target_config()))

assert result["success"] is True
warnings = result["validation"]["warnings"]
assert any("Target column not specified" in warning for warning in warnings)