Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions src/sktime_mcp/tools/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,11 @@ def export_code_tool(
# Optionally add fit/predict example
if include_fit_example:
# Priority: explicit argument > dataset used during fit_predict > "airline" fallback
effective_dataset = (
dataset
or handle_info.metadata.get("training_dataset")
or "airline"
)
# Resolve the dataset loader from DEMO_DATASETS
if effective_dataset in DEMO_DATASETS:
module_path = DEMO_DATASETS[effective_dataset]
effective_dataset = dataset or handle_info.metadata.get("training_dataset") or "airline"
# Resolve the dataset loader from discovered demo datasets
demo_datasets = _get_demo_datasets()
if effective_dataset in demo_datasets:
module_path = demo_datasets[effective_dataset]
module_parts = module_path.rsplit(".", 1)
loader_module = module_parts[0]
loader_func = module_parts[1]
Expand Down
4 changes: 1 addition & 3 deletions tests/test_evaluate_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def test_summary_contains_expected_stat_keys(self):

for metric_name, stats in summary.items():
for key in ("mean", "std", "min", "max"):
assert key in stats, (
f"Expected '{key}' in summary['{metric_name}'], got {stats}"
)
assert key in stats, f"Expected '{key}' in summary['{metric_name}'], got {stats}"
assert isinstance(stats[key], float), (
f"summary['{metric_name}']['{key}'] should be float, got {type(stats[key])}"
)
Expand Down
12 changes: 3 additions & 9 deletions tests/test_export_code_dataset_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def test_export_code_uses_training_dataset_by_default(self):
assert code_result["success"], code_result

code = code_result["code"]
assert "sunspots" in code, (
f"Expected 'sunspots' to appear in exported code, got:\n{code}"
)
assert "sunspots" in code, f"Expected 'sunspots' to appear in exported code, got:\n{code}"

def test_export_code_explicit_dataset_overrides_metadata(self):
"""An explicit dataset argument to export_code must take priority over metadata."""
Expand All @@ -66,9 +64,7 @@ def test_export_code_explicit_dataset_overrides_metadata(self):
assert code_result["success"], code_result

code = code_result["code"]
assert "airline" in code, (
f"Expected 'airline' to appear when explicitly set, got:\n{code}"
)
assert "airline" in code, f"Expected 'airline' to appear when explicitly set, got:\n{code}"

def test_export_code_falls_back_to_airline_without_fit(self):
"""export_code on a never-fitted handle must fall back to 'airline'."""
Expand All @@ -83,6 +79,4 @@ def test_export_code_falls_back_to_airline_without_fit(self):
assert code_result["success"], code_result

code = code_result["code"]
assert "airline" in code, (
f"Expected 'airline' fallback for unfitted handle, got:\n{code}"
)
assert "airline" in code, f"Expected 'airline' fallback for unfitted handle, got:\n{code}"
Loading