Skip to content

Commit d53350b

Browse files
authored
[PT FE] Support fft operations (#29507)
### Details: - *Support fft operations* ### Tickets: - *ticket-id* --------- Signed-off-by: Maxim Vafin <[email protected]>
1 parent df40180 commit d53350b

File tree

6 files changed

+347
-174
lines changed

6 files changed

+347
-174
lines changed

src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def __init__(
7575
help_msg = ("Tracing sometimes provide better results, "
7676
"please provide valid 'example_input' argument.\n")
7777
raise RuntimeError(
78-
f"Couldn't get TorchScript module by {msg}.\n{help_msg} "
79-
"You can also provide TorchScript module that you obtained"
78+
f"Couldn't get TorchScript module by {msg}.\nException:\n{e}\n"
79+
f"{help_msg} You can also provide TorchScript module that you obtained"
8080
" yourself, please refer to PyTorch documentation: "
8181
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html."
8282
) from e

src/frontends/pytorch/src/op/fft.cpp

+158-120
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
#include "openvino/frontend/complex_type_mark.hpp"
66
#include "openvino/frontend/pytorch/node_context.hpp"
7+
#include "openvino/op/dft.hpp"
78
#include "openvino/op/divide.hpp"
89
#include "openvino/op/equal.hpp"
910
#include "openvino/op/gather.hpp"
11+
#include "openvino/op/idft.hpp"
1012
#include "openvino/op/irdft.hpp"
1113
#include "openvino/op/multiply.hpp"
1214
#include "openvino/op/range.hpp"
@@ -28,18 +30,56 @@ namespace op {
2830

2931
using namespace ov::op;
3032

31-
OutputVector translate_fft_rfftn(const NodeContext& context) {
32-
// aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
33-
num_inputs_check(context, 1, 4);
34-
auto input = context.get_input(0);
35-
36-
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
33+
namespace {
34+
Output<Node> normalize(const NodeContext& context,
35+
const Output<Node>& node,
36+
const Output<Node>& s,
37+
const std::string& norm,
38+
bool inverse) {
39+
if (norm == "backward") {
40+
// No normalization
41+
return node;
42+
}
43+
// Apply normalizations
3744
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
38-
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
45+
auto n_int = context.mark_node(std::make_shared<v1::ReduceProd>(s, const_0));
46+
auto n = context.mark_node(std::make_shared<v1::ConvertLike>(n_int, node));
47+
Output<Node> normalized;
48+
if (norm == "forward") {
49+
// Normalize by 1/n
50+
if (inverse) {
51+
normalized = context.mark_node(std::make_shared<v1::Multiply>(node, n));
52+
} else {
53+
normalized = context.mark_node(std::make_shared<v1::Divide>(node, n));
54+
}
55+
} else if (norm == "ortho") {
56+
// Normalize by 1/sqrt(n)
57+
auto sqrt_n = context.mark_node(std::make_shared<v0::Sqrt>(n));
58+
if (inverse) {
59+
normalized = context.mark_node(std::make_shared<v1::Multiply>(node, sqrt_n));
60+
} else {
61+
normalized = context.mark_node(std::make_shared<v1::Divide>(node, sqrt_n));
62+
}
63+
} else {
64+
FRONT_END_THROW("Unrecognized normalization mode " + norm +
65+
". Only forward, backward and ortho are supported.");
66+
}
67+
return normalized;
68+
}
3969

70+
std::tuple<Output<Node>, Output<Node>> get_dim_s(const NodeContext& context,
71+
const Output<Node>& x,
72+
int size,
73+
bool is_irfft) {
4074
Output<Node> input_shape;
4175
Output<Node> input_rank_scalar;
42-
std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, input, true);
76+
std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, x, true);
77+
78+
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
79+
auto const_neg_1_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
80+
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
81+
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
82+
auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2}));
4383

4484
Output<Node> raw_s;
4585
// Inputs can be either none or List. Check whether input values should be used or should be set to default values.
@@ -48,33 +88,89 @@ OutputVector translate_fft_rfftn(const NodeContext& context) {
4888
raw_s = get_input_concat_if_list(context, 1);
4989
raw_s = context.mark_node(std::make_shared<v0::Convert>(raw_s, element::i32));
5090
}
51-
Output<Node> dim;
5291
// Handle dim parameter containing vector of integers indicating dimensions to be transformed.
92+
Output<Node> dim;
5393
if (!context.input_is_none(2)) {
5494
// dim is provided, load from input.
5595
dim = get_input_concat_if_list(context, 2);
5696
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
5797
} else if (!context.input_is_none(1)) {
5898
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
5999
auto s_len = context.mark_node(std::make_shared<v3::ShapeOf>(raw_s, element::i32));
60-
auto slice_start = context.mark_node(std::make_shared<v1::Subtract>(input_rank_scalar, s_len));
61-
auto slice_start_scalar = context.mark_node(std::make_shared<v0::Squeeze>(slice_start));
62-
dim = context.mark_node(
63-
std::make_shared<v4::Range>(slice_start_scalar, input_rank_scalar, const_1, element::i32));
100+
auto start = context.mark_node(std::make_shared<v1::Subtract>(input_rank_scalar, s_len));
101+
auto start_scalar = context.mark_node(std::make_shared<v0::Squeeze>(start));
102+
dim = context.mark_node(std::make_shared<v4::Range>(start_scalar, input_rank_scalar, const_1, element::i32));
64103
} else {
65-
// Dim and s are set to default, use all of dimensions.
66-
dim = context.mark_node(std::make_shared<v4::Range>(const_0, input_rank_scalar, const_1, element::i32));
104+
// Dim and s are set to default.
105+
switch (size) {
106+
case 1:
107+
dim = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
108+
break;
109+
case 2:
110+
dim = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {-2, -1}));
111+
break;
112+
case -1:
113+
dim = context.mark_node(std::make_shared<v4::Range>(const_0, input_rank_scalar, const_1, element::i32));
114+
break;
115+
default:
116+
FRONT_END_THROW("Invalid FFT size: " + std::to_string(size));
117+
}
118+
}
119+
if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) {
120+
dim = context.mark_node(std::make_shared<v1::Reshape>(dim, const_neg_1_1d, false));
67121
}
68122

123+
Output<Node> default_s;
124+
if (is_irfft) {
125+
// Calculate default s values. Use full available size except last element, which is set to even value in last
126+
// dimension: s[-1] = 2 * (complex_input_shape[dim[-1]] - 1).
127+
auto default_s_raw = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
128+
auto last_s = context.mark_node(std::make_shared<v8::Gather>(default_s_raw, const_neg_1, const_0));
129+
auto last_s_m_1 = context.mark_node(std::make_shared<v1::Subtract>(last_s, const_1));
130+
auto s_upd = context.mark_node(std::make_shared<v1::Multiply>(last_s_m_1, const_2));
131+
auto s_shape = context.mark_node(std::make_shared<v3::ShapeOf>(default_s_raw, element::i32));
132+
auto last_s_idx = context.mark_node(std::make_shared<v1::Subtract>(s_shape, const_1));
133+
s_upd = context.mark_node(std::make_shared<v1::Reshape>(s_upd, const_neg_1_1d, false));
134+
default_s = context.mark_node(std::make_shared<v3::ScatterUpdate>(default_s_raw, last_s_idx, s_upd, const_0));
135+
} else {
136+
default_s = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
137+
}
69138
Output<Node> s;
70139
if (context.input_is_none(1)) {
71140
// Value for s was set to default, use full size for all dimensions.
72-
s = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
141+
s = default_s;
73142
} else {
74143
// Values for s were provided. Replace -1 values with default full size in given dimension.
75144
auto full_s_cond = context.mark_node(std::make_shared<v1::Equal>(raw_s, const_neg_1));
76-
auto full_s_values = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
77-
s = context.mark_node(std::make_shared<v1::Select>(full_s_cond, full_s_values, raw_s));
145+
s = context.mark_node(std::make_shared<v1::Select>(full_s_cond, default_s, raw_s));
146+
}
147+
return {dim, s};
148+
}
149+
150+
template <typename T>
151+
OutputVector translate_fft_base(const NodeContext& context,
152+
int size,
153+
bool complex_input,
154+
bool complex_output,
155+
bool inverse = false,
156+
bool is_irfft = false) {
157+
num_inputs_check(context, 1, 4, true);
158+
auto input = context.get_input(0);
159+
160+
Output<Node> dim;
161+
Output<Node> s;
162+
std::tie(dim, s) = get_dim_s(context, input, size, is_irfft);
163+
164+
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
165+
if (complex_type_mark) {
166+
PYTORCH_OP_CONVERSION_CHECK(complex_input, "Operation does not support complex type tensor on input.");
167+
input = complex_type_mark->get_data();
168+
} else {
169+
if (complex_input) {
170+
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
171+
const_0 = context.mark_node(std::make_shared<v1::ConvertLike>(const_0, input));
172+
input = std::make_shared<ComplexTypeMark>(input, const_0)->get_data();
173+
}
78174
}
79175

80176
// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
@@ -83,123 +179,65 @@ OutputVector translate_fft_rfftn(const NodeContext& context) {
83179
norm = context.const_input<std::string>(3);
84180
}
85181

86-
auto rdft = context.mark_node(std::make_shared<v9::RDFT>(input, dim, s));
182+
auto node = context.mark_node(std::make_shared<T>(input, dim, s));
87183

88184
// Apply normalizations
89-
auto n_int = context.mark_node(std::make_shared<v1::ReduceProd>(s, const_0));
90-
auto n = context.mark_node(std::make_shared<v1::ConvertLike>(n_int, rdft));
91-
Output<Node> normalized_rfftn;
92-
if (norm == "forward") {
93-
// Normalize by 1/n
94-
normalized_rfftn = context.mark_node(std::make_shared<v1::Divide>(rdft, n));
95-
} else if (norm == "backward") {
96-
// No normalization
97-
normalized_rfftn = rdft;
98-
} else if (norm == "ortho") {
99-
// Normalize by 1/sqrt(n)
100-
auto sqrt_n = context.mark_node(std::make_shared<v0::Sqrt>(n));
101-
normalized_rfftn = context.mark_node(std::make_shared<v1::Divide>(rdft, sqrt_n));
102-
} else {
103-
FRONT_END_THROW(
104-
"aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
185+
Output<Node> normalized = normalize(context, node, s, norm, inverse);
186+
if (complex_output) {
187+
normalized = std::make_shared<ComplexTypeMark>(normalized, normalized.get_element_type());
105188
}
189+
return {normalized};
190+
}
191+
} // namespace
106192

107-
return {std::make_shared<ComplexTypeMark>(normalized_rfftn, normalized_rfftn.get_element_type())};
193+
OutputVector translate_fft_fft(const NodeContext& context) {
194+
return translate_fft_base<v7::DFT>(context, 1, true, true);
108195
}
109196

110-
OutputVector translate_fft_irfftn(const NodeContext& context) {
111-
// aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
112-
num_inputs_check(context, 1, 4, true);
113-
auto input = context.get_input(0);
197+
OutputVector translate_fft_fft2(const NodeContext& context) {
198+
return translate_fft_base<v7::DFT>(context, 2, true, true);
199+
}
114200

115-
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
116-
PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::fft_irfftn operation expects complex type tensor on input.");
117-
input = complex_type_mark->get_data();
201+
OutputVector translate_fft_fftn(const NodeContext& context) {
202+
return translate_fft_base<v7::DFT>(context, -1, true, true);
203+
}
118204

119-
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
120-
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
121-
auto const_scalar_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
122-
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
123-
auto const_scalar_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
124-
auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
205+
OutputVector translate_fft_rfft(const NodeContext& context) {
206+
return translate_fft_base<v9::RDFT>(context, 1, false, true);
207+
}
125208

126-
// Input shape of complex number (excluding dimension created by concatenation of real and imag)
127-
auto complex_input_shape = get_complex_shape(context, input);
128-
auto input_rank = context.mark_node(std::make_shared<v3::ShapeOf>(complex_input_shape, element::i32));
129-
auto input_rank_scalar = context.mark_node(std::make_shared<v0::Squeeze>(input_rank));
209+
OutputVector translate_fft_rfft2(const NodeContext& context) {
210+
return translate_fft_base<v9::RDFT>(context, 2, false, true);
211+
}
130212

131-
Output<Node> raw_s;
132-
// Inputs can be either none or List. Check whether input values should be used or should be set to default values.
133-
if (!context.input_is_none(1)) {
134-
// s is provided, load from input.
135-
raw_s = get_input_concat_if_list(context, 1);
136-
raw_s = context.mark_node(std::make_shared<v0::Convert>(raw_s, element::i32));
137-
}
213+
OutputVector translate_fft_rfftn(const NodeContext& context) {
214+
// aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
215+
return translate_fft_base<v9::RDFT>(context, -1, false, true);
216+
}
138217

139-
// Handle dim parameter containing vector of integers indicating dimensions to be transformed.
140-
Output<Node> dim;
141-
if (!context.input_is_none(2)) {
142-
// Dim values is provided, load from input.
143-
dim = get_input_concat_if_list(context, 2);
144-
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
145-
} else if (!context.input_is_none(1)) {
146-
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
147-
auto s_len = context.mark_node(std::make_shared<v3::ShapeOf>(raw_s, element::i32));
148-
auto range_start = context.mark_node(std::make_shared<v1::Subtract>(input_rank, s_len));
149-
auto range_start_scalar = context.mark_node(std::make_shared<v0::Squeeze>(range_start));
150-
dim = context.mark_node(
151-
std::make_shared<v4::Range>(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32));
152-
} else {
153-
// Dim and s are set to default, use all of dimensions.
154-
dim = context.mark_node(
155-
std::make_shared<v4::Range>(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32));
156-
}
218+
OutputVector translate_fft_ifft(const NodeContext& context) {
219+
return translate_fft_base<v7::IDFT>(context, 1, true, true, true);
220+
}
157221

158-
// Calculate default s values. Use full available size except last element, which is set to even value in last
159-
// dimension: s[-1] = 2 * (complex_input_shape[dim[-1]])
160-
auto default_s_raw = context.mark_node(std::make_shared<v8::Gather>(complex_input_shape, dim, const_0));
161-
auto last_s = context.mark_node(std::make_shared<v8::Gather>(default_s_raw, const_neg_1, const_0));
162-
auto last_s_m_1 = context.mark_node(std::make_shared<v1::Subtract>(last_s, const_1));
163-
auto s_upd = context.mark_node(std::make_shared<v1::Multiply>(last_s_m_1, const_2));
164-
auto s_shape = context.mark_node(std::make_shared<v3::ShapeOf>(default_s_raw, element::i32));
165-
auto last_s_idx = context.mark_node(std::make_shared<v1::Subtract>(s_shape, const_1));
166-
auto default_s = context.mark_node(std::make_shared<v3::ScatterUpdate>(default_s_raw, last_s_idx, s_upd, const_0));
167-
168-
// Handle s parameter containing vector of intigers indicating signal sizes for dimensions.
169-
Output<Node> s;
170-
if (!context.input_is_none(1)) {
171-
// Values for s were provided. Replace -1 values with default full size in given dimension.
172-
auto full_s_cond = context.mark_node(std::make_shared<v1::Equal>(raw_s, const_neg_1));
173-
s = context.mark_node(std::make_shared<v1::Select>(full_s_cond, default_s, raw_s));
174-
} else {
175-
// Value for s was set to default.
176-
s = default_s;
177-
}
222+
OutputVector translate_fft_ifft2(const NodeContext& context) {
223+
return translate_fft_base<v7::IDFT>(context, 2, true, true, true);
224+
}
178225

179-
// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
180-
std::string norm = "backward";
181-
if (!context.input_is_none(3)) {
182-
norm = context.const_input<std::string>(3);
183-
}
226+
OutputVector translate_fft_ifftn(const NodeContext& context) {
227+
return translate_fft_base<v7::IDFT>(context, -1, true, true, true);
228+
}
184229

185-
auto irdft = context.mark_node(std::make_shared<v9::IRDFT>(input, dim, s));
230+
OutputVector translate_fft_irfft(const NodeContext& context) {
231+
return translate_fft_base<v9::IRDFT>(context, 1, true, false, true, true);
232+
}
186233

187-
// Apply normalizations.
188-
auto n_int = context.mark_node(std::make_shared<v1::ReduceProd>(s, const_0));
189-
auto n = context.mark_node(std::make_shared<v1::ConvertLike>(n_int, irdft));
190-
Output<Node> normalized_irfftn;
191-
if (norm == "forward") {
192-
normalized_irfftn = context.mark_node(std::make_shared<v1::Multiply>(irdft, n));
193-
} else if (norm == "backward") {
194-
normalized_irfftn = irdft;
195-
} else if (norm == "ortho") {
196-
auto sqrt_n = context.mark_node(std::make_shared<v0::Sqrt>(n));
197-
normalized_irfftn = context.mark_node(std::make_shared<v1::Multiply>(irdft, sqrt_n));
198-
} else {
199-
FRONT_END_THROW(
200-
"aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
201-
}
202-
return {normalized_irfftn};
234+
OutputVector translate_fft_irfft2(const NodeContext& context) {
235+
return translate_fft_base<v9::IRDFT>(context, 2, true, false, true, true);
236+
}
237+
238+
OutputVector translate_fft_irfftn(const NodeContext& context) {
239+
// aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
240+
return translate_fft_base<v9::IRDFT>(context, -1, true, false, true, true);
203241
}
204242

205243
} // namespace op

0 commit comments

Comments
 (0)