4
4
5
5
#include " openvino/frontend/complex_type_mark.hpp"
6
6
#include " openvino/frontend/pytorch/node_context.hpp"
7
+ #include " openvino/op/dft.hpp"
7
8
#include " openvino/op/divide.hpp"
8
9
#include " openvino/op/equal.hpp"
9
10
#include " openvino/op/gather.hpp"
11
+ #include " openvino/op/idft.hpp"
10
12
#include " openvino/op/irdft.hpp"
11
13
#include " openvino/op/multiply.hpp"
12
14
#include " openvino/op/range.hpp"
@@ -28,18 +30,56 @@ namespace op {
28
30
29
31
using namespace ov ::op;
30
32
31
- OutputVector translate_fft_rfftn (const NodeContext& context) {
32
- // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
33
- num_inputs_check (context, 1 , 4 );
34
- auto input = context.get_input (0 );
35
-
36
- auto const_neg_1 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {-1 }));
33
+ namespace {
34
+ Output<Node> normalize (const NodeContext& context,
35
+ const Output<Node>& node,
36
+ const Output<Node>& s,
37
+ const std::string& norm,
38
+ bool inverse) {
39
+ if (norm == " backward" ) {
40
+ // No normalization
41
+ return node;
42
+ }
43
+ // Apply normalizations
37
44
auto const_0 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {0 }));
38
- auto const_1 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {1 }));
45
+ auto n_int = context.mark_node (std::make_shared<v1::ReduceProd>(s, const_0));
46
+ auto n = context.mark_node (std::make_shared<v1::ConvertLike>(n_int, node));
47
+ Output<Node> normalized;
48
+ if (norm == " forward" ) {
49
+ // Normalize by 1/n
50
+ if (inverse) {
51
+ normalized = context.mark_node (std::make_shared<v1::Multiply>(node, n));
52
+ } else {
53
+ normalized = context.mark_node (std::make_shared<v1::Divide>(node, n));
54
+ }
55
+ } else if (norm == " ortho" ) {
56
+ // Normalize by 1/sqrt(n)
57
+ auto sqrt_n = context.mark_node (std::make_shared<v0::Sqrt>(n));
58
+ if (inverse) {
59
+ normalized = context.mark_node (std::make_shared<v1::Multiply>(node, sqrt_n));
60
+ } else {
61
+ normalized = context.mark_node (std::make_shared<v1::Divide>(node, sqrt_n));
62
+ }
63
+ } else {
64
+ FRONT_END_THROW (" Unrecognized normalization mode " + norm +
65
+ " . Only forward, backward and ortho are supported." );
66
+ }
67
+ return normalized;
68
+ }
39
69
70
+ std::tuple<Output<Node>, Output<Node>> get_dim_s (const NodeContext& context,
71
+ const Output<Node>& x,
72
+ int size,
73
+ bool is_irfft) {
40
74
Output<Node> input_shape;
41
75
Output<Node> input_rank_scalar;
42
- std::tie (input_shape, input_rank_scalar) = get_shape_rank (context, input, true );
76
+ std::tie (input_shape, input_rank_scalar) = get_shape_rank (context, x, true );
77
+
78
+ auto const_neg_1 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {-1 }));
79
+ auto const_neg_1_1d = context.mark_node (v0::Constant::create (element::i32, Shape{1 }, {-1 }));
80
+ auto const_0 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {0 }));
81
+ auto const_1 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {1 }));
82
+ auto const_2 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {2 }));
43
83
44
84
Output<Node> raw_s;
45
85
// Inputs can be either none or List. Check whether input values should be used or should be set to default values.
@@ -48,33 +88,89 @@ OutputVector translate_fft_rfftn(const NodeContext& context) {
48
88
raw_s = get_input_concat_if_list (context, 1 );
49
89
raw_s = context.mark_node (std::make_shared<v0::Convert>(raw_s, element::i32));
50
90
}
51
- Output<Node> dim;
52
91
// Handle dim parameter containing vector of integers indicating dimensions to be transformed.
92
+ Output<Node> dim;
53
93
if (!context.input_is_none (2 )) {
54
94
// dim is provided, load from input.
55
95
dim = get_input_concat_if_list (context, 2 );
56
96
dim = context.mark_node (std::make_shared<v0::Convert>(dim, element::i32));
57
97
} else if (!context.input_is_none (1 )) {
58
98
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
59
99
auto s_len = context.mark_node (std::make_shared<v3::ShapeOf>(raw_s, element::i32));
60
- auto slice_start = context.mark_node (std::make_shared<v1::Subtract>(input_rank_scalar, s_len));
61
- auto slice_start_scalar = context.mark_node (std::make_shared<v0::Squeeze>(slice_start));
62
- dim = context.mark_node (
63
- std::make_shared<v4::Range>(slice_start_scalar, input_rank_scalar, const_1, element::i32));
100
+ auto start = context.mark_node (std::make_shared<v1::Subtract>(input_rank_scalar, s_len));
101
+ auto start_scalar = context.mark_node (std::make_shared<v0::Squeeze>(start));
102
+ dim = context.mark_node (std::make_shared<v4::Range>(start_scalar, input_rank_scalar, const_1, element::i32));
64
103
} else {
65
- // Dim and s are set to default, use all of dimensions.
66
- dim = context.mark_node (std::make_shared<v4::Range>(const_0, input_rank_scalar, const_1, element::i32));
104
+ // Dim and s are set to default.
105
+ switch (size) {
106
+ case 1 :
107
+ dim = context.mark_node (v0::Constant::create (element::i32, Shape{1 }, {-1 }));
108
+ break ;
109
+ case 2 :
110
+ dim = context.mark_node (v0::Constant::create (element::i32, Shape{2 }, {-2 , -1 }));
111
+ break ;
112
+ case -1 :
113
+ dim = context.mark_node (std::make_shared<v4::Range>(const_0, input_rank_scalar, const_1, element::i32));
114
+ break ;
115
+ default :
116
+ FRONT_END_THROW (" Invalid FFT size: " + std::to_string (size));
117
+ }
118
+ }
119
+ if (dim.get_partial_shape ().rank ().is_dynamic () || dim.get_partial_shape ().rank ().get_length () == 0 ) {
120
+ dim = context.mark_node (std::make_shared<v1::Reshape>(dim, const_neg_1_1d, false ));
67
121
}
68
122
123
+ Output<Node> default_s;
124
+ if (is_irfft) {
125
+ // Calculate default s values. Use full available size except last element, which is set to even value in last
126
+ // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]] - 1).
127
+ auto default_s_raw = context.mark_node (std::make_shared<v8::Gather>(input_shape, dim, const_0));
128
+ auto last_s = context.mark_node (std::make_shared<v8::Gather>(default_s_raw, const_neg_1, const_0));
129
+ auto last_s_m_1 = context.mark_node (std::make_shared<v1::Subtract>(last_s, const_1));
130
+ auto s_upd = context.mark_node (std::make_shared<v1::Multiply>(last_s_m_1, const_2));
131
+ auto s_shape = context.mark_node (std::make_shared<v3::ShapeOf>(default_s_raw, element::i32));
132
+ auto last_s_idx = context.mark_node (std::make_shared<v1::Subtract>(s_shape, const_1));
133
+ s_upd = context.mark_node (std::make_shared<v1::Reshape>(s_upd, const_neg_1_1d, false ));
134
+ default_s = context.mark_node (std::make_shared<v3::ScatterUpdate>(default_s_raw, last_s_idx, s_upd, const_0));
135
+ } else {
136
+ default_s = context.mark_node (std::make_shared<v8::Gather>(input_shape, dim, const_0));
137
+ }
69
138
Output<Node> s;
70
139
if (context.input_is_none (1 )) {
71
140
// Value for s was set to default, use full size for all dimensions.
72
- s = context. mark_node (std::make_shared<v8::Gather>(input_shape, dim, const_0)) ;
141
+ s = default_s ;
73
142
} else {
74
143
// Values for s were provided. Replace -1 values with default full size in given dimension.
75
144
auto full_s_cond = context.mark_node (std::make_shared<v1::Equal>(raw_s, const_neg_1));
76
- auto full_s_values = context.mark_node (std::make_shared<v8::Gather>(input_shape, dim, const_0));
77
- s = context.mark_node (std::make_shared<v1::Select>(full_s_cond, full_s_values, raw_s));
145
+ s = context.mark_node (std::make_shared<v1::Select>(full_s_cond, default_s, raw_s));
146
+ }
147
+ return {dim, s};
148
+ }
149
+
150
+ template <typename T>
151
+ OutputVector translate_fft_base (const NodeContext& context,
152
+ int size,
153
+ bool complex_input,
154
+ bool complex_output,
155
+ bool inverse = false ,
156
+ bool is_irfft = false ) {
157
+ num_inputs_check (context, 1 , 4 , true );
158
+ auto input = context.get_input (0 );
159
+
160
+ Output<Node> dim;
161
+ Output<Node> s;
162
+ std::tie (dim, s) = get_dim_s (context, input, size, is_irfft);
163
+
164
+ auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr ());
165
+ if (complex_type_mark) {
166
+ PYTORCH_OP_CONVERSION_CHECK (complex_input, " Operation does not support complex type tensor on input." );
167
+ input = complex_type_mark->get_data ();
168
+ } else {
169
+ if (complex_input) {
170
+ auto const_0 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {0 }));
171
+ const_0 = context.mark_node (std::make_shared<v1::ConvertLike>(const_0, input));
172
+ input = std::make_shared<ComplexTypeMark>(input, const_0)->get_data ();
173
+ }
78
174
}
79
175
80
176
// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
@@ -83,123 +179,65 @@ OutputVector translate_fft_rfftn(const NodeContext& context) {
83
179
norm = context.const_input <std::string>(3 );
84
180
}
85
181
86
- auto rdft = context.mark_node (std::make_shared<v9::RDFT >(input, dim, s));
182
+ auto node = context.mark_node (std::make_shared<T >(input, dim, s));
87
183
88
184
// Apply normalizations
89
- auto n_int = context.mark_node (std::make_shared<v1::ReduceProd>(s, const_0));
90
- auto n = context.mark_node (std::make_shared<v1::ConvertLike>(n_int, rdft));
91
- Output<Node> normalized_rfftn;
92
- if (norm == " forward" ) {
93
- // Normalize by 1/n
94
- normalized_rfftn = context.mark_node (std::make_shared<v1::Divide>(rdft, n));
95
- } else if (norm == " backward" ) {
96
- // No normalization
97
- normalized_rfftn = rdft;
98
- } else if (norm == " ortho" ) {
99
- // Normalize by 1/sqrt(n)
100
- auto sqrt_n = context.mark_node (std::make_shared<v0::Sqrt>(n));
101
- normalized_rfftn = context.mark_node (std::make_shared<v1::Divide>(rdft, sqrt_n));
102
- } else {
103
- FRONT_END_THROW (
104
- " aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported." );
185
+ Output<Node> normalized = normalize (context, node, s, norm, inverse);
186
+ if (complex_output) {
187
+ normalized = std::make_shared<ComplexTypeMark>(normalized, normalized.get_element_type ());
105
188
}
189
+ return {normalized};
190
+ }
191
+ } // namespace
106
192
107
- return {std::make_shared<ComplexTypeMark>(normalized_rfftn, normalized_rfftn.get_element_type ())};
193
+ OutputVector translate_fft_fft (const NodeContext& context) {
194
+ return translate_fft_base<v7::DFT>(context, 1 , true , true );
108
195
}
109
196
110
- OutputVector translate_fft_irfftn (const NodeContext& context) {
111
- // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
112
- num_inputs_check (context, 1 , 4 , true );
113
- auto input = context.get_input (0 );
197
+ OutputVector translate_fft_fft2 (const NodeContext& context) {
198
+ return translate_fft_base<v7::DFT>(context, 2 , true , true );
199
+ }
114
200
115
- auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input. get_node_shared_ptr ());
116
- PYTORCH_OP_CONVERSION_CHECK (complex_type_mark, " aten::fft_irfftn operation expects complex type tensor on input. " );
117
- input = complex_type_mark-> get_data ();
201
+ OutputVector translate_fft_fftn ( const NodeContext& context) {
202
+ return translate_fft_base<v7::DFT>(context, - 1 , true , true );
203
+ }
118
204
119
- auto const_neg_1 = context.mark_node (v0::Constant::create (element::i32, Shape{1 }, {-1 }));
120
- auto const_0 = context.mark_node (v0::Constant::create (element::i32, Shape{1 }, {0 }));
121
- auto const_scalar_0 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {0 }));
122
- auto const_1 = context.mark_node (v0::Constant::create (element::i32, Shape{1 }, {1 }));
123
- auto const_scalar_1 = context.mark_node (v0::Constant::create (element::i32, Shape{}, {1 }));
124
- auto const_2 = context.mark_node (v0::Constant::create (element::i32, Shape{1 }, {2 }));
205
+ OutputVector translate_fft_rfft (const NodeContext& context) {
206
+ return translate_fft_base<v9::RDFT>(context, 1 , false , true );
207
+ }
125
208
126
- // Input shape of complex number (excluding dimension created by concatenation of real and imag)
127
- auto complex_input_shape = get_complex_shape (context, input);
128
- auto input_rank = context.mark_node (std::make_shared<v3::ShapeOf>(complex_input_shape, element::i32));
129
- auto input_rank_scalar = context.mark_node (std::make_shared<v0::Squeeze>(input_rank));
209
+ OutputVector translate_fft_rfft2 (const NodeContext& context) {
210
+ return translate_fft_base<v9::RDFT>(context, 2 , false , true );
211
+ }
130
212
131
- Output<Node> raw_s;
132
- // Inputs can be either none or List. Check whether input values should be used or should be set to default values.
133
- if (!context.input_is_none (1 )) {
134
- // s is provided, load from input.
135
- raw_s = get_input_concat_if_list (context, 1 );
136
- raw_s = context.mark_node (std::make_shared<v0::Convert>(raw_s, element::i32));
137
- }
213
+ OutputVector translate_fft_rfftn (const NodeContext& context) {
214
+ // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
215
+ return translate_fft_base<v9::RDFT>(context, -1 , false , true );
216
+ }
138
217
139
- // Handle dim parameter containing vector of integers indicating dimensions to be transformed.
140
- Output<Node> dim;
141
- if (!context.input_is_none (2 )) {
142
- // Dim values is provided, load from input.
143
- dim = get_input_concat_if_list (context, 2 );
144
- dim = context.mark_node (std::make_shared<v0::Convert>(dim, element::i32));
145
- } else if (!context.input_is_none (1 )) {
146
- // If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
147
- auto s_len = context.mark_node (std::make_shared<v3::ShapeOf>(raw_s, element::i32));
148
- auto range_start = context.mark_node (std::make_shared<v1::Subtract>(input_rank, s_len));
149
- auto range_start_scalar = context.mark_node (std::make_shared<v0::Squeeze>(range_start));
150
- dim = context.mark_node (
151
- std::make_shared<v4::Range>(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32));
152
- } else {
153
- // Dim and s are set to default, use all of dimensions.
154
- dim = context.mark_node (
155
- std::make_shared<v4::Range>(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32));
156
- }
218
+ OutputVector translate_fft_ifft (const NodeContext& context) {
219
+ return translate_fft_base<v7::IDFT>(context, 1 , true , true , true );
220
+ }
157
221
158
- // Calculate default s values. Use full available size except last element, which is set to even value in last
159
- // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]])
160
- auto default_s_raw = context.mark_node (std::make_shared<v8::Gather>(complex_input_shape, dim, const_0));
161
- auto last_s = context.mark_node (std::make_shared<v8::Gather>(default_s_raw, const_neg_1, const_0));
162
- auto last_s_m_1 = context.mark_node (std::make_shared<v1::Subtract>(last_s, const_1));
163
- auto s_upd = context.mark_node (std::make_shared<v1::Multiply>(last_s_m_1, const_2));
164
- auto s_shape = context.mark_node (std::make_shared<v3::ShapeOf>(default_s_raw, element::i32));
165
- auto last_s_idx = context.mark_node (std::make_shared<v1::Subtract>(s_shape, const_1));
166
- auto default_s = context.mark_node (std::make_shared<v3::ScatterUpdate>(default_s_raw, last_s_idx, s_upd, const_0));
167
-
168
- // Handle s parameter containing vector of intigers indicating signal sizes for dimensions.
169
- Output<Node> s;
170
- if (!context.input_is_none (1 )) {
171
- // Values for s were provided. Replace -1 values with default full size in given dimension.
172
- auto full_s_cond = context.mark_node (std::make_shared<v1::Equal>(raw_s, const_neg_1));
173
- s = context.mark_node (std::make_shared<v1::Select>(full_s_cond, default_s, raw_s));
174
- } else {
175
- // Value for s was set to default.
176
- s = default_s;
177
- }
222
+ OutputVector translate_fft_ifft2 (const NodeContext& context) {
223
+ return translate_fft_base<v7::IDFT>(context, 2 , true , true , true );
224
+ }
178
225
179
- // Handle norm parameter indicating normalization mode to use. Defaults to "backward".
180
- std::string norm = " backward" ;
181
- if (!context.input_is_none (3 )) {
182
- norm = context.const_input <std::string>(3 );
183
- }
226
+ OutputVector translate_fft_ifftn (const NodeContext& context) {
227
+ return translate_fft_base<v7::IDFT>(context, -1 , true , true , true );
228
+ }
184
229
185
- auto irdft = context.mark_node (std::make_shared<v9::IRDFT>(input, dim, s));
230
+ OutputVector translate_fft_irfft (const NodeContext& context) {
231
+ return translate_fft_base<v9::IRDFT>(context, 1 , true , false , true , true );
232
+ }
186
233
187
- // Apply normalizations.
188
- auto n_int = context.mark_node (std::make_shared<v1::ReduceProd>(s, const_0));
189
- auto n = context.mark_node (std::make_shared<v1::ConvertLike>(n_int, irdft));
190
- Output<Node> normalized_irfftn;
191
- if (norm == " forward" ) {
192
- normalized_irfftn = context.mark_node (std::make_shared<v1::Multiply>(irdft, n));
193
- } else if (norm == " backward" ) {
194
- normalized_irfftn = irdft;
195
- } else if (norm == " ortho" ) {
196
- auto sqrt_n = context.mark_node (std::make_shared<v0::Sqrt>(n));
197
- normalized_irfftn = context.mark_node (std::make_shared<v1::Multiply>(irdft, sqrt_n));
198
- } else {
199
- FRONT_END_THROW (
200
- " aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported." );
201
- }
202
- return {normalized_irfftn};
234
+ OutputVector translate_fft_irfft2 (const NodeContext& context) {
235
+ return translate_fft_base<v9::IRDFT>(context, 2 , true , false , true , true );
236
+ }
237
+
238
+ OutputVector translate_fft_irfftn (const NodeContext& context) {
239
+ // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
240
+ return translate_fft_base<v9::IRDFT>(context, -1 , true , false , true , true );
203
241
}
204
242
205
243
} // namespace op
0 commit comments