Skip to content

Commit

Permalink
feat: Enhance best checkpoint management with intelligent deletion
Browse files Browse the repository at this point in the history
- Implement intelligent deletion of previous best checkpoints
- Add special handling for initial epoch 0 checkpoint
- Use glob to find and remove old best checkpoints
- Preserve initial checkpoint during checkpoint cleanup
- Print checkpoint deletion for transparency
  • Loading branch information
beduffy committed Feb 15, 2025
1 parent 3d46059 commit a9178c8
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion imitate_mouse/imitate_mouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from datetime import datetime
from collections import deque
import glob

import numpy as np
import torch
Expand Down Expand Up @@ -316,9 +317,35 @@ def train_mouse_policy(args_dict, device='cuda'):

# Save best checkpoint
if avg_loss < best_loss:
previous_best = best_loss # Store to compare after update

# Delete previous best checkpoint (except initial epoch 0)
if best_loss != float('inf'): # Skip deletion for first checkpoint
previous_ckpt_pattern = os.path.join(os.path.dirname(__file__), 'checkpoints', f'mouse_act_policy_best_epoch*')
previous_ckpts = glob.glob(previous_ckpt_pattern)

# Preserve initial epoch 0 file explicitly
if epoch == 0:
initial_ckpt = os.path.join(os.path.dirname(__file__), 'checkpoints', 'mouse_act_policy_initial_epoch0.ckpt')
if initial_ckpt in previous_ckpts:
previous_ckpts.remove(initial_ckpt)

# Delete all other best checkpoints
for ckpt in previous_ckpts:
if os.path.exists(ckpt):
os.remove(ckpt)
print(f"Deleted previous checkpoint: {ckpt}")

best_loss = avg_loss
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_path = os.path.join(os.path.dirname(__file__), 'checkpoints', f'mouse_act_policy_best_epoch{epoch}_{timestamp}.ckpt')

# Special name for initial checkpoint
if epoch == 0:
checkpoint_name = f'mouse_act_policy_initial_epoch0.ckpt'
else:
checkpoint_name = f'mouse_act_policy_best_epoch{epoch}_{timestamp}.ckpt'

checkpoint_path = os.path.join(os.path.dirname(__file__), 'checkpoints', checkpoint_name)
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
print(f"Saving best checkpoint to: {checkpoint_path}")
torch.save(policy.state_dict(), checkpoint_path)
Expand Down

0 comments on commit a9178c8

Please sign in to comment.