Skip to content

Commit c9f7d7b

Browse files
iotamudeltafacebook-github-bot
authored andcommitted
mark unit tests as working, skip failing unit test (pytorch#12313)
Summary: * enabled fp16 tests for test_torch * enable fp16 tests for test_nn * enabled multilabelmargin loss for fp16 * removed skip for test_pdist_empty_col * Enable test_nn tests that pass with compiler fixes etc. * Enable test_legacy_nn tests that pass with compiler fixes etc. ezyang bddppq Pull Request resolved: pytorch#12313 Differential Revision: D10189922 Pulled By: bddppq fbshipit-source-id: a5592817c04b14e355cb062d42ebea406f0c92b6
1 parent 8c64655 commit c9f7d7b

File tree

3 files changed

+18
-85
lines changed

3 files changed

+18
-85
lines changed

test/common_nn.py

-19
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def get_weight(m):
4141
constructor_args=(10, 8),
4242
input_size=(4, 10),
4343
reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
44-
test_cuda=(not TEST_WITH_ROCM)
4544
),
4645
dict(
4746
module_name='Linear',
@@ -103,35 +102,30 @@ def get_weight(m):
103102
constructor_args=(1,),
104103
input_size=(10, 20),
105104
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)),
106-
test_cuda=(not TEST_WITH_ROCM)
107105
),
108106
dict(
109107
module_name='Softmax2d',
110108
input_size=(1, 3, 10, 20),
111109
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, False)),
112-
test_cuda=(not TEST_WITH_ROCM)
113110
),
114111
dict(
115112
module_name='LogSoftmax',
116113
constructor_args=(1,),
117114
input_size=(10, 20),
118115
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(),
119-
test_cuda=(not TEST_WITH_ROCM)
120116
),
121117
dict(
122118
module_name='LogSoftmax',
123119
constructor_args=(1,),
124120
input_size=(1, 3, 10, 20),
125121
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
126122
desc='multiparam',
127-
test_cuda=(not TEST_WITH_ROCM)
128123
),
129124
dict(
130125
module_name='ELU',
131126
constructor_args=(2.,),
132127
input_size=(3, 2, 5),
133128
reference_fn=lambda x, _: torch.where(x >= 0, x, 2 * (x.exp() - 1)),
134-
test_cuda=(not TEST_WITH_ROCM),
135129
),
136130
# TODO: reference function
137131
dict(
@@ -204,7 +198,6 @@ def get_weight(m):
204198
input_size=(2, 3, 4),
205199
desc='1d_multiparam',
206200
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
207-
test_cuda=(not TEST_WITH_ROCM)
208201
),
209202
dict(
210203
module_name='PReLU',
@@ -218,7 +211,6 @@ def get_weight(m):
218211
input_size=(2, 3, 4, 5),
219212
desc='2d_multiparam',
220213
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
221-
test_cuda=(not TEST_WITH_ROCM)
222214
),
223215
dict(
224216
module_name='PReLU',
@@ -232,31 +224,26 @@ def get_weight(m):
232224
input_size=(2, 3, 4, 5, 6),
233225
desc='3d_multiparam',
234226
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
235-
test_cuda=(not TEST_WITH_ROCM)
236227
),
237228
dict(
238229
module_name='Softsign',
239230
input_size=(3, 2, 5),
240231
reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
241-
test_cuda=(not TEST_WITH_ROCM)
242232
),
243233
dict(
244234
module_name='Softmin',
245235
constructor_args=(1,),
246236
input_size=(10, 20),
247-
test_cuda=(not TEST_WITH_ROCM)
248237
),
249238
dict(
250239
module_name='Softmin',
251240
constructor_args=(1,),
252241
input_size=(2, 3, 5, 10),
253242
desc='multidim',
254-
test_cuda=(not TEST_WITH_ROCM)
255243
),
256244
dict(
257245
module_name='Tanhshrink',
258246
input_size=(2, 3, 4, 5),
259-
test_cuda=(not TEST_WITH_ROCM)
260247
),
261248
]
262249

@@ -573,7 +560,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
573560
reference_fn=lambda i, t, m:
574561
kldivloss_reference(i, t, get_reduction(m)),
575562
check_sum_reduction=True,
576-
test_cuda=(not TEST_WITH_ROCM)
577563
),
578564
dict(
579565
module_name='MSELoss',
@@ -590,7 +576,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
590576
reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
591577
(i.numel() if get_reduction(m) else 1),
592578
check_gradgrad=False,
593-
test_cuda=(not TEST_WITH_ROCM)
594579
),
595580
dict(
596581
module_name='BCELoss',
@@ -601,7 +586,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
601586
(i.numel() if get_reduction(m) else 1),
602587
desc='weights',
603588
check_gradgrad=False,
604-
test_cuda=(not TEST_WITH_ROCM)
605589
),
606590
dict(
607591
module_name='CrossEntropyLoss',
@@ -660,7 +644,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
660644
target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
661645
reference_fn=lambda i, t, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()).sum() / i.numel(),
662646
check_gradgrad=False,
663-
test_cuda=(not TEST_WITH_ROCM)
664647
),
665648
dict(
666649
module_name='MultiMarginLoss',
@@ -759,7 +742,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
759742
reference_fn=lambda i, t, m:
760743
marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
761744
check_sum_reduction=True,
762-
test_cuda=(not TEST_WITH_ROCM)
763745
),
764746
dict(
765747
module_name='MarginRankingLoss',
@@ -770,7 +752,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
770752
marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)),
771753
desc='margin',
772754
check_sum_reduction=True,
773-
test_cuda=(not TEST_WITH_ROCM)
774755
),
775756
]
776757

0 commit comments

Comments
 (0)