From a141551fbfecdc0397b6e235cb481955e10c818b Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Mon, 16 Sep 2024 15:32:24 +0200 Subject: [PATCH] Add tests for reducel2 --- onnx2pytorch/operations/reducel2.py | 7 +- .../onnx2pytorch/operations/test_reducel2.py | 249 ++++++++++++++++++ 2 files changed, 253 insertions(+), 3 deletions(-) create mode 100644 tests/onnx2pytorch/operations/test_reducel2.py diff --git a/onnx2pytorch/operations/reducel2.py b/onnx2pytorch/operations/reducel2.py index ac29eb8..6a9d452 100644 --- a/onnx2pytorch/operations/reducel2.py +++ b/onnx2pytorch/operations/reducel2.py @@ -1,13 +1,14 @@ import torch from torch import nn + class ReduceL2(nn.Module): def __init__( self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False ): self.opset_version = opset_version self.dim = dim - self.keepdim = keepdim + self.keepdim = bool(keepdim) self.noop_with_empty_axes = noop_with_empty_axes super().__init__() @@ -21,11 +22,11 @@ def forward(self, data: torch.Tensor, axes: torch.Tensor = None): return data else: dims = tuple(range(data.ndim)) - + if isinstance(dims, int): dim = dims else: - dim=tuple(list(dims)) + dim = tuple(list(dims)) ret = torch.sqrt(torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim)) return ret diff --git a/tests/onnx2pytorch/operations/test_reducel2.py b/tests/onnx2pytorch/operations/test_reducel2.py new file mode 100644 index 0000000..14947d2 --- /dev/null +++ b/tests/onnx2pytorch/operations/test_reducel2.py @@ -0,0 +1,249 @@ +import numpy as np +import onnx +import pytest +import torch + +from onnx2pytorch.convert.operations import convert_operations +from onnx2pytorch.operations import ReduceL2 + + +@pytest.fixture +def tensor(): + return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + +def test_reduce_l2_older_opset_version(tensor): + shape = [3, 2, 2] + axes = np.array([2], dtype=np.int64) + keepdims = 0 + + data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + op = ReduceL2(opset_version=10, keepdim=keepdims, dim=axes) + + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + +def test_do_not_keepdims_older_opset_version() -> None: + opset_version = 10 + shape = [3, 2, 2] + axes = np.array([2], dtype=np.int64) + keepdims = 0 + + node = onnx.helper.make_node( + "ReduceL2", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + axes=axes, + ) + graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) + + ops = list(convert_operations(graph, opset_version)) + op = ops[0][2] + + assert isinstance(op, ReduceL2) + + data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + # print(data) + # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] + + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + # print(reduced) + # [[2.23606798, 5.], + # [7.81024968, 10.63014581], + # [13.45362405, 16.2788206]] + + out = op(torch.from_numpy(data)) + np.testing.assert_array_equal(out, reduced) + + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + + out = op(torch.from_numpy(data)) + np.testing.assert_array_equal(out, reduced) + + +def test_do_not_keepdims() -> None: + shape = [3, 2, 2] + axes = np.array([2], dtype=np.int64) + keepdims = 0 + + node = onnx.helper.make_node( + "ReduceL2", + inputs=["data", "axes"], + outputs=["reduced"], + keepdims=keepdims, + ) + graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) + ops = list(convert_operations(graph, 18)) + op = ops[0][2] + + assert isinstance(op, ReduceL2) + + data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + # print(data) + # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] + + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + # print(reduced) + # [[2.23606798, 5.], + # [7.81024968, 10.63014581], + # [13.45362405, 16.2788206]] + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + +def test_export_keepdims() -> None: + shape = [3, 2, 2] + axes = np.array([2], dtype=np.int64) + keepdims = 1 + + node = onnx.helper.make_node( + "ReduceL2", + inputs=["data", "axes"], + outputs=["reduced"], + keepdims=keepdims, + ) + graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) + ops = list(convert_operations(graph, 18)) + op = ops[0][2] + + data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + # print(data) + # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] + + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + # print(reduced) + # [[[2.23606798], [5.]] + # [[7.81024968], [10.63014581]] + # [[13.45362405], [16.2788206 ]]] + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + +def test_export_default_axes_keepdims() -> None: + shape = [3, 2, 2] + axes = np.array([], dtype=np.int64) + keepdims = 1 + + node = onnx.helper.make_node( + "ReduceL2", inputs=["data", "axes"], outputs=["reduced"], keepdims=keepdims + ) + graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) + ops = list(convert_operations(graph, 18)) + op = ops[0][2] + + data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + # print(data) + # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] + + reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1)) + # print(reduced) + # [[[25.49509757]]] + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1)) + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + +def test_export_negative_axes_keepdims() -> None: + shape = [3, 2, 2] + axes = np.array([-1], dtype=np.int64) + keepdims = 1 + + node = onnx.helper.make_node( + "ReduceL2", + inputs=["data", "axes"], + outputs=["reduced"], + keepdims=keepdims, + ) + graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) + ops = list(convert_operations(graph, 18)) + op = ops[0][2] + + data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + # print(data) + # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] + + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + # print(reduced) + # [[[2.23606798], [5.]] + # [[7.81024968], [10.63014581]] + # [[13.45362405], [16.2788206 ]]] + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.sqrt( + np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) + ) + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced) + + +def test_export_empty_set() -> None: + shape = [2, 0, 4] + keepdims = 1 + reduced_shape = [2, 1, 4] + + node = onnx.helper.make_node( + "ReduceL2", + inputs=["data", "axes"], + outputs=["reduced"], + keepdims=keepdims, + ) + graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) + ops = list(convert_operations(graph, 18)) + op = ops[0][2] + + data = np.array([], dtype=np.float32).reshape(shape) + axes = np.array([1], dtype=np.int64) + reduced = np.array(np.zeros(reduced_shape, dtype=np.float32)) + + out = op(torch.from_numpy(data), axes=axes) + np.testing.assert_array_equal(out, reduced)