Skip to content

Commit

Permalink
Allow resuming training from a checkpoint that was trained on a diffe…
Browse files Browse the repository at this point in the history
…rent device (e.g. trained on GPU, resume on CPU).
  • Loading branch information
zhoupingjay committed Jan 30, 2024
1 parent cd45dc4 commit 1dcc02f
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def estimate_loss(model, data, device):
model.train()
return out

def load_checkpoint(ckpt_file: str) -> dict:
def load_checkpoint(ckpt_file: str, device: str) -> dict:
print(f"---> load checkpoint: {ckpt_file}")
return torch.load(ckpt_file)
return torch.load(ckpt_file, map_location=device)

def save_checkpoint(ckpt_file: str,
model_args: dict,
Expand Down Expand Up @@ -124,7 +124,7 @@ def train(session_name:str = None):
checkpoint = None
if not args.resume_from is None:
# Resume from a checkpoint.
checkpoint = load_checkpoint(args.resume_from)
checkpoint = load_checkpoint(args.resume_from, device)
# Load model arguments
model_args = checkpoint['model_args']
print('model_args')
Expand Down Expand Up @@ -235,6 +235,10 @@ def train(session_name:str = None):
max_iters = args.num_iters
print(f"Start from step {iter} up to {max_iters}, evaluate every {eval_interval} steps.")
checkpoint = None # free the memory
if iter >= max_iters:
print(f"You specified num_iters={max_iters}, but the checkpoint was already trained with {iter} steps.")
print('No training step will be executed.')
return None

while iter < max_iters:
# Set learning rate for this iteration
Expand Down Expand Up @@ -323,6 +327,8 @@ def train(session_name:str = None):
# Save the final model checkpoint
if not args.no_save_model:
print(f"Saving model checkpoint to {args.output}")
# Update max_iter in training_config
training_config['max_iters'] = max_iters
save_checkpoint(ckpt_file=args.output,
model_args=model_args,
training_config=training_config,
Expand Down

0 comments on commit 1dcc02f

Please sign in to comment.