Skip to content

Commit 21927fe

Browse files
committed
Add test_reducemax.py
1 parent 55c7a83 commit 21927fe

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
import torch
3+
4+
from onnx2pytorch.operations import ReduceMax
5+
6+
7+
@pytest.fixture
8+
def tensor():
9+
return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
10+
11+
12+
def test_reduce_max_with_dim(tensor):
13+
reduce_max = ReduceMax(dim=0, keepdim=True)
14+
output = reduce_max(tensor)
15+
expected_output = torch.tensor([[7, 8, 9]])
16+
17+
assert output.ndim == tensor.ndim
18+
assert torch.equal(output, expected_output)
19+
20+
21+
def test_reduce_max(tensor):
22+
reduce_max = ReduceMax(keepdim=False)
23+
output = reduce_max(tensor)
24+
expected_output = torch.tensor(9)
25+
26+
assert output.ndim == 0
27+
assert torch.equal(output, expected_output)

0 commit comments

Comments
 (0)