-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconformer.py
457 lines (386 loc) · 19.9 KB
/
conformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
import copy
import math
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
def multi_head_sep_attention_forward(query, # type: Tensor
key, # type: Tensor
value, # type: Tensor
embed_dim_to_check, # type: int
num_heads, # type: int
in_proj_weight, # type: Tensor
in_proj_bias, # type: Tensor
bias_k, # type: Optional[Tensor]
bias_v, # type: Optional[Tensor]
add_zero_attn, # type: bool
dropout_p, # type: float
out_proj_weight, # type: Tensor
out_proj_bias, # type: Tensor
training=True, # type: bool
key_padding_mask=None, # type: Optional[Tensor]
need_weights=True, # type: bool
attn_mask=None, # type: Optional[Tensor]
use_separate_proj_weight=False, # type: bool
q_proj_weight=None, # type: Optional[Tensor]
k_proj_weight=None, # type: Optional[Tensor]
v_proj_weight=None, # type: Optional[Tensor]
static_k=None, # type: Optional[Tensor]
static_v=None, # type: Optional[Tensor]
shared_qk=False, # type: bool
sep=False # type: bool
):
if not torch.jit.is_scripting():
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
out_proj_weight, out_proj_bias)
#if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
# return handle_torch_function(
# multi_head_sep_attention_forward, tens_ops, query, key, value,
# embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
# bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
# out_proj_bias, training=training, key_padding_mask=key_padding_mask,
# need_weights=need_weights, attn_mask=attn_mask,
# use_separate_proj_weight=use_separate_proj_weight,
# q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
# v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size() == value.size()
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if not use_separate_proj_weight:
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
if shared_qk:
k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(2, dim=-1)
q = k
else:
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
if key is None:
assert value is None
k = None
v = None
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = F.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = F.linear(value, _w, _b)
else:
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
len1, len2 = q_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == query.size(-1)
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
len1, len2 = k_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == key.size(-1)
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
len1, len2 = v_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == value.size(-1)
if in_proj_bias is not None:
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
else:
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
q = q * scaling
if attn_mask is not None:
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 2D attn_mask is not correct.')
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 3D attn_mask is not correct.')
else:
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
# attn_mask's dim is 3 now.
if bias_k is not None and bias_v is not None:
if static_k is None and static_v is None:
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
else:
assert bias_k is None
assert bias_v is None
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
if sep:
attn_output_weights = k.transpose(1, 2)
assert list(attn_output_weights.size()) == [bsz * num_heads, head_dim, src_len]
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, head_dim, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, head_dim, src_len)
attn_output_weights = F.softmax(
attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, head_dim, head_dim]
attn_output_weights = F.softmax(
q, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, attn_output)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
return attn_output, None
else:
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
attn_output_weights = F.softmax(
attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
class MultiheadSeparableAttention(nn.Module):
__annotations__ = {
'bias_k': torch._jit_internal.Optional[torch.Tensor],
'bias_v': torch._jit_internal.Optional[torch.Tensor],
}
__constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, shared_qk=False, sep=False):
super(MultiheadSeparableAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.shared_qk = shared_qk
self.sep = sep
if self._qkv_same_embed_dim is False:
self.q_proj_weight = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.vdim))
self.register_parameter('in_proj_weight', None)
else:
if shared_qk:
self.in_proj_weight = nn.Parameter(torch.empty(2 * embed_dim, embed_dim))
else:
self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
if bias:
if shared_qk:
self.in_proj_bias = nn.Parameter(torch.empty(2 * embed_dim))
else:
self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = nn.Parameter(torch.empty(1, 1, embed_dim))
self.bias_v = nn.Parameter(torch.empty(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
if self._qkv_same_embed_dim:
nn.init.xavier_uniform_(self.in_proj_weight)
else:
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadSeparableAttention checkpoints generated by v1.1.0
if '_qkv_same_embed_dim' not in state:
state['_qkv_same_embed_dim'] = True
super(MultiheadSeparableAttention, self).__setstate__(state)
def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):
if not self._qkv_same_embed_dim:
return multi_head_sep_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, shared_qk=self.shared_qk, sep=self.sep)
else:
return multi_head_sep_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, shared_qk=self.shared_qk, sep=self.sep)
class ConformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", attn=False, conv=False, convsz=1, shared_qk=False, sep=False):
super(ConformerEncoderLayer, self).__init__()
self.conv = conv
if conv:
assert convsz % 2 == 1 # kernel size should be an odd number
self.conv_layer = nn.Conv1d(d_model, d_model, convsz, padding=int((convsz-1)/2), groups=nhead)
self.attn = attn
if attn:
self.self_attn = MultiheadSeparableAttention(d_model, nhead, dropout=dropout, shared_qk=shared_qk, sep=sep)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(ConformerEncoderLayer, self).__setstate__(state)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
if self.conv:
src2 = src.permute(1, 2, 0)
src2 = self.conv_layer(src2)
src2 = src2.permute(2, 0, 1)
else:
src2 = src
if self.attn:
src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class ConformerEncoder(nn.Module):
__constants__ = ['norm']
def __init__(self, encoder_layer, num_layers, norm=None):
super(ConformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src, mask=None, src_key_padding_mask=None):
output = src
for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
if self.norm is not None:
output = self.norm(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)