Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@
}


def _merge_adapter_validation_warnings(
validation_report: dict[str, Any],
metadata: dict[str, Any],
) -> dict[str, Any]:
"""Merge warnings added during adapter conversion into validation output."""
metadata_validation = metadata.get("validation")
if not isinstance(metadata_validation, dict):
return validation_report

metadata_warnings = metadata_validation.get("warnings", [])
if not metadata_warnings:
return validation_report

merged = validation_report.copy()
existing_warnings = list(merged.get("warnings", []))
for warning in metadata_warnings:
if warning not in existing_warnings:
existing_warnings.append(warning)
merged["warnings"] = existing_warnings
return merged


class Executor:
"""
Execution runtime for sktime estimators.
Expand Down Expand Up @@ -497,6 +519,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]:

# Update metadata to reflect the target and used columns
metadata = adapter.get_metadata().copy()
validation_report = _merge_adapter_validation_warnings(validation_report, metadata)
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
Expand Down Expand Up @@ -620,6 +643,7 @@ async def load_data_source_async(
y, X = adapter.to_sktime_format(data)

metadata = adapter.get_metadata().copy()
validation_report = _merge_adapter_validation_warnings(validation_report, metadata)
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_data_validation_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Tests for data validation warning propagation."""

import asyncio
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from sktime_mcp.runtime.executor import Executor


def _ambiguous_target_config():
return {
"type": "pandas",
"data": {
"date": ["2020-01-01", "2020-01-02", "2020-01-03"],
"value": [1, 2, 3],
},
}


def test_load_data_source_propagates_default_target_warning():
"""Default target warnings should appear in the top-level validation result."""
result = Executor().load_data_source(_ambiguous_target_config())

assert result["success"] is True
warnings = result["validation"]["warnings"]
assert any("Target column not specified" in warning for warning in warnings)


def test_load_data_source_async_propagates_default_target_warning():
"""Async load should expose the same default-target warning as sync load."""
result = asyncio.run(Executor().load_data_source_async(_ambiguous_target_config()))

assert result["success"] is True
warnings = result["validation"]["warnings"]
assert any("Target column not specified" in warning for warning in warnings)