diff --git a/train.py b/train.py index 6994fb9b..8bb09044 100644 --- a/train.py +++ b/train.py @@ -606,11 +606,19 @@ def get_weight_decay(progress): total_tokens = step * TOTAL_BATCH_SIZE +# Save pre-eval checkpoint so training isn't lost if evaluation crashes (e.g. OOM) +checkpoint_path = "pre_eval_checkpoint.pt" +torch.save(model.state_dict(), checkpoint_path) + # Final eval model.eval() with autocast_ctx: val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE) +# Eval succeeded — remove the safety checkpoint +if os.path.exists(checkpoint_path): + os.remove(checkpoint_path) + # Final summary t_end = time.time() startup_time = t_start_training - t_start