diff --git a/train.py b/train.py index ef31e98..c3f55c6 100644 --- a/train.py +++ b/train.py @@ -74,9 +74,9 @@ def estimate_loss(model, data, device): model.train() return out -def load_checkpoint(ckpt_file: str) -> dict: +def load_checkpoint(ckpt_file: str, device: str) -> dict: print(f"---> load checkpoint: {ckpt_file}") - return torch.load(ckpt_file) + return torch.load(ckpt_file, map_location=device) def save_checkpoint(ckpt_file: str, model_args: dict, @@ -124,7 +124,7 @@ def train(session_name:str = None): checkpoint = None if not args.resume_from is None: # Resume from a checkpoint. - checkpoint = load_checkpoint(args.resume_from) + checkpoint = load_checkpoint(args.resume_from, device) # Load model arguments model_args = checkpoint['model_args'] print('model_args') @@ -235,6 +235,10 @@ def train(session_name:str = None): max_iters = args.num_iters print(f"Start from step {iter} up to {max_iters}, evaluate every {eval_interval} steps.") checkpoint = None # free the memory + if iter >= max_iters: + print(f"You specified num_iters={max_iters}, but the checkpoint was already trained with {iter} steps.") + print('No training step will be executed.') + return None while iter < max_iters: # Set learning rate for this iteration @@ -323,6 +327,8 @@ def train(session_name:str = None): # Save the final model checkpoint if not args.no_save_model: print(f"Saving model checkpoint to {args.output}") + # Update max_iter in training_config + training_config['max_iters'] = max_iters save_checkpoint(ckpt_file=args.output, model_args=model_args, training_config=training_config,