Skip to content

Commit 8d18cd6

Browse files
NingLi670rkazants
andauthored
[TF FE]: Support complex tensors for Equal operation (#29339)
### Details: - Support complex tensors for Equal operation - Add corresponding layer test ### Tickets: - #22947 --------- Co-authored-by: Roman Kazantsev <[email protected]>
1 parent a9faa94 commit 8d18cd6

File tree

6 files changed

+121
-1
lines changed

6 files changed

+121
-1
lines changed

Diff for: src/frontends/common_translators/include/common_translators.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ COMMON_OP_CONVERTER(translate_atan2);
2020
COMMON_OP_CONVERTER(translate_angle);
2121
COMMON_OP_CONVERTER(translate_erfc);
2222

23+
COMMON_OP_CONVERTER(translate_equal);
24+
2325
OutputVector translate_atan2_util(const NodeContext& context, const Output<Node>& lhs, const Output<Node>& rhs);
2426
OutputVector translate_erfc_util(const NodeContext& context, const Output<Node>& data);
2527

Diff for: src/frontends/common_translators/src/op/equal.cpp

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/op/equal.hpp"
6+
7+
#include "common_translators.hpp"
8+
#include "openvino/frontend/complex_type_mark.hpp"
9+
#include "openvino/frontend/exception.hpp"
10+
#include "openvino/op/constant.hpp"
11+
#include "openvino/op/reduce_logical_and.hpp"
12+
#include "utils.hpp"
13+
14+
using namespace std;
15+
using namespace ov::op;
16+
17+
namespace ov {
18+
namespace frontend {
19+
namespace common_translators {
20+
21+
OutputVector translate_equal(const NodeContext& node) {
22+
num_inputs_check(node, 2, 2, true);
23+
auto lhs = node.get_input(0);
24+
auto rhs = node.get_input(1);
25+
26+
auto lhs_complex = as_type_ptr<ComplexTypeMark>(lhs.get_node_shared_ptr());
27+
auto rhs_complex = as_type_ptr<ComplexTypeMark>(rhs.get_node_shared_ptr());
28+
29+
auto op_type = node.get_op_type();
30+
FRONT_END_OP_CONVERSION_CHECK(!(lhs_complex && !rhs_complex) && !(!lhs_complex && rhs_complex),
31+
op_type + " operation expects both operands to be of the same type.");
32+
33+
// both operands are of complex type
34+
if (lhs_complex && rhs_complex) {
35+
auto lhs_data = lhs_complex->get_data();
36+
auto rhs_data = rhs_complex->get_data();
37+
auto equal = make_shared<v1::Equal>(lhs_data, rhs_data);
38+
39+
// reduce along the last dimension using ReduceAnd
40+
auto reduce_axes = make_shared<v0::Constant>(element::i32, Shape{1}, std::vector<int32_t>{-1});
41+
auto equal_reduced = make_shared<v1::ReduceLogicalAnd>(equal, reduce_axes, false);
42+
43+
return {equal_reduced};
44+
}
45+
46+
// both operands are real
47+
auto result = make_shared<v1::Equal>(lhs, rhs);
48+
return {result};
49+
};
50+
51+
} // namespace common_translators
52+
} // namespace frontend
53+
} // namespace ov

Diff for: src/frontends/tensorflow/src/op_table.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
177177
{"RightShift", CreatorFunction(translate_binary_op<v15::BitwiseRightShift>)},
178178
{"LeftShift", CreatorFunction(translate_binary_op<v15::BitwiseLeftShift>)},
179179
{"Div", CreatorFunction(translate_div_op)},
180-
{"Equal", CreatorFunction(translate_binary_op<v1::Equal>)},
180+
{"Equal", CreatorFunction(translate_equal_op)},
181181
{"FloorMod", CreatorFunction(translate_binary_op<v1::FloorMod>)},
182182
{"Greater", CreatorFunction(translate_binary_op<v1::Greater>)},
183183
{"GreaterEqual", CreatorFunction(translate_binary_op<v1::GreaterEqual>)},

Diff for: src/frontends/tensorflow_common/include/common_op_table.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ OP_CONVERTER(translate_mul_op);
7474
OP_CONVERTER(translate_dynamic_partition_op);
7575
OP_CONVERTER(translate_einsum_op);
7676
OP_CONVERTER(translate_elu_op);
77+
OP_CONVERTER(translate_equal_op);
7778
OP_CONVERTER(translate_erfc_op);
7879
OP_CONVERTER(translate_expm1_op);
7980
OP_CONVERTER(translate_expand_dims_op);

Diff for: src/frontends/tensorflow_common/src/op/binary_op.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//
44

55
#include "common_op_table.hpp"
6+
#include "common_translators.hpp"
67
#include "helper_ops/complex_type_mark.hpp"
78
#include "openvino/op/add.hpp"
89
#include "openvino/op/bitwise_and.hpp"
@@ -127,6 +128,15 @@ OutputVector translate_sub_op(const NodeContext& node) {
127128
return {result};
128129
}
129130

131+
OutputVector translate_equal_op(const NodeContext& node) {
132+
default_op_checks(node, 2, {"Equal"}, true);
133+
134+
auto result = common_translators::translate_equal(node);
135+
136+
set_node_name(node.get_name(), result[0].get_node_shared_ptr());
137+
return result;
138+
}
139+
130140
template OutputVector translate_binary_op<v1::Add>(const NodeContext& node);
131141
template OutputVector translate_binary_op<v13::BitwiseAnd>(const NodeContext& node);
132142
template OutputVector translate_binary_op<v13::BitwiseOr>(const NodeContext& node);

Diff for: tests/layer_tests/tensorflow_tests/test_tf_Equal.py

+54
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,57 @@ def test_equal_str(self, x_shape, y_shape,
259259
self._test(*self.create_equal_net(x_shape=x_shape, y_shape=y_shape),
260260
ie_device, precision, ir_version, temp_dir=temp_dir,
261261
use_legacy_frontend=use_legacy_frontend)
262+
263+
264+
class TestComplexEqual(CommonTFLayerTest):
265+
def _prepare_input(self, inputs_info):
266+
rng = np.random.default_rng()
267+
assert 'param_real_1:0' in inputs_info
268+
assert 'param_imag_1:0' in inputs_info
269+
assert 'param_real_2:0' in inputs_info
270+
assert 'param_imag_2:0' in inputs_info
271+
param_real_shape_1 = inputs_info['param_real_1:0']
272+
param_imag_shape_1 = inputs_info['param_imag_1:0']
273+
param_real_shape_2 = inputs_info['param_real_2:0']
274+
param_imag_shape_2 = inputs_info['param_imag_2:0']
275+
inputs_data = {}
276+
inputs_data['param_real_1:0'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
277+
inputs_data['param_imag_1:0'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2
278+
inputs_data['param_real_2:0'] = 4 * rng.random(param_real_shape_2).astype(np.float32) - 2
279+
inputs_data['param_imag_2:0'] = 4 * rng.random(param_imag_shape_2).astype(np.float32) - 2
280+
return inputs_data
281+
282+
def create_complex_equal_net(self, input_shape):
283+
tf.compat.v1.reset_default_graph()
284+
# Create the graph and model
285+
with tf.compat.v1.Session() as sess:
286+
param_real1 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real_1')
287+
param_imag1 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag_1')
288+
param_real2 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real_2')
289+
param_imag2 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag_2')
290+
complex1 = tf.raw_ops.Complex(real=param_real1, imag=param_imag1)
291+
complex2 = tf.raw_ops.Complex(real=param_real2, imag=param_imag2)
292+
tf.raw_ops.Equal(x=complex1, y=complex2, name="complex_equal")
293+
tf.compat.v1.global_variables_initializer()
294+
tf_net = sess.graph_def
295+
296+
return tf_net, None
297+
298+
test_data_basic = [
299+
dict(input_shape=[]),
300+
dict(input_shape=[2]),
301+
dict(input_shape=[1, 3]),
302+
dict(input_shape=[2, 3, 4]),
303+
dict(input_shape=[3, 4, 5, 6]),
304+
]
305+
306+
@pytest.mark.parametrize("params", test_data_basic)
307+
@pytest.mark.precommit
308+
@pytest.mark.nightly
309+
def test_complex_equal(self, params, ie_device, precision, ir_version, temp_dir,
310+
use_legacy_frontend):
311+
self._test(
312+
*self.create_complex_equal_net(**params),
313+
ie_device, precision, ir_version, temp_dir=temp_dir,
314+
use_legacy_frontend=use_legacy_frontend
315+
)

0 commit comments

Comments
 (0)