From fd777e2b74428165b90e08589369f1712b388df5 Mon Sep 17 00:00:00 2001 From: Jae Date: Thu, 12 Sep 2024 17:39:37 -0400 Subject: [PATCH] fixes checkpoint loading when file path includes numbers --- pytorch_generative/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_generative/trainer.py b/pytorch_generative/trainer.py index 54d24c3..0e58fa5 100644 --- a/pytorch_generative/trainer.py +++ b/pytorch_generative/trainer.py @@ -113,7 +113,7 @@ def _save_checkpoint(self): def _find_latest_epoch(self): files = glob.glob(self._path("trainer_state_[0-9]*.ckpt")) - epochs = sorted([int(re.findall(r"\d+", f)[0]) for f in files]) + epochs = sorted([int(re.findall(r"\d+", f)[-1]) for f in files]) if not epochs: raise FileNotFoundError(f"No checkpoints found in {self.log_dir}.") print(f"Found {len(epochs)} saved checkpoints.")