Skip to content

Commit

Permalink
Add hardsigmoid with tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Talmaj committed Sep 15, 2024
1 parent da17e48 commit 3943a43
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 2 deletions.
7 changes: 6 additions & 1 deletion onnx2pytorch/convert/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def extract_attributes(node):
kwargs["negative_slope"] = extract_attr_values(attr)
elif node.op_type in ("Elu", "ThresholdedRelu"):
kwargs["alpha"] = extract_attr_values(attr)
elif node.op_type == "HardSigmoid":
kwargs["alpha"] = extract_attr_values(attr)
else:
kwargs["weight_multiplier"] = extract_attr_values(attr)
elif attr.name == "auto_pad":
Expand All @@ -84,7 +86,10 @@ def extract_attributes(node):
else:
kwargs["dim"] = v
elif attr.name == "beta":
kwargs["bias_multiplier"] = extract_attr_values(attr)
if node.op_type == "HardSigmoid":
kwargs["beta"] = extract_attr_values(attr)
else:
kwargs["bias_multiplier"] = extract_attr_values(attr)
elif attr.name == "body":
kwargs["body"] = extract_attr_values(attr)
elif attr.name == "ceil_mode":
Expand Down
4 changes: 3 additions & 1 deletion onnx2pytorch/convert/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from onnx2pytorch.operations import *
from onnx2pytorch.operations.base import OperatorWrapper
from onnx2pytorch.operations import Resize, Upsample
from onnx2pytorch.operations import Resize, Upsample, Hardsigmoid
from onnx2pytorch.utils import (
get_inputs_names,
get_outputs_names,
Expand Down Expand Up @@ -236,6 +236,8 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
op = Shape()
elif node.op_type == "Sigmoid":
op = nn.Sigmoid()
elif node.op_type == "HardSigmoid":
op = Hardsigmoid(**extract_attributes(node))
elif node.op_type == "Slice":
op = Slice(**extract_attributes(node))
elif node.op_type == "Softmax":
Expand Down
1 change: 1 addition & 0 deletions onnx2pytorch/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .gather import Gather
from .gathernd import GatherND
from .globalaveragepool import GlobalAveragePool
from .hardsigmoid import Hardsigmoid
from .instancenorm import InstanceNormWrapper
from .loop import Loop
from .lstm import LSTMWrapper
Expand Down
24 changes: 24 additions & 0 deletions onnx2pytorch/operations/hardsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import math

import torch
from torch import nn


class Hardsigmoid(nn.Module):
def __new__(cls, alpha=0.2, beta=0.5):
"""
If alpha and beta same as default values for torch's Hardsigmoid,
return torch's Hardsigmoid. Else, return custom Hardsigmoid.
"""
if math.isclose(alpha, 1 / 6, abs_tol=1e-2) and beta == 0.5:
return nn.Hardsigmoid()
else:
return super().__new__(cls)

def __init__(self, alpha=0.2, beta=0.5):
super().__init__()
self.alpha = alpha
self.beta = beta

def forward(self, input):
return torch.clip(input * self.alpha + self.beta, 0, 1)
57 changes: 57 additions & 0 deletions tests/onnx2pytorch/operations/test_hardsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest.mock import MagicMock

import numpy as np
import onnx
import torch
import pytest

from onnx2pytorch.convert.operations import convert_operations
from onnx2pytorch.operations import Hardsigmoid


@pytest.fixture
def x():
return np.random.randn(3, 4, 5).astype(np.float32)


def test_hardsigmoid(x):
alpha = 1 / 6
beta = 1 / 2
op = Hardsigmoid(alpha=alpha, beta=beta)
# For pytorch's default values it should use torch's Hardsigmoid
assert isinstance(op, torch.nn.Hardsigmoid)
x = np.random.randn(3, 4, 5).astype(np.float32)
y = np.clip(x * alpha + beta, 0, 1)
out = op(torch.from_numpy(x))
np.testing.assert_allclose(out, torch.from_numpy(y), rtol=1e-6, atol=1e-6)


def test_hardsigmoid_with_custom_alpha_and_beta(x):
alpha = 0.2
beta = 0.5
op = Hardsigmoid(alpha=alpha, beta=beta)
assert not isinstance(op, torch.nn.Hardsigmoid)
y = np.clip(x * alpha + beta, 0, 1)
out = op(torch.from_numpy(x))
np.testing.assert_allclose(out, torch.from_numpy(y), rtol=1e-6, atol=1e-6)


def test_hardsigmoid_conversion():
alpha = np.float32(0.2)
beta = np.float32(0.5)
node = onnx.helper.make_node(
"HardSigmoid",
inputs=["x"],
outputs=["y"],
alpha=alpha,
beta=beta,
)

graph = MagicMock()
graph.initializers = []
graph.node = [node]
converted_ops = list(convert_operations(graph, 10))
op_id, op_name, op = converted_ops[0]
assert isinstance(op, Hardsigmoid)
assert op.alpha == alpha
assert op.beta == beta

0 comments on commit 3943a43

Please sign in to comment.