Skip to content

Commit

Permalink
Merge pull request #66 from Talmaj/talmaj/updates
Browse files Browse the repository at this point in the history
Fix instance norm tests
Update pre-commit config
Update tox.ini
Fix test_clip
Add lint_and_test GitHub action
Remove circleci
  • Loading branch information
Talmaj authored Sep 15, 2024
2 parents 89a3fe7 + 6d2812a commit 03e81a3
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 61 deletions.
32 changes: 0 additions & 32 deletions .circleci/config.yml

This file was deleted.

43 changes: 43 additions & 0 deletions .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Lint and Test

on: [pull_request]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.9', '3.10', '3.11', '3.12' ]

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
${{ runner.os }}-pip-
- name: Install dependencies
run: |
pip install tox tox-gh-actions
- name: Run tests
run: |
bash download_fixtures.sh
tox
- name: Upload coverage to GitHub Artifacts
uses: actions/upload-artifact@v4
with:
name: coverage-${{ matrix.python-version }}
path: htmlcov/
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: stable
rev: 24.8.0
hooks:
- id: black
language_version: python3.8
language_version: python3.10
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ONNX to PyTorch
![PyPI - License](https://img.shields.io/pypi/l/onnx2pytorch?color)
[![CircleCI](https://circleci.com/gh/ToriML/onnx2pytorch.svg?style=shield)](https://app.circleci.com/pipelines/github/ToriML/onnx2pytorch)
[![Lint and Test](https://github.com/Talmaj/onnx2pytorch/actions/workflows/lint_and_test.yml/badge.svg)](https://github.com/Talmaj/onnx2pytorch/actions/workflows/lint_and_test.yml)
[![Downloads](https://pepy.tech/badge/onnx2pytorch)](https://pepy.tech/project/onnx2pytorch)
![PyPI](https://img.shields.io/pypi/v/onnx2pytorch)

Expand Down
14 changes: 7 additions & 7 deletions download_fixtures.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fi

if [[ ! -f shufflenet_v2.onnx ]]; then
echo Downloading shufflenet_v2
curl -LJo shufflenet_v2.onnx https://github.com/onnx/models/blob/master/vision/classification/shufflenet/model/shufflenet-v2-10.onnx\?raw\=true
curl -LJo shufflenet_v2.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/shufflenet/model/shufflenet-v2-10.onnx
fi

if [[ $1 == "--all" ]]; then
Expand All @@ -20,32 +20,32 @@ if [[ $1 == "--all" ]]; then

if [[ ! -f bertsquad-10.onnx ]]; then
echo Downloading bertsquad-10
curl -LJo bertsquad-10.onnx https://github.com/onnx/models/blob/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx\?raw\=true
curl -LJo bertsquad-10.onnx https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
fi

if [[ ! -f yolo_v4.onnx ]]; then
echo Downloading yolo_v4
curl -LJo yolo_v4.onnx https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/yolov4/model/yolov4.onnx\?raw\=true
curl -LJo yolo_v4.onnx https://github.com/onnx/models/raw/main/validated/vision/object_detection_segmentation/yolov4/model/yolov4.onnx
fi

if [[ ! -f super_res.onnx ]]; then
echo Downloading super_res
curl -LJo super_res.onnx https://github.com/onnx/models/blob/master/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx\?raw\=true
curl -LJo super_res.onnx https://github.com/onnx/models/raw/main/validated/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx
fi

if [[ ! -f fast_neural_style.onnx ]]; then
echo Downloading fast_neural_style
curl -LJo fast_neural_style.onnx https://github.com/onnx/models/blob/master/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx\?raw\=true
curl -LJo fast_neural_style.onnx https://github.com/onnx/models/raw/main/validated/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx
fi

if [[ ! -f efficientnet-lite4.onnx ]]; then
echo Downloading efficientnet-lite4
curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx\?raw\=true
curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx
fi

if [[ ! -f mobilenetv2-7.onnx ]]; then
echo Downloading mobilenetv2-7
curl -LJo mobilenetv2-7.onnx https://github.com/onnx/models/raw/master/vision/classification/mobilenet/model/mobilenetv2-7.onnx\?raw\=true
curl -LJo mobilenetv2-7.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx
fi

fi
Expand Down
25 changes: 15 additions & 10 deletions onnx2pytorch/operations/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
from torch.nn.modules.batchnorm import _LazyNormBase

class _LazyInstanceNorm(_LazyNormBase, _InstanceNorm):

cls_to_become = _InstanceNorm


except ImportError:
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter

class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm):

weight: UninitializedParameter # type: ignore[assignment]
bias: UninitializedParameter # type: ignore[assignment]

Expand Down Expand Up @@ -78,24 +75,29 @@ def initialize_parameters(self, input) -> None: # type: ignore[override]
self.reset_parameters()


class LazyInstanceNormUnsafe(_LazyInstanceNorm):
class InstanceNormMixin:
"""Skips dimension check."""

def __init__(self, *args, affine=True, **kwargs):
self.no_batch_dim = None # no_batch_dim has to be set at runtime
super().__init__(*args, affine=affine, **kwargs)

def set_no_dim_batch_dim(self, no_batch_dim):
self.no_batch_dim = no_batch_dim

def _check_input_dim(self, input):
return

def _get_no_batch_dim(self):
return self.no_batch_dim

class InstanceNormUnsafe(_InstanceNorm):
"""Skips dimension check."""

def __init__(self, *args, affine=True, **kwargs):
super().__init__(*args, affine=affine, **kwargs)
class LazyInstanceNormUnsafe(InstanceNormMixin, _LazyInstanceNorm):
pass

def _check_input_dim(self, input):
return

class InstanceNormUnsafe(InstanceNormMixin, _InstanceNorm):
pass


class InstanceNormWrapper(torch.nn.Module):
Expand All @@ -120,4 +122,7 @@ def forward(self, input, scale=None, B=None):
if B is not None:
getattr(self.inu, "bias").data = B

if self.inu.no_batch_dim is None:
self.inu.set_no_dim_batch_dim(input.dim() - 1)

return self.inu.forward(input)
12 changes: 6 additions & 6 deletions tests/onnx2pytorch/convert/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,17 @@ def test_single_layer_lstm(
o2p_lstm = ConvertModel(onnx_lstm, experimental=True)
with torch.no_grad():
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(input, h_0, c_0)
assert torch.equal(o2p_output, output)
assert torch.equal(o2p_h_n, h_n)
assert torch.equal(o2p_c_n, c_n)
torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6)
torch.testing.assert_allclose(o2p_h_n, h_n, rtol=1e-6, atol=1e-6)
torch.testing.assert_allclose(o2p_c_n, c_n, rtol=1e-6, atol=1e-6)

onnx_lstm = onnx.ModelProto.FromString(bitstream_data)
o2p_lstm = ConvertModel(onnx_lstm, experimental=True)
with torch.no_grad():
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input, c_0=c_0)
assert torch.equal(o2p_output, output)
assert torch.equal(o2p_h_n, h_n)
assert torch.equal(o2p_c_n, c_n)
torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6)
torch.testing.assert_allclose(o2p_h_n, h_n, rtol=1e-6, atol=1e-6)
torch.testing.assert_allclose(o2p_c_n, c_n, rtol=1e-6, atol=1e-6)
with pytest.raises(KeyError):
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input)
with pytest.raises(Exception):
Expand Down
4 changes: 2 additions & 2 deletions tests/onnx2pytorch/operations/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def test_clip():
assert torch.equal(op(x), exp_y)

op = Clip(max=0)
exp_y_np = np.clip(x_np, np.NINF, 0)
exp_y_np = np.clip(x_np, -np.inf, 0)
exp_y = torch.from_numpy(exp_y_np)
assert torch.equal(op(x), exp_y)

op = Clip()
exp_y_np = np.clip(x_np, np.NINF, np.inf)
exp_y_np = np.clip(x_np, -np.inf, np.inf)
exp_y = torch.from_numpy(exp_y_np)
assert torch.equal(op(x), exp_y)
9 changes: 8 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
# and then run "tox" from this directory.

[tox]
envlist = clean,py36,py37,py38,py38-torch19,py39
envlist = clean,py39,py310,py311,py312

[gh-actions]
python =
3.9: py39
3.10: py310
3.11: py311
3.12: py312

[testenv]
passenv =
Expand Down

0 comments on commit 03e81a3

Please sign in to comment.