Skip to content

Commit

Permalink
Fix tests for test_lstm by changing assert.equal assert_allclose.
Browse files Browse the repository at this point in the history
  • Loading branch information
Talmaj committed Sep 15, 2024
1 parent 20d2d64 commit 49ee5a7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/onnx2pytorch/convert/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ def test_single_layer_lstm(
o2p_lstm = ConvertModel(onnx_lstm, experimental=True)
with torch.no_grad():
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(input, h_0, c_0)
assert torch.equal(o2p_output, output)
torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6)
assert torch.equal(o2p_h_n, h_n)
assert torch.equal(o2p_c_n, c_n)

onnx_lstm = onnx.ModelProto.FromString(bitstream_data)
o2p_lstm = ConvertModel(onnx_lstm, experimental=True)
with torch.no_grad():
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input, c_0=c_0)
assert torch.equal(o2p_output, output)
torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6)
assert torch.equal(o2p_h_n, h_n)
assert torch.equal(o2p_c_n, c_n)
with pytest.raises(KeyError):
Expand Down

0 comments on commit 49ee5a7

Please sign in to comment.