@@ -41,7 +41,6 @@ def get_weight(m):
41
41
constructor_args = (10 , 8 ),
42
42
input_size = (4 , 10 ),
43
43
reference_fn = lambda i , p : torch .mm (i , p [0 ].t ()) + p [1 ].view (1 , - 1 ).expand (4 , 8 ),
44
- test_cuda = (not TEST_WITH_ROCM )
45
44
),
46
45
dict (
47
46
module_name = 'Linear' ,
@@ -103,35 +102,30 @@ def get_weight(m):
103
102
constructor_args = (1 ,),
104
103
input_size = (10 , 20 ),
105
104
reference_fn = lambda i , _ : torch .exp (i ).div (torch .exp (i ).sum (1 , True ).expand (10 , 20 )),
106
- test_cuda = (not TEST_WITH_ROCM )
107
105
),
108
106
dict (
109
107
module_name = 'Softmax2d' ,
110
108
input_size = (1 , 3 , 10 , 20 ),
111
109
reference_fn = lambda i , _ : torch .exp (i ).div (torch .exp (i ).sum (1 , False )),
112
- test_cuda = (not TEST_WITH_ROCM )
113
110
),
114
111
dict (
115
112
module_name = 'LogSoftmax' ,
116
113
constructor_args = (1 ,),
117
114
input_size = (10 , 20 ),
118
115
reference_fn = lambda i , _ : torch .exp (i ).div_ (torch .exp (i ).sum (1 , True ).expand (10 , 20 )).log_ (),
119
- test_cuda = (not TEST_WITH_ROCM )
120
116
),
121
117
dict (
122
118
module_name = 'LogSoftmax' ,
123
119
constructor_args = (1 ,),
124
120
input_size = (1 , 3 , 10 , 20 ),
125
121
reference_fn = lambda i , _ : torch .exp (i ).div_ (torch .exp (i ).sum (1 , False )).log_ (),
126
122
desc = 'multiparam' ,
127
- test_cuda = (not TEST_WITH_ROCM )
128
123
),
129
124
dict (
130
125
module_name = 'ELU' ,
131
126
constructor_args = (2. ,),
132
127
input_size = (3 , 2 , 5 ),
133
128
reference_fn = lambda x , _ : torch .where (x >= 0 , x , 2 * (x .exp () - 1 )),
134
- test_cuda = (not TEST_WITH_ROCM ),
135
129
),
136
130
# TODO: reference function
137
131
dict (
@@ -204,7 +198,6 @@ def get_weight(m):
204
198
input_size = (2 , 3 , 4 ),
205
199
desc = '1d_multiparam' ,
206
200
reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
207
- test_cuda = (not TEST_WITH_ROCM )
208
201
),
209
202
dict (
210
203
module_name = 'PReLU' ,
@@ -218,7 +211,6 @@ def get_weight(m):
218
211
input_size = (2 , 3 , 4 , 5 ),
219
212
desc = '2d_multiparam' ,
220
213
reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
221
- test_cuda = (not TEST_WITH_ROCM )
222
214
),
223
215
dict (
224
216
module_name = 'PReLU' ,
@@ -232,31 +224,26 @@ def get_weight(m):
232
224
input_size = (2 , 3 , 4 , 5 , 6 ),
233
225
desc = '3d_multiparam' ,
234
226
reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
235
- test_cuda = (not TEST_WITH_ROCM )
236
227
),
237
228
dict (
238
229
module_name = 'Softsign' ,
239
230
input_size = (3 , 2 , 5 ),
240
231
reference_fn = lambda i , _ : i .div (1 + torch .abs (i )),
241
- test_cuda = (not TEST_WITH_ROCM )
242
232
),
243
233
dict (
244
234
module_name = 'Softmin' ,
245
235
constructor_args = (1 ,),
246
236
input_size = (10 , 20 ),
247
- test_cuda = (not TEST_WITH_ROCM )
248
237
),
249
238
dict (
250
239
module_name = 'Softmin' ,
251
240
constructor_args = (1 ,),
252
241
input_size = (2 , 3 , 5 , 10 ),
253
242
desc = 'multidim' ,
254
- test_cuda = (not TEST_WITH_ROCM )
255
243
),
256
244
dict (
257
245
module_name = 'Tanhshrink' ,
258
246
input_size = (2 , 3 , 4 , 5 ),
259
- test_cuda = (not TEST_WITH_ROCM )
260
247
),
261
248
]
262
249
@@ -573,7 +560,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
573
560
reference_fn = lambda i , t , m :
574
561
kldivloss_reference (i , t , get_reduction (m )),
575
562
check_sum_reduction = True ,
576
- test_cuda = (not TEST_WITH_ROCM )
577
563
),
578
564
dict (
579
565
module_name = 'MSELoss' ,
@@ -590,7 +576,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
590
576
reference_fn = lambda i , t , m : - (t * i .log () + (1 - t ) * (1 - i ).log ()).sum () /
591
577
(i .numel () if get_reduction (m ) else 1 ),
592
578
check_gradgrad = False ,
593
- test_cuda = (not TEST_WITH_ROCM )
594
579
),
595
580
dict (
596
581
module_name = 'BCELoss' ,
@@ -601,7 +586,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
601
586
(i .numel () if get_reduction (m ) else 1 ),
602
587
desc = 'weights' ,
603
588
check_gradgrad = False ,
604
- test_cuda = (not TEST_WITH_ROCM )
605
589
),
606
590
dict (
607
591
module_name = 'CrossEntropyLoss' ,
@@ -660,7 +644,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
660
644
target_fn = lambda : torch .rand (5 , 10 ).mul (2 ).floor (),
661
645
reference_fn = lambda i , t , m : - (t * i .sigmoid ().log () + (1 - t ) * (- i ).sigmoid ().log ()).sum () / i .numel (),
662
646
check_gradgrad = False ,
663
- test_cuda = (not TEST_WITH_ROCM )
664
647
),
665
648
dict (
666
649
module_name = 'MultiMarginLoss' ,
@@ -759,7 +742,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
759
742
reference_fn = lambda i , t , m :
760
743
marginrankingloss_reference (i [0 ], i [1 ], t , reduction = get_reduction (m )),
761
744
check_sum_reduction = True ,
762
- test_cuda = (not TEST_WITH_ROCM )
763
745
),
764
746
dict (
765
747
module_name = 'MarginRankingLoss' ,
@@ -770,7 +752,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
770
752
marginrankingloss_reference (i [0 ], i [1 ], t , margin = 0.5 , reduction = get_reduction (m )),
771
753
desc = 'margin' ,
772
754
check_sum_reduction = True ,
773
- test_cuda = (not TEST_WITH_ROCM )
774
755
),
775
756
]
776
757
0 commit comments