forked from fschmid56/EfficientAT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathex_pl_audioset.py
361 lines (313 loc) · 16.6 KB
/
ex_pl_audioset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import wandb
import numpy as np
import os
from torch import autocast
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import argparse
from sklearn import metrics
from contextlib import nullcontext
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import download_url_to_file
import pickle
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from datasets.audioset import get_test_set, get_full_training_set, get_ft_weighted_sampler
from models.mn.model import get_model as get_mobilenet
from models.dymn.model import get_model as get_dymn
from models.ensemble import get_ensemble_model
from models.preprocess import AugmentMelSTFT
from helpers.init import worker_init_fn
from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup
preds_url = \
"https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/passt_enemble_logits_mAP_495.npy"
fname_to_index_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/fname_to_index.pkl"
class PLModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
# model to preprocess waveform to mel spectrograms
self.mel = AugmentMelSTFT(n_mels=config.n_mels,
sr=config.resample_rate,
win_length=config.window_size,
hopsize=config.hop_size,
n_fft=config.n_fft,
freqm=config.freqm,
timem=config.timem,
fmin=config.fmin,
fmax=config.fmax,
fmin_aug_range=config.fmin_aug_range,
fmax_aug_range=config.fmax_aug_range
)
# load prediction model
model_name = config.model_name
pretrained_name = model_name if config.pretrained else None
width = NAME_TO_WIDTH(model_name) if model_name and config.pretrained else config.model_width
if model_name.startswith("dymn"):
model = get_dymn(width_mult=width, pretrained_name=pretrained_name,
strides=config.strides, pretrain_final_temp=config.pretrain_final_temp)
else:
model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name,
strides=config.strides, head_type=config.head_type, se_dims=config.se_dims)
self.model = model
# prepare ingredients for knowledge distillation
assert 0 <= config.kd_lambda <= 1, "Lambda for Knowledge Distillation must be between 0 and 1."
self.distillation_loss = nn.BCEWithLogitsLoss(reduction="none")
# load stored teacher predictions
if not os.path.isfile(config.teacher_preds):
# download file
print("Download teacher predictions...")
download_url_to_file(preds_url, config.teacher_preds)
print(f"Load teacher predictions from file {config.teacher_preds}")
teacher_preds = np.load(config.teacher_preds)
teacher_preds = torch.from_numpy(teacher_preds).float()
teacher_preds = torch.sigmoid(teacher_preds / config.temperature)
teacher_preds.requires_grad = False
self.teacher_preds = teacher_preds
if not os.path.isfile(config.fname_to_index):
print("Download filename to teacher prediction index dictionary...")
download_url_to_file(fname_to_index_url, config.fname_to_index)
with open(config.fname_to_index, 'rb') as f:
fname_to_index = pickle.load(f)
self.fname_to_index = fname_to_index
self.distributed_mode = config.num_devices > 1
self.training_step_outputs = []
self.validation_step_outputs = []
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
return x
def forward(self, x):
"""
:param x: batch of raw audio signals (waveforms)
:return: final model predictions
"""
x = self.mel_forward(x)
x = self.model(x)
return x
def configure_optimizers(self):
"""
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
:return: dict containing optimizer and learning rate scheduler
"""
if self.config.adamw:
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.max_lr,
weight_decay=self.config.weight_decay)
else:
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.max_lr,
weight_decay=self.config.weight_decay)
# phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune
schedule_lambda = \
exp_warmup_linear_down(self.config.warm_up_len, self.config.ramp_down_len, self.config.ramp_down_start,
self.config.last_lr_value)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda)
return {
'optimizer': optimizer,
'lr_scheduler': lr_scheduler
}
def on_train_epoch_start(self):
# in case of DyMN: update DyConv temperature
if hasattr(self.model, "update_params"):
self.model.update_params(self.current_epoch)
def training_step(self, train_batch, batch_idx):
"""
:param train_batch: contains one batch from train dataloader
:param batch_idx
:return: a dict containing at least loss that is used to update model parameters, can also contain
other items that can be processed in 'training_epoch_end' to log other metrics than loss
"""
x, f, y, i = train_batch
bs = x.size(0)
x = self.mel_forward(x)
rn_indices, lam = None, None
if self.config.mixup_alpha:
rn_indices, lam = mixup(bs, self.config.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(bs, 1, 1, 1) + \
x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1))
y_hat, _ = self.model(x)
y_mix = y * lam.reshape(bs, 1) + y[rn_indices] * (1. - lam.reshape(bs, 1))
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y_mix, reduction="none")
else:
y_hat, _ = self.model(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
# hard label loss
label_loss = samples_loss.mean()
# distillation loss
if self.config.kd_lambda > 0:
# fetch the correct index in 'teacher_preds' for given filename
# insert -1 for files not in fname_to_index (proportion of files successfully downloaded from
# YouTube can vary for AudioSet)
indices = torch.tensor(
[self.fname_to_index[fname] if fname in self.fname_to_index else -1 for fname in f], dtype=torch.int64
)
# get indices of files we could not find the teacher predictions for
unknown_indices = indices == -1
y_soft_teacher = self.teacher_preds[indices]
y_soft_teacher = y_soft_teacher.to(y_hat.device).type_as(y_hat)
if self.config.mixup_alpha:
soft_targets_loss = \
self.distillation_loss(y_hat, y_soft_teacher).mean(dim=1) * lam.reshape(bs) + \
self.distillation_loss(y_hat, y_soft_teacher[rn_indices]).mean(dim=1) \
* (1. - lam.reshape(bs))
else:
soft_targets_loss = distillation_loss(y_hat, y_soft_teacher)
# zero out loss for samples we don't have teacher predictions for
soft_targets_loss[unknown_indices] = soft_targets_loss[unknown_indices] * 0
soft_targets_loss = soft_targets_loss.mean()
# weighting losses
label_loss = self.config.kd_lambda * label_loss
soft_targets_loss = (1 - self.config.kd_lambda) * soft_targets_loss
else:
soft_targets_loss = torch.tensor(0., device=label_loss.device, dtype=label_loss.dtype)
# total loss is sum of lambda-weighted label and distillation loss
loss = label_loss + soft_targets_loss
results = {"loss": loss.detach().cpu(), "label_loss": label_loss.detach().cpu(),
"kd_loss": soft_targets_loss.detach().cpu()}
self.training_step_outputs.append(results)
return loss
def on_train_epoch_end(self):
"""
:return: a dict containing the metrics you want to log to Weights and Biases
"""
avg_loss = torch.stack([x['loss'] for x in self.training_step_outputs]).mean()
avg_label_loss = torch.stack([x['label_loss'] for x in self.training_step_outputs]).mean()
avg_kd_loss = torch.stack([x['kd_loss'] for x in self.training_step_outputs]).mean()
self.log_dict({'train/loss': torch.as_tensor(avg_loss).cuda(),
'train/label_loss': torch.as_tensor(avg_label_loss).cuda(),
'train/kd_loss': torch.as_tensor(avg_kd_loss).cuda()
}, sync_dist=True)
self.training_step_outputs.clear()
def validation_step(self, val_batch, batch_idx):
x, _, y = val_batch
x = self.mel_forward(x)
y_hat, _ = self.model(x)
loss = F.binary_cross_entropy_with_logits(y_hat, y)
preds = torch.sigmoid(y_hat)
results = {'val_loss': loss, "preds": preds, "targets": y}
results = {k: v.cpu() for k, v in results.items()}
self.validation_step_outputs.append(results)
def on_validation_epoch_end(self):
loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs])
preds = torch.cat([x['preds'] for x in self.validation_step_outputs], dim=0)
targets = torch.cat([x['targets'] for x in self.validation_step_outputs], dim=0)
all_preds = self.all_gather(preds).reshape(-1, preds.shape[-1]).cpu().float().numpy()
all_targets = self.all_gather(targets).reshape(-1, targets.shape[-1]).cpu().float().numpy()
all_loss = self.all_gather(loss).reshape(-1,)
try:
average_precision = metrics.average_precision_score(
all_targets, all_preds, average=None)
except ValueError:
average_precision = np.array([np.nan] * 527)
try:
roc = metrics.roc_auc_score(all_targets, all_preds, average=None)
except ValueError:
roc = np.array([np.nan] * 527)
logs = {'val/loss': torch.as_tensor(all_loss).mean().cuda(),
'val/ap': torch.as_tensor(average_precision).mean().cuda(),
'val/roc': torch.as_tensor(roc).mean().cuda()
}
self.log_dict(logs, sync_dist=False)
self.validation_step_outputs.clear()
def train(config):
# Train Models from scratch or ImageNet pre-trained on AudioSet
# PaSST ensemble (https://github.com/kkoutini/PaSST) stored in 'resources/passt_enemble_logits_mAP_495.npy'
# can be used as a teacher.
# logging is done using wandb
wandb_logger = WandbLogger(
project="EfficientAudioTagging",
notes="Training efficient audio tagging models on AudioSet using Knowledge Distillation.",
tags=["AudioSet", "Audio Tagging", "Knowledge Disitillation"],
config=config,
name=config.experiment_name
)
train_dl = DataLoader(dataset=get_full_training_set(resample_rate=config.resample_rate,
roll=config.roll,
wavmix=config.wavmix,
gain_augment=config.gain_augment),
sampler=get_ft_weighted_sampler(config.epoch_len), # sampler important to balance classes
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size)
# eval dataloader
eval_dl = DataLoader(dataset=get_test_set(resample_rate=config.resample_rate),
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size)
# create pytorch lightening module
pl_module = PLModule(config)
# create monitor to keep track of learning rate - we want to check the behaviour of our learning rate schedule
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
# on which kind of device(s) to train and possible callbacks
trainer = pl.Trainer(max_epochs=config.n_epochs,
logger=wandb_logger,
accelerator='auto',
devices=config.num_devices,
precision=config.precision,
num_sanity_val_steps=0,
callbacks=[lr_monitor])
# start training and validation for the specified number of epochs
trainer.fit(pl_module, train_dl, eval_dl)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
# general
parser.add_argument('--experiment_name', type=str, default="AudioSet")
parser.add_argument('--batch_size', type=int, default=120)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--num_devices', type=int, default=4)
# evaluation
# if ensemble is set, 'model_name' is not used
parser.add_argument('--ensemble', nargs='+', default=[])
parser.add_argument('--model_name', type=str, default="mn10_as") # used also for training
parser.add_argument('--cuda', action='store_true', default=False)
# training
parser.add_argument('--precision', type=int, default=16)
parser.add_argument('--pretrained', action='store_true', default=False)
parser.add_argument('--pretrain_final_temp', type=float, default=30.0) # for DyMN
parser.add_argument('--model_width', type=float, default=1.0)
parser.add_argument('--strides', nargs=4, default=[2, 2, 2, 2], type=int)
parser.add_argument('--head_type', type=str, default="mlp")
parser.add_argument('--se_dims', type=str, default="c")
parser.add_argument('--n_epochs', type=int, default=200)
parser.add_argument('--mixup_alpha', type=float, default=0.3)
parser.add_argument('--epoch_len', type=int, default=100000)
parser.add_argument('--roll', action='store_true', default=False)
parser.add_argument('--wavmix', action='store_true', default=False)
parser.add_argument('--gain_augment', type=int, default=0)
# optimizer
parser.add_argument('--adamw', action='store_true', default=False)
parser.add_argument('--weight_decay', type=float, default=0.0001)
# lr schedule
parser.add_argument('--max_lr', type=float, default=0.003)
parser.add_argument('--warm_up_len', type=int, default=8)
parser.add_argument('--ramp_down_start', type=int, default=80)
parser.add_argument('--ramp_down_len', type=int, default=95)
parser.add_argument('--last_lr_value', type=float, default=0.01)
# knowledge distillation
parser.add_argument('--teacher_preds', type=str,
default=os.path.join("resources", "passt_enemble_logits_mAP_495.npy"))
parser.add_argument('--fname_to_index', type=str,
default=os.path.join("resources", "fname_to_index.pkl"))
parser.add_argument('--temperature', type=float, default=1)
parser.add_argument('--kd_lambda', type=float, default=0.1)
# preprocessing
parser.add_argument('--resample_rate', type=int, default=32000)
parser.add_argument('--window_size', type=int, default=800)
parser.add_argument('--hop_size', type=int, default=320)
parser.add_argument('--n_fft', type=int, default=1024)
parser.add_argument('--n_mels', type=int, default=128)
parser.add_argument('--freqm', type=int, default=0)
parser.add_argument('--timem', type=int, default=0)
parser.add_argument('--fmin', type=int, default=0)
parser.add_argument('--fmax', type=int, default=None)
parser.add_argument('--fmin_aug_range', type=int, default=10)
parser.add_argument('--fmax_aug_range', type=int, default=2000)
args = parser.parse_args()
train(args)