diff --git a/src/sktime_mcp/data/base.py b/src/sktime_mcp/data/base.py index cabb753b..7d988dbd 100644 --- a/src/sktime_mcp/data/base.py +++ b/src/sktime_mcp/data/base.py @@ -79,8 +79,15 @@ def to_sktime_format(self, data: pd.DataFrame) -> tuple[pd.Series, Optional[pd.D # Get exogenous variables if specified if exog_cols: - valid_exog_cols = [col for col in exog_cols if col in data.columns] - X = data[valid_exog_cols] if valid_exog_cols else None + missing_exog_cols = [col for col in exog_cols if col not in data.columns] + if missing_exog_cols: + available_columns = ", ".join(repr(col) for col in data.columns) + raise ValueError( + f"Exogenous column(s) not found in data: {missing_exog_cols!r}. " + f"Available columns: [{available_columns}]" + ) + + X = data[exog_cols] else: # Use all columns except target as exogenous other_cols = [col for col in data.columns if col != target_col] diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 35a2ab21..06929d1d 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -42,6 +42,28 @@ def _discover_demo_datasets() -> dict: DEMO_DATASETS = _discover_demo_datasets() +def _merge_adapter_validation_warnings( + validation_report: dict[str, Any], + metadata: dict[str, Any], +) -> dict[str, Any]: + """Merge warnings added during adapter conversion into validation output.""" + metadata_validation = metadata.get("validation") + if not isinstance(metadata_validation, dict): + return validation_report + + metadata_warnings = metadata_validation.get("warnings", []) + if not metadata_warnings: + return validation_report + + merged = validation_report.copy() + existing_warnings = list(merged.get("warnings", [])) + for warning in metadata_warnings: + if warning not in existing_warnings: + existing_warnings.append(warning) + merged["warnings"] = existing_warnings + return merged + + class Executor: """ Execution runtime for sktime estimators. @@ -524,6 +546,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]: # Update metadata to reflect the target and used columns metadata = adapter.get_metadata().copy() + validation_report = _merge_adapter_validation_warnings(validation_report, metadata) metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"] if X is not None: metadata["exog_columns"] = list(X.columns) @@ -647,6 +670,7 @@ async def load_data_source_async( y, X = adapter.to_sktime_format(data) metadata = adapter.get_metadata().copy() + validation_report = _merge_adapter_validation_warnings(validation_report, metadata) metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"] if X is not None: metadata["exog_columns"] = list(X.columns) diff --git a/tests/test_data_exog_column_validation.py b/tests/test_data_exog_column_validation.py new file mode 100644 index 00000000..ace1f046 --- /dev/null +++ b/tests/test_data_exog_column_validation.py @@ -0,0 +1,82 @@ +"""Tests for explicit exogenous column validation in data adapters.""" + +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from sktime_mcp.data.adapters.pandas_adapter import PandasAdapter +from sktime_mcp.runtime.executor import Executor + + +def test_to_sktime_format_rejects_missing_explicit_exog_column(): + """Explicit exog_columns should fail if any requested column is absent.""" + adapter = PandasAdapter( + { + "type": "pandas", + "data": { + "date": ["2020-01-01", "2020-01-02", "2020-01-03"], + "value": [1, 2, 3], + "promo": [0, 1, 0], + }, + "time_column": "date", + "target_column": "value", + "exog_columns": ["promo", "holiday"], + } + ) + data = adapter.load() + + with pytest.raises(ValueError, match="Exogenous column\\(s\\) not found"): + adapter.to_sktime_format(data) + + +def test_load_data_source_returns_error_for_missing_explicit_exog_column(): + """Executor should surface missing exog columns as a structured tool error.""" + executor = Executor() + + result = executor.load_data_source( + { + "type": "pandas", + "data": { + "date": ["2020-01-01", "2020-01-02", "2020-01-03"], + "value": [1, 2, 3], + "promo": [0, 1, 0], + }, + "time_column": "date", + "target_column": "value", + "exog_columns": ["promo", "holiday"], + } + ) + + assert result["success"] is False + assert result["error_type"] == "ValueError" + assert "Exogenous column(s) not found" in result["error"] + assert "holiday" in result["error"] + assert "promo" in result["error"] + + +def test_to_sktime_format_keeps_all_valid_explicit_exog_columns(): + """Valid exog_columns should be preserved exactly when all columns exist.""" + adapter = PandasAdapter( + { + "type": "pandas", + "data": { + "date": ["2020-01-01", "2020-01-02", "2020-01-03"], + "value": [1, 2, 3], + "promo": [0, 1, 0], + "price": [10, 11, 12], + }, + "time_column": "date", + "target_column": "value", + "exog_columns": ["promo", "price"], + } + ) + data = adapter.load() + + y, X = adapter.to_sktime_format(data) + + assert y.name == "value" + assert list(X.columns) == ["promo", "price"] + diff --git a/tests/test_data_validation_warnings.py b/tests/test_data_validation_warnings.py new file mode 100644 index 00000000..336aec51 --- /dev/null +++ b/tests/test_data_validation_warnings.py @@ -0,0 +1,38 @@ +"""Tests for data validation warning propagation.""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from sktime_mcp.runtime.executor import Executor + + +def _ambiguous_target_config(): + return { + "type": "pandas", + "data": { + "date": ["2020-01-01", "2020-01-02", "2020-01-03"], + "value": [1, 2, 3], + }, + } + + +def test_load_data_source_propagates_default_target_warning(): + """Default target warnings should appear in the top-level validation result.""" + result = Executor().load_data_source(_ambiguous_target_config()) + + assert result["success"] is True + warnings = result["validation"]["warnings"] + assert any("Target column not specified" in warning for warning in warnings) + + +def test_load_data_source_async_propagates_default_target_warning(): + """Async load should expose the same default-target warning as sync load.""" + result = asyncio.run(Executor().load_data_source_async(_ambiguous_target_config())) + + assert result["success"] is True + warnings = result["validation"]["warnings"] + assert any("Target column not specified" in warning for warning in warnings) +