Skip to content

Commit

Permalink
Add tests for reducel2
Browse files Browse the repository at this point in the history
  • Loading branch information
Talmaj committed Sep 16, 2024
1 parent 8b1a83e commit 5f95c61
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 3 deletions.
7 changes: 4 additions & 3 deletions onnx2pytorch/operations/reducel2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
from torch import nn


class ReduceL2(nn.Module):
def __init__(
self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False
):
self.opset_version = opset_version
self.dim = dim
self.keepdim = keepdim
self.keepdim = bool(keepdim)
self.noop_with_empty_axes = noop_with_empty_axes
super().__init__()

Expand All @@ -21,11 +22,11 @@ def forward(self, data: torch.Tensor, axes: torch.Tensor = None):
return data
else:
dims = tuple(range(data.ndim))

if isinstance(dims, int):
dim = dims
else:
dim=tuple(list(dims))
dim = tuple(list(dims))

ret = torch.sqrt(torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim))
return ret
249 changes: 249 additions & 0 deletions tests/onnx2pytorch/operations/test_reducel2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import numpy as np
import onnx
import pytest
import torch

from onnx2pytorch.convert.operations import convert_operations
from onnx2pytorch.operations import ReduceL2


@pytest.fixture
def tensor():
return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])


def test_reduce_l2_older_opset_version(tensor):
shape = [3, 2, 2]
axes = np.array([2], dtype=np.int64)
keepdims = 0

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
op = ReduceL2(opset_version=10, keepdim=keepdims, dim=axes)

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_do_not_keepdims_older_opset_version() -> None:
opset_version = 10
shape = [3, 2, 2]
axes = np.array([2], dtype=np.int64)
keepdims = 0

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data"],
outputs=["reduced"],
keepdims=keepdims,
axes=axes,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])

ops = list(convert_operations(graph, opset_version))
op = ops[0][2]

assert isinstance(op, ReduceL2)

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)
# print(reduced)
# [[2.23606798, 5.],
# [7.81024968, 10.63014581],
# [13.45362405, 16.2788206]]

out = op(torch.from_numpy(data))
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data))
np.testing.assert_array_equal(out, reduced)


def test_do_not_keepdims() -> None:
shape = [3, 2, 2]
axes = np.array([2], dtype=np.int64)
keepdims = 0

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data", "axes"],
outputs=["reduced"],
keepdims=keepdims,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

assert isinstance(op, ReduceL2)

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)
# print(reduced)
# [[2.23606798, 5.],
# [7.81024968, 10.63014581],
# [13.45362405, 16.2788206]]

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_export_keepdims() -> None:
shape = [3, 2, 2]
axes = np.array([2], dtype=np.int64)
keepdims = 1

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data", "axes"],
outputs=["reduced"],
keepdims=keepdims,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)
# print(reduced)
# [[[2.23606798], [5.]]
# [[7.81024968], [10.63014581]]
# [[13.45362405], [16.2788206 ]]]

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_export_default_axes_keepdims() -> None:
shape = [3, 2, 2]
axes = np.array([], dtype=np.int64)
keepdims = 1

node = onnx.helper.make_node(
"ReduceL2", inputs=["data", "axes"], outputs=["reduced"], keepdims=keepdims
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1))
# print(reduced)
# [[[25.49509757]]]

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1))

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_export_negative_axes_keepdims() -> None:
shape = [3, 2, 2]
axes = np.array([-1], dtype=np.int64)
keepdims = 1

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data", "axes"],
outputs=["reduced"],
keepdims=keepdims,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)
# print(reduced)
# [[[2.23606798], [5.]]
# [[7.81024968], [10.63014581]]
# [[13.45362405], [16.2788206 ]]]

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_export_empty_set() -> None:
shape = [2, 0, 4]
keepdims = 1
reduced_shape = [2, 1, 4]

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data", "axes"],
outputs=["reduced"],
keepdims=keepdims,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

data = np.array([], dtype=np.float32).reshape(shape)
axes = np.array([1], dtype=np.int64)
reduced = np.array(np.zeros(reduced_shape, dtype=np.float32))

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

0 comments on commit 5f95c61

Please sign in to comment.