@@ -569,6 +569,7 @@ def _execute_training(
569569 phase : TrainPhase ,
570570 * ,
571571 _artifact_dir : Path | None = None ,
572+ _callback_dir : Path | None = None ,
572573 _execution_dir : Path | None = None ,
573574 _metric_dir : Path | None = None ,
574575 show_sampler_progress : bool = True ,
@@ -620,6 +621,7 @@ def _execute_training(
620621 res = TrainResults (
621622 label = phase .label ,
622623 _artifact_dir = _artifact_dir ,
624+ _callback_dir = _callback_dir ,
623625 _execution_dir = _execution_dir ,
624626 _metric_dir = _metric_dir ,
625627 )
@@ -684,6 +686,7 @@ def _execute_evaluation(
684686 phase : EvalPhase ,
685687 * ,
686688 _artifact_dir : Path | None = None ,
689+ _callback_dir : Path | None = None ,
687690 _execution_dir : Path | None = None ,
688691 _metric_dir : Path | None = None ,
689692 show_eval_progress : bool = False ,
@@ -719,6 +722,7 @@ def _execute_evaluation(
719722 res = EvalResults (
720723 label = phase .label ,
721724 _artifact_dir = _artifact_dir ,
725+ _callback_dir = _callback_dir ,
722726 _execution_dir = _execution_dir ,
723727 _metric_dir = _metric_dir ,
724728 )
@@ -742,6 +746,7 @@ def _execute_fit(
742746 phase : FitPhase ,
743747 * ,
744748 _artifact_dir : Path | None = None ,
749+ _callback_dir : Path | None = None ,
745750 _execution_dir : Path | None = None ,
746751 _metric_dir : Path | None = None ,
747752 ) -> FitResults :
@@ -765,6 +770,7 @@ def _execute_fit(
765770 res = FitResults (
766771 label = phase .label ,
767772 _artifact_dir = _artifact_dir ,
773+ _callback_dir = _callback_dir ,
768774 _execution_dir = _execution_dir ,
769775 _metric_dir = _metric_dir ,
770776 )
@@ -784,8 +790,7 @@ def _execute_phase_with_meta(
784790 self ,
785791 phase : TrainPhase | EvalPhase | FitPhase ,
786792 * ,
787- _path_suffix : Path | None = None ,
788- _run_idx : int | None = None ,
793+ phase_dir : Path | None = None ,
789794 ** kwargs ,
790795 ) -> tuple [PhaseResults , PhaseExecutionMeta ]:
791796 """
@@ -808,9 +813,9 @@ def _execute_phase_with_meta(
808813 if exp_dir is None and self ._exp_checkpointing is not None :
809814 exp_dir = self ._exp_checkpointing .directory
810815 if exp_dir is not None :
811- phase_dir = exp_dir / phase .label
812- phase_dir .mkdir (parents = True , exist_ok = True )
813- phase .checkpointing ._directory = phase_dir
816+ ckpt_dir = exp_dir / phase .label
817+ ckpt_dir .mkdir (parents = True , exist_ok = True )
818+ phase .checkpointing ._directory = ckpt_dir
814819
815820 # Skip callbacks and checkpointing when inside a callback
816821 run_hooks = not self ._in_callback
@@ -837,23 +842,8 @@ def _execute_phase_with_meta(
837842 self ._save_experiment_checkpoint (label = phase .label )
838843
839844 # ------------------------------------------------
840- # Compute phase-specific storage directories
845+ # Compute phase-specific storage directories from the caller-supplied phase_dir
841846 # ------------------------------------------------
842- if _path_suffix is not None :
843- # Called from within a group; suffix already contains the run prefix
844- phase_dir = self ._results_config .phase_dir (_path_suffix / phase .label )
845- elif _run_idx is not None :
846- # Top-level call from run_phase; prefix with run index for stable ordering
847- phase_dir = self ._results_config .phase_dir (f"{ _run_idx } _{ phase .label } " )
848- elif (
849- self ._active_phase_dir is not None
850- and self ._results_config .results_dir is not None
851- ):
852- # Called from preview_phase during callback execution → nest under callbacks/
853- phase_dir = self ._active_phase_dir / "callbacks" / phase .label
854- else :
855- phase_dir = None # pure preview with no disk storage
856-
857847 cfg = self ._results_config
858848 phase_execution_dir = (
859849 phase_dir / "execution_data"
@@ -870,6 +860,11 @@ def _execute_phase_with_meta(
870860 if phase_dir is not None and cfg .save_artifacts
871861 else None
872862 )
863+ phase_callback_dir = (
864+ phase_dir / "callbacks"
865+ if phase_dir is not None and cfg .save_execution
866+ else None
867+ )
873868
874869 # Track active phase dir so nested callback previews nest under callbacks/
875870 prev_active_phase_dir = self ._active_phase_dir
@@ -892,6 +887,7 @@ def _execute_phase_with_meta(
892887 phase_res : TrainResults = self ._execute_training (
893888 phase ,
894889 _artifact_dir = phase_artifact_dir ,
890+ _callback_dir = phase_callback_dir ,
895891 _execution_dir = phase_execution_dir ,
896892 _metric_dir = phase_metric_dir ,
897893 ** {k : v for k , v in kwargs .items () if k in train_keys },
@@ -901,6 +897,7 @@ def _execute_phase_with_meta(
901897 phase_res : EvalResults = self ._execute_evaluation (
902898 phase ,
903899 _artifact_dir = phase_artifact_dir ,
900+ _callback_dir = phase_callback_dir ,
904901 _execution_dir = phase_execution_dir ,
905902 _metric_dir = phase_metric_dir ,
906903 ** {k : v for k , v in kwargs .items () if k in eval_keys },
@@ -909,6 +906,7 @@ def _execute_phase_with_meta(
909906 phase_res : FitResults = self ._execute_fit (
910907 phase ,
911908 _artifact_dir = phase_artifact_dir ,
909+ _callback_dir = phase_callback_dir ,
912910 _execution_dir = phase_execution_dir ,
913911 _metric_dir = phase_metric_dir ,
914912 )
@@ -957,8 +955,7 @@ def _execute_group_with_meta(
957955 self ,
958956 group : PhaseGroup ,
959957 * ,
960- _path_suffix : Path | None = None ,
961- _run_idx : int | None = None ,
958+ group_dir : Path | None = None ,
962959 ** kwargs ,
963960 ) -> tuple [PhaseGroupResults , PhaseGroupExecutionMeta ]:
964961 """
@@ -1001,15 +998,6 @@ def _execute_group_with_meta(
1001998 # - construct result container
1002999 # - run each phase in order
10031000 # ------------------------------------------------
1004- if _path_suffix is not None :
1005- # Nested group; suffix already contains the run prefix
1006- group_suffix = _path_suffix / group .label
1007- elif _run_idx is not None :
1008- # Top-level call from run_group; prefix with run index for stable ordering
1009- group_suffix = Path (f"{ _run_idx } _{ group .label } " )
1010- else :
1011- group_suffix = Path (group .label )
1012-
10131001 group_results = PhaseGroupResults (label = group .label )
10141002 group_meta = PhaseGroupExecutionMeta (
10151003 label = group .label ,
@@ -1019,9 +1007,12 @@ def _execute_group_with_meta(
10191007 for element in group .all :
10201008 if isinstance (element , ExperimentPhase ):
10211009 # Run phase with meta tracking
1010+ element_dir = (
1011+ group_dir / element .label if group_dir is not None else None
1012+ )
10221013 phase_res , phase_meta = self ._execute_phase_with_meta (
10231014 phase = element ,
1024- _path_suffix = group_suffix ,
1015+ phase_dir = element_dir ,
10251016 ** kwargs ,
10261017 )
10271018
@@ -1032,9 +1023,10 @@ def _execute_group_with_meta(
10321023
10331024 elif isinstance (element , PhaseGroup ):
10341025 # Run group with meta tracking
1026+ sub_dir = group_dir / element .label if group_dir is not None else None
10351027 sub_res , sub_meta = self ._execute_group_with_meta (
10361028 group = element ,
1037- _path_suffix = group_suffix ,
1029+ group_dir = sub_dir ,
10381030 ** kwargs ,
10391031 )
10401032
@@ -1137,9 +1129,12 @@ def run_phase(
11371129
11381130 # Run phase and record phase-level meta data
11391131 try :
1132+ run_dir = self ._results_config .phase_dir (
1133+ f"{ len (self ._history )} _{ phase .label } " ,
1134+ )
11401135 res , meta = self ._execute_phase_with_meta (
11411136 phase = phase ,
1142- _run_idx = len ( self . _history ) ,
1137+ phase_dir = run_dir ,
11431138 ** kwargs ,
11441139 )
11451140 except Exception :
@@ -1196,9 +1191,12 @@ def run_group(
11961191
11971192 # Run group and record phase-level meta data
11981193 try :
1194+ run_dir = self ._results_config .phase_dir (
1195+ f"{ len (self ._history )} _{ group .label } " ,
1196+ )
11991197 res , meta = self ._execute_group_with_meta (
12001198 group = group ,
1201- _run_idx = len ( self . _history ) ,
1199+ group_dir = run_dir ,
12021200 ** kwargs ,
12031201 )
12041202 except Exception :
@@ -1344,10 +1342,11 @@ def preview_phase(
13441342 needs_restore = _phase_mutates_state (phase )
13451343 state = self .get_state () if needs_restore else None
13461344
1347- # Execute phase with checkpointing disabled
1345+ # Execute phase with checkpointing disabled (no disk storage for previews)
13481346 with self .disable_checkpointing ():
13491347 res , _ = self ._execute_phase_with_meta (
13501348 phase = phase ,
1349+ phase_dir = None ,
13511350 ** kwargs ,
13521351 )
13531352
@@ -1385,10 +1384,11 @@ def preview_group(
13851384 needs_restore = _phase_mutates_state (group )
13861385 state = self .get_state () if needs_restore else None
13871386
1388- # Execute group with checkpointing disabled
1387+ # Execute group with checkpointing disabled (no disk storage for previews)
13891388 with self .disable_checkpointing ():
13901389 res , _ = self ._execute_group_with_meta (
13911390 group = group ,
1391+ group_dir = None ,
13921392 ** kwargs ,
13931393 )
13941394
0 commit comments