Skip to content

Commit 336f231

Browse files
authored
[release/2.8] fix miopen batchnorm changing output format (#2602)
cherry pick of pytorch#162112
1 parent b2b161e commit 336f231

File tree

3 files changed

+20
-36
lines changed

3 files changed

+20
-36
lines changed

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/NativeFunctions.h>
88
#else
99
#include <ATen/ops/empty.h>
10+
#include <ATen/ops/empty_like.h>
1011
#include <ATen/ops/miopen_batch_norm_native.h>
1112
#include <ATen/ops/miopen_batch_norm_backward_native.h>
1213
#endif
@@ -102,7 +103,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
102103
mode = miopenBNSpatial;
103104
}
104105

105-
auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
106+
auto output_t = at::empty_like(input_t, input_t.options(), input_t.suggest_memory_format());
106107
TensorArg output{ output_t, "output", 0 };
107108

108109
auto handle = getMiopenHandle();
@@ -170,22 +171,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
170171
const std::optional<Tensor>& save_var_t_opt,
171172
double epsilon) {
172173
// See [Note: hacky wrapper removal for optional tensor]
173-
const Tensor& running_mean =
174-
running_mean_opt.value_or(Tensor());
175-
const Tensor& running_var =
176-
running_var_opt.value_or(Tensor());
177-
const Tensor& save_mean_t =
178-
save_mean_t_opt.value_or(Tensor());
179-
const Tensor& save_var_t =
180-
save_var_t_opt.value_or(Tensor());
174+
const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor());
175+
const Tensor& save_var_t = save_var_t_opt.value_or(Tensor());
181176

182177
auto grad_output_contig =
183178
grad_output_t.contiguous(input_t.suggest_memory_format());
184-
TensorArg input{ input_t, "input", 1 },
185-
grad_output{ grad_output_contig, "grad_output", 2 },
186-
weight{ weight_t, "weight", 3 },
187-
save_mean{ save_mean_t, "save_mean", 4 },
188-
save_var{ save_var_t, "save_var", 5 };
179+
TensorArg input{input_t, "input", 1},
180+
grad_output{grad_output_contig, "grad_output", 2},
181+
weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
182+
save_var{save_var_t, "save_var", 5};
189183
CheckedFrom c = "miopen_batch_norm_backward";
190184

191185
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});

test/nn/test_convolution.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
skipCUDAIfMiopen,
3131
skipCUDAIfNoCudnn,
3232
skipCUDAIfNoMiopen,
33-
skipCUDAIfNotMiopenSuggestNHWC,
3433
skipCUDAIfRocm,
3534
skipMeta,
3635
skipMPS,
@@ -52,9 +51,7 @@
5251
parametrize as parametrize_test,
5352
run_tests,
5453
set_default_dtype,
55-
skipIfNotMiopenSuggestNHWC,
5654
skipIfRocmArch,
57-
skipIfRocmVersionLessThan,
5855
subtest,
5956
TEST_SCIPY,
6057
TEST_WITH_ROCM,
@@ -66,6 +63,7 @@
6663

6764
if TEST_WITH_ROCM:
6865
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
66+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
6967

7068

7169
if TEST_SCIPY:
@@ -717,7 +715,6 @@ def test_ConvTranspose2d_half_cublas_gemm(self):
717715
# Almost identical to the above `test_Conv2d_naive_groups`
718716
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
719717
@tf32_on_and_off(0.001)
720-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
721718
def test_Conv2d_groups_nobias(self):
722719
dev_dtypes = [("cpu", torch.float)]
723720
if TEST_CUDA:
@@ -763,7 +760,6 @@ def test_Conv2d_groups_nobias(self):
763760
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
764761
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
765762
@tf32_on_and_off(0.001)
766-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
767763
def test_Conv2d_groups_nobias_v2(self):
768764
torch.manual_seed(123)
769765
dev_dtypes = [("cpu", torch.float)]
@@ -898,7 +894,6 @@ def test_conv_tbc(self):
898894

899895
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
900896
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
901-
@skipIfNotMiopenSuggestNHWC
902897
def test_grouped_conv_cudnn_nhwc_support(self):
903898
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
904899
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
@@ -3147,7 +3142,6 @@ def test_conv_noncontig_weights_and_bias(self, device):
31473142

31483143
@onlyCUDA
31493144
@largeTensorTest("12GB")
3150-
@skipIfRocmVersionLessThan((6, 0))
31513145
def test_conv_transposed_large(self, device):
31523146
dtype = torch.half if self.device_type == "cuda" else torch.float
31533147
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
@@ -3191,7 +3185,6 @@ def test_conv_transposed_large(self, device):
31913185
self.assertEqual(maxdiff3, 0)
31923186

31933187
@onlyCUDA
3194-
@skipCUDAIfRocm
31953188
@largeTensorTest("12GB")
31963189
def test_conv_large(self, device):
31973190
dtype = torch.half if self.device_type == "cuda" else torch.float
@@ -3224,7 +3217,6 @@ def test_conv_large(self, device):
32243217
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
32253218

32263219
@onlyCUDA
3227-
@skipCUDAIfRocm
32283220
@largeTensorTest("20GB", "cpu")
32293221
@largeTensorTest("60GB", "cuda")
32303222
def test_conv_large_batch_1(self, device):
@@ -3372,7 +3364,6 @@ def test_ConvTranspose3d_size_1_kernel(self, device):
33723364
@dtypes(torch.float)
33733365
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
33743366
@tf32_on_and_off(0.001)
3375-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
33763367
def test_Conv2d_naive_groups(self, device, dtype):
33773368
# Check that grouped convolutions matches two half convolutions
33783369
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
@@ -3641,19 +3632,21 @@ def helper(
36413632
)
36423633

36433634
@onlyCUDA
3644-
@skipCUDAIfNotMiopenSuggestNHWC
36453635
@dtypes(torch.half, torch.float, torch.cfloat)
36463636
def test_conv_cudnn_nhwc(self, device, dtype):
36473637
def helper(n, c, h, w, out_channels, kernel_size, groups):
3648-
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3649-
memory_format=torch.channels_last
3650-
)
3638+
# randint with dtype=torch.cfloat fails with
3639+
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
3640+
# must create randint and randint_like using default int64, then cast to desired
3641+
input = torch.randint(
3642+
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
3643+
).to(dtype, memory_format=torch.channels_last)
36513644
input.requires_grad_()
36523645
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
36533646
device="cuda", dtype=dtype, memory_format=torch.channels_last
36543647
)
36553648
for p in conv.parameters():
3656-
p.data = torch.randint_like(p, -3, 3)
3649+
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)
36573650

36583651
# use FP64 channels-first conv as reference
36593652
ref_input = input.detach().clone().contiguous().double().requires_grad_()
@@ -3667,7 +3660,7 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36673660
out = conv(input)
36683661
ref_out = ref_conv(ref_input)
36693662

3670-
grad = torch.randint_like(out, -3, 3)
3663+
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
36713664
ref_grad = grad.detach().clone().double().contiguous()
36723665

36733666
out.backward(grad)
@@ -3694,7 +3687,6 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36943687
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
36953688

36963689
@onlyCUDA
3697-
@skipCUDAIfRocm
36983690
@dtypes(torch.half, torch.float)
36993691
def test_conv_cudnn_ndhwc(self, device, dtype):
37003692
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
@@ -3824,7 +3816,6 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
38243816
)
38253817

38263818
@onlyCUDA
3827-
@skipCUDAIfNotMiopenSuggestNHWC
38283819
@tf32_on_and_off(0.05)
38293820
def test_conv_cudnn_mismatch_memory_format(self, device):
38303821
configs = [
@@ -3958,7 +3949,6 @@ def test_cudnn_convolution_add_relu(self, device, dtype):
39583949
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
39593950

39603951
@onlyCUDA
3961-
@skipCUDAIfRocm
39623952
def test_convert_conv2d_weight_memory_format(self, device):
39633953
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
39643954
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
@@ -3978,7 +3968,6 @@ def test_convert_conv2d_weight_memory_format(self, device):
39783968
self.assertTrue(out.is_contiguous(memory_format=memory_format))
39793969

39803970
@onlyCUDA
3981-
@skipCUDAIfRocm
39823971
def test_convert_conv3d_weight_memory_format(self, device):
39833972
input = torch.randint(
39843973
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device

test/test_nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161

6262
if TEST_WITH_ROCM:
6363
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
64+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
6465

6566
# load_tests from common_utils is used to automatically filter tests for
6667
# sharding on sandcastle. This line silences flake warnings
@@ -3496,15 +3497,15 @@ def test_cudnn_forward_exception(self):
34963497
self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)
34973498

34983499
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3499-
@skipIfRocm
35003500
def test_cudnn_weight_format(self):
35013501
rnns = [
35023502
nn.LSTM(10, 20, batch_first=True),
35033503
nn.LSTM(10, 20, batch_first=True, proj_size=10),
35043504
nn.GRU(10, 20, batch_first=True),
35053505
nn.RNN(10, 20, batch_first=True)
35063506
]
3507-
first_warn = True
3507+
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
3508+
first_warn = False if torch.version.hip else True
35083509
for rnn in rnns:
35093510
rnn.cuda()
35103511
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")

0 commit comments

Comments
 (0)