Skip to content

Commit

Permalink
add get_checkpoint as different trainers have different init methods
Browse files Browse the repository at this point in the history
  • Loading branch information
hsulab committed Mar 17, 2024
1 parent 2ddaeba commit 3815188
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/gdpx/potential/managers/reann/reann.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ def _resolve_freeze_command(self, *args, **kwargs):

return self.freeze_command

def get_checkpoint(self):
""""""

return pathlib.Path(self.directory/self.ckpt_name).resolve()

def _train_from_the_restart(self, dataset, init_model) -> str:
"""Train from the restart"""
def _train_from_the_scratch(dataset, init_model) -> str:
Expand Down
14 changes: 10 additions & 4 deletions src/gdpx/trainer/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(
if self._active:
curr_iter = int(self.directory.parent.name.split(".")[-1])
if curr_iter > 0:
self._print(" >>> Update init_models...")
self._print(">>> Update init_models...")
prev_wdir = (
self.directory.parent.parent
/ f"iter.{str(curr_iter-1).zfill(4)}"
Expand All @@ -104,10 +104,16 @@ def forward(
if p.is_dir() and re.match("m[0-9]+", p.name):
prev_mdirs.append(p)
# TODO: replace `m` with a constant
init_models = []
prev_mdirs = sorted(prev_mdirs, key=lambda p: int(p.name[1:]))
init_models = [(p / trainer.frozen_name).resolve() for p in prev_mdirs]
for p in prev_mdirs:
trainer.directory = p
if hasattr(trainer, "get_checkpoint"):
init_models.append(trainer.get_checkpoint())
else:
init_models.append((p / trainer.frozen_name).resolve())
for p in init_models:
self._print(" " * 8 + str(p))
self._print(f" {str(p)}")
assert init_models, "No previous models found."

# -
Expand All @@ -132,7 +138,7 @@ def forward(
models = worker.retrieve(include_retrieved=True)
self._print("Frozen Models: ")
for m in models:
self._print(f"{str(m) =}")
self._print(f" {str(m) =}")
potter_params = potter.as_dict()
potter_params["params"]["model"] = models
potter.register_calculator(potter_params["params"])
Expand Down

0 comments on commit 3815188

Please sign in to comment.