From 50b75dc0504f3f8d1ce74bf0aaab26605d6cbbff Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 11 Dec 2024 15:50:36 -0800 Subject: [PATCH] test SymmetricContraction export --- .../operations/symmetric_contraction_test.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 80a4065..0776492 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -17,6 +17,9 @@ import numpy as np import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet @@ -108,3 +111,46 @@ def test_mace_compatibility(): output = n_sc(x, i) torch.testing.assert_close(output, expected_output, atol=1e-5, rtol=1e-5) + + +export_modes = ["export", "onnx", "trt", "torch_trt", "jit"] + + +@pytest.mark.parametrize( + "dtype, math_dtype, atol, rtol", + [ + (torch.float64, torch.float64, 1e-10, 1e-10), + (torch.float32, torch.float32, 1e-5, 1e-5), + ], +) +@pytest.mark.parametrize("mode", export_modes) +def test_export( + dtype: torch.dtype, + math_dtype: torch.dtype, + atol: float, + rtol: float, + mode: str, + tmp_path, +): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m = cuet.SymmetricContraction( + cue.Irreps("O3", "0e + 1o + 2e"), + cue.Irreps("O3", "0e + 1o"), + 3, + 5, + layout_in=cue.ir_mul, + layout_out=cue.mul_ir, + dtype=dtype, + math_dtype=math_dtype, + device=device, + ) + + x = torch.randn((1024, 36), device=device, dtype=dtype) + i = torch.randint(0, 5, (1024,), dtype=torch.int32).to(device) + + res = m(x, i) + m_script = module_with_mode(mode, m, [x, i], math_dtype, tmp_path) + res_script = m_script(x, i) + torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol)