From 87ac7c4e6ff29aa8495bdd3bf76369202045c528 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Sun, 16 Mar 2025 20:18:52 +0100 Subject: [PATCH 1/9] [PT FE] Support fft operations Signed-off-by: Maxim Vafin --- .../openvino/frontend/pytorch/ts_decoder.py | 4 +- src/frontends/pytorch/src/op/fft.cpp | 246 ++++++++++-------- src/frontends/pytorch/src/op_table.cpp | 13 + src/frontends/pytorch/src/utils.cpp | 8 +- .../test_rfftn_complex_transforms.py | 89 ++++++- 5 files changed, 248 insertions(+), 112 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index 6538a418462c97..066dc7c50dcc41 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -75,8 +75,8 @@ def __init__( help_msg = ("Tracing sometimes provide better results, " "please provide valid 'example_input' argument.\n") raise RuntimeError( - f"Couldn't get TorchScript module by {msg}.\n{help_msg} " - "You can also provide TorchScript module that you obtained" + f"Couldn't get TorchScript module by {msg}.\nException:\n{e}\n" + f"{help_msg} You can also provide TorchScript module that you obtained" " yourself, please refer to PyTorch documentation: " "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html." ) from e diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp index b86b01efb21639..95b6dc8ce85475 100644 --- a/src/frontends/pytorch/src/op/fft.cpp +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -4,9 +4,11 @@ #include "openvino/frontend/complex_type_mark.hpp" #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/dft.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/equal.hpp" #include "openvino/op/gather.hpp" +#include "openvino/op/idft.hpp" #include "openvino/op/irdft.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/range.hpp" @@ -19,6 +21,7 @@ #include "openvino/op/sqrt.hpp" #include "openvino/op/squeeze.hpp" #include "openvino/op/subtract.hpp" +#include "openvino/op/unsqueeze.hpp" #include "utils.hpp" namespace ov { @@ -28,18 +31,50 @@ namespace op { using namespace ov::op; -OutputVector translate_fft_rfftn(const NodeContext& context) { - // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor - num_inputs_check(context, 1, 4); - auto input = context.get_input(0); - - auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); +namespace { +Output normalize(const NodeContext& context, + const Output& node, + const Output& s, + std::string norm, + bool inverse) { + if (norm == "backward") { + // No normalization + return node; + } + // Apply normalizations auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); - auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto n_int = context.mark_node(std::make_shared(s, const_0)); + auto n = context.mark_node(std::make_shared(n_int, node)); + Output normalized; + if (norm == "forward") { + // Normalize by 1/n + if (inverse) { + normalized = context.mark_node(std::make_shared(node, n)); + } else { + normalized = context.mark_node(std::make_shared(node, n)); + } + } else if (norm == "ortho") { + // Normalize by 1/sqrt(n) + auto sqrt_n = context.mark_node(std::make_shared(n)); + if (inverse) { + normalized = context.mark_node(std::make_shared(node, sqrt_n)); + } else { + normalized = context.mark_node(std::make_shared(node, sqrt_n)); + } + } else { + FRONT_END_THROW("Unrecognized normalization mode. Only forward, backward and ortho are supported."); + } + return normalized; +} +std::tuple, Output> get_dim_s(const NodeContext& context, const Output& x, bool is_irfft) { Output input_shape; Output input_rank_scalar; - std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, input, true); + std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, x, true); + + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); Output raw_s; // Inputs can be either none or List. Check whether input values should be used or should be set to default values. @@ -48,8 +83,8 @@ OutputVector translate_fft_rfftn(const NodeContext& context) { raw_s = get_input_concat_if_list(context, 1); raw_s = context.mark_node(std::make_shared(raw_s, element::i32)); } - Output dim; // Handle dim parameter containing vector of integers indicating dimensions to be transformed. + Output dim; if (!context.input_is_none(2)) { // dim is provided, load from input. dim = get_input_concat_if_list(context, 2); @@ -57,24 +92,61 @@ OutputVector translate_fft_rfftn(const NodeContext& context) { } else if (!context.input_is_none(1)) { // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. auto s_len = context.mark_node(std::make_shared(raw_s, element::i32)); - auto slice_start = context.mark_node(std::make_shared(input_rank_scalar, s_len)); - auto slice_start_scalar = context.mark_node(std::make_shared(slice_start)); - dim = context.mark_node( - std::make_shared(slice_start_scalar, input_rank_scalar, const_1, element::i32)); + auto start = context.mark_node(std::make_shared(input_rank_scalar, s_len)); + auto start_scalar = context.mark_node(std::make_shared(start)); + dim = context.mark_node(std::make_shared(start_scalar, input_rank_scalar, const_1, element::i32)); } else { // Dim and s are set to default, use all of dimensions. dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32)); } + if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) { + auto const_neg_1_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + dim = context.mark_node(std::make_shared(dim, const_neg_1_1d, false)); + } + Output default_s; + if (is_irfft) { + auto const_2_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2})); + // Calculate default s values. Use full available size except last element, which is set to even value in last + // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]]) + auto default_s_raw = context.mark_node(std::make_shared(input_shape, dim, const_0)); + auto last_s = context.mark_node(std::make_shared(default_s_raw, const_neg_1, const_0)); + auto last_s_m_1 = context.mark_node(std::make_shared(last_s, const_1)); + auto s_upd = context.mark_node(std::make_shared(last_s_m_1, const_2_1d)); + auto s_shape = context.mark_node(std::make_shared(default_s_raw, element::i32)); + auto last_s_idx = context.mark_node(std::make_shared(s_shape, const_1)); + default_s = context.mark_node(std::make_shared(default_s_raw, last_s_idx, s_upd, const_0)); + } else { + default_s = context.mark_node(std::make_shared(input_shape, dim, const_0)); + } Output s; if (context.input_is_none(1)) { // Value for s was set to default, use full size for all dimensions. - s = context.mark_node(std::make_shared(input_shape, dim, const_0)); + s = default_s; } else { // Values for s were provided. Replace -1 values with default full size in given dimension. auto full_s_cond = context.mark_node(std::make_shared(raw_s, const_neg_1)); - auto full_s_values = context.mark_node(std::make_shared(input_shape, dim, const_0)); - s = context.mark_node(std::make_shared(full_s_cond, full_s_values, raw_s)); + s = context.mark_node(std::make_shared(full_s_cond, default_s, raw_s)); + } + return {dim, s}; +} +} // namespace + +OutputVector translate_fft_fftn(const NodeContext& context) { + num_inputs_check(context, 1, 4, true); + auto input = context.get_input(0); + + Output dim; + Output s; + std::tie(dim, s) = get_dim_s(context, input, false); + + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + if (complex_type_mark) { + input = complex_type_mark->get_data(); + } else { + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + const_0 = context.mark_node(std::make_shared(const_0, input)); + input = std::make_shared(input, const_0)->get_data(); } // Handle norm parameter indicating normalization mode to use. Defaults to "backward". @@ -83,98 +155,72 @@ OutputVector translate_fft_rfftn(const NodeContext& context) { norm = context.const_input(3); } - auto rdft = context.mark_node(std::make_shared(input, dim, s)); + auto node = context.mark_node(std::make_shared(input, dim, s)); // Apply normalizations - auto n_int = context.mark_node(std::make_shared(s, const_0)); - auto n = context.mark_node(std::make_shared(n_int, rdft)); - Output normalized_rfftn; - if (norm == "forward") { - // Normalize by 1/n - normalized_rfftn = context.mark_node(std::make_shared(rdft, n)); - } else if (norm == "backward") { - // No normalization - normalized_rfftn = rdft; - } else if (norm == "ortho") { - // Normalize by 1/sqrt(n) - auto sqrt_n = context.mark_node(std::make_shared(n)); - normalized_rfftn = context.mark_node(std::make_shared(rdft, sqrt_n)); - } else { - FRONT_END_THROW( - "aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); + Output normalized = normalize(context, node, s, norm, false); + return {std::make_shared(normalized, normalized.get_element_type())}; +} + +OutputVector translate_fft_rfftn(const NodeContext& context) { + // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + num_inputs_check(context, 1, 4); + auto input = context.get_input(0); + + Output dim; + Output s; + std::tie(dim, s) = get_dim_s(context, input, false); + + // Handle norm parameter indicating normalization mode to use. Defaults to "backward". + std::string norm = "backward"; + if (!context.input_is_none(3)) { + norm = context.const_input(3); } - return {std::make_shared(normalized_rfftn, normalized_rfftn.get_element_type())}; + auto node = context.mark_node(std::make_shared(input, dim, s)); + + // Apply normalizations + Output normalized = normalize(context, node, s, norm, false); + return {std::make_shared(normalized, normalized.get_element_type())}; } -OutputVector translate_fft_irfftn(const NodeContext& context) { +OutputVector translate_fft_ifftn(const NodeContext& context) { // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor num_inputs_check(context, 1, 4, true); auto input = context.get_input(0); + Output dim; + Output s; + std::tie(dim, s) = get_dim_s(context, input, false); + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); - PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::fft_irfftn operation expects complex type tensor on input."); + PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "Operation expects complex type tensor on input."); input = complex_type_mark->get_data(); - auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); - auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); - auto const_scalar_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); - auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); - auto const_scalar_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); - auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2})); + // Handle norm parameter indicating normalization mode to use. Defaults to "backward". + std::string norm = "backward"; + if (!context.input_is_none(3)) { + norm = context.const_input(3); + } - // Input shape of complex number (excluding dimension created by concatenation of real and imag) - auto complex_input_shape = get_complex_shape(context, input); - auto input_rank = context.mark_node(std::make_shared(complex_input_shape, element::i32)); - auto input_rank_scalar = context.mark_node(std::make_shared(input_rank)); + auto node = context.mark_node(std::make_shared(input, dim, s)); - Output raw_s; - // Inputs can be either none or List. Check whether input values should be used or should be set to default values. - if (!context.input_is_none(1)) { - // s is provided, load from input. - raw_s = get_input_concat_if_list(context, 1); - raw_s = context.mark_node(std::make_shared(raw_s, element::i32)); - } + Output normalized = normalize(context, node, s, norm, true); + return {std::make_shared(normalized, normalized.get_element_type())}; +} - // Handle dim parameter containing vector of integers indicating dimensions to be transformed. - Output dim; - if (!context.input_is_none(2)) { - // Dim values is provided, load from input. - dim = get_input_concat_if_list(context, 2); - dim = context.mark_node(std::make_shared(dim, element::i32)); - } else if (!context.input_is_none(1)) { - // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. - auto s_len = context.mark_node(std::make_shared(raw_s, element::i32)); - auto range_start = context.mark_node(std::make_shared(input_rank, s_len)); - auto range_start_scalar = context.mark_node(std::make_shared(range_start)); - dim = context.mark_node( - std::make_shared(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32)); - } else { - // Dim and s are set to default, use all of dimensions. - dim = context.mark_node( - std::make_shared(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32)); - } +OutputVector translate_fft_irfftn(const NodeContext& context) { + // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + num_inputs_check(context, 1, 4, true); + auto input = context.get_input(0); - // Calculate default s values. Use full available size except last element, which is set to even value in last - // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]]) - auto default_s_raw = context.mark_node(std::make_shared(complex_input_shape, dim, const_0)); - auto last_s = context.mark_node(std::make_shared(default_s_raw, const_neg_1, const_0)); - auto last_s_m_1 = context.mark_node(std::make_shared(last_s, const_1)); - auto s_upd = context.mark_node(std::make_shared(last_s_m_1, const_2)); - auto s_shape = context.mark_node(std::make_shared(default_s_raw, element::i32)); - auto last_s_idx = context.mark_node(std::make_shared(s_shape, const_1)); - auto default_s = context.mark_node(std::make_shared(default_s_raw, last_s_idx, s_upd, const_0)); - - // Handle s parameter containing vector of intigers indicating signal sizes for dimensions. + Output dim; Output s; - if (!context.input_is_none(1)) { - // Values for s were provided. Replace -1 values with default full size in given dimension. - auto full_s_cond = context.mark_node(std::make_shared(raw_s, const_neg_1)); - s = context.mark_node(std::make_shared(full_s_cond, default_s, raw_s)); - } else { - // Value for s was set to default. - s = default_s; - } + std::tie(dim, s) = get_dim_s(context, input, true); + + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "Operation expects complex type tensor on input."); + input = complex_type_mark->get_data(); // Handle norm parameter indicating normalization mode to use. Defaults to "backward". std::string norm = "backward"; @@ -182,24 +228,10 @@ OutputVector translate_fft_irfftn(const NodeContext& context) { norm = context.const_input(3); } - auto irdft = context.mark_node(std::make_shared(input, dim, s)); + auto node = context.mark_node(std::make_shared(input, dim, s)); - // Apply normalizations. - auto n_int = context.mark_node(std::make_shared(s, const_0)); - auto n = context.mark_node(std::make_shared(n_int, irdft)); - Output normalized_irfftn; - if (norm == "forward") { - normalized_irfftn = context.mark_node(std::make_shared(irdft, n)); - } else if (norm == "backward") { - normalized_irfftn = irdft; - } else if (norm == "ortho") { - auto sqrt_n = context.mark_node(std::make_shared(n)); - normalized_irfftn = context.mark_node(std::make_shared(irdft, sqrt_n)); - } else { - FRONT_END_THROW( - "aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); - } - return {normalized_irfftn}; + Output normalized = normalize(context, node, s, norm, true); + return {normalized}; } } // namespace op diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 018812354a23b6..b8cfb824939c38 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -92,6 +92,8 @@ OP_CONVERTER(translate_expm1); OP_CONVERTER(translate_eye); OP_CONVERTER(translate_fake_quantize_per_channel_affine); OP_CONVERTER(translate_fake_quantize_per_tensor_affine); +OP_CONVERTER(translate_fft_fftn); +OP_CONVERTER(translate_fft_ifftn); OP_CONVERTER(translate_fft_irfftn); OP_CONVERTER(translate_fft_rfftn); OP_CONVERTER(translate_fill); @@ -484,7 +486,17 @@ const std::unordered_map get_supported_ops_ts() { {"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine}, {"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine}, {"aten::feature_dropout", op::skip_node}, + {"aten::fft_fft", op::translate_fft_fftn}, + {"aten::fft_fft2", op::translate_fft_fftn}, + {"aten::fft_fftn", op::translate_fft_fftn}, + {"aten::fft_ifft", op::translate_fft_ifftn}, + {"aten::fft_ifft2", op::translate_fft_ifftn}, + {"aten::fft_ifftn", op::translate_fft_ifftn}, + {"aten::fft_irfft", op::translate_fft_irfftn}, + {"aten::fft_irfft2", op::translate_fft_irfftn}, {"aten::fft_irfftn", op::translate_fft_irfftn}, + {"aten::fft_rfft", op::translate_fft_rfftn}, + {"aten::fft_rfft2", op::translate_fft_rfftn}, {"aten::fft_rfftn", op::translate_fft_rfftn}, {"aten::fill", op::translate_fill}, {"aten::fill_diagonal", op::translate_fill_diagonal}, @@ -964,6 +976,7 @@ const std::unordered_map get_supported_ops_fx() { {"aten.relu.default", op::translate_1to1_match_1_inputs}, {"aten.relu_.default", op::inplace_op>}, {"aten.repeat.default", op::translate_repeat_fx}, + {"aten.repeat_interleave.Tensor", op::translate_repeat_interleave}, {"aten.rms_norm.default", op::translate_rms_norm}, {"aten.roll.default", op::translate_roll}, {"aten.rsqrt.default", op::translate_rsqrt}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 3789577a07f619..3a64b083f8b923 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -106,7 +106,13 @@ std::tuple, Output> get_shape_rank(const NodeContext& context const Output& x, bool as_scalar, element::Type output_type) { - auto shape = context.mark_node(std::make_shared(x, output_type)); + auto complex_type_mark = as_type_ptr(x.get_node_shared_ptr()); + Output shape; + if (complex_type_mark) { + shape = get_complex_shape(context, complex_type_mark->get_data()); + } else { + shape = context.mark_node(std::make_shared(x, output_type)); + } Output rank = context.mark_node(std::make_shared(shape, output_type)); if (as_scalar) { auto axis_0 = context.mark_node(v0::Constant::create(output_type, Shape{}, {0})); diff --git a/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py b/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py index 2e8d9123e48b3e..56710f2362bf60 100644 --- a/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py +++ b/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py @@ -47,5 +47,90 @@ def forward(self, x): def test_rfftn(self, ie_device, precision, ir_version, input_shape, dim, s, norm): self.input_shape = input_shape # Unfrozen test would fail due to issues with prim::GetAttr containing lists, strings or none. - self._test(*self.create_model(dim, s, norm), ie_device, precision, ir_version, custom_eps=1e-3, - freeze_model=True) + self._test(*self.create_model(dim, s, norm), ie_device, precision, ir_version, custom_eps=1e-3) + + +class aten_fft(torch.nn.Module): + def __init__(self, op, n, dim, norm): + super().__init__() + self.n = n + self.dim = dim + self.norm = norm + self.op = op + + def forward(self, x): + if x.shape[-1] == 2: + x = torch.view_as_complex(x) + res = self.op(x, self.n, dim=self.dim, norm=self.norm) + if res.dtype.is_complex: + return torch.view_as_real(res) + return res + + +class TestFFT(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(*self.input_shape).astype(np.float32),) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("input_shape", [[67], [80], [12, 14], [9, 6, 3]]) + @pytest.mark.parametrize("n", [None, 50, 6]) + @pytest.mark.parametrize("dim", [-1, 0]) + @pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"]) + @pytest.mark.parametrize("op,aten_name,in_complex", [ + (torch.fft.fft, "aten::fft_fft", True), + (torch.fft.fft, "aten::fft_fft", False), + pytest.param(torch.fft.hfft, "aten::fft_hfft", True, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.rfft, "aten::fft_rfft", False), + (torch.fft.ifft, "aten::fft_ifft", True), + pytest.param(torch.fft.ihfft, "aten::fft_ihfft", False, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.irfft, "aten::fft_irfft", True), + ]) + def test_1d(self, ie_device, precision, ir_version, input_shape, op, n, dim, norm, aten_name, in_complex): + if op in [torch.fft.rfft, torch.fft.irfft] and n is not None and input_shape[dim] < n: + pytest.skip("Signal size greater than input size is not supported yet") + if in_complex: + self.input_shape = input_shape + [2] + else: + self.input_shape = input_shape + m = aten_fft(op, n, dim, norm) + self._test(m, None, aten_name, ie_device, + precision, ir_version, trace_model=True, dynamic_shapes=False, custom_eps=1e-3) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("input_shape,s,dim", [ + ((4, 5), None, None), + ((4, 5), None, (0,)), + ((4, 5), None, (0, -1)), + ((4, 5, 6), None, None), + ((4, 5, 6), None, (0,)), + ((4, 5, 6), None, (0, -1)), + ((4, 5, 6, 7), None, None), + ((4, 5, 6, 7), None, (0,)), + ((4, 5, 6, 7), None, (0, -1)), + ((4, 5, 6, 7, 8, 4), None, None), + ((4, 5, 6, 7, 8), None, (1, 3, 4)), + ((4, 5, 6), None, (1,)), + ((4,), None, (0,)), + ((4, 5, 60, 70), (10, 10), None), + ((40, 50, 6, 7), (10, 10), (0, 1)), + ]) + @pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"]) + @pytest.mark.parametrize("op,aten_name,in_complex", [ + (torch.fft.fftn, "aten::fft_fftn", True), + (torch.fft.fftn, "aten::fft_fftn", False), + pytest.param(torch.fft.hfftn, "aten::fft_hfftn", True, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.rfftn, "aten::fft_rfftn", False), + (torch.fft.ifftn, "aten::fft_ifftn", True), + pytest.param(torch.fft.ihfftn, "aten::fft_ihfftn", False, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.irfftn, "aten::fft_irfftn", True), + ]) + def test_nd(self, ie_device, precision, ir_version, input_shape, op, s, dim, norm, aten_name, in_complex): + if in_complex: + self.input_shape = input_shape + (2,) + else: + self.input_shape = input_shape + m = aten_fft(op, s, dim, norm) + self._test(m, None, aten_name, ie_device, + precision, ir_version, trace_model=True, dynamic_shapes=False, custom_eps=1e-3) From b093949ccadf0b1363b9b745df68cde01fdc6ab9 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 17 Mar 2025 11:35:27 +0100 Subject: [PATCH 2/9] Update src/frontends/pytorch/src/op/fft.cpp --- src/frontends/pytorch/src/op/fft.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp index 95b6dc8ce85475..54ae514982acc2 100644 --- a/src/frontends/pytorch/src/op/fft.cpp +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -21,7 +21,6 @@ #include "openvino/op/sqrt.hpp" #include "openvino/op/squeeze.hpp" #include "openvino/op/subtract.hpp" -#include "openvino/op/unsqueeze.hpp" #include "utils.hpp" namespace ov { From 4b2d9316b6e4db56136ae27865607c5a4a82821c Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 17 Mar 2025 11:36:08 +0100 Subject: [PATCH 3/9] Update src/frontends/pytorch/src/op_table.cpp --- src/frontends/pytorch/src/op_table.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index b8cfb824939c38..7636ad54d96cce 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -976,7 +976,6 @@ const std::unordered_map get_supported_ops_fx() { {"aten.relu.default", op::translate_1to1_match_1_inputs}, {"aten.relu_.default", op::inplace_op>}, {"aten.repeat.default", op::translate_repeat_fx}, - {"aten.repeat_interleave.Tensor", op::translate_repeat_interleave}, {"aten.rms_norm.default", op::translate_rms_norm}, {"aten.roll.default", op::translate_roll}, {"aten.rsqrt.default", op::translate_rsqrt}, From 410e93f8687bc45a9f72c7fb08d604cc5b530686 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 17 Mar 2025 14:49:40 +0100 Subject: [PATCH 4/9] Fix tests Signed-off-by: Maxim Vafin --- src/frontends/pytorch/src/op/fft.cpp | 7 ++++--- .../{test_rfftn_complex_transforms.py => test_fft.py} | 0 2 files changed, 4 insertions(+), 3 deletions(-) rename tests/layer_tests/pytorch_tests/{test_rfftn_complex_transforms.py => test_fft.py} (100%) diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp index 54ae514982acc2..c725b5477ea823 100644 --- a/src/frontends/pytorch/src/op/fft.cpp +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -72,8 +72,10 @@ std::tuple, Output> get_dim_s(const NodeContext& context, con std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, x, true); auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + auto const_neg_1_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2})); Output raw_s; // Inputs can be either none or List. Check whether input values should be used or should be set to default values. @@ -99,21 +101,20 @@ std::tuple, Output> get_dim_s(const NodeContext& context, con dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32)); } if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) { - auto const_neg_1_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); dim = context.mark_node(std::make_shared(dim, const_neg_1_1d, false)); } Output default_s; if (is_irfft) { - auto const_2_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2})); // Calculate default s values. Use full available size except last element, which is set to even value in last // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]]) auto default_s_raw = context.mark_node(std::make_shared(input_shape, dim, const_0)); auto last_s = context.mark_node(std::make_shared(default_s_raw, const_neg_1, const_0)); auto last_s_m_1 = context.mark_node(std::make_shared(last_s, const_1)); - auto s_upd = context.mark_node(std::make_shared(last_s_m_1, const_2_1d)); + auto s_upd = context.mark_node(std::make_shared(last_s_m_1, const_2)); auto s_shape = context.mark_node(std::make_shared(default_s_raw, element::i32)); auto last_s_idx = context.mark_node(std::make_shared(s_shape, const_1)); + s_upd = context.mark_node(std::make_shared(s_upd, const_neg_1_1d, false)); default_s = context.mark_node(std::make_shared(default_s_raw, last_s_idx, s_upd, const_0)); } else { default_s = context.mark_node(std::make_shared(input_shape, dim, const_0)); diff --git a/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py b/tests/layer_tests/pytorch_tests/test_fft.py similarity index 100% rename from tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py rename to tests/layer_tests/pytorch_tests/test_fft.py From 0e538cef968bcbce23e70b4989d62624c6691fa8 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 17 Mar 2025 14:52:45 +0100 Subject: [PATCH 5/9] Update src/frontends/pytorch/src/op/fft.cpp --- src/frontends/pytorch/src/op/fft.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp index c725b5477ea823..cb4803845cb0d9 100644 --- a/src/frontends/pytorch/src/op/fft.cpp +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -107,7 +107,7 @@ std::tuple, Output> get_dim_s(const NodeContext& context, con Output default_s; if (is_irfft) { // Calculate default s values. Use full available size except last element, which is set to even value in last - // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]]) + // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]] - 1). auto default_s_raw = context.mark_node(std::make_shared(input_shape, dim, const_0)); auto last_s = context.mark_node(std::make_shared(default_s_raw, const_neg_1, const_0)); auto last_s_m_1 = context.mark_node(std::make_shared(last_s, const_1)); From ad43d2ab1582e923eee91fb754010b32ecd3beb2 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 17 Mar 2025 14:54:53 +0100 Subject: [PATCH 6/9] Update src/frontends/pytorch/src/op/fft.cpp --- src/frontends/pytorch/src/op/fft.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp index cb4803845cb0d9..d2d862dd7f19a0 100644 --- a/src/frontends/pytorch/src/op/fft.cpp +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -34,7 +34,7 @@ namespace { Output normalize(const NodeContext& context, const Output& node, const Output& s, - std::string norm, + const std::string& norm, bool inverse) { if (norm == "backward") { // No normalization From 289f5b600644d60d1a8784ae465dcc6129efdc55 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 17 Mar 2025 16:16:30 +0100 Subject: [PATCH 7/9] Fix 2d versions Signed-off-by: Maxim Vafin --- src/frontends/pytorch/src/op/fft.cpp | 139 ++++++++++---------- src/frontends/pytorch/src/op_table.cpp | 25 ++-- tests/layer_tests/pytorch_tests/test_fft.py | 24 ++++ 3 files changed, 113 insertions(+), 75 deletions(-) diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp index d2d862dd7f19a0..e11c61362d59e0 100644 --- a/src/frontends/pytorch/src/op/fft.cpp +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -66,7 +66,10 @@ Output normalize(const NodeContext& context, return normalized; } -std::tuple, Output> get_dim_s(const NodeContext& context, const Output& x, bool is_irfft) { +std::tuple, Output> get_dim_s(const NodeContext& context, + const Output& x, + int size, + bool is_irfft) { Output input_shape; Output input_rank_scalar; std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, x, true); @@ -97,8 +100,20 @@ std::tuple, Output> get_dim_s(const NodeContext& context, con auto start_scalar = context.mark_node(std::make_shared(start)); dim = context.mark_node(std::make_shared(start_scalar, input_rank_scalar, const_1, element::i32)); } else { - // Dim and s are set to default, use all of dimensions. - dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32)); + // Dim and s are set to default. + switch (size) { + case 1: + dim = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + break; + case 2: + dim = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {-2, -1})); + break; + case -1: + dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32)); + break; + default: + FRONT_END_THROW("Invalid FFT size."); + } } if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) { dim = context.mark_node(std::make_shared(dim, const_neg_1_1d, false)); @@ -130,23 +145,31 @@ std::tuple, Output> get_dim_s(const NodeContext& context, con } return {dim, s}; } -} // namespace -OutputVector translate_fft_fftn(const NodeContext& context) { +template +OutputVector translate_fft_fft_base(const NodeContext& context, + int size, + bool complex_input, + bool complex_output, + bool inverse = false, + bool is_irfft = false) { num_inputs_check(context, 1, 4, true); auto input = context.get_input(0); Output dim; Output s; - std::tie(dim, s) = get_dim_s(context, input, false); + std::tie(dim, s) = get_dim_s(context, input, size, is_irfft); auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); if (complex_type_mark) { + PYTORCH_OP_CONVERSION_CHECK(complex_input, "Operation does not support complex type tensor on input."); input = complex_type_mark->get_data(); } else { - auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); - const_0 = context.mark_node(std::make_shared(const_0, input)); - input = std::make_shared(input, const_0)->get_data(); + if (complex_input) { + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + const_0 = context.mark_node(std::make_shared(const_0, input)); + input = std::make_shared(input, const_0)->get_data(); + } } // Handle norm parameter indicating normalization mode to use. Defaults to "backward". @@ -155,83 +178,65 @@ OutputVector translate_fft_fftn(const NodeContext& context) { norm = context.const_input(3); } - auto node = context.mark_node(std::make_shared(input, dim, s)); + auto node = context.mark_node(std::make_shared(input, dim, s)); // Apply normalizations - Output normalized = normalize(context, node, s, norm, false); - return {std::make_shared(normalized, normalized.get_element_type())}; + Output normalized = normalize(context, node, s, norm, inverse); + if (complex_output) { + normalized = std::make_shared(normalized, normalized.get_element_type()); + } + return {normalized}; } +} // namespace -OutputVector translate_fft_rfftn(const NodeContext& context) { - // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor - num_inputs_check(context, 1, 4); - auto input = context.get_input(0); +OutputVector translate_fft_fft(const NodeContext& context) { + return translate_fft_fft_base(context, 1, true, true); +} - Output dim; - Output s; - std::tie(dim, s) = get_dim_s(context, input, false); +OutputVector translate_fft_fft2(const NodeContext& context) { + return translate_fft_fft_base(context, 2, true, true); +} - // Handle norm parameter indicating normalization mode to use. Defaults to "backward". - std::string norm = "backward"; - if (!context.input_is_none(3)) { - norm = context.const_input(3); - } +OutputVector translate_fft_fftn(const NodeContext& context) { + return translate_fft_fft_base(context, -1, true, true); +} - auto node = context.mark_node(std::make_shared(input, dim, s)); +OutputVector translate_fft_rfft(const NodeContext& context) { + return translate_fft_fft_base(context, 1, false, true); +} - // Apply normalizations - Output normalized = normalize(context, node, s, norm, false); - return {std::make_shared(normalized, normalized.get_element_type())}; +OutputVector translate_fft_rfft2(const NodeContext& context) { + return translate_fft_fft_base(context, 2, false, true); } -OutputVector translate_fft_ifftn(const NodeContext& context) { - // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor - num_inputs_check(context, 1, 4, true); - auto input = context.get_input(0); +OutputVector translate_fft_rfftn(const NodeContext& context) { + // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + return translate_fft_fft_base(context, -1, false, true); +} - Output dim; - Output s; - std::tie(dim, s) = get_dim_s(context, input, false); +OutputVector translate_fft_ifft(const NodeContext& context) { + return translate_fft_fft_base(context, 1, true, true, true); +} - auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); - PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "Operation expects complex type tensor on input."); - input = complex_type_mark->get_data(); +OutputVector translate_fft_ifft2(const NodeContext& context) { + return translate_fft_fft_base(context, 2, true, true, true); +} - // Handle norm parameter indicating normalization mode to use. Defaults to "backward". - std::string norm = "backward"; - if (!context.input_is_none(3)) { - norm = context.const_input(3); - } +OutputVector translate_fft_ifftn(const NodeContext& context) { + return translate_fft_fft_base(context, -1, true, true, true); +} - auto node = context.mark_node(std::make_shared(input, dim, s)); +OutputVector translate_fft_irfft(const NodeContext& context) { + return translate_fft_fft_base(context, 1, true, false, true, true); +} - Output normalized = normalize(context, node, s, norm, true); - return {std::make_shared(normalized, normalized.get_element_type())}; +OutputVector translate_fft_irfft2(const NodeContext& context) { + return translate_fft_fft_base(context, 2, true, false, true, true); } OutputVector translate_fft_irfftn(const NodeContext& context) { // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor - num_inputs_check(context, 1, 4, true); - auto input = context.get_input(0); - - Output dim; - Output s; - std::tie(dim, s) = get_dim_s(context, input, true); - - auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); - PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "Operation expects complex type tensor on input."); - input = complex_type_mark->get_data(); - - // Handle norm parameter indicating normalization mode to use. Defaults to "backward". - std::string norm = "backward"; - if (!context.input_is_none(3)) { - norm = context.const_input(3); - } - - auto node = context.mark_node(std::make_shared(input, dim, s)); - - Output normalized = normalize(context, node, s, norm, true); - return {normalized}; + return translate_fft_fft_base(context, -1, true, false, true, true); } } // namespace op diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 7636ad54d96cce..1dde0df8ac001e 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -92,9 +92,17 @@ OP_CONVERTER(translate_expm1); OP_CONVERTER(translate_eye); OP_CONVERTER(translate_fake_quantize_per_channel_affine); OP_CONVERTER(translate_fake_quantize_per_tensor_affine); +OP_CONVERTER(translate_fft_fft); +OP_CONVERTER(translate_fft_fft2); OP_CONVERTER(translate_fft_fftn); +OP_CONVERTER(translate_fft_ifft); +OP_CONVERTER(translate_fft_ifft2); OP_CONVERTER(translate_fft_ifftn); +OP_CONVERTER(translate_fft_irfft); +OP_CONVERTER(translate_fft_irfft2); OP_CONVERTER(translate_fft_irfftn); +OP_CONVERTER(translate_fft_rfft); +OP_CONVERTER(translate_fft_rfft2); OP_CONVERTER(translate_fft_rfftn); OP_CONVERTER(translate_fill); OP_CONVERTER(translate_fill_diagonal); @@ -486,17 +494,17 @@ const std::unordered_map get_supported_ops_ts() { {"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine}, {"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine}, {"aten::feature_dropout", op::skip_node}, - {"aten::fft_fft", op::translate_fft_fftn}, - {"aten::fft_fft2", op::translate_fft_fftn}, + {"aten::fft_fft", op::translate_fft_fft}, + {"aten::fft_fft2", op::translate_fft_fft2}, {"aten::fft_fftn", op::translate_fft_fftn}, - {"aten::fft_ifft", op::translate_fft_ifftn}, - {"aten::fft_ifft2", op::translate_fft_ifftn}, + {"aten::fft_ifft", op::translate_fft_ifft}, + {"aten::fft_ifft2", op::translate_fft_ifft2}, {"aten::fft_ifftn", op::translate_fft_ifftn}, - {"aten::fft_irfft", op::translate_fft_irfftn}, - {"aten::fft_irfft2", op::translate_fft_irfftn}, + {"aten::fft_irfft", op::translate_fft_irfft}, + {"aten::fft_irfft2", op::translate_fft_irfft2}, {"aten::fft_irfftn", op::translate_fft_irfftn}, - {"aten::fft_rfft", op::translate_fft_rfftn}, - {"aten::fft_rfft2", op::translate_fft_rfftn}, + {"aten::fft_rfft", op::translate_fft_rfft}, + {"aten::fft_rfft2", op::translate_fft_rfft2}, {"aten::fft_rfftn", op::translate_fft_rfftn}, {"aten::fill", op::translate_fill}, {"aten::fill_diagonal", op::translate_fill_diagonal}, @@ -881,6 +889,7 @@ const std::unordered_map get_supported_ops_fx() { {"aten.expand_copy.default", op::translate_expand}, {"aten.eye.m", op::translate_eye_fx}, {"aten.fake_quantize_per_channel_affine_cachemask.default", op::translate_fake_quantize_per_channel_affine_fx}, + {"aten.fft_fft.default", op::translate_fft_fftn}, {"aten.fill.Scalar", op::translate_fill}, {"aten.fill_.Scalar", op::inplace_op}, {"aten.fill.Tensor", op::translate_fill}, diff --git a/tests/layer_tests/pytorch_tests/test_fft.py b/tests/layer_tests/pytorch_tests/test_fft.py index 56710f2362bf60..099c5036b7ce7b 100644 --- a/tests/layer_tests/pytorch_tests/test_fft.py +++ b/tests/layer_tests/pytorch_tests/test_fft.py @@ -97,6 +97,30 @@ def test_1d(self, ie_device, precision, ir_version, input_shape, op, n, dim, nor self._test(m, None, aten_name, ie_device, precision, ir_version, trace_model=True, dynamic_shapes=False, custom_eps=1e-3) + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("input_shape", [[20, 30], [15, 20, 30], [10, 15, 20, 30]]) + @pytest.mark.parametrize("s", [None, (10, 10)]) + @pytest.mark.parametrize("dim", [(0, 1), (-2, -1)]) + @pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"]) + @pytest.mark.parametrize("op,aten_name,in_complex", [ + (torch.fft.fft2, "aten::fft_fft2", True), + (torch.fft.fft2, "aten::fft_fft2", False), + pytest.param(torch.fft.hfft2, "aten::fft_hfft2", True, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.rfft2, "aten::fft_rfft2", False), + (torch.fft.ifft2, "aten::fft_ifft2", True), + pytest.param(torch.fft.ihfft2, "aten::fft_ihfft2", False, marks=pytest.mark.skip(reason="Not supported yet.")), + (torch.fft.irfft2, "aten::fft_irfft2", True), + ]) + def test_2d(self, ie_device, precision, ir_version, input_shape, op, s, dim, norm, aten_name, in_complex): + if in_complex: + self.input_shape = input_shape + [2] + else: + self.input_shape = input_shape + m = aten_fft(op, s, dim, norm) + self._test(m, None, aten_name, ie_device, + precision, ir_version, trace_model=True, dynamic_shapes=False, custom_eps=1e-3) + @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.parametrize("input_shape,s,dim", [ From 95611452336ff6c55d6891332eb6c00b09f015a0 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 17 Mar 2025 16:11:11 +0100 Subject: [PATCH 8/9] Update src/frontends/pytorch/src/op_table.cpp --- src/frontends/pytorch/src/op_table.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 1dde0df8ac001e..1b0696305164fd 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -889,7 +889,6 @@ const std::unordered_map get_supported_ops_fx() { {"aten.expand_copy.default", op::translate_expand}, {"aten.eye.m", op::translate_eye_fx}, {"aten.fake_quantize_per_channel_affine_cachemask.default", op::translate_fake_quantize_per_channel_affine_fx}, - {"aten.fft_fft.default", op::translate_fft_fftn}, {"aten.fill.Scalar", op::translate_fill}, {"aten.fill_.Scalar", op::inplace_op}, {"aten.fill.Tensor", op::translate_fill}, From ebf417c40ffad8dd3ea183abcc48fcbb952767ec Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 18 Mar 2025 11:05:21 +0100 Subject: [PATCH 9/9] Add details in error message. Signed-off-by: Maxim Vafin --- src/frontends/pytorch/src/op/fft.cpp | 41 ++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp index e11c61362d59e0..9e07070e568e25 100644 --- a/src/frontends/pytorch/src/op/fft.cpp +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -61,7 +61,8 @@ Output normalize(const NodeContext& context, normalized = context.mark_node(std::make_shared(node, sqrt_n)); } } else { - FRONT_END_THROW("Unrecognized normalization mode. Only forward, backward and ortho are supported."); + FRONT_END_THROW("Unrecognized normalization mode " + norm + + ". Only forward, backward and ortho are supported."); } return normalized; } @@ -112,7 +113,7 @@ std::tuple, Output> get_dim_s(const NodeContext& context, dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32)); break; default: - FRONT_END_THROW("Invalid FFT size."); + FRONT_END_THROW("Invalid FFT size: " + std::to_string(size)); } } if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) { @@ -147,12 +148,12 @@ std::tuple, Output> get_dim_s(const NodeContext& context, } template -OutputVector translate_fft_fft_base(const NodeContext& context, - int size, - bool complex_input, - bool complex_output, - bool inverse = false, - bool is_irfft = false) { +OutputVector translate_fft_base(const NodeContext& context, + int size, + bool complex_input, + bool complex_output, + bool inverse = false, + bool is_irfft = false) { num_inputs_check(context, 1, 4, true); auto input = context.get_input(0); @@ -190,53 +191,53 @@ OutputVector translate_fft_fft_base(const NodeContext& context, } // namespace OutputVector translate_fft_fft(const NodeContext& context) { - return translate_fft_fft_base(context, 1, true, true); + return translate_fft_base(context, 1, true, true); } OutputVector translate_fft_fft2(const NodeContext& context) { - return translate_fft_fft_base(context, 2, true, true); + return translate_fft_base(context, 2, true, true); } OutputVector translate_fft_fftn(const NodeContext& context) { - return translate_fft_fft_base(context, -1, true, true); + return translate_fft_base(context, -1, true, true); } OutputVector translate_fft_rfft(const NodeContext& context) { - return translate_fft_fft_base(context, 1, false, true); + return translate_fft_base(context, 1, false, true); } OutputVector translate_fft_rfft2(const NodeContext& context) { - return translate_fft_fft_base(context, 2, false, true); + return translate_fft_base(context, 2, false, true); } OutputVector translate_fft_rfftn(const NodeContext& context) { // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor - return translate_fft_fft_base(context, -1, false, true); + return translate_fft_base(context, -1, false, true); } OutputVector translate_fft_ifft(const NodeContext& context) { - return translate_fft_fft_base(context, 1, true, true, true); + return translate_fft_base(context, 1, true, true, true); } OutputVector translate_fft_ifft2(const NodeContext& context) { - return translate_fft_fft_base(context, 2, true, true, true); + return translate_fft_base(context, 2, true, true, true); } OutputVector translate_fft_ifftn(const NodeContext& context) { - return translate_fft_fft_base(context, -1, true, true, true); + return translate_fft_base(context, -1, true, true, true); } OutputVector translate_fft_irfft(const NodeContext& context) { - return translate_fft_fft_base(context, 1, true, false, true, true); + return translate_fft_base(context, 1, true, false, true, true); } OutputVector translate_fft_irfft2(const NodeContext& context) { - return translate_fft_fft_base(context, 2, true, false, true, true); + return translate_fft_base(context, 2, true, false, true, true); } OutputVector translate_fft_irfftn(const NodeContext& context) { // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor - return translate_fft_fft_base(context, -1, true, false, true, true); + return translate_fft_base(context, -1, true, false, true, true); } } // namespace op