diff --git a/skydiscover/search/utils/checkpoint_manager.py b/skydiscover/search/utils/checkpoint_manager.py index 703beb05..0644f31b 100644 --- a/skydiscover/search/utils/checkpoint_manager.py +++ b/skydiscover/search/utils/checkpoint_manager.py @@ -84,6 +84,8 @@ def save( with open(os.path.join(save_path, "metadata.json"), "w") as f: json.dump(metadata, f) + self._write_evolution_trace(programs, best_program_id, last_iteration, save_path) + logger.info(f"[CHECKPOINT] Saved database with {len(programs)} programs to {save_path}") def load(self, path: str) -> Tuple[Dict[str, Program], Optional[str], int]: @@ -137,6 +139,50 @@ def load(self, path: str) -> Tuple[Dict[str, Program], Optional[str], int]: return programs, best_program_id, last_iteration + def _write_evolution_trace( + self, + programs: Dict[str, Program], + best_program_id: Optional[str], + last_iteration: int, + save_path: str, + ) -> None: + """Write evolution_trace.json — all programs in iteration order. + + Produces a single file that captures the full evolution history: + every program the search ever generated, its score, metrics, + lineage (parent_id), and solution. Useful for plotting score + trajectories, inspecting lineage, or replaying a run. + """ + from skydiscover.utils.metrics import get_score + + entries = [] + for program in programs.values(): + entries.append( + { + "id": program.id, + "iteration_found": program.iteration_found, + "generation": program.generation, + "score": get_score(program.metrics) if program.metrics else None, + "metrics": program.metrics, + "parent_id": program.parent_id, + "timestamp": program.timestamp, + "solution": program.solution, + } + ) + + entries.sort(key=lambda e: (e["iteration_found"], e["timestamp"])) + + trace = { + "last_iteration": last_iteration, + "best_program_id": best_program_id, + "total_programs": len(entries), + "programs": entries, + } + + trace_path = os.path.join(save_path, "evolution_trace.json") + with open(trace_path, "w") as f: + json.dump(trace, f, indent=2, cls=SafeJSONEncoder) + def _save_program( self, program: Program, diff --git a/tests/search/test_checkpoint_manager.py b/tests/search/test_checkpoint_manager.py new file mode 100644 index 00000000..bdb4acc3 --- /dev/null +++ b/tests/search/test_checkpoint_manager.py @@ -0,0 +1,107 @@ +"""Tests for CheckpointManager.save() — evolution_trace.json output.""" + +import json +import time + +import pytest + +from skydiscover.config import DatabaseConfig +from skydiscover.search.base_database import Program +from skydiscover.search.utils.checkpoint_manager import CheckpointManager + + +def _make_program(id_: str, iteration: int, score: float, parent_id=None) -> Program: + return Program( + id=id_, + solution=f"def solve(): return {score}", + language="python", + metrics={"combined_score": score}, + iteration_found=iteration, + generation=iteration, + parent_id=parent_id, + timestamp=time.time(), + ) + + +@pytest.fixture +def manager(tmp_path): + config = DatabaseConfig(db_path=str(tmp_path)) + return CheckpointManager(config) + + +class TestEvolutionTrace: + def test_trace_file_is_created(self, manager, tmp_path): + programs = {"a": _make_program("a", 0, 0.5)} + manager.save(programs, None, "a", 0) + assert (tmp_path / "evolution_trace.json").exists() + + def test_trace_top_level_fields(self, manager, tmp_path): + programs = { + "a": _make_program("a", 0, 0.5), + "b": _make_program("b", 1, 0.8, parent_id="a"), + } + manager.save(programs, None, "b", 1) + trace = json.loads((tmp_path / "evolution_trace.json").read_text()) + + assert trace["last_iteration"] == 1 + assert trace["best_program_id"] == "b" + assert trace["total_programs"] == 2 + assert len(trace["programs"]) == 2 + + def test_trace_sorted_by_iteration(self, manager, tmp_path): + programs = { + "c": _make_program("c", 2, 0.9), + "a": _make_program("a", 0, 0.5), + "b": _make_program("b", 1, 0.7), + } + manager.save(programs, None, "c", 2) + trace = json.loads((tmp_path / "evolution_trace.json").read_text()) + + iterations = [p["iteration_found"] for p in trace["programs"]] + assert iterations == sorted(iterations) + + def test_trace_entry_fields(self, manager, tmp_path): + prog = _make_program("a", 0, 0.5) + manager.save({"a": prog}, None, "a", 0) + trace = json.loads((tmp_path / "evolution_trace.json").read_text()) + + entry = trace["programs"][0] + assert entry["id"] == "a" + assert entry["score"] == pytest.approx(0.5) + assert entry["metrics"] == {"combined_score": 0.5} + assert entry["parent_id"] is None + assert entry["solution"] == prog.solution + assert "timestamp" in entry + + def test_trace_preserves_parent_id(self, manager, tmp_path): + programs = { + "a": _make_program("a", 0, 0.5), + "b": _make_program("b", 1, 0.8, parent_id="a"), + } + manager.save(programs, None, "b", 1) + trace = json.loads((tmp_path / "evolution_trace.json").read_text()) + + by_id = {p["id"]: p for p in trace["programs"]} + assert by_id["a"]["parent_id"] is None + assert by_id["b"]["parent_id"] == "a" + + def test_trace_no_metrics_score_is_none(self, manager, tmp_path): + prog = Program(id="x", solution="pass", metrics={}) + manager.save({"x": prog}, None, None, 0) + trace = json.loads((tmp_path / "evolution_trace.json").read_text()) + assert trace["programs"][0]["score"] is None + + def test_trace_overwritten_on_subsequent_save(self, manager, tmp_path): + manager.save({"a": _make_program("a", 0, 0.5)}, None, "a", 0) + manager.save( + { + "a": _make_program("a", 0, 0.5), + "b": _make_program("b", 1, 0.9), + }, + None, + "b", + 1, + ) + trace = json.loads((tmp_path / "evolution_trace.json").read_text()) + assert trace["total_programs"] == 2 + assert trace["last_iteration"] == 1 diff --git a/tests/search/test_evolution_trace_integration.py b/tests/search/test_evolution_trace_integration.py new file mode 100644 index 00000000..3a7c0b96 --- /dev/null +++ b/tests/search/test_evolution_trace_integration.py @@ -0,0 +1,195 @@ +"""Integration test: evolution_trace.json is written correctly during a real discovery run. + +Runs the full pipeline (3 iterations, mocked LLM) and verifies: +- evolution_trace.json exists in every checkpoint directory +- Contents are sorted, complete, and structurally correct +- Old checkpoint artefacts (metadata.json, best_program_info.json) still exist +- Trace score matches what the evaluator returns +""" + +import json +import os +import textwrap +from typing import Any, Dict, List +from unittest.mock import patch + +import pytest + +from skydiscover.api import DiscoveryResult, run_discovery +from skydiscover.config import Config, LLMModelConfig +from skydiscover.llm.base import LLMResponse + + +EVALUATOR_SOURCE = textwrap.dedent("""\ + import ast + + def evaluate(program_path: str) -> dict: + with open(program_path) as f: + source = f.read() + try: + tree = ast.parse(source) + except SyntaxError: + return {"combined_score": 0.0} + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == "solve": + return {"combined_score": 0.9} + return {"combined_score": 0.1} +""") + +SEED_SOURCE = textwrap.dedent("""\ + def hello(): + return "hi" +""") + +MOCK_LLM_CODE = textwrap.dedent("""\ + def solve(x): + return x * 2 +""") + +MOCK_RESPONSE = f"```python\n{MOCK_LLM_CODE}```" + + +class FakeLLMPool: + def __init__(self, models_cfg): + self.models_cfg = models_cfg + + async def generate(self, system_message, messages, **kwargs): + return LLMResponse(text=MOCK_RESPONSE) + + async def generate_all(self, system_message, messages, **kwargs): + return [LLMResponse(text=MOCK_RESPONSE)] + + +def _find_checkpoints(output_dir: str) -> List[str]: + checkpoint_dir = os.path.join(output_dir, "checkpoints") + if not os.path.isdir(checkpoint_dir): + return [] + dirs = [] + for name in os.listdir(checkpoint_dir): + full = os.path.join(checkpoint_dir, name) + if os.path.isdir(full) and name.startswith("checkpoint_"): + dirs.append(full) + return sorted(dirs) + + +class TestEvolutionTraceIntegration: + def _run(self, tmp_path, iterations=3): + evaluator_file = tmp_path / "evaluator.py" + evaluator_file.write_text(EVALUATOR_SOURCE) + seed_file = tmp_path / "seed.py" + seed_file.write_text(SEED_SOURCE) + output_dir = str(tmp_path / "output") + + config = Config.from_dict( + { + "max_iterations": iterations, + "diff_based_generation": False, + "monitor": {"enabled": False}, + "search": {"type": "topk"}, + "evaluator": {"evaluation_file": str(evaluator_file)}, + "llm": { + "models": [ + { + "name": "fake-model", + "api_key": "fake", + "api_base": "http://localhost:1", + } + ] + }, + } + ) + + with patch( + "skydiscover.search.default_discovery_controller.LLMPool", + FakeLLMPool, + ): + result = run_discovery( + evaluator=str(evaluator_file), + initial_program=str(seed_file), + config=config, + output_dir=output_dir, + cleanup=False, + ) + + return result, output_dir + + def test_trace_file_exists_in_every_checkpoint(self, tmp_path): + _, output_dir = self._run(tmp_path) + checkpoints = _find_checkpoints(output_dir) + assert checkpoints, "No checkpoints were created" + for ckpt in checkpoints: + assert os.path.exists(os.path.join(ckpt, "evolution_trace.json")), ( + f"Missing evolution_trace.json in {ckpt}" + ) + + def test_old_checkpoint_artefacts_still_present(self, tmp_path): + """metadata.json and best_program_info.json must still exist.""" + _, output_dir = self._run(tmp_path) + for ckpt in _find_checkpoints(output_dir): + assert os.path.exists(os.path.join(ckpt, "metadata.json")) + assert os.path.exists(os.path.join(ckpt, "programs")) + + def test_trace_is_valid_json(self, tmp_path): + _, output_dir = self._run(tmp_path) + for ckpt in _find_checkpoints(output_dir): + trace_path = os.path.join(ckpt, "evolution_trace.json") + with open(trace_path) as f: + trace = json.load(f) + assert "programs" in trace + assert "last_iteration" in trace + assert "best_program_id" in trace + assert "total_programs" in trace + + def test_trace_programs_sorted_by_iteration(self, tmp_path): + _, output_dir = self._run(tmp_path) + # Check the final checkpoint (has the most programs) + ckpt = sorted(_find_checkpoints(output_dir))[-1] + with open(os.path.join(ckpt, "evolution_trace.json")) as f: + trace = json.load(f) + iterations = [p["iteration_found"] for p in trace["programs"]] + assert iterations == sorted(iterations) + + def test_trace_total_programs_matches_programs_list(self, tmp_path): + _, output_dir = self._run(tmp_path) + for ckpt in _find_checkpoints(output_dir): + with open(os.path.join(ckpt, "evolution_trace.json")) as f: + trace = json.load(f) + assert trace["total_programs"] == len(trace["programs"]) + + def test_trace_best_program_id_matches_metadata(self, tmp_path): + _, output_dir = self._run(tmp_path) + for ckpt in _find_checkpoints(output_dir): + with open(os.path.join(ckpt, "evolution_trace.json")) as f: + trace = json.load(f) + with open(os.path.join(ckpt, "metadata.json")) as f: + metadata = json.load(f) + assert trace["best_program_id"] == metadata["best_program_id"] + + def test_trace_entries_have_required_fields(self, tmp_path): + _, output_dir = self._run(tmp_path) + ckpt = sorted(_find_checkpoints(output_dir))[-1] + with open(os.path.join(ckpt, "evolution_trace.json")) as f: + trace = json.load(f) + required = {"id", "iteration_found", "generation", "score", "metrics", "parent_id", + "timestamp", "solution"} + for entry in trace["programs"]: + assert required.issubset(entry.keys()), f"Missing fields in entry: {entry.keys()}" + + def test_trace_score_matches_evaluator(self, tmp_path): + """The mock LLM produces `def solve` which scores 0.9.""" + _, output_dir = self._run(tmp_path) + ckpt = sorted(_find_checkpoints(output_dir))[-1] + with open(os.path.join(ckpt, "evolution_trace.json")) as f: + trace = json.load(f) + # All programs after iteration 0 should be from the mock LLM (def solve → 0.9) + llm_programs = [p for p in trace["programs"] if p["iteration_found"] > 0] + for p in llm_programs: + assert p["score"] == pytest.approx(0.9), f"Unexpected score: {p['score']}" + + def test_best_program_in_trace(self, tmp_path): + result, output_dir = self._run(tmp_path) + ckpt = sorted(_find_checkpoints(output_dir))[-1] + with open(os.path.join(ckpt, "evolution_trace.json")) as f: + trace = json.load(f) + ids = {p["id"] for p in trace["programs"]} + assert trace["best_program_id"] in ids