-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
253 lines (221 loc) · 8.79 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import datetime
import logging
import pickle
import sys
from typing import Tuple
import numpy as np
import torch
from transformers import DistilBertTokenizer
from data_processing import (
LVL,
RESPGROUP,
SUBTYPE,
TYPE,
batchify_act_seqs,
batchify_static_data,
process,
)
from load_staged_acts import get_dat_data
from model import IndependentCategorical, SAModel
from utils import field_accuracy, field_printer, set_seed
set_seed(0)
load_chkpnt = True
model_name = "SIAG4" # Seq_Ind_Acts_Generation
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
fh = logging.FileHandler(f"{model_name}_training.log")
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trnseq, tstseq, trnstat, tststat = get_dat_data(split_frac=0.8)
sequence_length = (
5 # maximum number of independent category groups that make up a sequence
)
num_act_cats = 4 # number of independent fields in a category group
batch_sz = 12 # minibatch size, sequences of independent cat groups to be processed in parallel
rec_len = len(trnseq) // batch_sz # num records in training set, used for batchifying
emb_dim = 192 # embedding dim for each categorical
embedding_dim_into_tran = (
emb_dim * num_act_cats
) # embedding dim size into transformer layers
tfmr_num_hidden = emb_dim * 4 # number of hidden units in transfomer linear layer
num_attn_heads = 8 # number of transformer attention heads
num_dec_layers = 4 # number of transformer decoder layers (main layers)
bptt = (
sequence_length + 2
) # back prop through time or sequence length, how long the sequence is that we are working with
num_epochs = 250
# tokenize, truncate, pad
(
numer_trn_act_seqs,
numer_tst_act_seqs,
numer_trn_static_data,
numer_tst_static_data,
) = process(trnseq, trnstat, tstseq, tststat, sequence_length + 1)
# check to ensure data is good
assert (
len(trnseq[~trnseq.index.isin(numer_trn_act_seqs.index)]) == 0
), "Records are being dropped during processing"
assert trnseq.index.identical(
numer_trn_act_seqs.index
), "Mismatch between seq tokenized data and seq numericalized data."
assert numer_trn_static_data.index.identical(
numer_trn_act_seqs.index
), "Mismatch between static data and seq data."
seq_data_trn = batchify_act_seqs(numer_trn_act_seqs, batch_sz).contiguous().to(device)
seq_data_tst = batchify_act_seqs(numer_tst_act_seqs, batch_sz).contiguous().to(device)
static_data_trn = batchify_static_data(
numer_trn_static_data[: seq_data_trn.shape[0] * batch_sz], batch_sz
)
static_data_tst = batchify_static_data(
numer_tst_static_data[: seq_data_tst.shape[0] * batch_sz], batch_sz
)
# dims (mini_batch(batch_sz) x bptt x act_cats)
def gen_inp_data_set(seq_data: torch.Tensor, static_data: np.array):
"""generator that advances through the 'group-of-sequences' dimension,
one group at a time, generating sequence input and target sets and static data"""
for i in range(len(seq_data)):
inp = seq_data[i, 0:-1]
target = seq_data[i, 1:]
yield inp, target, static_data[i]
def validate(eval_model, seq_data, static_data) -> Tuple[float, float]:
eval_model.eval()
with torch.no_grad():
val_loss = 0.0
data_gen = gen_inp_data_set(seq_data, static_data)
# data, tgt, static_data = next(data_gen)
val_acc_l = []
for data, tgt, static_data in data_gen:
static_data = static_tokenizer(
static_data.tolist(), padding=True, truncation=True, return_tensors="pt"
).to(eval_model.device)
preds = eval_model(data.to(eval_model.device), static_data)
batch_loss = eval_model.loss(preds, tgt.to(eval_model.device))
val_loss += batch_loss
val_acc = [
field_accuracy(field, preds[idx], tgt[..., idx], 3)
for idx, field in enumerate(fields)
]
val_acc_l.append(val_acc)
logger.debug(f"Val Acc: {val_acc}")
avg_val_acc = [sum(i) / len(i) for i in zip(*val_acc_l)]
logger.info(f"Mean Val Acc: {avg_val_acc}")
logger.info(
[
field_printer(field, preds[idx], tgt[..., idx])
for idx, field in enumerate(fields)
]
)
return (val_loss.item() / len(seq_data), sum(avg_val_acc) / len(avg_val_acc))
fields = [TYPE, SUBTYPE, LVL, RESPGROUP]
total_trn_samples = len(trnseq)
type_ = IndependentCategorical.from_torchtext_field("type_", TYPE, total_trn_samples)
subtype = IndependentCategorical.from_torchtext_field(
"subtype", SUBTYPE, total_trn_samples
)
lvl = IndependentCategorical.from_torchtext_field("lvl", LVL, total_trn_samples)
respgroup = IndependentCategorical.from_torchtext_field(
"respgroup", RESPGROUP, total_trn_samples
)
model = SAModel(
sequence_length=sequence_length,
categorical_embedding_dim=emb_dim,
num_attn_heads=num_attn_heads,
num_hidden=tfmr_num_hidden,
num_transformer_layers=num_dec_layers,
learning_rate=1e-5,
independent_categoricals=[type_, subtype, lvl, respgroup],
freeze_static_model_weights=False,
warmup_steps=(rec_len) * 1.5, # about 1 epoch
total_steps=num_epochs * (rec_len),
device=device,
)
static_tokenizer = DistilBertTokenizer.from_pretrained(
"distilbert-base-uncased",
return_tensors="pt",
vocab_file="./distilbert_weights/vocab.txt",
)
if load_chkpnt: # continue training
model_path = (
"./saved_models/chkpnt-SIAG4-EP65-TRNLOSS0dot096-2020-08-29_14-01-27.ptm"
)
logger.info(f"Loading model from {model_path}")
model.to(device) # move model before loading optimizer
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
else:
epoch = 0
pickle.dump(
[TYPE, SUBTYPE, LVL, RESPGROUP],
open(f"./saved_models/{model_name}_fields.pkl", "wb"),
)
model.to(device)
log_interval = 200
train_loss_record = []
val_acc_record = []
for i in range(epoch, num_epochs):
model.train()
epoch_loss = 0.0
counter = 0
loss_tracker = []
data_gen = gen_inp_data_set(seq_data_trn, static_data_trn)
for data, tgt, static_data_txt in data_gen:
static_data = static_tokenizer(
static_data_txt.tolist(), padding=True, truncation=True, return_tensors="pt"
).to(device)
batch_loss = model.learn(data.to(device), static_data, tgt.to(device))
epoch_loss += batch_loss
counter += 1
if counter % log_interval == 0:
loss_tracker.append(epoch_loss / log_interval)
logger.info(f"Epoch: {i}")
logger.info(f"Record: {counter}/{rec_len}")
logger.info(f"LR: {model.scheduler.get_last_lr()[0]}")
logger.info(f"Loss: {(epoch_loss / log_interval):.3f}")
epoch_loss = 0.0
epoch_avg_loss = sum(loss_tracker) / len(loss_tracker)
train_loss_record.append(epoch_avg_loss)
val_loss, val_acc = validate(model, seq_data_tst, static_data_tst)
val_acc_record.append(val_acc)
logger.info(f"Validation Loss: {val_loss:.3f}")
logger.info(f"Valdation Accuracy: {val_acc:.3f}")
# save checkpoint
if val_acc > max([0] if not val_acc_record[:-1] else val_acc_record[:-1]):
checkpoint_path = f"./saved_models/chkpnt-{model_name}-EP{i}-TRNLOSS{str(epoch_avg_loss)[:5].replace('.','dot')}-{datetime.datetime.today().strftime('%Y-%m-%d %H-%M-%S')}.ptm"
checkpoint_path = checkpoint_path[:260].replace(" ", "_")
logger.info(f"Saving Checkpoint {checkpoint_path}")
torch.save(
{
"epoch": i,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": model.optimizer.state_dict(),
"scheduler_state_dict": model.scheduler.state_dict(),
},
checkpoint_path,
)
torch.save(model.state_dict(), f"./saved_models/{model_name}.ptm")
def load_model(device: torch.device):
model = SAModel(
sequence_length=sequence_length,
categorical_embedding_dim=emb_dim,
num_attn_heads=num_attn_heads,
num_hidden=tfmr_num_hidden,
num_transformer_layers=num_dec_layers,
learning_rate=1e-3,
independent_categoricals=[type_, subtype, lvl, respgroup],
freeze_static_model_weights=True,
device=device,
)
chkpnt = torch.load(
"./saved_models/chkpnt-SIAG4-EP93-TRNLOSS7dot716-2020-08-30_10-36-06.ptm",
map_location=device,
)
model.load_state_dict(chkpnt["model_state_dict"])
model.eval()
model.to(device)
return model