Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions skydiscover/search/utils/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def save(
with open(os.path.join(save_path, "metadata.json"), "w") as f:
json.dump(metadata, f)

self._write_evolution_trace(programs, best_program_id, last_iteration, save_path)

logger.info(f"[CHECKPOINT] Saved database with {len(programs)} programs to {save_path}")

def load(self, path: str) -> Tuple[Dict[str, Program], Optional[str], int]:
Expand Down Expand Up @@ -137,6 +139,50 @@ def load(self, path: str) -> Tuple[Dict[str, Program], Optional[str], int]:

return programs, best_program_id, last_iteration

def _write_evolution_trace(
self,
programs: Dict[str, Program],
best_program_id: Optional[str],
last_iteration: int,
save_path: str,
) -> None:
"""Write evolution_trace.json — all programs in iteration order.

Produces a single file that captures the full evolution history:
every program the search ever generated, its score, metrics,
lineage (parent_id), and solution. Useful for plotting score
trajectories, inspecting lineage, or replaying a run.
"""
from skydiscover.utils.metrics import get_score

entries = []
for program in programs.values():
entries.append(
{
"id": program.id,
"iteration_found": program.iteration_found,
"generation": program.generation,
"score": get_score(program.metrics) if program.metrics else None,
"metrics": program.metrics,
"parent_id": program.parent_id,
"timestamp": program.timestamp,
"solution": program.solution,
}
)

entries.sort(key=lambda e: (e["iteration_found"], e["timestamp"]))

trace = {
"last_iteration": last_iteration,
"best_program_id": best_program_id,
"total_programs": len(entries),
"programs": entries,
}

trace_path = os.path.join(save_path, "evolution_trace.json")
with open(trace_path, "w") as f:
json.dump(trace, f, indent=2, cls=SafeJSONEncoder)

def _save_program(
self,
program: Program,
Expand Down
107 changes: 107 additions & 0 deletions tests/search/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Tests for CheckpointManager.save() — evolution_trace.json output."""

import json
import time

import pytest

from skydiscover.config import DatabaseConfig
from skydiscover.search.base_database import Program
from skydiscover.search.utils.checkpoint_manager import CheckpointManager


def _make_program(id_: str, iteration: int, score: float, parent_id=None) -> Program:
return Program(
id=id_,
solution=f"def solve(): return {score}",
language="python",
metrics={"combined_score": score},
iteration_found=iteration,
generation=iteration,
parent_id=parent_id,
timestamp=time.time(),
)


@pytest.fixture
def manager(tmp_path):
config = DatabaseConfig(db_path=str(tmp_path))
return CheckpointManager(config)


class TestEvolutionTrace:
def test_trace_file_is_created(self, manager, tmp_path):
programs = {"a": _make_program("a", 0, 0.5)}
manager.save(programs, None, "a", 0)
assert (tmp_path / "evolution_trace.json").exists()

def test_trace_top_level_fields(self, manager, tmp_path):
programs = {
"a": _make_program("a", 0, 0.5),
"b": _make_program("b", 1, 0.8, parent_id="a"),
}
manager.save(programs, None, "b", 1)
trace = json.loads((tmp_path / "evolution_trace.json").read_text())

assert trace["last_iteration"] == 1
assert trace["best_program_id"] == "b"
assert trace["total_programs"] == 2
assert len(trace["programs"]) == 2

def test_trace_sorted_by_iteration(self, manager, tmp_path):
programs = {
"c": _make_program("c", 2, 0.9),
"a": _make_program("a", 0, 0.5),
"b": _make_program("b", 1, 0.7),
}
manager.save(programs, None, "c", 2)
trace = json.loads((tmp_path / "evolution_trace.json").read_text())

iterations = [p["iteration_found"] for p in trace["programs"]]
assert iterations == sorted(iterations)

def test_trace_entry_fields(self, manager, tmp_path):
prog = _make_program("a", 0, 0.5)
manager.save({"a": prog}, None, "a", 0)
trace = json.loads((tmp_path / "evolution_trace.json").read_text())

entry = trace["programs"][0]
assert entry["id"] == "a"
assert entry["score"] == pytest.approx(0.5)
assert entry["metrics"] == {"combined_score": 0.5}
assert entry["parent_id"] is None
assert entry["solution"] == prog.solution
assert "timestamp" in entry

def test_trace_preserves_parent_id(self, manager, tmp_path):
programs = {
"a": _make_program("a", 0, 0.5),
"b": _make_program("b", 1, 0.8, parent_id="a"),
}
manager.save(programs, None, "b", 1)
trace = json.loads((tmp_path / "evolution_trace.json").read_text())

by_id = {p["id"]: p for p in trace["programs"]}
assert by_id["a"]["parent_id"] is None
assert by_id["b"]["parent_id"] == "a"

def test_trace_no_metrics_score_is_none(self, manager, tmp_path):
prog = Program(id="x", solution="pass", metrics={})
manager.save({"x": prog}, None, None, 0)
trace = json.loads((tmp_path / "evolution_trace.json").read_text())
assert trace["programs"][0]["score"] is None

def test_trace_overwritten_on_subsequent_save(self, manager, tmp_path):
manager.save({"a": _make_program("a", 0, 0.5)}, None, "a", 0)
manager.save(
{
"a": _make_program("a", 0, 0.5),
"b": _make_program("b", 1, 0.9),
},
None,
"b",
1,
)
trace = json.loads((tmp_path / "evolution_trace.json").read_text())
assert trace["total_programs"] == 2
assert trace["last_iteration"] == 1
195 changes: 195 additions & 0 deletions tests/search/test_evolution_trace_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Integration test: evolution_trace.json is written correctly during a real discovery run.

Runs the full pipeline (3 iterations, mocked LLM) and verifies:
- evolution_trace.json exists in every checkpoint directory
- Contents are sorted, complete, and structurally correct
- Old checkpoint artefacts (metadata.json, best_program_info.json) still exist
- Trace score matches what the evaluator returns
"""

import json
import os
import textwrap
from typing import Any, Dict, List
from unittest.mock import patch

import pytest

from skydiscover.api import DiscoveryResult, run_discovery
from skydiscover.config import Config, LLMModelConfig
from skydiscover.llm.base import LLMResponse


EVALUATOR_SOURCE = textwrap.dedent("""\
import ast

def evaluate(program_path: str) -> dict:
with open(program_path) as f:
source = f.read()
try:
tree = ast.parse(source)
except SyntaxError:
return {"combined_score": 0.0}
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "solve":
return {"combined_score": 0.9}
return {"combined_score": 0.1}
""")

SEED_SOURCE = textwrap.dedent("""\
def hello():
return "hi"
""")

MOCK_LLM_CODE = textwrap.dedent("""\
def solve(x):
return x * 2
""")

MOCK_RESPONSE = f"```python\n{MOCK_LLM_CODE}```"


class FakeLLMPool:
def __init__(self, models_cfg):
self.models_cfg = models_cfg

async def generate(self, system_message, messages, **kwargs):
return LLMResponse(text=MOCK_RESPONSE)

async def generate_all(self, system_message, messages, **kwargs):
return [LLMResponse(text=MOCK_RESPONSE)]


def _find_checkpoints(output_dir: str) -> List[str]:
checkpoint_dir = os.path.join(output_dir, "checkpoints")
if not os.path.isdir(checkpoint_dir):
return []
dirs = []
for name in os.listdir(checkpoint_dir):
full = os.path.join(checkpoint_dir, name)
if os.path.isdir(full) and name.startswith("checkpoint_"):
dirs.append(full)
return sorted(dirs)


class TestEvolutionTraceIntegration:
def _run(self, tmp_path, iterations=3):
evaluator_file = tmp_path / "evaluator.py"
evaluator_file.write_text(EVALUATOR_SOURCE)
seed_file = tmp_path / "seed.py"
seed_file.write_text(SEED_SOURCE)
output_dir = str(tmp_path / "output")

config = Config.from_dict(
{
"max_iterations": iterations,
"diff_based_generation": False,
"monitor": {"enabled": False},
"search": {"type": "topk"},
"evaluator": {"evaluation_file": str(evaluator_file)},
"llm": {
"models": [
{
"name": "fake-model",
"api_key": "fake",
"api_base": "http://localhost:1",
}
]
},
}
)

with patch(
"skydiscover.search.default_discovery_controller.LLMPool",
FakeLLMPool,
):
result = run_discovery(
evaluator=str(evaluator_file),
initial_program=str(seed_file),
config=config,
output_dir=output_dir,
cleanup=False,
)

return result, output_dir

def test_trace_file_exists_in_every_checkpoint(self, tmp_path):
_, output_dir = self._run(tmp_path)
checkpoints = _find_checkpoints(output_dir)
assert checkpoints, "No checkpoints were created"
for ckpt in checkpoints:
assert os.path.exists(os.path.join(ckpt, "evolution_trace.json")), (
f"Missing evolution_trace.json in {ckpt}"
)

def test_old_checkpoint_artefacts_still_present(self, tmp_path):
"""metadata.json and best_program_info.json must still exist."""
_, output_dir = self._run(tmp_path)
for ckpt in _find_checkpoints(output_dir):
assert os.path.exists(os.path.join(ckpt, "metadata.json"))
assert os.path.exists(os.path.join(ckpt, "programs"))

def test_trace_is_valid_json(self, tmp_path):
_, output_dir = self._run(tmp_path)
for ckpt in _find_checkpoints(output_dir):
trace_path = os.path.join(ckpt, "evolution_trace.json")
with open(trace_path) as f:
trace = json.load(f)
assert "programs" in trace
assert "last_iteration" in trace
assert "best_program_id" in trace
assert "total_programs" in trace

def test_trace_programs_sorted_by_iteration(self, tmp_path):
_, output_dir = self._run(tmp_path)
# Check the final checkpoint (has the most programs)
ckpt = sorted(_find_checkpoints(output_dir))[-1]
with open(os.path.join(ckpt, "evolution_trace.json")) as f:
trace = json.load(f)
iterations = [p["iteration_found"] for p in trace["programs"]]
assert iterations == sorted(iterations)

def test_trace_total_programs_matches_programs_list(self, tmp_path):
_, output_dir = self._run(tmp_path)
for ckpt in _find_checkpoints(output_dir):
with open(os.path.join(ckpt, "evolution_trace.json")) as f:
trace = json.load(f)
assert trace["total_programs"] == len(trace["programs"])

def test_trace_best_program_id_matches_metadata(self, tmp_path):
_, output_dir = self._run(tmp_path)
for ckpt in _find_checkpoints(output_dir):
with open(os.path.join(ckpt, "evolution_trace.json")) as f:
trace = json.load(f)
with open(os.path.join(ckpt, "metadata.json")) as f:
metadata = json.load(f)
assert trace["best_program_id"] == metadata["best_program_id"]

def test_trace_entries_have_required_fields(self, tmp_path):
_, output_dir = self._run(tmp_path)
ckpt = sorted(_find_checkpoints(output_dir))[-1]
with open(os.path.join(ckpt, "evolution_trace.json")) as f:
trace = json.load(f)
required = {"id", "iteration_found", "generation", "score", "metrics", "parent_id",
"timestamp", "solution"}
for entry in trace["programs"]:
assert required.issubset(entry.keys()), f"Missing fields in entry: {entry.keys()}"

def test_trace_score_matches_evaluator(self, tmp_path):
"""The mock LLM produces `def solve` which scores 0.9."""
_, output_dir = self._run(tmp_path)
ckpt = sorted(_find_checkpoints(output_dir))[-1]
with open(os.path.join(ckpt, "evolution_trace.json")) as f:
trace = json.load(f)
# All programs after iteration 0 should be from the mock LLM (def solve → 0.9)
llm_programs = [p for p in trace["programs"] if p["iteration_found"] > 0]
for p in llm_programs:
assert p["score"] == pytest.approx(0.9), f"Unexpected score: {p['score']}"

def test_best_program_in_trace(self, tmp_path):
result, output_dir = self._run(tmp_path)
ckpt = sorted(_find_checkpoints(output_dir))[-1]
with open(os.path.join(ckpt, "evolution_trace.json")) as f:
trace = json.load(f)
ids = {p["id"] for p in trace["programs"]}
assert trace["best_program_id"] in ids
Loading