diff --git a/src/sktime_mcp/tools/codegen.py b/src/sktime_mcp/tools/codegen.py index 058fe969..7a400dc7 100644 --- a/src/sktime_mcp/tools/codegen.py +++ b/src/sktime_mcp/tools/codegen.py @@ -305,14 +305,11 @@ def export_code_tool( # Optionally add fit/predict example if include_fit_example: # Priority: explicit argument > dataset used during fit_predict > "airline" fallback - effective_dataset = ( - dataset - or handle_info.metadata.get("training_dataset") - or "airline" - ) - # Resolve the dataset loader from DEMO_DATASETS - if effective_dataset in DEMO_DATASETS: - module_path = DEMO_DATASETS[effective_dataset] + effective_dataset = dataset or handle_info.metadata.get("training_dataset") or "airline" + # Resolve the dataset loader from discovered demo datasets + demo_datasets = _get_demo_datasets() + if effective_dataset in demo_datasets: + module_path = demo_datasets[effective_dataset] module_parts = module_path.rsplit(".", 1) loader_module = module_parts[0] loader_func = module_parts[1] diff --git a/tests/test_evaluate_summary.py b/tests/test_evaluate_summary.py index a68ad0dd..030c1481 100644 --- a/tests/test_evaluate_summary.py +++ b/tests/test_evaluate_summary.py @@ -41,9 +41,7 @@ def test_summary_contains_expected_stat_keys(self): for metric_name, stats in summary.items(): for key in ("mean", "std", "min", "max"): - assert key in stats, ( - f"Expected '{key}' in summary['{metric_name}'], got {stats}" - ) + assert key in stats, f"Expected '{key}' in summary['{metric_name}'], got {stats}" assert isinstance(stats[key], float), ( f"summary['{metric_name}']['{key}'] should be float, got {type(stats[key])}" ) diff --git a/tests/test_export_code_dataset_tracking.py b/tests/test_export_code_dataset_tracking.py index 116b47a5..7086aeff 100644 --- a/tests/test_export_code_dataset_tracking.py +++ b/tests/test_export_code_dataset_tracking.py @@ -44,9 +44,7 @@ def test_export_code_uses_training_dataset_by_default(self): assert code_result["success"], code_result code = code_result["code"] - assert "sunspots" in code, ( - f"Expected 'sunspots' to appear in exported code, got:\n{code}" - ) + assert "sunspots" in code, f"Expected 'sunspots' to appear in exported code, got:\n{code}" def test_export_code_explicit_dataset_overrides_metadata(self): """An explicit dataset argument to export_code must take priority over metadata.""" @@ -66,9 +64,7 @@ def test_export_code_explicit_dataset_overrides_metadata(self): assert code_result["success"], code_result code = code_result["code"] - assert "airline" in code, ( - f"Expected 'airline' to appear when explicitly set, got:\n{code}" - ) + assert "airline" in code, f"Expected 'airline' to appear when explicitly set, got:\n{code}" def test_export_code_falls_back_to_airline_without_fit(self): """export_code on a never-fitted handle must fall back to 'airline'.""" @@ -83,6 +79,4 @@ def test_export_code_falls_back_to_airline_without_fit(self): assert code_result["success"], code_result code = code_result["code"] - assert "airline" in code, ( - f"Expected 'airline' fallback for unfitted handle, got:\n{code}" - ) + assert "airline" in code, f"Expected 'airline' fallback for unfitted handle, got:\n{code}"