diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index 6538a418462c97..066dc7c50dcc41 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -75,8 +75,8 @@ def __init__( help_msg = ("Tracing sometimes provide better results, " "please provide valid 'example_input' argument.\n") raise RuntimeError( - f"Couldn't get TorchScript module by {msg}.\n{help_msg} " - "You can also provide TorchScript module that you obtained" + f"Couldn't get TorchScript module by {msg}.\nException:\n{e}\n" + f"{help_msg} You can also provide TorchScript module that you obtained" " yourself, please refer to PyTorch documentation: " "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html." ) from e diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp index b86b01efb21639..9e07070e568e25 100644 --- a/src/frontends/pytorch/src/op/fft.cpp +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -4,9 +4,11 @@ #include "openvino/frontend/complex_type_mark.hpp" #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/dft.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/equal.hpp" #include "openvino/op/gather.hpp" +#include "openvino/op/idft.hpp" #include "openvino/op/irdft.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/range.hpp" @@ -28,18 +30,56 @@ namespace op { using namespace ov::op; -OutputVector translate_fft_rfftn(const NodeContext& context) { - // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor - num_inputs_check(context, 1, 4); - auto input = context.get_input(0); - - auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); +namespace { +Output normalize(const NodeContext& context, + const Output& node, + const Output& s, + const std::string& norm, + bool inverse) { + if (norm == "backward") { + // No normalization + return node; + } + // Apply normalizations auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); - auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto n_int = context.mark_node(std::make_shared(s, const_0)); + auto n = context.mark_node(std::make_shared(n_int, node)); + Output normalized; + if (norm == "forward") { + // Normalize by 1/n + if (inverse) { + normalized = context.mark_node(std::make_shared(node, n)); + } else { + normalized = context.mark_node(std::make_shared(node, n)); + } + } else if (norm == "ortho") { + // Normalize by 1/sqrt(n) + auto sqrt_n = context.mark_node(std::make_shared(n)); + if (inverse) { + normalized = context.mark_node(std::make_shared(node, sqrt_n)); + } else { + normalized = context.mark_node(std::make_shared(node, sqrt_n)); + } + } else { + FRONT_END_THROW("Unrecognized normalization mode " + norm + + ". Only forward, backward and ortho are supported."); + } + return normalized; +} +std::tuple, Output> get_dim_s(const NodeContext& context, + const Output& x, + int size, + bool is_irfft) { Output input_shape; Output input_rank_scalar; - std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, input, true); + std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, x, true); + + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + auto const_neg_1_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2})); Output raw_s; // Inputs can be either none or List. Check whether input values should be used or should be set to default values. @@ -48,8 +88,8 @@ OutputVector translate_fft_rfftn(const NodeContext& context) { raw_s = get_input_concat_if_list(context, 1); raw_s = context.mark_node(std::make_shared(raw_s, element::i32)); } - Output dim; // Handle dim parameter containing vector of integers indicating dimensions to be transformed. + Output dim; if (!context.input_is_none(2)) { // dim is provided, load from input. dim = get_input_concat_if_list(context, 2); @@ -57,24 +97,80 @@ OutputVector translate_fft_rfftn(const NodeContext& context) { } else if (!context.input_is_none(1)) { // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. auto s_len = context.mark_node(std::make_shared(raw_s, element::i32)); - auto slice_start = context.mark_node(std::make_shared(input_rank_scalar, s_len)); - auto slice_start_scalar = context.mark_node(std::make_shared(slice_start)); - dim = context.mark_node( - std::make_shared(slice_start_scalar, input_rank_scalar, const_1, element::i32)); + auto start = context.mark_node(std::make_shared(input_rank_scalar, s_len)); + auto start_scalar = context.mark_node(std::make_shared(start)); + dim = context.mark_node(std::make_shared(start_scalar, input_rank_scalar, const_1, element::i32)); } else { - // Dim and s are set to default, use all of dimensions. - dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32)); + // Dim and s are set to default. + switch (size) { + case 1: + dim = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + break; + case 2: + dim = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {-2, -1})); + break; + case -1: + dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32)); + break; + default: + FRONT_END_THROW("Invalid FFT size: " + std::to_string(size)); + } + } + if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) { + dim = context.mark_node(std::make_shared(dim, const_neg_1_1d, false)); } + Output default_s; + if (is_irfft) { + // Calculate default s values. Use full available size except last element, which is set to even value in last + // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]] - 1). + auto default_s_raw = context.mark_node(std::make_shared(input_shape, dim, const_0)); + auto last_s = context.mark_node(std::make_shared(default_s_raw, const_neg_1, const_0)); + auto last_s_m_1 = context.mark_node(std::make_shared(last_s, const_1)); + auto s_upd = context.mark_node(std::make_shared(last_s_m_1, const_2)); + auto s_shape = context.mark_node(std::make_shared(default_s_raw, element::i32)); + auto last_s_idx = context.mark_node(std::make_shared(s_shape, const_1)); + s_upd = context.mark_node(std::make_shared(s_upd, const_neg_1_1d, false)); + default_s = context.mark_node(std::make_shared(default_s_raw, last_s_idx, s_upd, const_0)); + } else { + default_s = context.mark_node(std::make_shared(input_shape, dim, const_0)); + } Output s; if (context.input_is_none(1)) { // Value for s was set to default, use full size for all dimensions. - s = context.mark_node(std::make_shared(input_shape, dim, const_0)); + s = default_s; } else { // Values for s were provided. Replace -1 values with default full size in given dimension. auto full_s_cond = context.mark_node(std::make_shared(raw_s, const_neg_1)); - auto full_s_values = context.mark_node(std::make_shared(input_shape, dim, const_0)); - s = context.mark_node(std::make_shared(full_s_cond, full_s_values, raw_s)); + s = context.mark_node(std::make_shared(full_s_cond, default_s, raw_s)); + } + return {dim, s}; +} + +template +OutputVector translate_fft_base(const NodeContext& context, + int size, + bool complex_input, + bool complex_output, + bool inverse = false, + bool is_irfft = false) { + num_inputs_check(context, 1, 4, true); + auto input = context.get_input(0); + + Output dim; + Output s; + std::tie(dim, s) = get_dim_s(context, input, size, is_irfft); + + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + if (complex_type_mark) { + PYTORCH_OP_CONVERSION_CHECK(complex_input, "Operation does not support complex type tensor on input."); + input = complex_type_mark->get_data(); + } else { + if (complex_input) { + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + const_0 = context.mark_node(std::make_shared(const_0, input)); + input = std::make_shared(input, const_0)->get_data(); + } } // Handle norm parameter indicating normalization mode to use. Defaults to "backward". @@ -83,123 +179,65 @@ OutputVector translate_fft_rfftn(const NodeContext& context) { norm = context.const_input(3); } - auto rdft = context.mark_node(std::make_shared(input, dim, s)); + auto node = context.mark_node(std::make_shared(input, dim, s)); // Apply normalizations - auto n_int = context.mark_node(std::make_shared(s, const_0)); - auto n = context.mark_node(std::make_shared(n_int, rdft)); - Output normalized_rfftn; - if (norm == "forward") { - // Normalize by 1/n - normalized_rfftn = context.mark_node(std::make_shared(rdft, n)); - } else if (norm == "backward") { - // No normalization - normalized_rfftn = rdft; - } else if (norm == "ortho") { - // Normalize by 1/sqrt(n) - auto sqrt_n = context.mark_node(std::make_shared(n)); - normalized_rfftn = context.mark_node(std::make_shared(rdft, sqrt_n)); - } else { - FRONT_END_THROW( - "aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); + Output normalized = normalize(context, node, s, norm, inverse); + if (complex_output) { + normalized = std::make_shared(normalized, normalized.get_element_type()); } + return {normalized}; +} +} // namespace - return {std::make_shared(normalized_rfftn, normalized_rfftn.get_element_type())}; +OutputVector translate_fft_fft(const NodeContext& context) { + return translate_fft_base(context, 1, true, true); } -OutputVector translate_fft_irfftn(const NodeContext& context) { - // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor - num_inputs_check(context, 1, 4, true); - auto input = context.get_input(0); +OutputVector translate_fft_fft2(const NodeContext& context) { + return translate_fft_base(context, 2, true, true); +} - auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); - PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::fft_irfftn operation expects complex type tensor on input."); - input = complex_type_mark->get_data(); +OutputVector translate_fft_fftn(const NodeContext& context) { + return translate_fft_base(context, -1, true, true); +} - auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); - auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); - auto const_scalar_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); - auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); - auto const_scalar_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); - auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2})); +OutputVector translate_fft_rfft(const NodeContext& context) { + return translate_fft_base(context, 1, false, true); +} - // Input shape of complex number (excluding dimension created by concatenation of real and imag) - auto complex_input_shape = get_complex_shape(context, input); - auto input_rank = context.mark_node(std::make_shared(complex_input_shape, element::i32)); - auto input_rank_scalar = context.mark_node(std::make_shared(input_rank)); +OutputVector translate_fft_rfft2(const NodeContext& context) { + return translate_fft_base(context, 2, false, true); +} - Output raw_s; - // Inputs can be either none or List. Check whether input values should be used or should be set to default values. - if (!context.input_is_none(1)) { - // s is provided, load from input. - raw_s = get_input_concat_if_list(context, 1); - raw_s = context.mark_node(std::make_shared(raw_s, element::i32)); - } +OutputVector translate_fft_rfftn(const NodeContext& context) { + // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + return translate_fft_base(context, -1, false, true); +} - // Handle dim parameter containing vector of integers indicating dimensions to be transformed. - Output dim; - if (!context.input_is_none(2)) { - // Dim values is provided, load from input. - dim = get_input_concat_if_list(context, 2); - dim = context.mark_node(std::make_shared(dim, element::i32)); - } else if (!context.input_is_none(1)) { - // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. - auto s_len = context.mark_node(std::make_shared(raw_s, element::i32)); - auto range_start = context.mark_node(std::make_shared(input_rank, s_len)); - auto range_start_scalar = context.mark_node(std::make_shared(range_start)); - dim = context.mark_node( - std::make_shared(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32)); - } else { - // Dim and s are set to default, use all of dimensions. - dim = context.mark_node( - std::make_shared(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32)); - } +OutputVector translate_fft_ifft(const NodeContext& context) { + return translate_fft_base(context, 1, true, true, true); +} - // Calculate default s values. Use full available size except last element, which is set to even value in last - // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]]) - auto default_s_raw = context.mark_node(std::make_shared(complex_input_shape, dim, const_0)); - auto last_s = context.mark_node(std::make_shared(default_s_raw, const_neg_1, const_0)); - auto last_s_m_1 = context.mark_node(std::make_shared(last_s, const_1)); - auto s_upd = context.mark_node(std::make_shared(last_s_m_1, const_2)); - auto s_shape = context.mark_node(std::make_shared(default_s_raw, element::i32)); - auto last_s_idx = context.mark_node(std::make_shared(s_shape, const_1)); - auto default_s = context.mark_node(std::make_shared(default_s_raw, last_s_idx, s_upd, const_0)); - - // Handle s parameter containing vector of intigers indicating signal sizes for dimensions. - Output s; - if (!context.input_is_none(1)) { - // Values for s were provided. Replace -1 values with default full size in given dimension. - auto full_s_cond = context.mark_node(std::make_shared(raw_s, const_neg_1)); - s = context.mark_node(std::make_shared(full_s_cond, default_s, raw_s)); - } else { - // Value for s was set to default. - s = default_s; - } +OutputVector translate_fft_ifft2(const NodeContext& context) { + return translate_fft_base(context, 2, true, true, true); +} - // Handle norm parameter indicating normalization mode to use. Defaults to "backward". - std::string norm = "backward"; - if (!context.input_is_none(3)) { - norm = context.const_input(3); - } +OutputVector translate_fft_ifftn(const NodeContext& context) { + return translate_fft_base(context, -1, true, true, true); +} - auto irdft = context.mark_node(std::make_shared(input, dim, s)); +OutputVector translate_fft_irfft(const NodeContext& context) { + return translate_fft_base(context, 1, true, false, true, true); +} - // Apply normalizations. - auto n_int = context.mark_node(std::make_shared(s, const_0)); - auto n = context.mark_node(std::make_shared(n_int, irdft)); - Output normalized_irfftn; - if (norm == "forward") { - normalized_irfftn = context.mark_node(std::make_shared(irdft, n)); - } else if (norm == "backward") { - normalized_irfftn = irdft; - } else if (norm == "ortho") { - auto sqrt_n = context.mark_node(std::make_shared(n)); - normalized_irfftn = context.mark_node(std::make_shared(irdft, sqrt_n)); - } else { - FRONT_END_THROW( - "aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); - } - return {normalized_irfftn}; +OutputVector translate_fft_irfft2(const NodeContext& context) { + return translate_fft_base(context, 2, true, false, true, true); +} + +OutputVector translate_fft_irfftn(const NodeContext& context) { + // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + return translate_fft_base(context, -1, true, false, true, true); } } // namespace op diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 018812354a23b6..1b0696305164fd 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -92,7 +92,17 @@ OP_CONVERTER(translate_expm1); OP_CONVERTER(translate_eye); OP_CONVERTER(translate_fake_quantize_per_channel_affine); OP_CONVERTER(translate_fake_quantize_per_tensor_affine); +OP_CONVERTER(translate_fft_fft); +OP_CONVERTER(translate_fft_fft2); +OP_CONVERTER(translate_fft_fftn); +OP_CONVERTER(translate_fft_ifft); +OP_CONVERTER(translate_fft_ifft2); +OP_CONVERTER(translate_fft_ifftn); +OP_CONVERTER(translate_fft_irfft); +OP_CONVERTER(translate_fft_irfft2); OP_CONVERTER(translate_fft_irfftn); +OP_CONVERTER(translate_fft_rfft); +OP_CONVERTER(translate_fft_rfft2); OP_CONVERTER(translate_fft_rfftn); OP_CONVERTER(translate_fill); OP_CONVERTER(translate_fill_diagonal); @@ -484,7 +494,17 @@ const std::unordered_map get_supported_ops_ts() { {"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine}, {"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine}, {"aten::feature_dropout", op::skip_node}, + {"aten::fft_fft", op::translate_fft_fft}, + {"aten::fft_fft2", op::translate_fft_fft2}, + {"aten::fft_fftn", op::translate_fft_fftn}, + {"aten::fft_ifft", op::translate_fft_ifft}, + {"aten::fft_ifft2", op::translate_fft_ifft2}, + {"aten::fft_ifftn", op::translate_fft_ifftn}, + {"aten::fft_irfft", op::translate_fft_irfft}, + {"aten::fft_irfft2", op::translate_fft_irfft2}, {"aten::fft_irfftn", op::translate_fft_irfftn}, + {"aten::fft_rfft", op::translate_fft_rfft}, + {"aten::fft_rfft2", op::translate_fft_rfft2}, {"aten::fft_rfftn", op::translate_fft_rfftn}, {"aten::fill", op::translate_fill}, {"aten::fill_diagonal", op::translate_fill_diagonal}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 3789577a07f619..3a64b083f8b923 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -106,7 +106,13 @@ std::tuple, Output> get_shape_rank(const NodeContext& context const Output& x, bool as_scalar, element::Type output_type) { - auto shape = context.mark_node(std::make_shared(x, output_type)); + auto complex_type_mark = as_type_ptr(x.get_node_shared_ptr()); + Output shape; + if (complex_type_mark) { + shape = get_complex_shape(context, complex_type_mark->get_data()); + } else { + shape = context.mark_node(std::make_shared(x, output_type)); + } Output rank = context.mark_node(std::make_shared(shape, output_type)); if (as_scalar) { auto axis_0 = context.mark_node(v0::Constant::create(output_type, Shape{}, {0})); diff --git a/tests/layer_tests/pytorch_tests/test_fft.py b/tests/layer_tests/pytorch_tests/test_fft.py new file mode 100644 index 00000000000000..099c5036b7ce7b --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_fft.py @@ -0,0 +1,160 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from sys import platform + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestRFFTN(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(*self.input_shape).astype(np.float32),) + + def create_model(self, dim, s, norm): + class aten_fft_rfftn(torch.nn.Module): + def __init__(self, dim, s, norm): + super(aten_fft_rfftn, self).__init__() + self.dim = dim + self.s = s + self.norm = norm + + def forward(self, x): + rfftn = torch.fft.rfftn(x, s=self.s, dim=self.dim, norm=self.norm) + r = rfftn.real + i = rfftn.imag + irfftn = torch.fft.irfftn(torch.complex(r, i), s=self.s, dim=self.dim, norm=self.norm) + return irfftn, r, i + + ref_net = None + + return ( + aten_fft_rfftn(dim, s, norm), + ref_net, + ["aten::fft_irfftn", "aten::complex", "aten::fft_rfftn", "aten::real", "aten::imag"], + ) + + @pytest.mark.parametrize("input_shape", [[64, 49], [64, 50], [64, 64, 49]]) + @pytest.mark.parametrize("dim", [[0, -1], [-2, -1], None, [0, 1]]) + @pytest.mark.parametrize("s", [None, [-1, 49], [64, -1], [64, 49], [5, 1]]) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho", None]) + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182") + def test_rfftn(self, ie_device, precision, ir_version, input_shape, dim, s, norm): + self.input_shape = input_shape + # Unfrozen test would fail due to issues with prim::GetAttr containing lists, strings or none. + self._test(*self.create_model(dim, s, norm), ie_device, precision, ir_version, custom_eps=1e-3) + + +class aten_fft(torch.nn.Module): + def __init__(self, op, n, dim, norm): + super().__init__() + self.n = n + self.dim = dim + self.norm = norm + self.op = op + + def forward(self, x): + if x.shape[-1] == 2: + x = torch.view_as_complex(x) + res = self.op(x, self.n, dim=self.dim, norm=self.norm) + if res.dtype.is_complex: + return torch.view_as_real(res) + return res + + +class TestFFT(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(*self.input_shape).astype(np.float32),) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("input_shape", [[67], [80], [12, 14], [9, 6, 3]]) + @pytest.mark.parametrize("n", [None, 50, 6]) + @pytest.mark.parametrize("dim", [-1, 0]) + @pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"]) + @pytest.mark.parametrize("op,aten_name,in_complex", [ + (torch.fft.fft, "aten::fft_fft", True), + (torch.fft.fft, "aten::fft_fft", False), + pytest.param(torch.fft.hfft, "aten::fft_hfft", True, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.rfft, "aten::fft_rfft", False), + (torch.fft.ifft, "aten::fft_ifft", True), + pytest.param(torch.fft.ihfft, "aten::fft_ihfft", False, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.irfft, "aten::fft_irfft", True), + ]) + def test_1d(self, ie_device, precision, ir_version, input_shape, op, n, dim, norm, aten_name, in_complex): + if op in [torch.fft.rfft, torch.fft.irfft] and n is not None and input_shape[dim] < n: + pytest.skip("Signal size greater than input size is not supported yet") + if in_complex: + self.input_shape = input_shape + [2] + else: + self.input_shape = input_shape + m = aten_fft(op, n, dim, norm) + self._test(m, None, aten_name, ie_device, + precision, ir_version, trace_model=True, dynamic_shapes=False, custom_eps=1e-3) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("input_shape", [[20, 30], [15, 20, 30], [10, 15, 20, 30]]) + @pytest.mark.parametrize("s", [None, (10, 10)]) + @pytest.mark.parametrize("dim", [(0, 1), (-2, -1)]) + @pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"]) + @pytest.mark.parametrize("op,aten_name,in_complex", [ + (torch.fft.fft2, "aten::fft_fft2", True), + (torch.fft.fft2, "aten::fft_fft2", False), + pytest.param(torch.fft.hfft2, "aten::fft_hfft2", True, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.rfft2, "aten::fft_rfft2", False), + (torch.fft.ifft2, "aten::fft_ifft2", True), + pytest.param(torch.fft.ihfft2, "aten::fft_ihfft2", False, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.irfft2, "aten::fft_irfft2", True), + ]) + def test_2d(self, ie_device, precision, ir_version, input_shape, op, s, dim, norm, aten_name, in_complex): + if in_complex: + self.input_shape = input_shape + [2] + else: + self.input_shape = input_shape + m = aten_fft(op, s, dim, norm) + self._test(m, None, aten_name, ie_device, + precision, ir_version, trace_model=True, dynamic_shapes=False, custom_eps=1e-3) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("input_shape,s,dim", [ + ((4, 5), None, None), + ((4, 5), None, (0,)), + ((4, 5), None, (0, -1)), + ((4, 5, 6), None, None), + ((4, 5, 6), None, (0,)), + ((4, 5, 6), None, (0, -1)), + ((4, 5, 6, 7), None, None), + ((4, 5, 6, 7), None, (0,)), + ((4, 5, 6, 7), None, (0, -1)), + ((4, 5, 6, 7, 8, 4), None, None), + ((4, 5, 6, 7, 8), None, (1, 3, 4)), + ((4, 5, 6), None, (1,)), + ((4,), None, (0,)), + ((4, 5, 60, 70), (10, 10), None), + ((40, 50, 6, 7), (10, 10), (0, 1)), + ]) + @pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"]) + @pytest.mark.parametrize("op,aten_name,in_complex", [ + (torch.fft.fftn, "aten::fft_fftn", True), + (torch.fft.fftn, "aten::fft_fftn", False), + pytest.param(torch.fft.hfftn, "aten::fft_hfftn", True, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.rfftn, "aten::fft_rfftn", False), + (torch.fft.ifftn, "aten::fft_ifftn", True), + pytest.param(torch.fft.ihfftn, "aten::fft_ihfftn", False, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.irfftn, "aten::fft_irfftn", True), + ]) + def test_nd(self, ie_device, precision, ir_version, input_shape, op, s, dim, norm, aten_name, in_complex): + if in_complex: + self.input_shape = input_shape + (2,) + else: + self.input_shape = input_shape + m = aten_fft(op, s, dim, norm) + self._test(m, None, aten_name, ie_device, + precision, ir_version, trace_model=True, dynamic_shapes=False, custom_eps=1e-3) diff --git a/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py b/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py deleted file mode 100644 index 2e8d9123e48b3e..00000000000000 --- a/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (C) 2018-2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -from sys import platform - -import numpy as np -import pytest -import torch - -from pytorch_layer_test_class import PytorchLayerTest - - -class TestRFFTN(PytorchLayerTest): - def _prepare_input(self): - return (np.random.randn(*self.input_shape).astype(np.float32),) - - def create_model(self, dim, s, norm): - class aten_fft_rfftn(torch.nn.Module): - def __init__(self, dim, s, norm): - super(aten_fft_rfftn, self).__init__() - self.dim = dim - self.s = s - self.norm = norm - - def forward(self, x): - rfftn = torch.fft.rfftn(x, s=self.s, dim=self.dim, norm=self.norm) - r = rfftn.real - i = rfftn.imag - irfftn = torch.fft.irfftn(torch.complex(r, i), s=self.s, dim=self.dim, norm=self.norm) - return irfftn, r, i - - ref_net = None - - return ( - aten_fft_rfftn(dim, s, norm), - ref_net, - ["aten::fft_irfftn", "aten::complex", "aten::fft_rfftn", "aten::real", "aten::imag"], - ) - - @pytest.mark.parametrize("input_shape", [[64, 49], [64, 50], [64, 64, 49]]) - @pytest.mark.parametrize("dim", [[0, -1], [-2, -1], None, [0, 1]]) - @pytest.mark.parametrize("s", [None, [-1, 49], [64, -1], [64, 49], [5, 1]]) - @pytest.mark.parametrize("norm", ["forward", "backward", "ortho", None]) - @pytest.mark.nightly - @pytest.mark.precommit - @pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182") - def test_rfftn(self, ie_device, precision, ir_version, input_shape, dim, s, norm): - self.input_shape = input_shape - # Unfrozen test would fail due to issues with prim::GetAttr containing lists, strings or none. - self._test(*self.create_model(dim, s, norm), ie_device, precision, ir_version, custom_eps=1e-3, - freeze_model=True)