Skip to content

Commit

Permalink
test SymmetricContraction export
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Dec 11, 2024
1 parent 01914dd commit 50b75dc
Showing 1 changed file with 46 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import numpy as np
import pytest
import torch
from tests.utils import (
module_with_mode,
)

import cuequivariance as cue
import cuequivariance_torch as cuet
Expand Down Expand Up @@ -108,3 +111,46 @@ def test_mace_compatibility():
output = n_sc(x, i)

torch.testing.assert_close(output, expected_output, atol=1e-5, rtol=1e-5)


export_modes = ["export", "onnx", "trt", "torch_trt", "jit"]


@pytest.mark.parametrize(
"dtype, math_dtype, atol, rtol",
[
(torch.float64, torch.float64, 1e-10, 1e-10),
(torch.float32, torch.float32, 1e-5, 1e-5),
],
)
@pytest.mark.parametrize("mode", export_modes)
def test_export(
dtype: torch.dtype,
math_dtype: torch.dtype,
atol: float,
rtol: float,
mode: str,
tmp_path,
):
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")

m = cuet.SymmetricContraction(
cue.Irreps("O3", "0e + 1o + 2e"),
cue.Irreps("O3", "0e + 1o"),
3,
5,
layout_in=cue.ir_mul,
layout_out=cue.mul_ir,
dtype=dtype,
math_dtype=math_dtype,
device=device,
)

x = torch.randn((1024, 36), device=device, dtype=dtype)
i = torch.randint(0, 5, (1024,), dtype=torch.int32).to(device)

res = m(x, i)
m_script = module_with_mode(mode, m, [x, i], math_dtype, tmp_path)
res_script = m_script(x, i)
torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol)

0 comments on commit 50b75dc

Please sign in to comment.