diff --git a/src/gdpx/potential/managers/reann/reann.py b/src/gdpx/potential/managers/reann/reann.py index 5ad70ffa..9391a68b 100644 --- a/src/gdpx/potential/managers/reann/reann.py +++ b/src/gdpx/potential/managers/reann/reann.py @@ -288,6 +288,11 @@ def _resolve_freeze_command(self, *args, **kwargs): return self.freeze_command + def get_checkpoint(self): + """""" + + return pathlib.Path(self.directory/self.ckpt_name).resolve() + def _train_from_the_restart(self, dataset, init_model) -> str: """Train from the restart""" def _train_from_the_scratch(dataset, init_model) -> str: diff --git a/src/gdpx/trainer/interface.py b/src/gdpx/trainer/interface.py index d7e8c259..fc03ed2c 100644 --- a/src/gdpx/trainer/interface.py +++ b/src/gdpx/trainer/interface.py @@ -93,7 +93,7 @@ def forward( if self._active: curr_iter = int(self.directory.parent.name.split(".")[-1]) if curr_iter > 0: - self._print(" >>> Update init_models...") + self._print(">>> Update init_models...") prev_wdir = ( self.directory.parent.parent / f"iter.{str(curr_iter-1).zfill(4)}" @@ -104,10 +104,16 @@ def forward( if p.is_dir() and re.match("m[0-9]+", p.name): prev_mdirs.append(p) # TODO: replace `m` with a constant + init_models = [] prev_mdirs = sorted(prev_mdirs, key=lambda p: int(p.name[1:])) - init_models = [(p / trainer.frozen_name).resolve() for p in prev_mdirs] + for p in prev_mdirs: + trainer.directory = p + if hasattr(trainer, "get_checkpoint"): + init_models.append(trainer.get_checkpoint()) + else: + init_models.append((p / trainer.frozen_name).resolve()) for p in init_models: - self._print(" " * 8 + str(p)) + self._print(f" {str(p)}") assert init_models, "No previous models found." # - @@ -132,7 +138,7 @@ def forward( models = worker.retrieve(include_retrieved=True) self._print("Frozen Models: ") for m in models: - self._print(f"{str(m) =}") + self._print(f" {str(m) =}") potter_params = potter.as_dict() potter_params["params"]["model"] = models potter.register_calculator(potter_params["params"])