forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_transformers.py
2982 lines (2571 loc) · 148 KB
/
test_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["module: nn"]
import contextlib
from functools import partial
from collections import namedtuple
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import unittest
from unittest.mock import patch, MagicMock, ANY
import math
from torch.backends.cuda import sdp_kernel, SDPBackend
import torch.optim as optim
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
from typing import List, Tuple, Optional
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
TEST_FAIRSEQ,
run_tests,
parametrize,
freeze_rng_state,
TEST_WITH_CROSSREF,
slowTest,
set_default_dtype,
gradcheck,
make_tensor,
NOTEST_CPU
)
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
from torch.testing._internal.common_cuda import (
SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION
)
if TEST_FAIRSEQ:
import fairseq.models.transformer as fairseq_transformer
SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
@contextlib.contextmanager
def use_deterministic_algorithims(mode: bool, warn_only: bool):
r"""
This context manager can be used to temporarily enable or disable deterministic algorithms.
Upon exiting the context manager, the previous state of the flag will be restored.
"""
previous_mode: bool = torch.are_deterministic_algorithms_enabled()
previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled()
try:
torch.use_deterministic_algorithms(mode, warn_only=warn_only)
yield {}
finally:
torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)
# Found in torch/testing/_comparison.py
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)]
isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
deviation = true_value - computed_value
deviation = torch.abs(deviation / true_value)
# Fill in the nans with the default rtol
torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
return deviation.max().item()
def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
deviation = true_value - computed_value
atol = torch.abs(deviation).max().item()
return atol
def get_tolerances(
true_value: torch.Tensor,
computed_value: torch.Tensor,
fudge_factor: Optional[float] = None,
) -> Tuple[float, float]:
"""Returns the absolute and relative tolerances for comparing two tensors."""
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
atol = get_atol(true_value, computed_value)
rtol = get_rtol(true_value, computed_value)
atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
# torch.isclose() has weird behavior around see:
# https://github.com/pytorch/pytorch/issues/102400
if rtol > 1e30:
rtol = default_rtol[computed_value.dtype]
return atol, rtol
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str,
requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
"""Creates rand dense or nested tensor with given shape and type.
Args:
shape (Tuple[int]): Shape of Tensor to construct
device (str): which device to create tensor on
dtype (torch.dtype): Tensors' dtype
type (str): Nested or Dense
requires_grad (bool, optional): Tensors grad status. Defaults to False.
packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
Returns:
torch.Tensor: A new tensor
"""
batch, num_heads, seq_len, head_dim = shape.batch, shape.num_heads, shape.seq_len, shape.head_dim
if type == "nested":
if isinstance(seq_len, list):
def _size(i):
return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim)
return torch.nested.nested_tensor([
torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad)
for i in range(batch)])
else:
size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
return torch.nested.nested_tensor([
torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
for _ in range(batch)])
else:
assert (isinstance(seq_len, int))
size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim)
return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
def calculate_nt_tolerances(nt_ref_hp, nt_ref_lp, default_dtype, fudge_factor=1):
# TODO use NT ops when we have implemented Max for NestedTensor instead of unrolling
ref_atol = default_atol[default_dtype]
ref_rtol = default_rtol[default_dtype]
for tensor_component_ref, tensor_component_ref_lp in zip(nt_ref_hp.unbind(), nt_ref_lp.unbind()):
ref_atol = max((fudge_factor * torch.abs(tensor_component_ref - tensor_component_ref_lp)).max().item(), ref_atol)
ref_rtol = max(get_rtol(tensor_component_ref, tensor_component_ref_lp), ref_rtol)
return ref_atol, ref_rtol
class TestTransformers(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@onlyCUDA
@unittest.skip("4D mask not supported yet - activate when 4D mask supported")
def test_self_attn_TxT_attn_mask(self, device):
embed_dim = 16
num_heads = 4
batch_size = 10
tgt_len = 16
query = torch.rand(batch_size, tgt_len, embed_dim, device=device) # [N, T, D]
attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T]
attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, 0.0)
attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
mta_model.eval()
# Generate 3D results
with torch.inference_mode():
output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D]
output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D]
self.assertEqual(output_mask_4d, output_mask_TxT)
@slowTest
def test_train_with_pad_and_catch_error(self, device):
iters = 100
pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device)
layer = nn.TransformerEncoderLayer(
d_model=2,
dim_feedforward=4,
nhead=2,
batch_first=True,
activation="gelu",
dropout=0,
)
criterion = nn.MSELoss()
encoder = nn.TransformerEncoder(layer, 2).to(device)
optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
encoder.train()
for i in range(iters):
encoder.train()
optimizer.zero_grad()
inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
outputs = encoder(inputs, src_key_padding_mask=pad_mask)
loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
loss.backward()
optimizer.step()
with torch.no_grad():
test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
# Expect uint8 type not supported
ex = None
try:
test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported uint8 type exception")
test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
encoder.eval()
# Expect long type not supported
ex = None
try:
test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported Long type exception")
test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
@parametrize("attn_mask_dim", [2, 3, None])
@parametrize("key_padding_mask_dim", [2, None])
@parametrize("mask_dtype", [torch.bool, torch.float32])
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
with torch.no_grad():
B = 2
L = 4
D = 8
H = 4
if attn_mask_dim == 2:
attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
elif attn_mask_dim == 3:
attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
elif attn_mask_dim is None:
attn_mask = None
if key_padding_mask_dim == 2:
key_padding_mask = make_tensor((B, L), dtype=mask_dtype, device=device)
elif key_padding_mask_dim is None:
key_padding_mask = None
mha = nn.MultiheadAttention(D, H, batch_first=True, device=device)
X = torch.randn(B, L, D, device=device)
mha.train() # disable fast path
out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
mha.eval() # enable fast path
out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
self.assertEqual(out, out_fp)
@parametrize("nhead", [1, 4, 8])
def test_transformerencoderlayer_src_mask(self, device, nhead):
batch_size = 2
seqlen = 4
d_model = 8
dim_feedforward = 32
model = torch.nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True).to(device)
src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
model(src, src_mask=src_mask)
model.eval()
with torch.no_grad():
model(src, src_mask=src_mask)
@parametrize("use_torchscript", [False])
@parametrize("enable_nested_tensor", [True, False])
@parametrize("use_autocast", [True, False])
@parametrize("d_model", [12, 256])
def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast, d_model):
"""
Test TransformerEncoder fastpath output matches slowpath output
"""
torch.manual_seed(1234)
nhead = 4
dim_feedforward = d_model
batch_first = True
model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=batch_first),
num_layers=2,
enable_nested_tensor=enable_nested_tensor
).to(device).eval()
if use_torchscript:
model = torch.jit.script(model)
# each input is (input, mask)
input_mask_pairs = [
(
torch.rand(3, 2, d_model),
[
[0, 1],
[0, 1],
[1, 1]
]
),
(
torch.rand(2, 100, d_model),
[
[0] * 98 + [1] * 2,
[0] * 90 + [1] * 10
]
),
# softmax.cu switches from fast->slowpath at masked seqlen 1024. test 1024.
(
torch.rand(2, 1024, d_model),
[
[0] * 1020 + [1] * 4,
[0] * 1024,
]
),
(
torch.rand(1, 1026, d_model),
[[0] * 1024 + [1] * 2]
),
# softmax.cu switches from fast->slowpath at masked seqlen 1024. test range of masks above 1024.
(
torch.rand(4, 1040, d_model),
[
[0] * 1024 + [1] * 16,
[0] * 1025 + [1] * 15,
[0] * 1031 + [1] * 9,
[0] * 1040,
]
)
]
input_mask_pairs = [
(
torch.tensor(pair[0], device=device, dtype=torch.get_default_dtype()), # float input
torch.tensor(pair[1], device=device, dtype=torch.bool) # bool mask
) for pair in input_mask_pairs
]
maybe_autocast = torch.autocast("cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext()
with maybe_autocast:
for input, src_key_padding_mask in input_mask_pairs:
with torch.no_grad():
fastpath_output = model(input, src_key_padding_mask=src_key_padding_mask)
slowpath_output = model(input, src_key_padding_mask=src_key_padding_mask) # reference
# Make sure fastpath_output is same shape as slowpath_output and mask.
# When enable_nested_tensor=true, fastpath_output may be smaller than input tensor.
# Eg if input bs=1, seqlen=6, and we mask out 2 tokens, fastpath_output will have bs=1, seqlen=4.
# Expand back to old size to match.
bs, true_seqlen, embed_dim = fastpath_output.shape
expanded_seqlen = src_key_padding_mask.shape[1]
fastpath_output_expanded = torch.zeros(bs, expanded_seqlen, embed_dim, device=device)
fastpath_output_expanded[:, :true_seqlen, :] = fastpath_output
# no garauntees on output corresponding to masked tokens, so they may vary between slow/fast path. set all to 0.
fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5)
@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
@parametrize("enable_nested_tensor", [False])
def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
"""
Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
batch size == sequence length
"""
model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
num_layers=2,
enable_nested_tensor=enable_nested_tensor
).to(device)
with torch.no_grad():
# set constant weights of the model
for idx, p in enumerate(model.parameters()):
x = p.data
sz = x.view(-1).size(0)
shape = x.shape
x = torch.cos(torch.arange(0, sz).float().view(shape))
p.data.copy_(x)
if training:
model = model.train()
else:
model = model.eval()
x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.get_default_dtype()).to(device)
src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)
if with_no_grad:
cm = torch.no_grad()
else:
cm = contextlib.nullcontext()
with cm:
result = model(x, mask=src_mask)
ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
[[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
@parametrize("batch_first", [True, False])
@parametrize("training", [True, False])
@parametrize("enable_nested_tensor", [True, False])
def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
def get_a_test_layer(activation, batch_first=False):
d_model = 4
nhead = 2
dim_feedforward = 16
dropout = 0.0
layer = nn.TransformerEncoderLayer(
d_model,
nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
batch_first=batch_first,
).to(device)
with torch.no_grad():
# set constant weights of the model
for idx, p in enumerate(layer.parameters()):
x = p.data
sz = x.view(-1).size(0)
shape = x.shape
x = torch.cos(torch.arange(0, sz).float().view(shape))
p.data.copy_(x)
return layer
# this is a deterministic test for TransformerEncoder
activation = F.relu
def _test(batch_first, training, enable_nested_tensor):
def perm_fn(x):
return x.transpose(1, 0) if batch_first else x
encoder_layer = get_a_test_layer(activation=activation,
batch_first=batch_first)
model = nn.TransformerEncoder(
encoder_layer, 1, enable_nested_tensor=enable_nested_tensor
).to(device)
if not training:
model = model.eval()
# deterministic input
encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
[0.5387, 0.1655, 0.3565, 0.0471]],
[[0.8335, 0.2799, 0.5031, 0.2947],
[0.1402, 0.0318, 0.7636, 0.1346]],
[[0.6333, 0.9344, 0.1376, 0.9938],
[0.8924, 0.2872, 0.6692, 0.2944]],
[[0.9897, 0.6915, 0.3154, 0.1733],
[0.8645, 0.3513, 0.3064, 0.0767]],
[[0.8117, 0.2366, 0.4838, 0.7881],
[0.3718, 0.4945, 0.9511, 0.0864]]]
)).to(device)
result = model(encoder_input)
ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
[2.427987, 0.021213, -0.602496, -0.084103]],
[[2.424689, 0.019155, -0.604793, -0.085672],
[2.413863, 0.022211, -0.612486, -0.072490]],
[[2.433774, 0.021598, -0.598343, -0.087548],
[2.425104, 0.019748, -0.604515, -0.084839]],
[[2.436185, 0.022682, -0.596625, -0.087261],
[2.433556, 0.021891, -0.598509, -0.086832]],
[[2.416246, 0.017512, -0.610712, -0.082961],
[2.422901, 0.024187, -0.606178, -0.074929]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# all 0 src_mask
src_mask = torch.zeros([5, 5]).to(device) == 1
result = model(encoder_input, mask=src_mask)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# all 0
mask = torch.zeros([2, 5]).to(device) == 1
result = model(encoder_input, src_key_padding_mask=mask)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
mask[0, 1] = 1
mask[1, 3] = 1
mask[1, 4] = 1
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
[2.428811, 0.021445, -0.601912, -0.084252]],
[[2.425009, 0.019155, -0.604566, -0.085899],
[2.415408, 0.02249, -0.611415, -0.073]],
[[2.434199, 0.021682, -0.598039, -0.087699],
[2.42598, 0.019941, -0.603896, -0.085091]],
[[2.436457, 0.022736, -0.59643, -0.08736],
[2.434021, 0.022093, -0.598179, -0.08679]],
[[2.416531, 0.017498, -0.610513, -0.083181],
[2.4242, 0.024653, -0.605266, -0.074959]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# test case 2, multiple layers no norm
model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
[2.419102, 0.017452, -0.608703, -0.085026]],
[[2.419043, 0.017445, -0.608744, -0.084999],
[2.419052, 0.017446, -0.608738, -0.085004]],
[[2.419067, 0.017448, -0.608727, -0.085010],
[2.419098, 0.017452, -0.608706, -0.085024]],
[[2.419072, 0.017449, -0.608724, -0.085012],
[2.419119, 0.017455, -0.608691, -0.085034]],
[[2.419019, 0.017442, -0.608761, -0.084989],
[2.419075, 0.017449, -0.608722, -0.085014]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# test case 3, multiple layers with norm
# d_model = 4
norm = nn.LayerNorm(4)
model = nn.TransformerEncoder(encoder_layer, 2, norm=norm,
enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
[1.695955, -0.357639, -0.893050, -0.445266]],
[[1.695948, -0.357634, -0.893082, -0.445233],
[1.695950, -0.357635, -0.893077, -0.445238]],
[[1.695951, -0.357636, -0.893069, -0.445246],
[1.695955, -0.357639, -0.893052, -0.445264]],
[[1.695952, -0.357636, -0.893066, -0.445249],
[1.695957, -0.357641, -0.893041, -0.445276]],
[[1.695946, -0.357632, -0.893095, -0.445220],
[1.695952, -0.357637, -0.893065, -0.445251]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
model = nn.TransformerEncoder(encoder_layer, 6, norm=norm,
enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# TODO: remove set default dtype to double by making ref_output more precise.
# Added because this test was copied from test_nn.py, which has default
# dtype double. If default dtype is float, tests will say tensors not close because
# ref output precision too low
with set_default_dtype(torch.double):
if training:
cm = contextlib.nullcontext()
else:
cm = torch.no_grad() # transformer fast path requires no grad
with cm:
_test(batch_first, training, enable_nested_tensor)
@unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
def test_encoder_padding_and_src_mask_bool(self):
encoder_layer = nn.TransformerEncoderLayer(
d_model=16,
nhead=2,
dim_feedforward=32,
dropout=0.1,
activation='relu',
batch_first=True,
)
encoder_norm = nn.LayerNorm(16)
encoder = nn.TransformerEncoder(
encoder_layer, 2, encoder_norm
)
inputs = torch.randn(2, 3, 16)
src_mask = torch.ones(3, 3, dtype=torch.bool).triu_(diagonal=1)
input_seq_len = torch.tensor([3, 2])
padding_mask = (
torch.arange(3)[None, :].cpu() >= input_seq_len[:, None]
)
with self.assertNoLogs(None):
encoder(
inputs,
mask=src_mask,
src_key_padding_mask=padding_mask,
)
@unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
def test_decoder_padding_and_src_mask_bool(self):
def transformer_decoder(inputs, input_seq_len, memory):
decoder_layer = nn.TransformerDecoderLayer(
d_model=16,
nhead=2,
dim_feedforward=32,
dropout=0.1,
activation='relu',
batch_first=True,
)
decoder_norm = nn.LayerNorm(16)
decoder = nn.TransformerDecoder(
decoder_layer, 2, decoder_norm
)
src_mask = torch.ones(
inputs.shape[1], inputs.shape[1], dtype=torch.bool
).triu_(diagonal=1)
padding_mask = (
torch.arange(inputs.shape[1])[None, :].cpu()
>= input_seq_len[:, None]
)
return decoder(
inputs,
memory,
tgt_mask=src_mask,
tgt_key_padding_mask=padding_mask,
memory_key_padding_mask=padding_mask,
)
inputs = torch.randn(2, 3, 16)
memory = torch.randn(2, 3, 16)
input_seq_len = torch.tensor([3, 2])
with self.assertNoLogs(None):
transformer_decoder(inputs, input_seq_len, memory)
def test_encoder_is_causal(self):
d_model = 3
layer = torch.nn.TransformerEncoderLayer(d_model, 1, 6, batch_first=True)
layer.eval()
x = torch.randn(1, 5, d_model)
unmasked_output = layer(x)
mask = torch.nn.Transformer.generate_square_subsequent_mask(x.size(1))
is_causal_output = layer(x, src_mask=mask, is_causal=True)
masked_output = layer(x, src_mask=mask)
self.assertEqual(masked_output, is_causal_output)
@onlyCUDA
@parametrize("nb_heads", [1, 8])
@parametrize("bias", [True, False])
def test_mha_native_args(self, nb_heads, bias):
B, L, F = 8, 100, 128
batch_first = True
fast_path = True
use_pad_mask = (bias % 2) == 1
mha = nn.MultiheadAttention(
embed_dim=F,
num_heads=nb_heads,
batch_first=batch_first,
bias=bias
).cuda()
mha.eval()
ctx = torch.no_grad if fast_path else contextlib.nullcontext
with ctx():
x = torch.randn(B, L, F).cuda()
if not batch_first:
x = x.transpose(0, 1)
pad_mask = None
if use_pad_mask:
pad_mask = torch.zeros((B, L), dtype=torch.bool).cuda()
mha(query=x, key=x, value=x, key_padding_mask=pad_mask)
def test_kpm_mask_trailing_column_with_nested_tensor(self, device):
encoder_layer = nn.TransformerEncoderLayer(
d_model=256,
nhead=4,
dim_feedforward=512,
activation='gelu',
norm_first=False,
batch_first=False,
)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
x = torch.randn(10, 6, 256).to(device)
mask = torch.ones(6, 10)
mask[0, :] = 0 # here I masked 5 columns instead of just one
mask = mask.bool().to(device)
out = transformer_encoder(src=x, src_key_padding_mask=mask)
self.assertEqual(out.shape[1], 6)
# CPU unit test has_torch_functions in test environment,
# preventing successful completion
@onlyCUDA
def test_with_nested_tensor_input(self, device):
encoder_layer = nn.TransformerEncoderLayer(
d_model=256,
nhead=4,
dim_feedforward=512,
activation='gelu',
norm_first=False,
batch_first=True,
)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
transformer_encoder.eval()
with torch.no_grad():
x = torch.randn(6, 10, 256).to(device)
mask = torch.ones(6, 10)
mask[0, 0:] = 0 # here I masked 5 columns instead of just one
mask[2, 2:] = 0 # here I masked 5 columns instead of just one
mask[4, 4:] = 0 # here I masked 5 columns instead of just one
mask[5, 8:] = 0 # here I masked 5 columns instead of just one
mask = mask.bool().to(device)
x = torch._nested_tensor_from_mask(x, mask.logical_not(), mask_check=False)
out = transformer_encoder(src=x, src_key_padding_mask=None)
self.assertEqual(out.is_nested, True)
def test_script_encoder_subclass(self, device):
class MyCustomLayer(nn.TransformerEncoderLayer):
pass
encoder = nn.TransformerEncoder(
MyCustomLayer(d_model=256, nhead=8), num_layers=6
).to(device=device)
torch.jit.script(encoder)
# brazenly adapted from test_transformerencoderlayer_src_mask to test execution of
# torchscripted transformerencoderlayer subclass
def test_transformerencoderlayer_subclass(self, device):
class MyCustomLayer(nn.TransformerEncoderLayer):
pass
nhead = 4
batch_size = 2
seqlen = 4
d_model = 8
dim_feedforward = 32
model = MyCustomLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True).to(device)
script_model = torch.jit.script(model)
src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
torch.manual_seed(42)
result = model(src, src_mask=src_mask)
torch.manual_seed(42)
scripted_result = script_model(src, src_mask=src_mask)
self.assertEqual(result, scripted_result)
model.eval()
script_model = torch.jit.script(model)
with torch.no_grad():
result = model(src, src_mask=src_mask)
scripted_result = script_model(src, src_mask=src_mask)
self.assertEqual(result, scripted_result)
def test_transformerencoderlayer_subclass_model(self, device):
class MyCustomLayer(nn.TransformerEncoderLayer):
pass
nhead = 4
batch_size = 2
seqlen = 4
d_model = 8
dim_feedforward = 32
layer = MyCustomLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True)
model = nn.TransformerEncoder(
layer, num_layers=6
).to(device=device)
script_model = torch.jit.script(model)
src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
torch.manual_seed(42)
result = model(src, mask=src_mask)
torch.manual_seed(42)
scripted_result = script_model(src, mask=src_mask)
self.assertEqual(result, scripted_result)
model.eval()
script_model = torch.jit.script(model)
with torch.no_grad():
result = model(src, mask=src_mask)
scripted_result = script_model(src, mask=src_mask)
self.assertEqual(result, scripted_result)
@onlyCUDA
@unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
def test_decoder_only_layer(self):
DEFAULT_PADDING_IDX = 0
class FairseqDecoder(torch.nn.Module):
def __init__(
self,
embed_dim,
attention_heads,
ffn_embed_dim,
num_layers,
embedding_layer, # torch.nn.Embedding. Must have a padding_idx field
dropout=0,
normalize_before=False,
torch_encoder=None, # torch encoder that you can map weights from
activation="relu",
):
super().__init__()
cfg = fairseq_transformer.TransformerConfig()
cfg.decoder.embed_dim = embed_dim
cfg.decoder.output_dim = embed_dim
cfg.decoder.attention_heads = attention_heads
cfg.decoder.ffn_embed_dim = ffn_embed_dim
cfg.dropout = dropout
cfg.decoder.normalize_before = normalize_before
cfg.decoder.layers = num_layers
# make embedding behavior same as other encoders
cfg.no_token_positional_embeddings = True
cfg.no_scale_embedding = True
cfg.activation_fn = activation
dictionary = {} # TODO: verify what this is
self.decoder = fairseq_transformer.TransformerDecoder(
cfg,
dictionary,
embedding_layer,
no_encoder_attn=True,
output_projection=None,
)
if torch_encoder is not None:
self.decoder = torch_to_fairseq(torch_encoder, self.decoder)
self.decoder = self.decoder.eval().cuda().half()
def forward(
self,
tokens,
src_lengths=None,
with_triangle_mask=False,
incremental_state=None,
):
return self.decoder(
prev_output_tokens=tokens,
encoder_out=None,
incremental_state=incremental_state,
features_only=True,
full_context_alignment=not with_triangle_mask,
alignment_layer=None,
alignment_heads=None,
src_lengths=src_lengths,
return_all_hiddens=False,
)[0]
@parametrize("input_dim,attn_mask_dim,is_causal",
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
name_fn=lambda input_dim, attn_dim, is_causal: (
f"{input_dim}D_input_dim_" + (
f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask"
if attn_dim is not None else "no_attn_mask")))
@parametrize("dropout_p", [0.0, 0.2, 0.5])
@sdp_kernel(enable_flash=False, enable_mem_efficient=False)
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
def sdp_ref(
q,
k,
v,
attn_mask=None,
dropout_p=0.0):
E = q.size(-1)
q = q / math.sqrt(E)
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
if attn_mask is not None:
attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
else:
attn = torch.bmm(q, k.transpose(-2, -1))
attn = torch.nn.functional.softmax(attn, dim=-1)
if dropout_p > 0.0:
attn = torch.nn.functional.dropout(attn, p=dropout_p)
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
output = torch.bmm(attn, v)
return output
# TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
dtypes = [torch.double, torch.float]
for dtype in dtypes:
def rand_tensor(*shape):
return torch.randn(shape, device=device, dtype=dtype)
# This test compares python and C++ implementations of SDP.
N, N_prime, L, S, E = 5, 2, 4, 3, 6
if input_dim == 3:
query = rand_tensor(N, L, E)
key = rand_tensor(N, S, E)
value = rand_tensor(N, S, E)
elif input_dim == 4:
query = rand_tensor(N, N_prime, L, E)
key = rand_tensor(N, N_prime, S, E)
value = rand_tensor(N, N_prime, S, E)
else:
self.fail(f'Invalid input_dim {input_dim} encountered in SDP test')
attn_mask = None
if attn_mask_dim is not None:
assert attn_mask_dim in [2, input_dim]
mask_size = (L, S) if attn_mask_dim == 2 else ((N, L, S) if input_dim == 3 else (N, N_prime, L, S))
attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal
else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool))
with freeze_rng_state():
# Python impl only supports float mask and 3D inputs.
attn_mask_float = attn_mask
if attn_mask_float is not None:
attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype)
attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf"))
q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E)
a = attn_mask_float
if a is not None and attn_mask_dim > 3:
a = a.view(-1, L, S)
expected = sdp_ref(q, k, v, attn_mask=a, dropout_p=dropout_p)
if input_dim > 3:
expected = expected.view(-1, N_prime, L, E)
with freeze_rng_state():
if is_causal: