Skip to content

Commit

Permalink
Add test_reducemax.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Talmaj committed Sep 16, 2024
1 parent 55c7a83 commit 21927fe
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/onnx2pytorch/operations/test_reducemax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

from onnx2pytorch.operations import ReduceMax


@pytest.fixture
def tensor():
return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])


def test_reduce_max_with_dim(tensor):
reduce_max = ReduceMax(dim=0, keepdim=True)
output = reduce_max(tensor)
expected_output = torch.tensor([[7, 8, 9]])

assert output.ndim == tensor.ndim
assert torch.equal(output, expected_output)


def test_reduce_max(tensor):
reduce_max = ReduceMax(keepdim=False)
output = reduce_max(tensor)
expected_output = torch.tensor(9)

assert output.ndim == 0
assert torch.equal(output, expected_output)

0 comments on commit 21927fe

Please sign in to comment.