Skip to content

Commit 8aaf009

Browse files
authored
Merge pull request #25 from ben-nowacki/main
Bug fixes with on-disk result storage
2 parents d93e3dc + 730f029 commit 8aaf009

File tree

10 files changed

+309
-225
lines changed

10 files changed

+309
-225
lines changed

modularml/core/experiment/experiment.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ def _execute_training(
569569
phase: TrainPhase,
570570
*,
571571
_artifact_dir: Path | None = None,
572+
_callback_dir: Path | None = None,
572573
_execution_dir: Path | None = None,
573574
_metric_dir: Path | None = None,
574575
show_sampler_progress: bool = True,
@@ -620,6 +621,7 @@ def _execute_training(
620621
res = TrainResults(
621622
label=phase.label,
622623
_artifact_dir=_artifact_dir,
624+
_callback_dir=_callback_dir,
623625
_execution_dir=_execution_dir,
624626
_metric_dir=_metric_dir,
625627
)
@@ -684,6 +686,7 @@ def _execute_evaluation(
684686
phase: EvalPhase,
685687
*,
686688
_artifact_dir: Path | None = None,
689+
_callback_dir: Path | None = None,
687690
_execution_dir: Path | None = None,
688691
_metric_dir: Path | None = None,
689692
show_eval_progress: bool = False,
@@ -719,6 +722,7 @@ def _execute_evaluation(
719722
res = EvalResults(
720723
label=phase.label,
721724
_artifact_dir=_artifact_dir,
725+
_callback_dir=_callback_dir,
722726
_execution_dir=_execution_dir,
723727
_metric_dir=_metric_dir,
724728
)
@@ -742,6 +746,7 @@ def _execute_fit(
742746
phase: FitPhase,
743747
*,
744748
_artifact_dir: Path | None = None,
749+
_callback_dir: Path | None = None,
745750
_execution_dir: Path | None = None,
746751
_metric_dir: Path | None = None,
747752
) -> FitResults:
@@ -765,6 +770,7 @@ def _execute_fit(
765770
res = FitResults(
766771
label=phase.label,
767772
_artifact_dir=_artifact_dir,
773+
_callback_dir=_callback_dir,
768774
_execution_dir=_execution_dir,
769775
_metric_dir=_metric_dir,
770776
)
@@ -784,8 +790,7 @@ def _execute_phase_with_meta(
784790
self,
785791
phase: TrainPhase | EvalPhase | FitPhase,
786792
*,
787-
_path_suffix: Path | None = None,
788-
_run_idx: int | None = None,
793+
phase_dir: Path | None = None,
789794
**kwargs,
790795
) -> tuple[PhaseResults, PhaseExecutionMeta]:
791796
"""
@@ -808,9 +813,9 @@ def _execute_phase_with_meta(
808813
if exp_dir is None and self._exp_checkpointing is not None:
809814
exp_dir = self._exp_checkpointing.directory
810815
if exp_dir is not None:
811-
phase_dir = exp_dir / phase.label
812-
phase_dir.mkdir(parents=True, exist_ok=True)
813-
phase.checkpointing._directory = phase_dir
816+
ckpt_dir = exp_dir / phase.label
817+
ckpt_dir.mkdir(parents=True, exist_ok=True)
818+
phase.checkpointing._directory = ckpt_dir
814819

815820
# Skip callbacks and checkpointing when inside a callback
816821
run_hooks = not self._in_callback
@@ -837,23 +842,8 @@ def _execute_phase_with_meta(
837842
self._save_experiment_checkpoint(label=phase.label)
838843

839844
# ------------------------------------------------
840-
# Compute phase-specific storage directories
845+
# Compute phase-specific storage directories from the caller-supplied phase_dir
841846
# ------------------------------------------------
842-
if _path_suffix is not None:
843-
# Called from within a group; suffix already contains the run prefix
844-
phase_dir = self._results_config.phase_dir(_path_suffix / phase.label)
845-
elif _run_idx is not None:
846-
# Top-level call from run_phase; prefix with run index for stable ordering
847-
phase_dir = self._results_config.phase_dir(f"{_run_idx}_{phase.label}")
848-
elif (
849-
self._active_phase_dir is not None
850-
and self._results_config.results_dir is not None
851-
):
852-
# Called from preview_phase during callback execution → nest under callbacks/
853-
phase_dir = self._active_phase_dir / "callbacks" / phase.label
854-
else:
855-
phase_dir = None # pure preview with no disk storage
856-
857847
cfg = self._results_config
858848
phase_execution_dir = (
859849
phase_dir / "execution_data"
@@ -870,6 +860,11 @@ def _execute_phase_with_meta(
870860
if phase_dir is not None and cfg.save_artifacts
871861
else None
872862
)
863+
phase_callback_dir = (
864+
phase_dir / "callbacks"
865+
if phase_dir is not None and cfg.save_execution
866+
else None
867+
)
873868

874869
# Track active phase dir so nested callback previews nest under callbacks/
875870
prev_active_phase_dir = self._active_phase_dir
@@ -892,6 +887,7 @@ def _execute_phase_with_meta(
892887
phase_res: TrainResults = self._execute_training(
893888
phase,
894889
_artifact_dir=phase_artifact_dir,
890+
_callback_dir=phase_callback_dir,
895891
_execution_dir=phase_execution_dir,
896892
_metric_dir=phase_metric_dir,
897893
**{k: v for k, v in kwargs.items() if k in train_keys},
@@ -901,6 +897,7 @@ def _execute_phase_with_meta(
901897
phase_res: EvalResults = self._execute_evaluation(
902898
phase,
903899
_artifact_dir=phase_artifact_dir,
900+
_callback_dir=phase_callback_dir,
904901
_execution_dir=phase_execution_dir,
905902
_metric_dir=phase_metric_dir,
906903
**{k: v for k, v in kwargs.items() if k in eval_keys},
@@ -909,6 +906,7 @@ def _execute_phase_with_meta(
909906
phase_res: FitResults = self._execute_fit(
910907
phase,
911908
_artifact_dir=phase_artifact_dir,
909+
_callback_dir=phase_callback_dir,
912910
_execution_dir=phase_execution_dir,
913911
_metric_dir=phase_metric_dir,
914912
)
@@ -957,8 +955,7 @@ def _execute_group_with_meta(
957955
self,
958956
group: PhaseGroup,
959957
*,
960-
_path_suffix: Path | None = None,
961-
_run_idx: int | None = None,
958+
group_dir: Path | None = None,
962959
**kwargs,
963960
) -> tuple[PhaseGroupResults, PhaseGroupExecutionMeta]:
964961
"""
@@ -1001,15 +998,6 @@ def _execute_group_with_meta(
1001998
# - construct result container
1002999
# - run each phase in order
10031000
# ------------------------------------------------
1004-
if _path_suffix is not None:
1005-
# Nested group; suffix already contains the run prefix
1006-
group_suffix = _path_suffix / group.label
1007-
elif _run_idx is not None:
1008-
# Top-level call from run_group; prefix with run index for stable ordering
1009-
group_suffix = Path(f"{_run_idx}_{group.label}")
1010-
else:
1011-
group_suffix = Path(group.label)
1012-
10131001
group_results = PhaseGroupResults(label=group.label)
10141002
group_meta = PhaseGroupExecutionMeta(
10151003
label=group.label,
@@ -1019,9 +1007,12 @@ def _execute_group_with_meta(
10191007
for element in group.all:
10201008
if isinstance(element, ExperimentPhase):
10211009
# Run phase with meta tracking
1010+
element_dir = (
1011+
group_dir / element.label if group_dir is not None else None
1012+
)
10221013
phase_res, phase_meta = self._execute_phase_with_meta(
10231014
phase=element,
1024-
_path_suffix=group_suffix,
1015+
phase_dir=element_dir,
10251016
**kwargs,
10261017
)
10271018

@@ -1032,9 +1023,10 @@ def _execute_group_with_meta(
10321023

10331024
elif isinstance(element, PhaseGroup):
10341025
# Run group with meta tracking
1026+
sub_dir = group_dir / element.label if group_dir is not None else None
10351027
sub_res, sub_meta = self._execute_group_with_meta(
10361028
group=element,
1037-
_path_suffix=group_suffix,
1029+
group_dir=sub_dir,
10381030
**kwargs,
10391031
)
10401032

@@ -1137,9 +1129,12 @@ def run_phase(
11371129

11381130
# Run phase and record phase-level meta data
11391131
try:
1132+
run_dir = self._results_config.phase_dir(
1133+
f"{len(self._history)}_{phase.label}",
1134+
)
11401135
res, meta = self._execute_phase_with_meta(
11411136
phase=phase,
1142-
_run_idx=len(self._history),
1137+
phase_dir=run_dir,
11431138
**kwargs,
11441139
)
11451140
except Exception:
@@ -1196,9 +1191,12 @@ def run_group(
11961191

11971192
# Run group and record phase-level meta data
11981193
try:
1194+
run_dir = self._results_config.phase_dir(
1195+
f"{len(self._history)}_{group.label}",
1196+
)
11991197
res, meta = self._execute_group_with_meta(
12001198
group=group,
1201-
_run_idx=len(self._history),
1199+
group_dir=run_dir,
12021200
**kwargs,
12031201
)
12041202
except Exception:
@@ -1344,10 +1342,11 @@ def preview_phase(
13441342
needs_restore = _phase_mutates_state(phase)
13451343
state = self.get_state() if needs_restore else None
13461344

1347-
# Execute phase with checkpointing disabled
1345+
# Execute phase with checkpointing disabled (no disk storage for previews)
13481346
with self.disable_checkpointing():
13491347
res, _ = self._execute_phase_with_meta(
13501348
phase=phase,
1349+
phase_dir=None,
13511350
**kwargs,
13521351
)
13531352

@@ -1385,10 +1384,11 @@ def preview_group(
13851384
needs_restore = _phase_mutates_state(group)
13861385
state = self.get_state() if needs_restore else None
13871386

1388-
# Execute group with checkpointing disabled
1387+
# Execute group with checkpointing disabled (no disk storage for previews)
13891388
with self.disable_checkpointing():
13901389
res, _ = self._execute_group_with_meta(
13911390
group=group,
1391+
group_dir=None,
13921392
**kwargs,
13931393
)
13941394

modularml/core/experiment/results/artifact_store.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
import re
5+
import pickle
66
from dataclasses import dataclass
77
from pathlib import Path
88
from typing import Any, ClassVar
@@ -40,10 +40,15 @@ def __init__(
4040
def artifact(self) -> Any:
4141
"""The artifact object. Transparently loads from disk if serialized as a Path."""
4242
if isinstance(self._artifact, Path):
43-
import pickle
44-
4543
with self._artifact.open("rb") as f:
46-
return pickle.load(f)
44+
payload = pickle.load(f)
45+
46+
# New format: dict payload containing the artifact under "artifact" key
47+
if isinstance(payload, dict) and "artifact" in payload:
48+
return payload["artifact"]
49+
50+
# Legacy format: file contained just the artifact object
51+
return payload
4752
return self._artifact
4853

4954
def __repr__(self) -> str:
@@ -85,6 +90,7 @@ class ArtifactStore:
8590
def __init__(self, location: Path | None = None) -> None:
8691
self._location = location
8792
self._entries: dict[str, list[ArtifactEntry]] = {}
93+
self._count: int = 0
8894

8995
# ================================================
9096
# Writing
@@ -109,17 +115,18 @@ def log(
109115
The batch index. Defaults to None (epoch-level artifact).
110116
111117
"""
112-
stored = artifact
113118
if self._location is not None:
114-
stored = self._save_to_disk(
119+
artifact_value = self._save_to_disk(
115120
name=name,
116121
artifact=artifact,
117122
epoch_idx=epoch_idx,
118123
batch_idx=batch_idx,
119124
)
125+
else:
126+
artifact_value = artifact
120127
entry = ArtifactEntry(
121128
name=name,
122-
artifact=stored,
129+
artifact=artifact_value,
123130
epoch_idx=epoch_idx,
124131
batch_idx=batch_idx,
125132
)
@@ -135,15 +142,20 @@ def _save_to_disk(
135142
epoch_idx: int,
136143
batch_idx: int | None,
137144
) -> Path:
138-
"""Serialize ``artifact`` to disk and return the file path."""
145+
"""Serialize artifact and metadata to disk, returning the file path."""
139146
import pickle
140147

141-
batch_str = f"_b{batch_idx}" if batch_idx is not None else ""
142-
filename = f"{name}_e{epoch_idx}{batch_str}.pkl"
143-
filepath = Path(self._location) / filename # type: ignore[arg-type]
148+
filepath = Path(self._location) / f"{self._count}.pkl" # type: ignore[arg-type]
144149
filepath.parent.mkdir(parents=True, exist_ok=True)
150+
payload = {
151+
"name": name,
152+
"epoch_idx": epoch_idx,
153+
"batch_idx": batch_idx,
154+
"artifact": artifact,
155+
}
145156
with filepath.open("wb") as f:
146-
pickle.dump(artifact, f)
157+
pickle.dump(payload, f)
158+
self._count += 1
147159
return filepath
148160

149161
# ================================================
@@ -202,35 +214,31 @@ def from_directory(cls, location: Path) -> ArtifactStore:
202214
"""
203215
Reconstruct a store index from existing pickle files in ``location``.
204216
205-
Globs for ``*.pkl`` files and parses ``{name}_e{epoch}`` or
206-
``{name}_e{epoch}_b{batch}`` from filenames.
207-
208217
Args:
209218
location (Path): Directory containing serialized artifact files.
210219
211220
Returns:
212221
ArtifactStore: Store with ``_entries`` populated as ``Path`` references.
213222
214223
"""
224+
import pickle
225+
215226
store = cls(location=location)
216-
# Matches: {name}_e{epoch}.pkl or {name}_e{epoch}_b{batch}.pkl
217-
pattern = re.compile(r"^(.+)_e(\d+)(?:_b(\d+))?\.pkl$")
218-
for filepath in sorted(Path(location).glob("*.pkl")):
219-
m = pattern.match(filepath.name)
220-
if m is None:
221-
continue
222-
name = m.group(1)
223-
epoch_idx = int(m.group(2))
224-
batch_idx = int(m.group(3)) if m.group(3) is not None else None
227+
pkl_files = sorted(Path(location).glob("*.pkl"), key=lambda p: int(p.stem))
228+
for filepath in pkl_files:
229+
with filepath.open("rb") as f:
230+
payload = pickle.load(f)
231+
name = payload["name"]
225232
entry = ArtifactEntry(
226233
name=name,
227-
artifact=filepath,
228-
epoch_idx=epoch_idx,
229-
batch_idx=batch_idx,
234+
artifact=filepath, # lazy-load
235+
epoch_idx=payload["epoch_idx"],
236+
batch_idx=payload["batch_idx"],
230237
)
231238
if name not in store._entries:
232239
store._entries[name] = []
233240
store._entries[name].append(entry)
241+
store._count = len(pkl_files)
234242
return store
235243

236244
# ================================================

0 commit comments

Comments
 (0)