Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions aten/src/ATen/native/miopen/BatchNorm_miopen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/miopen_batch_norm_native.h>
#include <ATen/ops/miopen_batch_norm_backward_native.h>
#endif
Expand Down Expand Up @@ -102,7 +103,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
mode = miopenBNSpatial;
}

auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
auto output_t = at::empty_like(input_t, input_t.options(), input_t.suggest_memory_format());
TensorArg output{ output_t, "output", 0 };

auto handle = getMiopenHandle();
Expand Down Expand Up @@ -170,22 +171,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
const std::optional<Tensor>& save_var_t_opt,
double epsilon) {
// See [Note: hacky wrapper removal for optional tensor]
const Tensor& running_mean =
running_mean_opt.value_or(Tensor());
const Tensor& running_var =
running_var_opt.value_or(Tensor());
const Tensor& save_mean_t =
save_mean_t_opt.value_or(Tensor());
const Tensor& save_var_t =
save_var_t_opt.value_or(Tensor());
const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor());
const Tensor& save_var_t = save_var_t_opt.value_or(Tensor());

auto grad_output_contig =
grad_output_t.contiguous(input_t.suggest_memory_format());
TensorArg input{ input_t, "input", 1 },
grad_output{ grad_output_contig, "grad_output", 2 },
weight{ weight_t, "weight", 3 },
save_mean{ save_mean_t, "save_mean", 4 },
save_var{ save_var_t, "save_var", 5 };
TensorArg input{input_t, "input", 1},
grad_output{grad_output_contig, "grad_output", 2},
weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
save_var{save_var_t, "save_var", 5};
CheckedFrom c = "miopen_batch_norm_backward";

checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
Expand Down
29 changes: 9 additions & 20 deletions test/nn/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
skipCUDAIfMiopen,
skipCUDAIfNoCudnn,
skipCUDAIfNoMiopen,
skipCUDAIfNotMiopenSuggestNHWC,
skipCUDAIfRocm,
skipMeta,
skipMPS,
Expand All @@ -52,9 +51,7 @@
parametrize as parametrize_test,
run_tests,
set_default_dtype,
skipIfNotMiopenSuggestNHWC,
skipIfRocmArch,
skipIfRocmVersionLessThan,
subtest,
TEST_SCIPY,
TEST_WITH_ROCM,
Expand All @@ -66,6 +63,7 @@

if TEST_WITH_ROCM:
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"


if TEST_SCIPY:
Expand Down Expand Up @@ -717,7 +715,6 @@ def test_ConvTranspose2d_half_cublas_gemm(self):
# Almost identical to the above `test_Conv2d_naive_groups`
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
@tf32_on_and_off(0.001)
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
def test_Conv2d_groups_nobias(self):
dev_dtypes = [("cpu", torch.float)]
if TEST_CUDA:
Expand Down Expand Up @@ -763,7 +760,6 @@ def test_Conv2d_groups_nobias(self):
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
@tf32_on_and_off(0.001)
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
def test_Conv2d_groups_nobias_v2(self):
torch.manual_seed(123)
dev_dtypes = [("cpu", torch.float)]
Expand Down Expand Up @@ -898,7 +894,6 @@ def test_conv_tbc(self):

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
@skipIfNotMiopenSuggestNHWC
def test_grouped_conv_cudnn_nhwc_support(self):
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
Expand Down Expand Up @@ -3147,7 +3142,6 @@ def test_conv_noncontig_weights_and_bias(self, device):

@onlyCUDA
@largeTensorTest("12GB")
@skipIfRocmVersionLessThan((6, 0))
def test_conv_transposed_large(self, device):
dtype = torch.half if self.device_type == "cuda" else torch.float
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
Expand Down Expand Up @@ -3191,7 +3185,6 @@ def test_conv_transposed_large(self, device):
self.assertEqual(maxdiff3, 0)

@onlyCUDA
@skipCUDAIfRocm
@largeTensorTest("12GB")
def test_conv_large(self, device):
dtype = torch.half if self.device_type == "cuda" else torch.float
Expand Down Expand Up @@ -3224,7 +3217,6 @@ def test_conv_large(self, device):
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)

@onlyCUDA
@skipCUDAIfRocm
@largeTensorTest("20GB", "cpu")
@largeTensorTest("60GB", "cuda")
def test_conv_large_batch_1(self, device):
Expand Down Expand Up @@ -3372,7 +3364,6 @@ def test_ConvTranspose3d_size_1_kernel(self, device):
@dtypes(torch.float)
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
@tf32_on_and_off(0.001)
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
def test_Conv2d_naive_groups(self, device, dtype):
# Check that grouped convolutions matches two half convolutions
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
Expand Down Expand Up @@ -3641,19 +3632,21 @@ def helper(
)

@onlyCUDA
@skipCUDAIfNotMiopenSuggestNHWC
@dtypes(torch.half, torch.float, torch.cfloat)
def test_conv_cudnn_nhwc(self, device, dtype):
def helper(n, c, h, w, out_channels, kernel_size, groups):
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
memory_format=torch.channels_last
)
# randint with dtype=torch.cfloat fails with
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
# must create randint and randint_like using default int64, then cast to desired
input = torch.randint(
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
).to(dtype, memory_format=torch.channels_last)
input.requires_grad_()
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
device="cuda", dtype=dtype, memory_format=torch.channels_last
)
for p in conv.parameters():
p.data = torch.randint_like(p, -3, 3)
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)

# use FP64 channels-first conv as reference
ref_input = input.detach().clone().contiguous().double().requires_grad_()
Expand All @@ -3667,7 +3660,7 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
out = conv(input)
ref_out = ref_conv(ref_input)

grad = torch.randint_like(out, -3, 3)
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
ref_grad = grad.detach().clone().double().contiguous()

out.backward(grad)
Expand All @@ -3694,7 +3687,6 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)

@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.half, torch.float)
def test_conv_cudnn_ndhwc(self, device, dtype):
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
Expand Down Expand Up @@ -3824,7 +3816,6 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
)

@onlyCUDA
@skipCUDAIfNotMiopenSuggestNHWC
@tf32_on_and_off(0.05)
def test_conv_cudnn_mismatch_memory_format(self, device):
configs = [
Expand Down Expand Up @@ -3958,7 +3949,6 @@ def test_cudnn_convolution_add_relu(self, device, dtype):
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)

@onlyCUDA
@skipCUDAIfRocm
def test_convert_conv2d_weight_memory_format(self, device):
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
Expand All @@ -3978,7 +3968,6 @@ def test_convert_conv2d_weight_memory_format(self, device):
self.assertTrue(out.is_contiguous(memory_format=memory_format))

@onlyCUDA
@skipCUDAIfRocm
def test_convert_conv3d_weight_memory_format(self, device):
input = torch.randint(
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device
Expand Down
5 changes: 3 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

if TEST_WITH_ROCM:
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"

# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
Expand Down Expand Up @@ -3496,15 +3497,15 @@ def test_cudnn_forward_exception(self):
self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)

@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
@skipIfRocm
def test_cudnn_weight_format(self):
rnns = [
nn.LSTM(10, 20, batch_first=True),
nn.LSTM(10, 20, batch_first=True, proj_size=10),
nn.GRU(10, 20, batch_first=True),
nn.RNN(10, 20, batch_first=True)
]
first_warn = True
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
first_warn = False if torch.version.hip else True
for rnn in rnns:
rnn.cuda()
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
Expand Down