diff --git a/src/sktime_mcp/tools/codegen.py b/src/sktime_mcp/tools/codegen.py index e68a6df6..0320d35f 100644 --- a/src/sktime_mcp/tools/codegen.py +++ b/src/sktime_mcp/tools/codegen.py @@ -4,6 +4,7 @@ Generates Python code to recreate estimators and pipelines. """ +import inspect import keyword from typing import Any @@ -51,6 +52,21 @@ def _is_valid_var_name(var_name: str) -> bool: return isinstance(var_name, str) and var_name.isidentifier() and not keyword.iskeyword(var_name) +def _supports_forecasting_fit_example(dataset: str) -> bool: + """Return whether a demo dataset is compatible with the forecasting example template.""" + module_path = DEMO_DATASETS.get(dataset) + if module_path is None: + return False + + module_parts = module_path.rsplit(".", 1) + module = __import__(module_parts[0], fromlist=[module_parts[1]]) + loader = getattr(module, module_parts[1]) + signature = inspect.signature(loader) + + # Supervised/classification/regression demo loaders typically expose return_X_y. + return "return_X_y" not in signature.parameters + + def _generate_single_estimator_code( estimator_name: str, params: dict[str, Any], var_name: str = "model" ) -> dict[str, Any]: @@ -254,7 +270,23 @@ def export_code_tool( # Optionally add fit/predict example if include_fit_example: # Resolve the dataset loader from DEMO_DATASETS - if dataset and dataset in DEMO_DATASETS: + if dataset and dataset not in DEMO_DATASETS: + return { + "success": False, + "error": f"Unknown dataset: {dataset}", + "available": sorted(DEMO_DATASETS.keys()), + } + + if dataset and not _supports_forecasting_fit_example(dataset): + return { + "success": False, + "error": ( + f"Dataset '{dataset}' is not supported for forecasting fit examples. " + "Please use a forecasting-compatible demo dataset such as 'airline'." + ), + } + + if dataset: module_path = DEMO_DATASETS[dataset] module_parts = module_path.rsplit(".", 1) loader_module = module_parts[0] diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 95e9a71e..9e5fcd8a 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -242,6 +242,26 @@ def test_include_fit_example_false(self): finally: self._cleanup_handle(handle) + def test_include_fit_example_supported_dataset(self): + """Supported forecasting demo datasets should still produce example code.""" + handle = self._create_handle() + try: + result = export_code_tool(handle, include_fit_example=True, dataset="airline") + assert result["success"] + assert "load_airline" in result["code"] + finally: + self._cleanup_handle(handle) + + def test_include_fit_example_rejects_unsupported_demo_dataset(self): + """Supervised demo datasets should be rejected for forecasting example generation.""" + handle = self._create_handle() + try: + result = export_code_tool(handle, include_fit_example=True, dataset="basic_motions") + assert result["success"] is False + assert "not supported for forecasting fit examples" in result["error"] + finally: + self._cleanup_handle(handle) + def test_custom_var_name(self): """Custom var_name should appear in the generated code.""" handle = self._create_handle()