-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrainer.py
159 lines (141 loc) · 5.16 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gat import *
from encoder import *
from dep_arcs import DEPARCS
from data import *
from utils import *
class Trainer(object):
""" ABCD Model Definition"""
def __init__(self, cfg, word_dim, hidden_dim, arc_dim, num_heads, lstm_dp, weight_loss, device):
self.enc = BLSTMEncoder(word_dim, hidden_dim, device).to(device)
# self.enc_lstm = nn.LSTM(word_dim, hidden_dim, 1,
# bidirectional=bilstm, dropout=lstm_dp)
self.gat = GAT(2*hidden_dim, hidden_dim, arc_dim, arc_dim, num_heads).to(device)
if cfg["classifer"] == "Bilinear":
self.classifer = BilinearClassifier(2*hidden_dim, arc_dim, 4).to(device) #num_classes
print("Currently using BILINEAR CLASSIFIER")
else:
if "multi_layer" in cfg:
self.classifer = Classifier(2*2*hidden_dim+arc_dim, cfg["multi_layer"]).to(device)
else:
self.classifer = Classifier(2*2*hidden_dim+arc_dim).to(device)
self.cfg = cfg
self.arcs = DEPARCS
self.arcids = {v:k for k,v in DEPARCS.items()}
self.weight_loss = weight_loss
self.optimizer = None
self.batch_loss = torch.tensor(0).float().to(device)
self.device = device
if self.weight_loss:
self.loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(cfg["inverse_label_weights"]).to(device)).to(device)
else:
self.loss_fn = nn.CrossEntropyLoss().to(device)
self.pos_flag = self.cfg.get("position_flag", True)
if self.pos_flag:
self.pos_enc = PositionalEncoder(word_dim, self.device)
def get_parameters(self):
self.enc_params = list(self.enc.parameters())
self.gat_params = list(self.gat.parameters())
self.cls_params = list(self.classifer.parameters())
self.model_params = self.enc_params + self.gat_params + self.cls_params
def create_optimizer(self):
if "optimizer_type" not in self.cfg:
# default Adam
self.optimizer = optim.Adam(self.model_params,
lr=self.cfg['lr'],
weight_decay=self.cfg['weight_decay'])
else:
opt_type = self.cfg["optimizer_type"]
if opt_type == "SGD":
self.optimizer = torch.optim.SGD(
self.model_params, lr=self.cfg['lr'],
momentum=self.cfg.get("momentum", 0.9),
weight_decay=self.cfg.get("weight_decay", 0.0))
elif opt_type == "Adadelta":
self.optimizer = torch.optim.Adadelta(self.model_params, lr=self.cfg['lr'])
elif opt_type == "RMSprop":
self.optimizer = torch.optim.RMSprop(self.model_params, lr=self.cfg['lr'])
else:
raise NotImplementedError(
"Not supported optimizer [{}]".format(opt_type))
def init_onehot(self):
# initialize onehot matrix
label_oh, labels = EncodeOnehot(self.arcs)
self.label_oh = label_oh
def prepare(self):
# get list of parameters and create optimizer for it
print("PREPARING TRIANING OPTIMIZER")
self.get_parameters()
self.create_optimizer()
def train(self):
if self.optimizer is None:
self.prepare()
self.enc.train()
self.gat.train()
self.classifer.train()
def eval(self):
self.enc.eval()
self.gat.eval()
self.classifer.eval()
def save_model(self, path, prefix):
torch.save(self.enc.state_dict(), path+"/"+prefix+ "_enc.pt")
torch.save(self.gat.state_dict(), path+"/"+prefix+"_gat.pt")
torch.save(self.classifer.state_dict(), path+"/"+prefix+"_clsf.pt")
def update(self):
""" Update the network
Args:
loss: loss to train the network; dict()
call outside the trainer, in main function
"""
#self.it = self.it + 1
self.optimizer.zero_grad() # set gradients as zero before update
self.batch_loss.backward(retain_graph=True)
#if self.scheduler is not None: self.scheduler.step()
if self.cfg["gradient_clip"]:
torch.nn.utils.clip_grad_norm_(self.model_params, 2.0)
self.optimizer.step()
#self.optimizer.zero_grad()
self.batch_loss = torch.tensor(0).float().to(self.device)
def extract(self, sents, pairs):
h_srcs, h_tgts, h_arcs= [], [], []
for tup in pairs:
h_srcs.append(sents[tup[0]])
h_tgts.append(sents[tup[1]])
vecs = tup[-1]
h_arcs.append(vecs)
srcs = torch.stack(h_srcs, dim=0)
tgts = torch.stack(h_tgts, dim=0)
arcs = torch.stack(h_arcs, dim=0)
return srcs, tgts, arcs
def main(self, sents, lengths, adj_pairs, golds, mode="Train"):
if self.pos_flag:
sents = self.pos_enc(sents.unsqueeze(0))
sent_hidden = self.enc((sents.squeeze(0).unsqueeze(1), lengths)).float() # length x bsize x h
srcs, tgts, arcs = self.extract(sent_hidden.squeeze(1), adj_pairs)
h_att2, h_src, h_tgt, h_arc = self.gat(srcs, tgts, arcs)
preds = self.classifer(h_att2, h_src, h_tgt, h_arc )
#golds = golds.to(self.device)
if mode == "Train":
golds = golds.to(self.device)
self.batch_loss += self.loss_fn(preds, golds)
return preds
def constructgraph(self, preds, adj_pairs, adjs, itov):
pred_ind = torch.max(preds, dim=1)[1]
a, b, c, d = [],[],[],[]
pred_pairs = []
for p in adj_pairs:
_ind = adj_pairs.index(p)
src, tgt, arc = p[0], p[1], pred_ind[_ind].item()
if arc == 0:
a.append((src, tgt, arc))
elif arc == 1:
b.append((src, tgt, arc))
elif arc == 2:
c.append((src, tgt, arc))
else:
d.append((src, tgt, arc))
output_strs = PredictGraph(a, b, c, d, adjs, itov)
return output_strs