Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support fft operations #29507

Merged
merged 9 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def __init__(
help_msg = ("Tracing sometimes provide better results, "
"please provide valid 'example_input' argument.\n")
raise RuntimeError(
f"Couldn't get TorchScript module by {msg}.\n{help_msg} "
"You can also provide TorchScript module that you obtained"
f"Couldn't get TorchScript module by {msg}.\nException:\n{e}\n"
f"{help_msg} You can also provide TorchScript module that you obtained"
" yourself, please refer to PyTorch documentation: "
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html."
) from e
Expand Down
245 changes: 138 additions & 107 deletions src/frontends/pytorch/src/op/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/dft.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/idft.hpp"
#include "openvino/op/irdft.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
Expand All @@ -28,18 +30,50 @@ namespace op {

using namespace ov::op;

OutputVector translate_fft_rfftn(const NodeContext& context) {
// aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
num_inputs_check(context, 1, 4);
auto input = context.get_input(0);

auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
namespace {
Output<Node> normalize(const NodeContext& context,
const Output<Node>& node,
const Output<Node>& s,
std::string norm,
bool inverse) {
if (norm == "backward") {
// No normalization
return node;
}
// Apply normalizations
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto n_int = context.mark_node(std::make_shared<v1::ReduceProd>(s, const_0));
auto n = context.mark_node(std::make_shared<v1::ConvertLike>(n_int, node));
Output<Node> normalized;
if (norm == "forward") {
// Normalize by 1/n
if (inverse) {
normalized = context.mark_node(std::make_shared<v1::Multiply>(node, n));
} else {
normalized = context.mark_node(std::make_shared<v1::Divide>(node, n));
}
} else if (norm == "ortho") {
// Normalize by 1/sqrt(n)
auto sqrt_n = context.mark_node(std::make_shared<v0::Sqrt>(n));
if (inverse) {
normalized = context.mark_node(std::make_shared<v1::Multiply>(node, sqrt_n));
} else {
normalized = context.mark_node(std::make_shared<v1::Divide>(node, sqrt_n));
}
} else {
FRONT_END_THROW("Unrecognized normalization mode. Only forward, backward and ortho are supported.");
}
return normalized;
}

std::tuple<Output<Node>, Output<Node>> get_dim_s(const NodeContext& context, const Output<Node>& x, bool is_irfft) {
Output<Node> input_shape;
Output<Node> input_rank_scalar;
std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, input, true);
std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, x, true);

auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));

Output<Node> raw_s;
// Inputs can be either none or List. Check whether input values should be used or should be set to default values.
Expand All @@ -48,33 +82,70 @@ OutputVector translate_fft_rfftn(const NodeContext& context) {
raw_s = get_input_concat_if_list(context, 1);
raw_s = context.mark_node(std::make_shared<v0::Convert>(raw_s, element::i32));
}
Output<Node> dim;
// Handle dim parameter containing vector of integers indicating dimensions to be transformed.
Output<Node> dim;
if (!context.input_is_none(2)) {
// dim is provided, load from input.
dim = get_input_concat_if_list(context, 2);
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
} else if (!context.input_is_none(1)) {
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
auto s_len = context.mark_node(std::make_shared<v3::ShapeOf>(raw_s, element::i32));
auto slice_start = context.mark_node(std::make_shared<v1::Subtract>(input_rank_scalar, s_len));
auto slice_start_scalar = context.mark_node(std::make_shared<v0::Squeeze>(slice_start));
dim = context.mark_node(
std::make_shared<v4::Range>(slice_start_scalar, input_rank_scalar, const_1, element::i32));
auto start = context.mark_node(std::make_shared<v1::Subtract>(input_rank_scalar, s_len));
auto start_scalar = context.mark_node(std::make_shared<v0::Squeeze>(start));
dim = context.mark_node(std::make_shared<v4::Range>(start_scalar, input_rank_scalar, const_1, element::i32));
} else {
// Dim and s are set to default, use all of dimensions.
dim = context.mark_node(std::make_shared<v4::Range>(const_0, input_rank_scalar, const_1, element::i32));
}
if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) {
auto const_neg_1_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
dim = context.mark_node(std::make_shared<v1::Reshape>(dim, const_neg_1_1d, false));
}

Output<Node> default_s;
if (is_irfft) {
auto const_2_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
// Calculate default s values. Use full available size except last element, which is set to even value in last
// dimension: s[-1] = 2 * (complex_input_shape[dim[-1]])
auto default_s_raw = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
auto last_s = context.mark_node(std::make_shared<v8::Gather>(default_s_raw, const_neg_1, const_0));
auto last_s_m_1 = context.mark_node(std::make_shared<v1::Subtract>(last_s, const_1));
auto s_upd = context.mark_node(std::make_shared<v1::Multiply>(last_s_m_1, const_2_1d));
auto s_shape = context.mark_node(std::make_shared<v3::ShapeOf>(default_s_raw, element::i32));
auto last_s_idx = context.mark_node(std::make_shared<v1::Subtract>(s_shape, const_1));
default_s = context.mark_node(std::make_shared<v3::ScatterUpdate>(default_s_raw, last_s_idx, s_upd, const_0));
} else {
default_s = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
}
Output<Node> s;
if (context.input_is_none(1)) {
// Value for s was set to default, use full size for all dimensions.
s = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
s = default_s;
} else {
// Values for s were provided. Replace -1 values with default full size in given dimension.
auto full_s_cond = context.mark_node(std::make_shared<v1::Equal>(raw_s, const_neg_1));
auto full_s_values = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
s = context.mark_node(std::make_shared<v1::Select>(full_s_cond, full_s_values, raw_s));
s = context.mark_node(std::make_shared<v1::Select>(full_s_cond, default_s, raw_s));
}
return {dim, s};
}
} // namespace

OutputVector translate_fft_fftn(const NodeContext& context) {
num_inputs_check(context, 1, 4, true);
auto input = context.get_input(0);

Output<Node> dim;
Output<Node> s;
std::tie(dim, s) = get_dim_s(context, input, false);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
if (complex_type_mark) {
input = complex_type_mark->get_data();
} else {
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
const_0 = context.mark_node(std::make_shared<v1::ConvertLike>(const_0, input));
input = std::make_shared<ComplexTypeMark>(input, const_0)->get_data();
}

// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
Expand All @@ -83,123 +154,83 @@ OutputVector translate_fft_rfftn(const NodeContext& context) {
norm = context.const_input<std::string>(3);
}

auto rdft = context.mark_node(std::make_shared<v9::RDFT>(input, dim, s));
auto node = context.mark_node(std::make_shared<v7::DFT>(input, dim, s));

// Apply normalizations
auto n_int = context.mark_node(std::make_shared<v1::ReduceProd>(s, const_0));
auto n = context.mark_node(std::make_shared<v1::ConvertLike>(n_int, rdft));
Output<Node> normalized_rfftn;
if (norm == "forward") {
// Normalize by 1/n
normalized_rfftn = context.mark_node(std::make_shared<v1::Divide>(rdft, n));
} else if (norm == "backward") {
// No normalization
normalized_rfftn = rdft;
} else if (norm == "ortho") {
// Normalize by 1/sqrt(n)
auto sqrt_n = context.mark_node(std::make_shared<v0::Sqrt>(n));
normalized_rfftn = context.mark_node(std::make_shared<v1::Divide>(rdft, sqrt_n));
} else {
FRONT_END_THROW(
"aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
Output<Node> normalized = normalize(context, node, s, norm, false);
return {std::make_shared<ComplexTypeMark>(normalized, normalized.get_element_type())};
}

OutputVector translate_fft_rfftn(const NodeContext& context) {
// aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
num_inputs_check(context, 1, 4);
auto input = context.get_input(0);

Output<Node> dim;
Output<Node> s;
std::tie(dim, s) = get_dim_s(context, input, false);

// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
std::string norm = "backward";
if (!context.input_is_none(3)) {
norm = context.const_input<std::string>(3);
}

return {std::make_shared<ComplexTypeMark>(normalized_rfftn, normalized_rfftn.get_element_type())};
auto node = context.mark_node(std::make_shared<v9::RDFT>(input, dim, s));

// Apply normalizations
Output<Node> normalized = normalize(context, node, s, norm, false);
return {std::make_shared<ComplexTypeMark>(normalized, normalized.get_element_type())};
}

OutputVector translate_fft_irfftn(const NodeContext& context) {
OutputVector translate_fft_ifftn(const NodeContext& context) {
// aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
num_inputs_check(context, 1, 4, true);
auto input = context.get_input(0);

Output<Node> dim;
Output<Node> s;
std::tie(dim, s) = get_dim_s(context, input, false);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::fft_irfftn operation expects complex type tensor on input.");
PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "Operation expects complex type tensor on input.");
input = complex_type_mark->get_data();

auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_scalar_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_scalar_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
std::string norm = "backward";
if (!context.input_is_none(3)) {
norm = context.const_input<std::string>(3);
}

// Input shape of complex number (excluding dimension created by concatenation of real and imag)
auto complex_input_shape = get_complex_shape(context, input);
auto input_rank = context.mark_node(std::make_shared<v3::ShapeOf>(complex_input_shape, element::i32));
auto input_rank_scalar = context.mark_node(std::make_shared<v0::Squeeze>(input_rank));
auto node = context.mark_node(std::make_shared<v7::IDFT>(input, dim, s));

Output<Node> raw_s;
// Inputs can be either none or List. Check whether input values should be used or should be set to default values.
if (!context.input_is_none(1)) {
// s is provided, load from input.
raw_s = get_input_concat_if_list(context, 1);
raw_s = context.mark_node(std::make_shared<v0::Convert>(raw_s, element::i32));
}
Output<Node> normalized = normalize(context, node, s, norm, true);
return {std::make_shared<ComplexTypeMark>(normalized, normalized.get_element_type())};
}

// Handle dim parameter containing vector of integers indicating dimensions to be transformed.
Output<Node> dim;
if (!context.input_is_none(2)) {
// Dim values is provided, load from input.
dim = get_input_concat_if_list(context, 2);
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
} else if (!context.input_is_none(1)) {
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
auto s_len = context.mark_node(std::make_shared<v3::ShapeOf>(raw_s, element::i32));
auto range_start = context.mark_node(std::make_shared<v1::Subtract>(input_rank, s_len));
auto range_start_scalar = context.mark_node(std::make_shared<v0::Squeeze>(range_start));
dim = context.mark_node(
std::make_shared<v4::Range>(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32));
} else {
// Dim and s are set to default, use all of dimensions.
dim = context.mark_node(
std::make_shared<v4::Range>(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32));
}
OutputVector translate_fft_irfftn(const NodeContext& context) {
// aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
num_inputs_check(context, 1, 4, true);
auto input = context.get_input(0);

// Calculate default s values. Use full available size except last element, which is set to even value in last
// dimension: s[-1] = 2 * (complex_input_shape[dim[-1]])
auto default_s_raw = context.mark_node(std::make_shared<v8::Gather>(complex_input_shape, dim, const_0));
auto last_s = context.mark_node(std::make_shared<v8::Gather>(default_s_raw, const_neg_1, const_0));
auto last_s_m_1 = context.mark_node(std::make_shared<v1::Subtract>(last_s, const_1));
auto s_upd = context.mark_node(std::make_shared<v1::Multiply>(last_s_m_1, const_2));
auto s_shape = context.mark_node(std::make_shared<v3::ShapeOf>(default_s_raw, element::i32));
auto last_s_idx = context.mark_node(std::make_shared<v1::Subtract>(s_shape, const_1));
auto default_s = context.mark_node(std::make_shared<v3::ScatterUpdate>(default_s_raw, last_s_idx, s_upd, const_0));

// Handle s parameter containing vector of intigers indicating signal sizes for dimensions.
Output<Node> dim;
Output<Node> s;
if (!context.input_is_none(1)) {
// Values for s were provided. Replace -1 values with default full size in given dimension.
auto full_s_cond = context.mark_node(std::make_shared<v1::Equal>(raw_s, const_neg_1));
s = context.mark_node(std::make_shared<v1::Select>(full_s_cond, default_s, raw_s));
} else {
// Value for s was set to default.
s = default_s;
}
std::tie(dim, s) = get_dim_s(context, input, true);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "Operation expects complex type tensor on input.");
input = complex_type_mark->get_data();

// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
std::string norm = "backward";
if (!context.input_is_none(3)) {
norm = context.const_input<std::string>(3);
}

auto irdft = context.mark_node(std::make_shared<v9::IRDFT>(input, dim, s));
auto node = context.mark_node(std::make_shared<v9::IRDFT>(input, dim, s));

// Apply normalizations.
auto n_int = context.mark_node(std::make_shared<v1::ReduceProd>(s, const_0));
auto n = context.mark_node(std::make_shared<v1::ConvertLike>(n_int, irdft));
Output<Node> normalized_irfftn;
if (norm == "forward") {
normalized_irfftn = context.mark_node(std::make_shared<v1::Multiply>(irdft, n));
} else if (norm == "backward") {
normalized_irfftn = irdft;
} else if (norm == "ortho") {
auto sqrt_n = context.mark_node(std::make_shared<v0::Sqrt>(n));
normalized_irfftn = context.mark_node(std::make_shared<v1::Multiply>(irdft, sqrt_n));
} else {
FRONT_END_THROW(
"aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
}
return {normalized_irfftn};
Output<Node> normalized = normalize(context, node, s, norm, true);
return {normalized};
}

} // namespace op
Expand Down
13 changes: 13 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ OP_CONVERTER(translate_expm1);
OP_CONVERTER(translate_eye);
OP_CONVERTER(translate_fake_quantize_per_channel_affine);
OP_CONVERTER(translate_fake_quantize_per_tensor_affine);
OP_CONVERTER(translate_fft_fftn);
OP_CONVERTER(translate_fft_ifftn);
OP_CONVERTER(translate_fft_irfftn);
OP_CONVERTER(translate_fft_rfftn);
OP_CONVERTER(translate_fill);
Expand Down Expand Up @@ -484,7 +486,17 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine},
{"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine},
{"aten::feature_dropout", op::skip_node},
{"aten::fft_fft", op::translate_fft_fftn},
{"aten::fft_fft2", op::translate_fft_fftn},
{"aten::fft_fftn", op::translate_fft_fftn},
{"aten::fft_ifft", op::translate_fft_ifftn},
{"aten::fft_ifft2", op::translate_fft_ifftn},
{"aten::fft_ifftn", op::translate_fft_ifftn},
{"aten::fft_irfft", op::translate_fft_irfftn},
{"aten::fft_irfft2", op::translate_fft_irfftn},
{"aten::fft_irfftn", op::translate_fft_irfftn},
{"aten::fft_rfft", op::translate_fft_rfftn},
{"aten::fft_rfft2", op::translate_fft_rfftn},
{"aten::fft_rfftn", op::translate_fft_rfftn},
{"aten::fill", op::translate_fill},
{"aten::fill_diagonal", op::translate_fill_diagonal},
Expand Down Expand Up @@ -964,6 +976,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten.repeat.default", op::translate_repeat_fx},
{"aten.repeat_interleave.Tensor", op::translate_repeat_interleave},
{"aten.rms_norm.default", op::translate_rms_norm},
{"aten.roll.default", op::translate_roll},
{"aten.rsqrt.default", op::translate_rsqrt},
Expand Down
Loading
Loading