Skip to content

Remove duplicate normalisation in FFT methods and enable relevant tests #1469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 28 additions & 70 deletions src/ATen/native/xpu/mkl/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,10 @@ void _mkl_dft(
int64_t idist = istrides[0];
int64_t odist = ostrides[0];

std::vector<int64_t> fwd_strides(1 + signal_ndim, 0),
bwd_strides(1 + signal_ndim, 0);

for (int64_t i = 1; i <= signal_ndim; i++) {
if (!inverse) {
fwd_strides[i] = istrides[i];
bwd_strides[i] = ostrides[i];
} else {
fwd_strides[i] = ostrides[i];
bwd_strides[i] = istrides[i];
}
}
std::vector<int64_t> fwd_strides(istrides.cbegin(), istrides.cbegin() + signal_ndim + 1),
bwd_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1);
fwd_strides[0] = 0;
bwd_strides[0] = 0;

auto desc = descriptor<prec, signal_type>(mkl_signal_sizes);
desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE);
Expand All @@ -77,16 +69,15 @@ void _mkl_dft(
if (!inverse) {
desc.set_value(config_param::FWD_DISTANCE, idist);
desc.set_value(config_param::BWD_DISTANCE, odist);

desc.set_value(config_param::FWD_STRIDES, fwd_strides.data());
desc.set_value(config_param::BWD_STRIDES, bwd_strides.data());
} else {
desc.set_value(config_param::FWD_DISTANCE, odist);
desc.set_value(config_param::BWD_DISTANCE, idist);
}

if (!fwd_strides.empty()) {
desc.set_value(config_param::FWD_STRIDES, fwd_strides.data());
}
if (!bwd_strides.empty()) {
desc.set_value(config_param::BWD_STRIDES, bwd_strides.data());
desc.set_value(config_param::FWD_STRIDES, bwd_strides.data());
desc.set_value(config_param::BWD_STRIDES, fwd_strides.data());
}

if (!complex_input || !complex_output) {
Expand Down Expand Up @@ -136,10 +127,10 @@ void _fft_with_size(
// real/imag dimension must aligned when viewed as of complex type

if (complex_input) {
bool need_contiguous = input_.stride(-1) != 1;

const auto strides = input_.strides();
bool need_contiguous = strides.back() != 1;
for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) {
need_contiguous |= input_.stride(i) % 2 != 0;
need_contiguous |= strides[i] % 2;
}

if (need_contiguous) {
Expand Down Expand Up @@ -230,12 +221,13 @@ Tensor& _exec_fft(
batched_sizes.begin() + 1);
input = input.reshape(batched_sizes);

const auto batch_size = input.sizes()[0];
const auto in_sizes = input.sizes();
const auto batch_size = in_sizes[0];
DimVector signal_size(signal_ndim + 1);
signal_size[0] = batch_size;

for (const auto i : c10::irange(signal_ndim)) {
auto in_size = input.sizes()[i + 1];
auto in_size = in_sizes[i + 1];
auto out_size = out_sizes[dim[i]];
signal_size[i + 1] = std::max(in_size, out_size);
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -272,12 +264,12 @@ Tensor& _exec_fft(
int64_t batch_numel = 1;

for (int64_t i = batch_dims - 1; i >= 0; --i) {
out_strides[dim_permute[i]] = batch_numel * out.strides()[0];
out_strides[dim_permute[i]] = batch_numel * out.stride(0);
batch_numel *= out_sizes[dim_permute[i]];
}

for (const auto i : c10::irange(batch_dims, ndim)) {
out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)];
out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims));
}

out.as_strided_(out_sizes, out_strides, out.storage_offset());
Expand All @@ -287,8 +279,7 @@ Tensor& _exec_fft(

double _dft_scale(
IntArrayRef dim,
IntArrayRef input_sizes,
IntArrayRef out_sizes,
IntArrayRef norm_sizes,
int64_t normalization) {
const auto norm = static_cast<fft_norm_mode>(normalization);
double double_scale = 1.0;
Expand All @@ -297,21 +288,10 @@ double _dft_scale(
return double_scale;
}

const int64_t signal_ndim = dim.size();
int64_t signal_numel = 1;

for (int64_t i = 0; i < signal_ndim; ++i) {
auto in_size = input_sizes[dim[i]];
auto out_size = out_sizes[dim[i]];
auto signal_size = std::max(in_size, out_size);

signal_numel *= signal_size;
TORCH_INTERNAL_ASSERT(
in_size == signal_size || in_size == (signal_size / 2) + 1);
TORCH_INTERNAL_ASSERT(
out_size == signal_size || out_size == (signal_size / 2) + 1);
for (const int64_t& d : dim) {
signal_numel *= norm_sizes[d];
}

if (norm == fft_norm_mode::by_root_n) {
double_scale = 1.0 / std::sqrt(signal_numel);
} else {
Expand All @@ -324,22 +304,12 @@ double _dft_scale(
const Tensor& _fft_apply_normalization(
const Tensor& self,
int64_t normalization,
IntArrayRef sizes,
IntArrayRef norm_sizes,
IntArrayRef dims) {
auto scale = _dft_scale(dims, sizes, self.sizes(), normalization);
auto scale = _dft_scale(dims, norm_sizes, normalization);
return (scale == 1.0) ? self : self.mul_(scale);
}

Tensor& _fft_apply_normalization_out(
Tensor& out,
const Tensor& self,
int64_t normalization,
IntArrayRef sizes,
IntArrayRef dims) {
auto scale = _dft_scale(dims, sizes, self.sizes(), normalization);
return at::mul_out(out, self, c10::scalar_to_tensor(scale));
}

} // namespace impl

Tensor _fft_c2c_mkl(
Expand Down Expand Up @@ -399,8 +369,8 @@ Tensor& _fft_c2c_mkl_out(
auto result = _fft_c2c_mkl(
self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
at::native::resize_output(out, result.sizes());
return impl::_fft_apply_normalization_out(
out, result, normalization, result.sizes(), dim);
out.copy_(result);
return out;
}

void HermitSymmImpl(Tensor& input, int64_t dim, int pos) {
Expand Down Expand Up @@ -475,8 +445,8 @@ Tensor& _fft_c2r_mkl_out(
auto result = _fft_c2r_mkl(
self, dim, static_cast<int64_t>(fft_norm_mode::none), last_dim_size);
at::native::resize_output(out, result.sizes());
return impl::_fft_apply_normalization_out(
out, result, normalization, result.sizes(), dim);
out.copy_(result);
return out;
}

REGISTER_XPU_DISPATCH(
Expand Down Expand Up @@ -573,20 +543,8 @@ Tensor& _fft_r2c_mkl_out(
auto result = _fft_r2c_mkl(
self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);

if (onesided) {
return impl::_fft_apply_normalization_out(
out, result, normalization, self.sizes(), dim);
}

at::native::resize_output(out, self.sizes());

auto last_dim = dim.back();
auto last_dim_halfsize = result.sizes()[last_dim];
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);

impl::_fft_apply_normalization_out(
out_slice, result, normalization, self.sizes(), dim);
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
at::native::resize_output(out, result.sizes());
out.copy_(result);
return out;
}

Expand Down
8 changes: 5 additions & 3 deletions src/ATen/native/xpu/sycl/FFTKernelFunctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ struct HermitianSymmetryOffsetCalculator {
TORCH_INTERNAL_ASSERT(sizes.size() <= XPU_MAX_TENSORINFO_DIMS);
dims = sizes.size();

for (dim_type i = 0; i < XPU_MAX_TENSORINFO_DIMS; ++i) {
if (i < dims) {
{
dim_type i;
for (i = 0; i < dims; ++i) {
sizes_[i] = at::detail::IntDivider<index_t>(sizes[i]);
strides_[i] = strides[i] / element_size;
} else {
}
for (; i < XPU_MAX_TENSORINFO_DIMS; ++i) {
sizes_[i] = at::detail::IntDivider<index_t>(1);
strides_[i] = 0;
}
Expand Down
19 changes: 0 additions & 19 deletions test/xpu/skip_list_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,25 +635,6 @@
"test_python_ref_torch_fallback__refs_div_trunc_rounding_xpu_float64",
# TODO: passed from source code building version, investigate
"test_python_ref__refs_log2_xpu_complex128",
# The following dtypes did not work in backward but are listed by the OpInfo: {torch.bfloat16}.
"test_dtypes_fft_fft2_xpu",
"test_dtypes_fft_fft_xpu",
"test_dtypes_fft_fftn_xpu",
"test_dtypes_fft_hfft2_xpu",
"test_dtypes_fft_hfft_xpu",
"test_dtypes_fft_hfftn_xpu",
"test_dtypes_fft_ifft2_xpu",
"test_dtypes_fft_ifft_xpu",
"test_dtypes_fft_ifftn_xpu",
"test_dtypes_fft_ihfft2_xpu",
"test_dtypes_fft_ihfft_xpu",
"test_dtypes_fft_ihfftn_xpu",
"test_dtypes_fft_irfft2_xpu",
"test_dtypes_fft_irfft_xpu",
"test_dtypes_fft_irfftn_xpu",
"test_dtypes_fft_rfft2_xpu",
"test_dtypes_fft_rfft_xpu",
"test_dtypes_fft_rfftn_xpu",
),
"test_binary_ufuncs_xpu.py": (
"test_fmod_remainder_by_zero_integral_xpu_int64", # zero division is an undefined behavior: different handles on different backends
Expand Down
6 changes: 6 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -747,3 +747,9 @@ supported:
- take.out
- segment_reduce
- _segment_reduce_backward
- _fft_c2c
- _fft_c2c.out
- _fft_c2r
- _fft_c2r.out
- _fft_r2c
- _fft_r2c.out