-
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_tde.py
130 lines (101 loc) · 5.02 KB
/
train_tde.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
######################################
# Kaihua Tang
######################################
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
import utils.general_utils as utils
from data.dataloader import get_loader
from utils.checkpoint_utils import Checkpoint
from utils.training_utils import *
from utils.test_loader import test_loader
class train_tde():
def __init__(self, args, config, logger, eval=False):
# ============================================================================
# create model
logger.info('=====> Model construction from: ' + str(config['networks']['type']))
model_type = config['networks']['type']
model_file = config['networks'][model_type]['def_file']
model_args = config['networks'][model_type]['params']
logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
classifier_type = config['classifiers']['type']
classifier_file = config['classifiers'][classifier_type]['def_file']
classifier_args = config['classifiers'][classifier_type]['params']
model = utils.source_import(model_file).create_model(**model_args)
classifier = utils.source_import(classifier_file).create_model(**classifier_args)
model = nn.DataParallel(model).cuda()
classifier = nn.DataParallel(classifier).cuda()
# other initialization
self.config = config
self.logger = logger
self.model = model
self.classifier = classifier
self.optimizer = create_optimizer(model, classifier, logger, config)
self.scheduler = create_scheduler(self.optimizer, logger, config)
self.eval = eval
self.training_opt = config['training_opt']
self.checkpoint = Checkpoint(config)
# get dataloader
self.logger.info('=====> Get train dataloader')
self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
# get loss
self.loss_fc = create_loss(logger, config, self.train_loader)
# set eval
if self.eval:
test_func = test_loader(config)
self.testing = test_func(config, logger, model, classifier, val=True)
def update_embed(self, embed, input):
# embed is only updated during training
assert len(input.shape) == 2
with torch.no_grad():
embed = embed * 0.995 + (1 - 0.995) * input.mean(0).view(-1)
return embed
def run(self):
# Start Training
self.logger.info('=====> Start TDE Training')
# init embed
self.embed = torch.zeros(2048).cuda()
# run epoch
for epoch in range(self.training_opt['num_epochs']):
self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
# preprocess for each epoch
total_batch = len(self.train_loader)
for step, (inputs, labels, _, _) in enumerate(self.train_loader):
self.optimizer.zero_grad()
# additional inputs
inputs, labels = inputs.cuda(), labels.cuda()
add_inputs = {}
features = self.model(inputs)
predictions = self.classifier(features, add_inputs)
# update embed during training
self.embed = self.update_embed(self.embed, features)
# calculate loss
loss = self.loss_fc(predictions, labels)
iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
loss.backward()
self.optimizer.step()
# calculate accuracy
accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
# log information
iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
first_batch = (epoch == 0) and (step == 0)
if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
utils.print_grad(self.classifier.named_parameters())
utils.print_grad(self.model.named_parameters())
# evaluation on validation set
if self.eval:
val_acc = self.testing.run_val(epoch, self.embed)
else:
val_acc = 0.0
# checkpoint
add_dict = {}
add_dict['embed'] = self.embed.cpu()
self.logger.info('Embed Mean: {}'.format(self.embed.mean().item()))
self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc, add_dict=add_dict)
# update scheduler
self.scheduler.step()
# save best model path
self.checkpoint.save_best_model_path(self.logger)