-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy pathincremental_learning.py
214 lines (194 loc) · 10.1 KB
/
incremental_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import time
import torch
import numpy as np
from argparse import ArgumentParser
from loggers.exp_logger import ExperimentLogger
from datasets.exemplars_dataset import ExemplarsDataset
class Inc_Learning_Appr:
"""Basic class for implementing incremental learning approaches"""
def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
eval_on_train=False, logger: ExperimentLogger = None, exemplars_dataset: ExemplarsDataset = None):
self.model = model
self.device = device
self.nepochs = nepochs
self.lr = lr
self.lr_min = lr_min
self.lr_factor = lr_factor
self.lr_patience = lr_patience
self.clipgrad = clipgrad
self.momentum = momentum
self.wd = wd
self.multi_softmax = multi_softmax
self.logger = logger
self.exemplars_dataset = exemplars_dataset
self.warmup_epochs = wu_nepochs
self.warmup_lr = lr * wu_lr_factor
self.warmup_loss = torch.nn.CrossEntropyLoss()
self.fix_bn = fix_bn
self.eval_on_train = eval_on_train
self.optimizer = None
@staticmethod
def extra_parser(args):
"""Returns a parser containing the approach specific parameters"""
parser = ArgumentParser()
return parser.parse_known_args(args)
@staticmethod
def exemplars_dataset_class():
"""Returns a exemplar dataset to use during the training if the approach needs it
:return: ExemplarDataset class or None
"""
return None
def _get_optimizer(self):
"""Returns the optimizer"""
return torch.optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
def train(self, t, trn_loader, val_loader):
"""Main train structure"""
self.pre_train_process(t, trn_loader)
self.train_loop(t, trn_loader, val_loader)
self.post_train_process(t, trn_loader)
def pre_train_process(self, t, trn_loader):
"""Runs before training all epochs of the task (before the train session)"""
# Warm-up phase
if self.warmup_epochs and t > 0:
self.optimizer = torch.optim.SGD(self.model.heads[-1].parameters(), lr=self.warmup_lr)
# Loop epochs -- train warm-up head
for e in range(self.warmup_epochs):
warmupclock0 = time.time()
self.model.heads[-1].train()
for images, targets in trn_loader:
outputs = self.model(images.to(self.device))
loss = self.warmup_loss(outputs[t], targets.to(self.device) - self.model.task_offset[t])
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.heads[-1].parameters(), self.clipgrad)
self.optimizer.step()
warmupclock1 = time.time()
with torch.no_grad():
total_loss, total_acc_taw = 0, 0
self.model.eval()
for images, targets in trn_loader:
outputs = self.model(images.to(self.device))
loss = self.warmup_loss(outputs[t], targets.to(self.device) - self.model.task_offset[t])
pred = torch.zeros_like(targets.to(self.device))
for m in range(len(pred)):
this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum()
pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task]
hits_taw = (pred == targets.to(self.device)).float()
total_loss += loss.item() * len(targets)
total_acc_taw += hits_taw.sum().item()
total_num = len(trn_loader.dataset.labels)
trn_loss, trn_acc = total_loss / total_num, total_acc_taw / total_num
warmupclock2 = time.time()
print('| Warm-up Epoch {:3d}, time={:5.1f}s/{:5.1f}s | Train: loss={:.3f}, TAw acc={:5.1f}% |'.format(
e + 1, warmupclock1 - warmupclock0, warmupclock2 - warmupclock1, trn_loss, 100 * trn_acc))
self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=trn_loss, group="warmup")
self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * trn_acc, group="warmup")
def train_loop(self, t, trn_loader, val_loader):
"""Contains the epochs loop"""
lr = self.lr
best_loss = np.inf
patience = self.lr_patience
best_model = self.model.get_copy()
self.optimizer = self._get_optimizer()
# Loop epochs
for e in range(self.nepochs):
# Train
clock0 = time.time()
self.train_epoch(t, trn_loader)
clock1 = time.time()
if self.eval_on_train:
train_loss, train_acc, _ = self.eval(t, trn_loader)
clock2 = time.time()
print('| Epoch {:3d}, time={:5.1f}s/{:5.1f}s | Train: loss={:.3f}, TAw acc={:5.1f}% |'.format(
e + 1, clock1 - clock0, clock2 - clock1, train_loss, 100 * train_acc), end='')
self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=train_loss, group="train")
self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * train_acc, group="train")
else:
print('| Epoch {:3d}, time={:5.1f}s | Train: skip eval |'.format(e + 1, clock1 - clock0), end='')
# Valid
clock3 = time.time()
valid_loss, valid_acc, _ = self.eval(t, val_loader)
clock4 = time.time()
print(' Valid: time={:5.1f}s loss={:.3f}, TAw acc={:5.1f}% |'.format(
clock4 - clock3, valid_loss, 100 * valid_acc), end='')
self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=valid_loss, group="valid")
self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * valid_acc, group="valid")
# Adapt learning rate - patience scheme - early stopping regularization
if valid_loss < best_loss:
# if the loss goes down, keep it as the best model and end line with a star ( * )
best_loss = valid_loss
best_model = self.model.get_copy()
patience = self.lr_patience
print(' *', end='')
else:
# if the loss does not go down, decrease patience
patience -= 1
if patience <= 0:
# if it runs out of patience, reduce the learning rate
lr /= self.lr_factor
print(' lr={:.1e}'.format(lr), end='')
if lr < self.lr_min:
# if the lr decreases below minimum, stop the training session
print()
break
# reset patience and recover best model so far to continue training
patience = self.lr_patience
self.optimizer.param_groups[0]['lr'] = lr
self.model.set_state_dict(best_model)
self.logger.log_scalar(task=t, iter=e + 1, name="patience", value=patience, group="train")
self.logger.log_scalar(task=t, iter=e + 1, name="lr", value=lr, group="train")
print()
self.model.set_state_dict(best_model)
def post_train_process(self, t, trn_loader):
"""Runs after training all the epochs of the task (after the train session)"""
pass
def train_epoch(self, t, trn_loader):
"""Runs a single epoch"""
self.model.train()
if self.fix_bn and t > 0:
self.model.freeze_bn()
for images, targets in trn_loader:
# Forward current model
outputs = self.model(images.to(self.device))
loss = self.criterion(t, outputs, targets.to(self.device))
# Backward
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
self.optimizer.step()
def eval(self, t, val_loader):
"""Contains the evaluation code"""
with torch.no_grad():
total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0
self.model.eval()
for images, targets in val_loader:
# Forward current model
outputs = self.model(images.to(self.device))
loss = self.criterion(t, outputs, targets.to(self.device))
hits_taw, hits_tag = self.calculate_metrics(outputs, targets)
# Log
total_loss += loss.item() * len(targets)
total_acc_taw += hits_taw.sum().item()
total_acc_tag += hits_tag.sum().item()
total_num += len(targets)
return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num
def calculate_metrics(self, outputs, targets):
"""Contains the main Task-Aware and Task-Agnostic metrics"""
pred = torch.zeros_like(targets.to(self.device))
# Task-Aware Multi-Head
for m in range(len(pred)):
this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum()
pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task]
hits_taw = (pred == targets.to(self.device)).float()
# Task-Agnostic Multi-Head
if self.multi_softmax:
outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs]
pred = torch.cat(outputs, dim=1).argmax(1)
else:
pred = torch.cat(outputs, dim=1).argmax(1)
hits_tag = (pred == targets.to(self.device)).float()
return hits_taw, hits_tag
def criterion(self, t, outputs, targets):
"""Returns the loss value"""
return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])