Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReduceMax and ReduceL2 #69

Merged
merged 5 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnx2pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .convert import ConvertModel

__version__ = "0.4.1"
__version__ = "0.5.0"
6 changes: 3 additions & 3 deletions onnx2pytorch/convert/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
elif node.op_type == "Reciprocal":
op = OperatorWrapper(torch.reciprocal)
elif node.op_type == "ReduceMax":
kwargs = dict(keepdim=True)
kwargs.update(extract_attributes(node))
op = partial(torch.max, **kwargs)
op = ReduceMax(**extract_attributes(node))
elif node.op_type == "ReduceMean":
kwargs = dict(keepdim=True)
kwargs.update(extract_attributes(node))
Expand All @@ -216,6 +214,8 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
op = partial(torch.prod, **kwargs)
elif node.op_type == "ReduceSum":
op = ReduceSum(opset_version=opset_version, **extract_attributes(node))
elif node.op_type == "ReduceL2":
op = ReduceL2(opset_version=opset_version, **extract_attributes(node))
elif node.op_type == "Relu":
op = nn.ReLU(inplace=True)
elif node.op_type == "Reshape":
Expand Down
4 changes: 4 additions & 0 deletions onnx2pytorch/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from .pad import Pad
from .prelu import PRelu
from .range import Range
from .reducemax import ReduceMax
from .reducesum import ReduceSum
from .reducel2 import ReduceL2
from .reshape import Reshape
from .resize import Resize, Upsample
from .scatter import Scatter
Expand Down Expand Up @@ -60,7 +62,9 @@
"Pad",
"PRelu",
"Range",
"ReduceMax",
"ReduceSum",
"ReduceL2",
"Reshape",
"Resize",
"Scatter",
Expand Down
32 changes: 32 additions & 0 deletions onnx2pytorch/operations/reducel2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from torch import nn


class ReduceL2(nn.Module):
def __init__(
self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False
):
self.opset_version = opset_version
self.dim = dim
self.keepdim = bool(keepdim)
self.noop_with_empty_axes = noop_with_empty_axes
super().__init__()

def forward(self, data: torch.Tensor, axes: torch.Tensor = None):
if self.opset_version < 13:
dims = self.dim
else:
dims = axes
if dims is None:
if self.noop_with_empty_axes:
return data
else:
dims = tuple(range(data.ndim))

if isinstance(dims, int):
dim = dims
else:
dim = tuple(list(dims))

ret = torch.sqrt(torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim))
return ret
15 changes: 15 additions & 0 deletions onnx2pytorch/operations/reducemax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from torch import nn


class ReduceMax(nn.Module):
def __init__(self, dim=None, keepdim=True):
self.dim = dim
self.keepdim = keepdim
super().__init__()

def forward(self, data: torch.Tensor):
dim = self.dim
if dim is None:
dim = tuple(range(data.ndim))
return torch.amax(data, dim=dim, keepdim=self.keepdim)
249 changes: 249 additions & 0 deletions tests/onnx2pytorch/operations/test_reducel2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import numpy as np
import onnx
import pytest
import torch

from onnx2pytorch.convert.operations import convert_operations
from onnx2pytorch.operations import ReduceL2


@pytest.fixture
def tensor():
return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])


def test_reduce_l2_older_opset_version(tensor):
shape = [3, 2, 2]
axes = np.array([2], dtype=np.int64)
keepdims = 0

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
op = ReduceL2(opset_version=10, keepdim=keepdims, dim=axes)

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_do_not_keepdims_older_opset_version() -> None:
opset_version = 10
shape = [3, 2, 2]
axes = np.array([2], dtype=np.int64)
keepdims = 0

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data"],
outputs=["reduced"],
keepdims=keepdims,
axes=axes,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])

ops = list(convert_operations(graph, opset_version))
op = ops[0][2]

assert isinstance(op, ReduceL2)

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)
# print(reduced)
# [[2.23606798, 5.],
# [7.81024968, 10.63014581],
# [13.45362405, 16.2788206]]

out = op(torch.from_numpy(data))
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data))
np.testing.assert_array_equal(out, reduced)


def test_do_not_keepdims() -> None:
shape = [3, 2, 2]
axes = np.array([2], dtype=np.int64)
keepdims = 0

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data", "axes"],
outputs=["reduced"],
keepdims=keepdims,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

assert isinstance(op, ReduceL2)

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)
# print(reduced)
# [[2.23606798, 5.],
# [7.81024968, 10.63014581],
# [13.45362405, 16.2788206]]

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_export_keepdims() -> None:
shape = [3, 2, 2]
axes = np.array([2], dtype=np.int64)
keepdims = 1

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data", "axes"],
outputs=["reduced"],
keepdims=keepdims,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)
# print(reduced)
# [[[2.23606798], [5.]]
# [[7.81024968], [10.63014581]]
# [[13.45362405], [16.2788206 ]]]

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_export_default_axes_keepdims() -> None:
shape = [3, 2, 2]
axes = np.array([], dtype=np.int64)
keepdims = 1

node = onnx.helper.make_node(
"ReduceL2", inputs=["data", "axes"], outputs=["reduced"], keepdims=keepdims
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1))
# print(reduced)
# [[[25.49509757]]]

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1))

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_export_negative_axes_keepdims() -> None:
shape = [3, 2, 2]
axes = np.array([-1], dtype=np.int64)
keepdims = 1

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data", "axes"],
outputs=["reduced"],
keepdims=keepdims,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)
# print(reduced)
# [[[2.23606798], [5.]]
# [[7.81024968], [10.63014581]]
# [[13.45362405], [16.2788206 ]]]

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
)

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)


def test_export_empty_set() -> None:
shape = [2, 0, 4]
keepdims = 1
reduced_shape = [2, 1, 4]

node = onnx.helper.make_node(
"ReduceL2",
inputs=["data", "axes"],
outputs=["reduced"],
keepdims=keepdims,
)
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
ops = list(convert_operations(graph, 18))
op = ops[0][2]

data = np.array([], dtype=np.float32).reshape(shape)
axes = np.array([1], dtype=np.int64)
reduced = np.array(np.zeros(reduced_shape, dtype=np.float32))

out = op(torch.from_numpy(data), axes=axes)
np.testing.assert_array_equal(out, reduced)
27 changes: 27 additions & 0 deletions tests/onnx2pytorch/operations/test_reducemax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

from onnx2pytorch.operations import ReduceMax


@pytest.fixture
def tensor():
return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])


def test_reduce_max_with_dim(tensor):
reduce_max = ReduceMax(dim=0, keepdim=True)
output = reduce_max(tensor)
expected_output = torch.tensor([[7, 8, 9]])

assert output.ndim == tensor.ndim
assert torch.equal(output, expected_output)


def test_reduce_max(tensor):
reduce_max = ReduceMax(keepdim=False)
output = reduce_max(tensor)
expected_output = torch.tensor(9)

assert output.ndim == 0
assert torch.equal(output, expected_output)