-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy pathjoint.py
109 lines (92 loc) · 4.58 KB
/
joint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from argparse import ArgumentParser
from torch.utils.data import DataLoader, Dataset
from .incremental_learning import Inc_Learning_Appr
from datasets.exemplars_dataset import ExemplarsDataset
class Appr(Inc_Learning_Appr):
"""Class implementing the joint baseline"""
def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
logger=None, exemplars_dataset=None, freeze_after=-1):
super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
exemplars_dataset)
self.trn_datasets = []
self.val_datasets = []
self.freeze_after = freeze_after
have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
assert (have_exemplars == 0), 'Warning: Joint does not use exemplars. Comment this line to force it.'
@staticmethod
def exemplars_dataset_class():
return ExemplarsDataset
@staticmethod
def extra_parser(args):
"""Returns a parser containing the approach specific parameters"""
parser = ArgumentParser()
parser.add_argument('--freeze-after', default=-1, type=int, required=False,
help='Freeze model except heads after the specified task'
'(-1: normal Incremental Joint Training, no freeze) (default=%(default)s)')
return parser.parse_known_args(args)
def post_train_process(self, t, trn_loader):
"""Runs after training all the epochs of the task (after the train session)"""
if self.freeze_after > -1 and t >= self.freeze_after:
self.model.freeze_all()
for head in self.model.heads:
for param in head.parameters():
param.requires_grad = True
def train_loop(self, t, trn_loader, val_loader):
"""Contains the epochs loop"""
# add new datasets to existing cumulative ones
self.trn_datasets.append(trn_loader.dataset)
self.val_datasets.append(val_loader.dataset)
trn_dset = JointDataset(self.trn_datasets)
val_dset = JointDataset(self.val_datasets)
trn_loader = DataLoader(trn_dset,
batch_size=trn_loader.batch_size,
shuffle=True,
num_workers=trn_loader.num_workers,
pin_memory=trn_loader.pin_memory)
val_loader = DataLoader(val_dset,
batch_size=val_loader.batch_size,
shuffle=False,
num_workers=val_loader.num_workers,
pin_memory=val_loader.pin_memory)
# continue training as usual
super().train_loop(t, trn_loader, val_loader)
def train_epoch(self, t, trn_loader):
"""Runs a single epoch"""
if self.freeze_after < 0 or t <= self.freeze_after:
self.model.train()
if self.fix_bn and t > 0:
self.model.freeze_bn()
else:
self.model.eval()
for head in self.model.heads:
head.train()
for images, targets in trn_loader:
# Forward current model
outputs = self.model(images.to(self.device))
loss = self.criterion(t, outputs, targets.to(self.device))
# Backward
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
self.optimizer.step()
def criterion(self, t, outputs, targets):
"""Returns the loss value"""
return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
class JointDataset(Dataset):
"""Characterizes a dataset for PyTorch -- this dataset accumulates each task dataset incrementally"""
def __init__(self, datasets):
self.datasets = datasets
self._len = sum([len(d) for d in self.datasets])
def __len__(self):
'Denotes the total number of samples'
return self._len
def __getitem__(self, index):
for d in self.datasets:
if len(d) <= index:
index -= len(d)
else:
x, y = d[index]
return x, y