-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine_amp_MLM_NSP.py
108 lines (83 loc) · 4.63 KB
/
engine_amp_MLM_NSP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# @Author: Khush Patel, Zhi lab
# Train function using Automatic Mixed Precision
from config import *
from model import *
def train_fn(data_loader, model, optimizer, device, scheduler, epoch, writer, seed_value_changer, steps_completed=0):
model = model.train()
final_loss = 0
final_acc = []
running_loss = 0
loop = tqdm(data_loader, leave=True) #If leave =True, keeps all traces of the progressbar upon termination of iteration
counter = 0
mlm_running_acc = 0
nsp_running_acc = 0
scaler = torch.cuda.amp.GradScaler()
for batch in loop:
optimizer.zero_grad()
# pull all tensor batches required for training
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
segment_type_ids = batch['segment_type_ids'].to(device)
masked_indices = batch['masked_indices'].to(device)
next_sentence_label = batch['next_sentence_label'].to(device)
with torch.cuda.amp.autocast():
outputs = model(input_ids, attention_mask=attention_mask,
labels=labels, token_type_ids=segment_type_ids, next_sentence_label=next_sentence_label)
# extract loss
loss = outputs.loss
# backpropogation
scaler.scale(loss).backward()
# update parameters
scaler.step(optimizer)
# scheduler
scheduler.step()
scaler.update()
running_loss += loss.item()
predictions = outputs.prediction_logits
preds= (predictions.argmax(dim=2)).cpu().numpy()
labels= labels.cpu().numpy()
masked_indices = masked_indices.cpu().numpy()
acc = []
for i in range(predictions.shape[0]):
mask = masked_indices[i]
pred = preds[i][mask]
label = labels[i][mask]
if len(label)!=0:
acc.append(accuracy_score(label, pred))
mlm_running_acc += mean(acc)
m = nn.Softmax(dim=1)
preds_nsp = (torch.argmax(m(outputs.seq_relationship_logits), dim=1)).reshape(-1,1)
NSP_accuracy=(preds_nsp == next_sentence_label).sum().item()/preds_nsp.shape[0]
nsp_running_acc += NSP_accuracy
# saving checkpoint
if (counter % save_every_step == 0) & (counter!=0):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': running_loss / save_every_step,
'number_of_training_steps': counter,
'scheduler': scheduler,
"seed_value_changer": seed_value_changer
}, save_dir, pickle_module=dill)
if (counter % measure_metrics_steps == 0) & (counter!=0):
writer.add_scalar('MLM accuracy', mlm_running_acc/measure_metrics_steps, steps_completed + seed_value_changer * len(data_loader) * num_of_epochs + epoch * len(data_loader) + counter)
writer.add_scalar('NSP accuracy', nsp_running_acc/measure_metrics_steps, steps_completed + seed_value_changer * len(data_loader) * num_of_epochs + epoch * len(data_loader) + counter)
writer.add_scalar('Training loss', running_loss/measure_metrics_steps, steps_completed + seed_value_changer * len(data_loader) * num_of_epochs + epoch * len(data_loader) + counter)
logging.info(
f"The value of mean MLM accuracy for steps {steps_completed + seed_value_changer * len(data_loader) * num_of_epochs + epoch * len(data_loader) + counter} was {mlm_running_acc/measure_metrics_steps}")
logging.info(
f"The value of mean NSP accuracy for steps {steps_completed + seed_value_changer * len(data_loader) * num_of_epochs + epoch * len(data_loader) + counter} was {nsp_running_acc/measure_metrics_steps}")
logging.info(
f"The value of loss for steps {steps_completed + seed_value_changer * len(data_loader) * num_of_epochs + epoch * len(data_loader) + counter} was {running_loss/measure_metrics_steps}")
running_loss = 0
mlm_running_acc = 0
nsp_running_acc = 0
# Updating step counter
counter += 1
# print relevant info to progress bar
loop.set_description(f'Epoch {epoch}')
loop.set_postfix(loss=loss.item(), MLM_accuracy=mean(acc), NSP_accuracy=NSP_accuracy)
final_loss += loss.item()
return final_loss/counter