Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/sktime_mcp/tools/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Generates Python code to recreate estimators and pipelines.
"""

import inspect
import keyword
from typing import Any

Expand Down Expand Up @@ -51,6 +52,21 @@ def _is_valid_var_name(var_name: str) -> bool:
return isinstance(var_name, str) and var_name.isidentifier() and not keyword.iskeyword(var_name)


def _supports_forecasting_fit_example(dataset: str) -> bool:
"""Return whether a demo dataset is compatible with the forecasting example template."""
module_path = DEMO_DATASETS.get(dataset)
if module_path is None:
return False

module_parts = module_path.rsplit(".", 1)
module = __import__(module_parts[0], fromlist=[module_parts[1]])
loader = getattr(module, module_parts[1])
signature = inspect.signature(loader)

# Supervised/classification/regression demo loaders typically expose return_X_y.
return "return_X_y" not in signature.parameters


def _generate_single_estimator_code(
estimator_name: str, params: dict[str, Any], var_name: str = "model"
) -> dict[str, Any]:
Expand Down Expand Up @@ -254,7 +270,23 @@ def export_code_tool(
# Optionally add fit/predict example
if include_fit_example:
# Resolve the dataset loader from DEMO_DATASETS
if dataset and dataset in DEMO_DATASETS:
if dataset and dataset not in DEMO_DATASETS:
return {
"success": False,
"error": f"Unknown dataset: {dataset}",
"available": sorted(DEMO_DATASETS.keys()),
}

if dataset and not _supports_forecasting_fit_example(dataset):
return {
"success": False,
"error": (
f"Dataset '{dataset}' is not supported for forecasting fit examples. "
"Please use a forecasting-compatible demo dataset such as 'airline'."
),
}

if dataset:
module_path = DEMO_DATASETS[dataset]
module_parts = module_path.rsplit(".", 1)
loader_module = module_parts[0]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,26 @@ def test_include_fit_example_false(self):
finally:
self._cleanup_handle(handle)

def test_include_fit_example_supported_dataset(self):
"""Supported forecasting demo datasets should still produce example code."""
handle = self._create_handle()
try:
result = export_code_tool(handle, include_fit_example=True, dataset="airline")
assert result["success"]
assert "load_airline" in result["code"]
finally:
self._cleanup_handle(handle)

def test_include_fit_example_rejects_unsupported_demo_dataset(self):
"""Supervised demo datasets should be rejected for forecasting example generation."""
handle = self._create_handle()
try:
result = export_code_tool(handle, include_fit_example=True, dataset="basic_motions")
assert result["success"] is False
assert "not supported for forecasting fit examples" in result["error"]
finally:
self._cleanup_handle(handle)

def test_custom_var_name(self):
"""Custom var_name should appear in the generated code."""
handle = self._create_handle()
Expand Down
Loading