-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresume_train_fixed_mask.py
64 lines (50 loc) · 2.54 KB
/
resume_train_fixed_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# @Author: Khush Patel, Zhi lab, [email protected]
# local imports
import model as m
from config import *
from dataset_MLM import *
from engine_amp_MLM import *
if __name__ == "__main__":
writer = SummaryWriter(tbpath)
with open(raw_data_path, "rb") as f:
rawdata = pickle.load(f)
mbdataset = MBDataset(rawdata, seed_value_changer=run_no_mask)
loader = torch.utils.data.DataLoader(
mbdataset, batch_size=batch_size, shuffle=True, num_workers= 12, pin_memory=True)
model = m.model
model = model.to(device)
optim = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(
optim, num_warmup_steps=0, num_training_steps=num_train_steps
)
checkpoint = torch.load(save_dir, pickle_module=dill)
model.load_state_dict(checkpoint['model_state_dict'])
optim.load_state_dict(checkpoint['optimizer_state_dict'])
last_epoch = checkpoint['epoch']
loss = checkpoint['loss']
scheduler = checkpoint['scheduler']
steps_completed = checkpoint['number_of_training_steps']
remainingepochs = num_of_epochs - last_epoch
print("Remaining epochs are", remainingepochs)
best_loss = np.inf
mbdataset = MBDataset(rawdata, seed_value_changer=run_no_mask)
loader = torch.utils.data.DataLoader(
mbdataset, batch_size=batch_size, shuffle=True, num_workers= 12, pin_memory=True)
#Resuming the last interrupted epoch
train_loss = train_fn(loader, model, optimizer=optim, device=device, scheduler=scheduler, epoch=last_epoch, writer=writer, seed_value_changer=run_no_mask, steps_completed=steps_completed)
print(f"Training loss at epoch {last_epoch} is {train_loss}")
if train_loss<best_loss:
best_loss = train_loss
torch.save(model.state_dict(), save_wts_loc)
print(f"Lowest training loss found at epoch {last_epoch}. Saving the model weights")
#Continuing from the next epoch
for epoch in range(last_epoch+1, num_of_epochs):
train_loss = train_fn(loader, model, optimizer=optim, device=device, scheduler=scheduler, epoch=last_epoch, writer=writer, seed_value_changer=run_no_mask)
print(f"Training loss at epoch {epoch} is {train_loss}")
if train_loss<best_loss:
best_loss = train_loss
torch.save(model.state_dict(), save_wts_loc)
print(f"Lowest training loss found at epoch {epoch}. Saving the model weights")
writer.flush()
writer.close()
sys.exit()