From 1303f4dc4f0adc32528a7d53fa75c0b38e542aea Mon Sep 17 00:00:00 2001 From: Piyush kumar Date: Sat, 20 Sep 2025 10:58:26 +0530 Subject: [PATCH 1/2] feat(pt_frontend): Add support for prim::data operation --- src/frontends/pytorch/src/op/data.cpp | 27 ++++++++++++ src/frontends/pytorch/src/op_table.cpp | 3 +- tests/layer_tests/pytorch_tests/test_data.py | 45 ++++++++++++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 src/frontends/pytorch/src/op/data.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_data.py diff --git a/src/frontends/pytorch/src/op/data.cpp b/src/frontends/pytorch/src/op/data.cpp new file mode 100644 index 00000000000000..55c0213ceefd1a --- /dev/null +++ b/src/frontends/pytorch/src/op/data.cpp @@ -0,0 +1,27 @@ + +#include "openvino/frontend/complex_type_mark.hpp" +#include "openvino/frontend/pytorch/node_context.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_data(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + + if (context.get_decoder()->get_output_type(0).is()) { + auto mark = std::make_shared(input, input.get_element_type()); + return {context.mark_node(mark)}; + } + + return {input}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 2f72ba91eb651c..faeb3c497d19d9 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -360,7 +360,7 @@ OP_CONVERTER(translate_embedding_ext); OP_CONVERTER(translate_linear_awq); OP_CONVERTER(translate_linear_bitnet); OP_CONVERTER(translate_linear_ext); - +OP_CONVERTER(translate_data); } // namespace op // Supported ops for TorchScript @@ -795,6 +795,7 @@ const std::unordered_map get_supported_ops_ts() { {"prim::TupleIndex", op::translate_tuple_index}, // prim::TupleUnpack - Supported in limited set of patterns {"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode. + {"prim::data", op::translate_data}, {"quantized::add", op::translate_quantized_add}, {"quantized::add_relu", op::translate_quantized_add_relu}, {"quantized::cat", op::translate_quantized_cat}, diff --git a/tests/layer_tests/pytorch_tests/test_data.py b/tests/layer_tests/pytorch_tests/test_data.py new file mode 100644 index 00000000000000..b24e7c0e32f231 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_data.py @@ -0,0 +1,45 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +class Model(torch.nn.Module): + def forward(self, x): + return x.data + +class ModelGrad(torch.nn.Module): + def forward(self, x): + y = x * 2.5 + return y.data + + +class TestPrimData(PytorchLayerTest): + def _prepare_input(self): + np.random.seed(self.seed) + data = (np.random.randn(*self.shape) * 10).astype(np.float32) + tensor = torch.from_numpy(data).to(self.dtype) + return (tensor.numpy(),) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.int32, torch.int64]) + @pytest.mark.parametrize("shape", [[2, 3, 4], [1, 5], [10]]) + def test_data_basic(self, shape, dtype, ie_device, precision, ir_version): + self.shape = shape + self.dtype = dtype + self.seed = 0 + + + self._test(Model(), None, "prim::data", ie_device, precision, ir_version) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.int32]) + def test_data_requires_grad(self, dtype, ie_device, precision, ir_version): + self.shape = (3, 2) + self.dtype = dtype + self.seed = 1 + + + self._test(ModelGrad(), None, "prim::data", ie_device, precision, ir_version) + From 41edc34e3d96af9c83f0b6db0daf0d842a0b15b9 Mon Sep 17 00:00:00 2001 From: Piyush kumar Date: Mon, 22 Sep 2025 22:46:19 +0530 Subject: [PATCH 2/2] Support prim::data by skipping node; remove translate_data impl; mark complex tests xfail --- src/frontends/pytorch/src/op/data.cpp | 27 -------------- src/frontends/pytorch/src/op_table.cpp | 3 +- tests/layer_tests/pytorch_tests/test_data.py | 38 ++++++++++++++++---- 3 files changed, 33 insertions(+), 35 deletions(-) delete mode 100644 src/frontends/pytorch/src/op/data.cpp diff --git a/src/frontends/pytorch/src/op/data.cpp b/src/frontends/pytorch/src/op/data.cpp deleted file mode 100644 index 55c0213ceefd1a..00000000000000 --- a/src/frontends/pytorch/src/op/data.cpp +++ /dev/null @@ -1,27 +0,0 @@ - -#include "openvino/frontend/complex_type_mark.hpp" -#include "openvino/frontend/pytorch/node_context.hpp" -#include "utils.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace op { - -OutputVector translate_data(const NodeContext& context) { - num_inputs_check(context, 1, 1); - - auto input = context.get_input(0); - - if (context.get_decoder()->get_output_type(0).is()) { - auto mark = std::make_shared(input, input.get_element_type()); - return {context.mark_node(mark)}; - } - - return {input}; -} - -} // namespace op -} // namespace pytorch -} // namespace frontend -} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index faeb3c497d19d9..a6b342f9560b5a 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -360,7 +360,6 @@ OP_CONVERTER(translate_embedding_ext); OP_CONVERTER(translate_linear_awq); OP_CONVERTER(translate_linear_bitnet); OP_CONVERTER(translate_linear_ext); -OP_CONVERTER(translate_data); } // namespace op // Supported ops for TorchScript @@ -795,7 +794,7 @@ const std::unordered_map get_supported_ops_ts() { {"prim::TupleIndex", op::translate_tuple_index}, // prim::TupleUnpack - Supported in limited set of patterns {"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode. - {"prim::data", op::translate_data}, + {"prim::data", op::skip_node}, {"quantized::add", op::translate_quantized_add}, {"quantized::add_relu", op::translate_quantized_add_relu}, {"quantized::cat", op::translate_quantized_cat}, diff --git a/tests/layer_tests/pytorch_tests/test_data.py b/tests/layer_tests/pytorch_tests/test_data.py index b24e7c0e32f231..afd65cef57688d 100644 --- a/tests/layer_tests/pytorch_tests/test_data.py +++ b/tests/layer_tests/pytorch_tests/test_data.py @@ -7,12 +7,14 @@ from pytorch_layer_test_class import PytorchLayerTest + class Model(torch.nn.Module): def forward(self, x): return x.data + class ModelGrad(torch.nn.Module): - def forward(self, x): + def forward(self, x): y = x * 2.5 return y.data @@ -20,7 +22,13 @@ def forward(self, x): class TestPrimData(PytorchLayerTest): def _prepare_input(self): np.random.seed(self.seed) - data = (np.random.randn(*self.shape) * 10).astype(np.float32) + if self.dtype in (torch.complex64, torch.complex128): + real = (np.random.randn(*self.shape) * 10).astype(np.float32) + imag = (np.random.randn(*self.shape) * 10).astype(np.float32) + data = real + 1j * imag + data = data.astype(np.complex128 if self.dtype == torch.complex128 else np.complex64) + else: + data = (np.random.randn(*self.shape) * 10).astype(np.float32) tensor = torch.from_numpy(data).to(self.dtype) return (tensor.numpy(),) @@ -30,8 +38,6 @@ def test_data_basic(self, shape, dtype, ie_device, precision, ir_version): self.shape = shape self.dtype = dtype self.seed = 0 - - self._test(Model(), None, "prim::data", ie_device, precision, ir_version) @pytest.mark.parametrize("dtype", [torch.float32, torch.int32]) @@ -39,7 +45,27 @@ def test_data_requires_grad(self, dtype, ie_device, precision, ir_version): self.shape = (3, 2) self.dtype = dtype self.seed = 1 - - self._test(ModelGrad(), None, "prim::data", ie_device, precision, ir_version) + @pytest.mark.parametrize("dtype", [torch.complex64, torch.complex128]) + @pytest.mark.parametrize("shape", [[2, 3], [1, 5], [4]]) + @pytest.mark.xfail( + reason="OpenVINO frontend does not yet support complex tensor inputs", + raises=AssertionError, + ) + def test_data_complex(self, shape, dtype, ie_device, precision, ir_version): + self.shape = shape + self.dtype = dtype + self.seed = 2 + self._test(Model(), None, "prim::data", ie_device, precision, ir_version) + + @pytest.mark.parametrize("dtype", [torch.complex64, torch.complex128]) + @pytest.mark.xfail( + reason="OpenVINO frontend does not yet support complex tensor inputs", + raises=AssertionError, + ) + def test_data_complex_requires_grad(self, dtype, ie_device, precision, ir_version): + self.shape = (2, 3) + self.dtype = dtype + self.seed = 3 + self._test(ModelGrad(), None, "prim::data", ie_device, precision, ir_version)