-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_fixed_mask_nsp.py
35 lines (30 loc) · 1.29 KB
/
train_fixed_mask_nsp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# @Author: Khush Patel, Zhi lab
# Local imports
import model as m
from config import *
from dataset_MLM_NSP import *
from engine_amp_MLM_NSP import *
if __name__ == "__main__":
writer = SummaryWriter(tbpath)
with open(raw_data_path, "rb") as f:
rawdata = pickle.load(f)
mbdataset = MBDataset(rawdata, seed_value_changer=run_no_mask)
loader = torch.utils.data.DataLoader(
mbdataset, batch_size=batch_size, shuffle=True, num_workers= 12, pin_memory=True)
model = m.model
model = model.to(device)
optim = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(
optim, num_warmup_steps=0, num_training_steps=num_train_steps
)
best_loss = np.inf
for epoch in range(num_of_epochs):
train_loss = train_fn(loader, model, optimizer=optim, device=device, scheduler=scheduler, epoch=epoch, writer=writer, seed_value_changer=run_no_mask)
print(f"Training loss at epoch {epoch} is {train_loss}")
if train_loss<best_loss:
best_loss = train_loss
torch.save(model.state_dict(), save_wts_loc)
print(f"Lowest training loss found at epoch {epoch}. Saving the model weights")
writer.flush()
writer.close()
sys.exit()