diff --git a/cuequivariance/cuequivariance/irreps_array/irreps.py b/cuequivariance/cuequivariance/irreps_array/irreps.py index ee26c79..1bb7792 100644 --- a/cuequivariance/cuequivariance/irreps_array/irreps.py +++ b/cuequivariance/cuequivariance/irreps_array/irreps.py @@ -16,7 +16,7 @@ import dataclasses import re -from typing import NamedTuple, Union, Type, Any, Sequence, Callable, Optional +from typing import Any, Callable, NamedTuple, Optional, Sequence, Type, Union import cuequivariance as cue diff --git a/cuequivariance/cuequivariance/irreps_array/misc_ui.py b/cuequivariance/cuequivariance/irreps_array/misc_ui.py index 4bf7297..d473eda 100644 --- a/cuequivariance/cuequivariance/irreps_array/misc_ui.py +++ b/cuequivariance/cuequivariance/irreps_array/misc_ui.py @@ -15,7 +15,7 @@ from __future__ import annotations import warnings -from typing import Generator, Optional, Union, Any +from typing import Any, Generator, Optional, Union import cuequivariance as cue diff --git a/cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py b/cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py index 9c24d82..4f48ba2 100644 --- a/cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py +++ b/cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py @@ -16,7 +16,7 @@ import itertools import logging from math import prod -from typing import FrozenSet, List, Optional, Sequence, Tuple, Iterator, Union +from typing import FrozenSet, Iterator, List, Optional, Sequence, Tuple, Union import numpy as np diff --git a/cuequivariance/cuequivariance/misc/sympy_utils.py b/cuequivariance/cuequivariance/misc/sympy_utils.py index 2596dba..c285d67 100644 --- a/cuequivariance/cuequivariance/misc/sympy_utils.py +++ b/cuequivariance/cuequivariance/misc/sympy_utils.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from fractions import Fraction -import numpy as np +import numpy as np import sympy diff --git a/cuequivariance/cuequivariance/representation/irrep_so3.py b/cuequivariance/cuequivariance/representation/irrep_so3.py index 6345d40..5cd9cce 100644 --- a/cuequivariance/cuequivariance/representation/irrep_so3.py +++ b/cuequivariance/cuequivariance/representation/irrep_so3.py @@ -22,7 +22,7 @@ import numpy as np from cuequivariance.misc.linalg import round_to_sqrt_rational -from cuequivariance.representation import Irrep, SU2 +from cuequivariance.representation import SU2, Irrep # This function is copied from https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irreps/so3_real.py diff --git a/cuequivariance/cuequivariance/representation/irrep_su2.py b/cuequivariance/cuequivariance/representation/irrep_su2.py index 9b3865e..08e3cc9 100644 --- a/cuequivariance/cuequivariance/representation/irrep_su2.py +++ b/cuequivariance/cuequivariance/representation/irrep_su2.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations + import itertools import re from dataclasses import dataclass diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/operand.py b/cuequivariance/cuequivariance/segmented_tensor_product/operand.py index 0511145..12770e1 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/operand.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/operand.py @@ -16,7 +16,7 @@ import dataclasses import math -from typing import Optional, Union, Sequence +from typing import Optional, Sequence, Union from cuequivariance import segmented_tensor_product as stp diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index 3059d8f..3d468d3 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -25,7 +25,7 @@ import math import re import zlib -from typing import Any, Optional, Union, Callable, Sequence +from typing import Any, Callable, Optional, Sequence, Union import numpy as np import opt_einsum diff --git a/cuequivariance/tests/context_test.py b/cuequivariance/tests/context_test.py index 11ffe8f..5cc2d01 100644 --- a/cuequivariance/tests/context_test.py +++ b/cuequivariance/tests/context_test.py @@ -12,10 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest import cuequivariance as cue -import numpy as np def test_rep_collection_context(): diff --git a/cuequivariance/tests/equivariant_tensor_products_test.py b/cuequivariance/tests/equivariant_tensor_products_test.py index ee9e192..b231844 100644 --- a/cuequivariance/tests/equivariant_tensor_products_test.py +++ b/cuequivariance/tests/equivariant_tensor_products_test.py @@ -16,8 +16,8 @@ import pytest import cuequivariance as cue -from cuequivariance import descriptors import cuequivariance.segmented_tensor_product as stp +from cuequivariance import descriptors def test_commutativity_squeeze_flatten(): diff --git a/cuequivariance_torch/cuequivariance_torch/__init__.py b/cuequivariance_torch/cuequivariance_torch/__init__.py index e1b524e..b045128 100644 --- a/cuequivariance_torch/cuequivariance_torch/__init__.py +++ b/cuequivariance_torch/cuequivariance_torch/__init__.py @@ -36,7 +36,7 @@ vector_to_euler_angles, Inversion, ) -from .operations.spherical_harmonics import spherical_harmonics +from .operations.spherical_harmonics import SphericalHarmonics from cuequivariance_torch import layers @@ -55,6 +55,6 @@ "Inversion", "encode_rotation_angle", "vector_to_euler_angles", - "spherical_harmonics", + "SphericalHarmonics", "layers", ] diff --git a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py index b6a6fb0..a5b2e1a 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py @@ -59,7 +59,6 @@ class FullyConnectedTensorProductConv(nn.Module): use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Examples: >>> in_irreps = cue.Irreps("O3", "4x0e + 4x1o") @@ -106,7 +105,6 @@ def __init__( mlp_activation: Union[nn.Module, Sequence[nn.Module], None] = nn.GELU(), layout: cue.IrrepsLayout = None, # e3nn_compat_mode use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -127,7 +125,6 @@ def __init__( layout=self.layout, shared_weights=False, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) self.batch_norm = ( diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index 33788c9..d7fbcea 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -35,7 +35,6 @@ class Linear(torch.nn.Module): use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. """ def __init__( @@ -52,7 +51,6 @@ def __init__( dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() irreps_in, irreps_out = default_irreps(irreps_in, irreps_out) @@ -77,7 +75,7 @@ def __init__( if not self.shared_weights: raise ValueError("Internal weights should be shared") self.weight = torch.nn.Parameter( - torch.randn(self.weight_numel, device=device, dtype=dtype) + torch.randn(1, self.weight_numel, device=device, dtype=dtype) ) else: self.weight = None @@ -90,7 +88,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) def extra_repr(self) -> str: @@ -122,9 +119,7 @@ def forward( weight = self.weight - if self.shared_weights and weight.ndim != 1: - raise ValueError("Shared weights should be 1D tensor") - if not self.shared_weights and weight.ndim != 2: - raise ValueError("Weights should be 2D tensor") + if weight is None: + raise ValueError("Weights should not be None") return self.f([weight, x]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py index cc2356f..62c72f8 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py @@ -41,7 +41,6 @@ def __init__( device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() (irreps,) = default_irreps(irreps) @@ -62,7 +61,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) def forward( @@ -158,8 +156,6 @@ class Inversion(torch.nn.Module): use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - - optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. """ def __init__( @@ -172,7 +168,6 @@ def __init__( device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() (irreps,) = default_irreps(irreps) @@ -191,7 +186,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py index da6be5d..0cad723 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py @@ -15,52 +15,57 @@ from typing import Optional import torch +import torch.nn as nn import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -def spherical_harmonics( - ls: list[int], - vectors: torch.Tensor, - normalize: bool = True, - use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, -) -> torch.Tensor: - r"""Compute the spherical harmonics of the input vectors. +class SphericalHarmonics(nn.Module): + r"""Compute the spherical harmonics of the input vectors as a torch module.""" - Args: - ls (list of int): List of spherical harmonic degrees. - vectors (torch.Tensor): Input vectors of shape (..., 3). - normalize (bool, optional): Whether to normalize the input vectors. Defaults to True. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + def __init__( + self, + ls: list[int], + normalize: bool = True, + device: Optional[torch.device] = None, + math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, + ): + """ + Args: + ls (list of int): List of spherical harmonic degrees. + normalize (bool, optional): Whether to normalize the input vectors. Defaults to True. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + """ + super().__init__() + self.ls = ls if isinstance(ls, list) else [ls] + assert self.ls == sorted(set(self.ls)) + self.normalize = normalize - optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. + self.f = cuet.EquivariantTensorProduct( + descriptors.spherical_harmonics(cue.SO3(1), self.ls), + layout=cue.ir_mul, + device=device, + math_dtype=math_dtype, + use_fallback=use_fallback, + ) - Returns: - torch.Tensor: The spherical harmonics of the input vectors of shape (..., dim) - where dim is the sum of 2*l+1 for l in ls. - """ - if isinstance(ls, int): - ls = [ls] - assert ls == sorted(set(ls)) - assert vectors.shape[-1] == 3 + def forward(self, vectors: torch.Tensor) -> torch.Tensor: + """ + Args: + vectors (torch.Tensor): Input vectors of shape (batch, 3). - if normalize: - vectors = torch.nn.functional.normalize(vectors, dim=-1) + Returns: + torch.Tensor: The spherical harmonics of the input vectors of shape (batch, dim), + where dim is the sum of 2*l+1 for l in ls. + """ + torch._assert(vectors.ndim == 2, "Input must have shape (batch, 3)") - x = vectors.reshape(-1, 3) - m = cuet.EquivariantTensorProduct( - descriptors.spherical_harmonics(cue.SO3(1), ls), - layout=cue.ir_mul, - device=x.device, - math_dtype=x.dtype, - use_fallback=use_fallback, - optimize_fallback=optimize_fallback, - ) - y = m([x]) - y = y.reshape(vectors.shape[:-1] + (y.shape[-1],)) - return y + if self.normalize: + vectors = torch.nn.functional.normalize(vectors, dim=1) + + return self.f([vectors]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 7f41775..38f81f0 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -41,18 +41,17 @@ class SymmetricContraction(torch.nn.Module): use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Examples: + >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o") >>> irreps_out = cue.Irreps("O3", "32x0e") - >>> layer = SymmetricContraction(irreps_in, irreps_out, contraction_degree=3, num_elements=5, layout=cue.ir_mul, dtype=torch.float32) + >>> layer = SymmetricContraction(irreps_in, irreps_out, contraction_degree=3, num_elements=5, layout=cue.ir_mul, dtype=torch.float32, device=device) Now `layer` can be used as part of a PyTorch model. The argument `original_mace` can be set to `True` to emulate the original MACE implementation. - >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e") >>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o") >>> # OLD FUNCTION DEFINITION: @@ -109,7 +108,6 @@ def __init__( math_dtype: Optional[torch.dtype] = None, original_mace: bool = False, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -155,7 +153,6 @@ def __init__( device=device, math_dtype=math_dtype or dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) def extra_repr(self) -> str: @@ -180,10 +177,6 @@ def forward( Returns: torch.Tensor: The output tensor. It has shape (batch, irreps_out.dim). """ - torch._assert( - x.shape[-1] == self.irreps_in.dim, - f"Input tensor must have shape (..., {self.irreps_in.dim}), got {x.shape}", - ) if self.projection is not None: weight = torch.einsum("zau,ab->zbu", self.weight, self.projection) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index a6ac80f..026b666 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -36,7 +36,6 @@ class ChannelWiseTensorProduct(torch.nn.Module): use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Note: In e3nn there was a irrep_normalization and path_normalization parameters. @@ -59,7 +58,6 @@ def __init__( dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() irreps_in1, irreps_in2 = default_irreps(irreps_in1, irreps_in2) @@ -88,7 +86,7 @@ def __init__( if not self.shared_weights: raise ValueError("Internal weights should be shared") self.weight = torch.nn.Parameter( - torch.randn(self.weight_numel, device=device, dtype=dtype) + torch.randn(1, self.weight_numel, device=device, dtype=dtype) ) else: self.weight = None @@ -101,7 +99,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) def extra_repr(self) -> str: @@ -137,15 +134,14 @@ def forward( or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor. """ - if self.internal_weights: + if self.weight is not None: if weight is not None: raise ValueError("Internal weights are used, weight should be None") - - weight = self.weight - - if self.shared_weights and weight.ndim != 1: - raise ValueError("Shared weights should be 1D tensor") - if not self.shared_weights and weight.ndim != 2: - raise ValueError("Weights should be 2D tensor") - - return self.f([weight, x1, x2]) + return self.f([self.weight, x1, x2]) + else: + if weight is None: + raise ValueError( + "Internal weights are not used, weight should not be None" + ) + else: + return self.f([weight, x1, x2]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index 44c781d..e1e3122 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -36,7 +36,6 @@ class FullyConnectedTensorProduct(torch.nn.Module): use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Note: In e3nn there was a irrep_normalization and path_normalization parameters. @@ -59,7 +58,6 @@ def __init__( dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() irreps_in1, irreps_in2, irreps_out = default_irreps( @@ -89,7 +87,7 @@ def __init__( if not self.shared_weights: raise ValueError("Internal weights should be shared") self.weight = torch.nn.Parameter( - torch.randn(self.weight_numel, device=device, dtype=dtype) + torch.randn(1, self.weight_numel, device=device, dtype=dtype) ) else: self.weight = None @@ -101,7 +99,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, - optimize_fallback=optimize_fallback, + use_fallback=use_fallback, ) def extra_repr(self) -> str: @@ -137,15 +135,14 @@ def forward( or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor. """ - if self.internal_weights: + if self.weight is not None: if weight is not None: raise ValueError("Internal weights are used, weight should be None") - - weight = self.weight - - if self.shared_weights and weight.ndim != 1: - raise ValueError("Shared weights should be 1D tensor") - if not self.shared_weights and weight.ndim != 2: - raise ValueError("Weights should be 2D tensor") - - return self.f([weight, x1, x2]) + return self.f([self.weight, x1, x2]) + else: + if weight is None: + raise ValueError( + "Internal weights are not used, weight should not be None" + ) + else: + return self.f([weight, x1, x2]) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 5e99746..2d5610b 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -119,7 +119,6 @@ class EquivariantTensorProduct(torch.nn.Module): device (torch.device): device of the Module. math_dtype (torch.dtype): dtype for internal computations. use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - optimize_fallback (bool): whether to optimize the fallback implementation. Raises: RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. @@ -128,7 +127,7 @@ class EquivariantTensorProduct(torch.nn.Module): >>> e = cue.descriptors.fully_connected_tensor_product( ... cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") ... ) - >>> w = torch.ones(e.inputs[0].dim, device=device) + >>> w = torch.ones(1, e.inputs[0].dim, device=device) >>> x1 = torch.ones(17, e.inputs[1].dim, device=device) >>> x2 = torch.ones(17, e.inputs[2].dim, device=device) >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device) @@ -155,7 +154,6 @@ def __init__( device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() if not isinstance(layout_in, tuple): @@ -217,7 +215,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) ) elif e.num_inputs == 2: @@ -227,7 +224,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) ) else: @@ -239,7 +235,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) ) @@ -259,7 +254,10 @@ def forward( # assert len(inputs) == len(self.etp.inputs) for a, dim in zip(inputs, self.operands_dims): - assert a.shape[-1] == dim + torch._assert( + a.shape[-1] == dim, + f"Expected last dimension of input to be {dim}, got {a.shape[-1]}", + ) # Transpose inputs inputs = self.transpose_in(inputs) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 0c9821f..f56eb37 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -14,14 +14,13 @@ # limitations under the License. import logging import math -from typing import Optional +from typing import List, Optional import torch import torch.fx import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet -from cuequivariance_torch.primitives.tensor_product import broadcast_shapes, prod logger = logging.getLogger(__name__) @@ -33,7 +32,6 @@ class SymmetricTensorProduct(torch.nn.Module): Args: descriptors (list of SegmentedTensorProduct): The list of SegmentedTensorProduct descriptors. math_dtype (torch.dtype, optional): The data type of the coefficients and calculations. - optimize_fallback (bool, optional): If `True`, the torch.fx graph will be optimized before execution. Because the optimization takes time, it is turned off by default. """ def __init__( @@ -43,7 +41,6 @@ def __init__( device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -69,7 +66,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=use_fallback, - optimize_fallback=optimize_fallback, ) def forward(self, x0: torch.Tensor) -> torch.Tensor: @@ -107,9 +103,6 @@ class IWeightedSymmetricTensorProduct(torch.nn.Module): The list of SegmentedTensorProduct descriptors math_dtype : torch.dtype, optional The data type of the coefficients and calculations - optimize_fallback : bool, optional - If `True`, the torch.fx graph will be optimized before execution - Because the optimization takes time, it is turned off by default. """ def __init__( @@ -119,7 +112,6 @@ def __init__( device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -153,9 +145,9 @@ def __init__( descriptors, device, math_dtype=math_dtype, - optimize_fallback=optimize_fallback, ) + @torch.jit.ignore def __repr__(self): has_cuda_kernel = ( "(with CUDA kernel)" @@ -179,9 +171,9 @@ def forward( x0 : torch.Tensor The input tensor for the first operand. It should have the shape (i0.max() + 1, x0_size). i0 : torch.Tensor - The index tensor for the first operand. It should have the shape (...). + The index tensor for the first operand. It should have the shape (batch). x1 : torch.Tensor - The repeated input tensor. It should have the shape (..., x1_size). + The repeated input tensor. It should have the shape (batch, x1_size). Returns ------- @@ -194,13 +186,15 @@ def forward( x0.ndim == 2, f"Expected 2 dims (i0.max() + 1, x0_size), got shape {x0.shape}", ) - shape = broadcast_shapes([i0.shape, x1.shape[:-1]]) - i0 = i0.expand(shape).reshape((prod(shape),)) - x1 = x1.expand(shape + (x1.shape[-1],)).reshape((prod(shape), x1.shape[-1])) - - out = self.f(x0, i0, x1) - out = out.reshape(shape + (self.x2_size,)) - return out + torch._assert( + i0.ndim == 1, + f"Expected 1 dim (batch), got shape {i0.shape}", + ) + torch._assert( + x1.ndim == 2, + f"Expected 2 dims (batch, x1_size), got shape {x1.shape}", + ) + return self.f(x0, i0, x1) def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): @@ -230,6 +224,9 @@ def __init__( ): super().__init__() + if not torch.cuda.is_available(): + raise NotImplementedError("CUDA is not available.") + max_degree = max(d.num_operands - 2 for d in ds) if max_degree > 6: @@ -338,7 +335,6 @@ def __init__( stps: list[stp.SegmentedTensorProduct], device: Optional[torch.device], math_dtype: Optional[torch.dtype], - optimize_fallback: Optional[bool], ): super().__init__() self.fs = torch.nn.ModuleList( @@ -348,7 +344,6 @@ def __init__( device=device, math_dtype=math_dtype, use_fallback=True, - optimize_fallback=optimize_fallback, ) for d in stps ] @@ -357,6 +352,7 @@ def __init__( def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: - return sum( - f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2)) for f in self.fs - ) + fs: List[torch.Tensor] = [ + f([x0[i0]] + [x1] * (f.num_operands - 2)) for f in self.fs + ] + return torch.sum(torch.stack(fs), dim=0) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 91af5b0..c0d4bf0 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -15,6 +15,7 @@ import logging import math import warnings +from functools import partial from typing import List, Optional, OrderedDict, Tuple import torch @@ -26,50 +27,16 @@ def prod(numbers: List[int]): - product = 1 - for num in numbers: - product *= num - return product - - -def broadcast_shapes(shapes: List[List[int]]): + """ + This method is a workaround for script() not recognizing math.prod() + """ if torch.jit.is_scripting(): - max_len = 0 - for shape in shapes: - if isinstance(shape, int): - if max_len < 1: - max_len = 1 - elif isinstance(shape, (tuple, list)): - s = len(shape) - if max_len < s: - max_len = s - result = [1] * max_len - for shape in shapes: - if isinstance(shape, int): - shape = (shape,) - if isinstance(shape, (tuple, list)): - for i in range(-1, -1 - len(shape), -1): - if shape[i] < 0: - raise RuntimeError( - "Trying to create tensor with negative dimension ({}): ({})".format( - shape[i], shape[i] - ) - ) - if shape[i] == 1 or shape[i] == result[i]: - continue - if result[i] != 1: - raise RuntimeError( - "Shape mismatch: objects cannot be broadcast to a single shape" - ) - result[i] = shape[i] - else: - raise RuntimeError( - "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", - shape, - ) - return torch.Size(result) + product = 1 + for num in numbers: + product *= num + return product else: - return torch.functional.broadcast_shapes(*shapes) + return math.prod(numbers) class TensorProduct(torch.nn.Module): @@ -82,7 +49,6 @@ class TensorProduct(torch.nn.Module): device (torch.device, optional): The device on which the calculations are performed. use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - optimize_fallback (bool, optional): If `True`, the fallback method is optimized. If `False`, the fallback method is used without optimization. Raises: RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. @@ -95,7 +61,6 @@ def __init__( device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, use_fallback: Optional[bool] = None, - optimize_fallback: Optional[bool] = None, ): super().__init__() self.descriptor = descriptor @@ -124,18 +89,9 @@ def __init__( ) if not self.has_cuda: - if optimize_fallback is None: - optimize_fallback = False - warnings.warn( - "The fallback method is used but it has not been optimized. " - "Consider setting optimize_fallback=True when creating the TensorProduct module." - ) - - self.f = _tensor_product_fx( - descriptor, device, math_dtype, optimize_fallback - ) - self._optimize_fallback = optimize_fallback + self.f = _tensor_product_fx(descriptor, device, math_dtype, True) + @torch.jit.ignore def __repr__(self): has_cuda_kernel = ( "(with CUDA kernel)" if self.has_cuda else "(without CUDA kernel)" @@ -148,8 +104,9 @@ def forward(self, inputs: List[torch.Tensor]): Args: inputs (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. - Each input tensor should have a shape of ((batch,) operand_size), where `operand_size` corresponds to the size - of each operand as defined in the tensor product descriptor. + Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) + where `operand_size` corresponds to the size of each operand as defined in + the tensor product descriptor. Returns: torch.Tensor: @@ -158,12 +115,27 @@ def forward(self, inputs: List[torch.Tensor]): `last_operand_size` is the size of the last operand in the descriptor. """ - # if any(x.numel() == 0 for x in inputs): - # use_fallback = True # Empty tensors are not supported by the CUDA kernel - return self.f(inputs) +def to_notypeconv(t, *args, **kwargs): + new_kwargs = kwargs.copy() + new_kwargs.pop("dtype", None) + new_args = [None if isinstance(a, torch.dtype) else a for a in args] + result = t.__original_to(*new_args, **new_kwargs) + return result + + +def disable_type_conv(t): + """ + This modifier can be used on Tensors or whole Modules + to prevent them from being modified during to(dtype=x) calls + """ + t.__original_to = t.to + t.to = partial(to_notypeconv, t) + return t + + def _tensor_product_fx( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], @@ -204,28 +176,23 @@ def _tensor_product_fx( outputs = [] for path_idx, path in enumerate(descriptor.paths): - segments = [ - inputs[oid][..., slices[oid][path.indices[oid]]] - .reshape( - inputs[oid].shape[:-1] + descriptor.get_segment_shape(oid, path) - ) - .to(dtype=math_dtype) - for oid in range(num_inputs) - ] - constants[f"c{path_idx}"] = torch.tensor( - path.coefficients, dtype=math_dtype, device=device - ).view( - { - 2: torch.int16, - 4: torch.int32, - 8: torch.int64, - }[math_dtype.itemsize] - ) - c = ( - torch.fx.Proxy(graph.get_attr(f"c{path_idx}"), tracer=tracer) - .view(math_dtype) - .clone() + segments = [] + for oid in range(num_inputs): + seg_shape = descriptor.get_segment_shape(oid, path) + inp = inputs[oid][..., slices[oid][path.indices[oid]]] + if len(seg_shape) > 0: + inp = inp.reshape(inputs[oid].shape[:-1] + seg_shape) + else: + inp = inp.reshape(inputs[oid].shape[:-1]) + + segments.append(inp.to(dtype=math_dtype)) + + c_tensor = disable_type_conv( + torch.tensor(path.coefficients, dtype=math_dtype, device=device) ) + constants[f"c{path_idx}"] = c_tensor + + c = torch.fx.Proxy(graph.get_attr(f"c{path_idx}"), tracer=tracer).clone() out = torch.einsum(formula, c, *segments) out = out.to(dtype=inputs[0].dtype) @@ -237,6 +204,14 @@ def _tensor_product_fx( if len(outputs) == 0: raise NotImplementedError("No FX implementation for empty paths") + def _sum(tensors, *, shape=None, like=None): + if len(tensors) == 0: + return like.new_zeros(shape) + out = tensors[0] + for t in tensors[1:]: + out = torch.add(out, t) + return out + batch_shape = outputs[0].shape[:-1] output = torch.cat( [ @@ -384,34 +359,15 @@ def forward(self, args: List[torch.Tensor]): if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): for oid, arg in enumerate(args): torch._assert( - arg.shape[-1] == self.descriptor.operands[oid].size, - "input shape[-1] does not match operand size", + arg.ndim == 2, + f"input {oid} should have ndim=2", ) - - shape = broadcast_shapes([arg.shape[:-1] for arg in args]) - - args = [ - ( - arg.expand(shape + (arg.shape[-1],)).reshape( - (prod(shape), arg.shape[-1]) + torch._assert( + arg.shape[1] == self.descriptor.operands[oid].size, + f"input {oid} should have shape (batch, {self.descriptor.operands[oid].size})", ) - if prod(arg.shape[:-1]) > 1 - else arg.reshape((prod(arg.shape[:-1]), arg.shape[-1])) - ) - for arg in args - ] - out = self.module(args) - - return out.reshape(shape + (out.shape[-1],)) - -def _sum(tensors, *, shape=None, like=None): - if len(tensors) == 0: - return like.new_zeros(shape) - out = tensors[0] - for t in tensors[1:]: - out += t - return out + return self.module(args) def _tensor_product_cuda( @@ -492,16 +448,6 @@ def _tensor_product_cuda( return FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype) -def _reshape(x: torch.Tensor, leading_shape: List[int]) -> torch.Tensor: - # Make x have shape (Z, x.shape[-1]) or (x.shape[-1],) - if prod(leading_shape) > 1 and prod(x.shape[:-1]) == 1: - return x.reshape((x.shape[-1],)) - else: - return x.expand(leading_shape + (x.shape[-1],)).reshape( - (prod(leading_shape), x.shape[-1]) - ) - - class FusedTensorProductOp3(torch.nn.Module): def __init__( self, @@ -535,26 +481,32 @@ def __init__( math_dtype=math_dtype, ).to(device=device) + @torch.jit.ignore def __repr__(self) -> str: return f"FusedTensorProductOp3({self.descriptor} (output last operand))" def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1 = self._perm(inputs[0], inputs[1]) - assert x0.ndim >= 1, x0.ndim - assert x1.ndim >= 1, x1.ndim - - shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1]]) - x0 = _reshape(x0, shape) - x1 = _reshape(x1, shape) if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" ) - out = self._f(x0, x1) + torch._assert(x0.ndim == 2, "input should be (batch, dim) or (1, dim)") + torch._assert(x1.ndim == 2, "input should be (batch, dim) or (1, dim)") + + batch = max(x0.shape[0], x1.shape[0]) + + if batch > 1: + if x0.shape[0] == 1: + x0 = x0.squeeze(0) + if x1.shape[0] == 1: + x1 = x1.squeeze(0) - return out.reshape(shape + (out.shape[-1],)) + # ops.FusedTensorProductOp3 expects inputs + # of shape (Z, dim) or (dim,) + return self._f(x0, x1) class FusedTensorProductOp4(torch.nn.Module): @@ -590,28 +542,35 @@ def __init__( math_dtype=math_dtype, ).to(device=device) + @torch.jit.ignore def __repr__(self) -> str: return f"FusedTensorProductOp4({self.descriptor} (output last operand))" def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1, x2 = self._perm(inputs[0], inputs[1], inputs[2]) - assert x0.ndim >= 1, x0.ndim - assert x1.ndim >= 1, x1.ndim - assert x2.ndim >= 1, x2.ndim - - shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]]) - x0 = _reshape(x0, shape) - x1 = _reshape(x1, shape) - x2 = _reshape(x2, shape) if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" ) - out = self._f(x0, x1, x2) + torch._assert(x0.ndim == 2, "input should be (batch, dim) or (1, dim)") + torch._assert(x1.ndim == 2, "input should be (batch, dim) or (1, dim)") + torch._assert(x2.ndim == 2, "input should be (batch, dim) or (1, dim)") + + batch = max(x0.shape[0], x1.shape[0], x2.shape[0]) + + if batch > 1: + if x0.shape[0] == 1: + x0 = x0.squeeze(0) + if x1.shape[0] == 1: + x1 = x1.squeeze(0) + if x2.shape[0] == 1: + x2 = x2.squeeze(0) - return out.reshape(shape + (out.shape[-1],)) + # ops.FusedTensorProductOp4 expects inputs + # of shape (Z, dim) or (dim,) + return self._f(x0, x1, x2) class TensorProductUniform1d(torch.nn.Module): @@ -642,63 +601,46 @@ def __init__( class TensorProductUniform3x1d(TensorProductUniform1d): + @torch.jit.ignore def __repr__(self): return f"TensorProductUniform3x1d({self.descriptor} (output last operand))" def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1 = inputs - assert x0.ndim >= 1, x0.ndim - assert x1.ndim >= 1, x1.ndim - - shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1]]) - x0 = _reshape(x0, shape) - x1 = _reshape(x1, shape) - - if x0.ndim == 1: - x0 = x0.unsqueeze(0) - if x1.ndim == 1: - x1 = x1.unsqueeze(0) if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" ) - out = self._f(x0, x1, x0) + torch._assert(x0.ndim == 2, "input should be (batch, dim) or (1, dim)") + torch._assert(x1.ndim == 2, "input should be (batch, dim) or (1, dim)") - return out.reshape(shape + (out.shape[-1],)) + # ops.TensorProductUniform1d expects inputs + # of shape (Z, dim) or (1, dim) + return self._f(x0, x1) class TensorProductUniform4x1d(TensorProductUniform1d): + @torch.jit.ignore def __repr__(self): return f"TensorProductUniform4x1d({self.descriptor} (output last operand))" def forward(self, inputs: List[torch.Tensor]): x0, x1, x2 = inputs - assert x0.ndim >= 1, x0.ndim - assert x1.ndim >= 1, x1.ndim - assert x2.ndim >= 1, x2.ndim - - shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]]) - x0 = _reshape(x0, shape) - x1 = _reshape(x1, shape) - x2 = _reshape(x2, shape) - - if x0.ndim == 1: - x0 = x0.unsqueeze(0) - if x1.ndim == 1: - x1 = x1.unsqueeze(0) - if x2.ndim == 1: - x2 = x2.unsqueeze(0) if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" ) - out = self._f(x0, x1, x2) + torch._assert(x0.ndim == 2, "input should be (batch, dim) or (1, dim)") + torch._assert(x1.ndim == 2, "input should be (batch, dim) or (1, dim)") + torch._assert(x2.ndim == 2, "input should be (batch, dim) or (1, dim)") - return out.reshape(shape + (out.shape[-1],)) + # ops.TensorProductUniform1d expects inputs + # of shape (Z, dim) or (1, dim) + return self._f(x0, x1, x2) def _permutation_module(permutation: Tuple[int, ...]): diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index b23777a..bc6144e 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -36,7 +36,7 @@ def __init__( source: cue.IrrepsLayout, target: cue.IrrepsLayout, device: Optional[torch.device] = None, - use_fallback: Optional[bool] = False, + use_fallback: Optional[bool] = None, ): super().__init__() @@ -85,7 +85,7 @@ def __init__( self, segments: list[tuple[int, int]], device: Optional[torch.device] = None, - use_fallback: Optional[bool] = False, + use_fallback: Optional[bool] = None, ): super().__init__() @@ -99,7 +99,8 @@ def __init__( except ImportError: pass else: - self.f = _transpose(info).to(device=device) + if torch.cuda.is_available(): + self.f = _transpose(info).to(device=device) if use_fallback is False and self.f is None: raise RuntimeError("CUDA kernel not available for TransposeSegments.") diff --git a/cuequivariance_torch/tests/__init__.py b/cuequivariance_torch/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py b/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py index 96602a3..1c1f72c 100644 --- a/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py +++ b/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py @@ -51,6 +51,7 @@ def test_tensor_product_conv_equivariance( mlp_activation=mlp_activation, batch_norm=batch_norm, layout=layout, + use_fallback=not torch.cuda.is_available(), ).to(device) num_src_nodes, num_dst_nodes = 9, 7 diff --git a/cuequivariance_torch/tests/operations/linear_test.py b/cuequivariance_torch/tests/operations/linear_test.py index 2b78ff1..7e904d6 100644 --- a/cuequivariance_torch/tests/operations/linear_test.py +++ b/cuequivariance_torch/tests/operations/linear_test.py @@ -16,6 +16,9 @@ import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet @@ -169,3 +172,49 @@ def test_linear_copy( ).to(device) copy.deepcopy(linear) + + +export_modes = ["compile", "script", "jit"] + + +@pytest.mark.parametrize("irreps_in", list_of_irreps) +@pytest.mark.parametrize("irreps_out", list_of_irreps) +@pytest.mark.parametrize("layout", [cue.mul_ir, cue.ir_mul]) +@pytest.mark.parametrize("shared_weights", [True, False]) +@pytest.mark.parametrize("mode", export_modes) +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_export( + irreps_in: cue.Irreps, + irreps_out: cue.Irreps, + layout: cue.IrrepsLayout, + shared_weights: bool, + mode: str, + use_fallback: bool, + tmp_path: str, +): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(0) + m = cuet.Linear( + irreps_in, + irreps_out, + layout=layout, + shared_weights=shared_weights, + device=device, + dtype=torch.float32, + use_fallback=use_fallback, + ) + + x = torch.randn(10, irreps_in.dim, dtype=torch.float32).cuda() + + if shared_weights: + inputs = (x,) + else: + w = torch.randn(10, m.weight_numel, dtype=torch.float32).cuda() + inputs = (x, w) + + out1 = m(*inputs) + m = module_with_mode(mode, m, inputs, torch.float32, tmp_path) + out2 = m(*inputs) + torch.testing.assert_close(out1, out2) diff --git a/cuequivariance_torch/tests/operations/rotation_test.py b/cuequivariance_torch/tests/operations/rotation_test.py index cc70899..dd8721f 100644 --- a/cuequivariance_torch/tests/operations/rotation_test.py +++ b/cuequivariance_torch/tests/operations/rotation_test.py @@ -14,6 +14,9 @@ # limitations under the License. import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet @@ -23,9 +26,9 @@ def test_rotation(): irreps = cue.Irreps("SO3", "3x0 + 1 + 0 + 4x2 + 4") - alpha = torch.tensor(0.3).to(device) - beta = torch.tensor(0.4).to(device) - gamma = torch.tensor(-0.5).to(device) + alpha = torch.tensor([0.3]).to(device) + beta = torch.tensor([0.4]).to(device) + gamma = torch.tensor([-0.5]).to(device) rot = cuet.Rotation(irreps, layout=cue.ir_mul).to(device) @@ -42,8 +45,10 @@ def test_vector_to_euler_angles(): A = torch.nn.functional.normalize(A, dim=-1) beta, alpha = cuet.vector_to_euler_angles(A) - ey = torch.tensor([0.0, 1.0, 0.0]) - B = cuet.Rotation(cue.Irreps("SO3", "1"), layout=cue.ir_mul)(0.0, beta, alpha, ey) + ey = torch.tensor([[0.0, 1.0, 0.0]]) + B = cuet.Rotation(cue.Irreps("SO3", "1"), layout=cue.ir_mul)( + torch.tensor([0.0]), beta, alpha, ey + ) assert torch.allclose(A, B) @@ -60,3 +65,28 @@ def test_inversion(use_fallback: bool): )(torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], device=device)), torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0]], device=device), ) + + +export_modes = ["compile", "script", "jit"] + + +@pytest.mark.parametrize("mode", export_modes) +def test_export(mode: str, tmp_path: str): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + irreps = cue.Irreps("SO3", "3x0 + 1 + 0 + 4x2 + 4") + dtype = torch.float32 + alpha = torch.tensor([0.3]).to(device) + beta = torch.tensor([0.4]).to(device) + gamma = torch.tensor([-0.5]).to(device) + + m = cuet.Rotation(irreps, layout=cue.ir_mul).to(device) + + x = torch.randn(10, irreps.dim).to(device) + inputs = (gamma, beta, alpha, x) + + out1 = m(*inputs) + m = module_with_mode(mode, m, inputs, dtype, tmp_path) + out2 = m(*inputs) + torch.testing.assert_close(out1, out2) diff --git a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py index 2b8db35..1404ef7 100644 --- a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py +++ b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py @@ -15,6 +15,9 @@ import numpy as np import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet @@ -24,7 +27,7 @@ @pytest.mark.parametrize( "dtype, tol", - [(torch.float64, 1e-6), (torch.float32, 1e-4)], + [(torch.float64, 1e-5), (torch.float32, 1e-4)], ) @pytest.mark.parametrize("ell", [0, 1, 2, 3]) @pytest.mark.parametrize("use_fallback", [False, True]) @@ -37,12 +40,14 @@ def test_spherical_harmonics_equivariance(use_fallback: bool, ell: int, dtype, t angle = np.random.rand() scale = 1.3 - yl = cuet.spherical_harmonics([ell], vec, False, use_fallback=use_fallback) + m = cuet.SphericalHarmonics([ell], False, device=device, use_fallback=use_fallback) + + yl = m(vec.unsqueeze(0)).squeeze(0) R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype).to(device) Rl = torch.from_numpy(cue.SO3(ell).rotation(axis, angle)).to(dtype).to(device) - yl1 = cuet.spherical_harmonics([ell], scale * R @ vec, False) + yl1 = m((scale * R @ vec).unsqueeze(0)).squeeze(0) yl2 = scale**ell * Rl @ yl torch.testing.assert_close(yl1, yl2, rtol=tol, atol=tol) @@ -61,6 +66,34 @@ def test_spherical_harmonics_full(dtype, ls: list[int], use_fallback: bool): if use_fallback is False and not torch.cuda.is_available(): pytest.skip("CUDA is not available") - vec = torch.randn(3, device=device, dtype=dtype) - yl = cuet.spherical_harmonics(ls, vec, False, use_fallback=use_fallback) - assert yl.shape[-1] == sum(2 * ell + 1 for ell in ls) + m = cuet.SphericalHarmonics(ls, False, use_fallback=use_fallback, device=device) + + vec = torch.randn(10, 3, device=device, dtype=dtype) + yl = m(vec) + assert yl.shape[0] == 10 + assert yl.shape[1] == sum(2 * ell + 1 for ell in ls) + + +export_modes = ["compile", "script", "jit"] + + +@pytest.mark.parametrize("dtype", data_types) +@pytest.mark.parametrize("ls", [[0], [1], [2], [0, 1], [0, 1, 2]]) +@pytest.mark.parametrize("use_fallback", [False, True]) +@pytest.mark.parametrize("mode", export_modes) +def test_export(dtype, ls: list[int], use_fallback: bool, mode: str, tmp_path: str): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + tol = 1e-5 + if dtype in [torch.float16, torch.bfloat16]: + tol = 1e-2 + m = cuet.SphericalHarmonics(ls, False, use_fallback=use_fallback, device=device) + + vec = torch.randn(10, 3, device=device, dtype=dtype) + inputs = (vec,) + out1 = m(vec) + + m = module_with_mode(mode, m, inputs, dtype, tmp_path) + out2 = m(*inputs) + torch.testing.assert_close(out1, out2, atol=tol, rtol=tol) diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 80a4065..3bf8467 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -17,6 +17,9 @@ import numpy as np import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet @@ -34,6 +37,9 @@ @pytest.mark.parametrize("original_mace", [True, False]) @pytest.mark.parametrize("batch", [1, 32]) def test_symmetric_contraction(dtype, layout, original_mace, batch): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + mul = 64 irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e") irreps_out = mul * cue.Irreps("O3", "0e + 1o") @@ -100,6 +106,7 @@ def test_mace_compatibility(): device=device, dtype=torch.float32, math_dtype=torch.float64, + use_fallback=not torch.cuda.is_available(), ) n_sc.weight.data = from64( (2, 164 // mul, mul), @@ -108,3 +115,51 @@ def test_mace_compatibility(): output = n_sc(x, i) torch.testing.assert_close(output, expected_output, atol=1e-5, rtol=1e-5) + + +export_modes = ["compile", "script", "jit"] + + +@pytest.mark.parametrize( + "dtype, math_dtype, atol, rtol", + [ + (torch.float64, torch.float64, 1e-10, 1e-10), + (torch.float32, torch.float32, 1e-5, 1e-5), + ], +) +@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) +@pytest.mark.parametrize("original_mace", [True, False]) +@pytest.mark.parametrize("batch", [1, 32]) +@pytest.mark.parametrize("mode", export_modes) +def test_export( + dtype, math_dtype, atol, rtol, layout, original_mace, batch, mode, tmp_path +): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + mul = 64 + irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e") + irreps_out = mul * cue.Irreps("O3", "0e + 1o") + + m = cuet.SymmetricContraction( + irreps_in, + irreps_out, + 3, + 5, + layout_in=layout, + layout_out=layout, + dtype=dtype, + math_dtype=math_dtype, + device=device, + original_mace=original_mace, + ) + + x = torch.randn((batch, irreps_in.dim), dtype=dtype).to(device) + indices = torch.randint(0, 5, (batch,), dtype=torch.int32).to(device) + + out = m(x, indices) + assert out.shape == (batch, irreps_out.dim) + + m_script = module_with_mode(mode, m, [x, indices], dtype, tmp_path) + out_script = m_script(x, indices) + torch.testing.assert_close(out, out_script, atol=atol, rtol=rtol) diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index d3e7cdd..4d275f0 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -14,6 +14,9 @@ # limitations under the License. import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet @@ -21,20 +24,25 @@ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -list_of_irreps = [ - cue.Irreps("O3", "32x0e + 32x1o"), - cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), - cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"), +irreps = [ + ( + cue.Irreps("O3", "32x0e + 32x1o"), + cue.Irreps("O3", "0e + 1o + 2e"), + cue.Irreps("O3", "32x0e + 32x1o"), + ), + ( + cue.Irreps("O3", "2x1o + 3x0e + 4x2e + 3x1e + 2x1o"), + cue.Irreps("O3", "1o + 2e"), + cue.Irreps("O3", "2x1o + 5x0e + 1e + 1o"), + ), ] -@pytest.mark.parametrize("irreps1", list_of_irreps) -@pytest.mark.parametrize("irreps2", [irreps.set_mul(1) for irreps in list_of_irreps]) -@pytest.mark.parametrize("irreps3", list_of_irreps) +@pytest.mark.parametrize("irreps1, irreps2, irreps3", irreps) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("use_fallback", [False, True]) @pytest.mark.parametrize("batch", [1, 32]) -def test_channel_wise( +def test_channel_wise_fwd( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps, @@ -72,13 +80,63 @@ def test_channel_wise( torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) -def test_channel_wise_bwd_bwd(): +export_modes = ["compile", "script", "jit"] + + +@pytest.mark.parametrize("irreps1, irreps2, irreps3", irreps) +@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) +@pytest.mark.parametrize("internal_weights", [False, True]) +@pytest.mark.parametrize("use_fallback", [False, True]) +@pytest.mark.parametrize("batch", [1, 32]) +@pytest.mark.parametrize("mode", export_modes) +def test_export( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3: cue.Irreps, + layout: cue.IrrepsLayout, + internal_weights: bool, + use_fallback: bool, + batch: int, + mode: str, + tmp_path: str, +): + dtype = torch.float32 + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m1 = cuet.ChannelWiseTensorProduct( + irreps1, + irreps2, + irreps3, + shared_weights=True, + internal_weights=internal_weights, + layout=layout, + device=device, + dtype=dtype, + use_fallback=use_fallback, + ) + x1 = torch.randn(batch, irreps1.dim, dtype=dtype).to(device) + x2 = torch.randn(batch, irreps2.dim, dtype=dtype).to(device) + if internal_weights: + inputs = (x1, x2) + else: + weights = torch.randn(1, m1.weight_numel, device=device, dtype=dtype) + inputs = (x1, x2, weights) + out1 = m1(*inputs) + + m1 = module_with_mode(mode, m1, inputs, dtype, tmp_path) + out2 = m1(*inputs) + torch.testing.assert_close(out1, out2) + + +@pytest.mark.parametrize("irreps", ["32x0", "2x0 + 3x1"]) +def test_channel_wise_bwd_bwd(irreps: cue.Irreps): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") - irreps1 = cue.Irreps("SO3", "2x0 + 3x1") + irreps1 = cue.Irreps("SO3", irreps) irreps2 = cue.Irreps("SO3", "0 + 1") - irreps3 = cue.Irreps("SO3", "0 + 1") + irreps3 = cue.Irreps("SO3", irreps) x1 = torch.randn( 32, irreps1.dim, device=device, requires_grad=True, dtype=torch.float64 @@ -103,7 +161,7 @@ def test_channel_wise_bwd_bwd(): torch.manual_seed(0) w = torch.randn( - m.weight_numel, device="cuda", requires_grad=True, dtype=torch.float64 + 1, m.weight_numel, device="cuda", requires_grad=True, dtype=torch.float64 ) (grad1, grad2, grad3) = torch.autograd.grad( diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index 832904b..57a77b8 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -14,6 +14,9 @@ # limitations under the License. import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet @@ -21,16 +24,23 @@ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -list_of_irreps = [ - cue.Irreps("O3", "4x0e + 4x1o"), - cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), - cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"), +export_modes = ["compile", "script", "jit"] + +irreps = [ + ( + cue.Irreps("O3", "4x0e + 4x1o"), + cue.Irreps("O3", "4x0e + 4x1o"), + cue.Irreps("O3", "4x0e + 4x1o"), + ), + ( + cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"), + cue.Irreps("O3", "4x0e + 4x1o"), + cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"), + ), ] -@pytest.mark.parametrize("irreps1", list_of_irreps) -@pytest.mark.parametrize("irreps2", list_of_irreps) -@pytest.mark.parametrize("irreps3", list_of_irreps) +@pytest.mark.parametrize("irreps1, irreps2, irreps3", irreps) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("use_fallback", [False, True]) def test_fully_connected( @@ -71,21 +81,47 @@ def test_fully_connected( torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize("irreps1, irreps2, irreps3", irreps) +@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) +@pytest.mark.parametrize("internal_weights", [False, True]) @pytest.mark.parametrize("use_fallback", [False, True]) -def test_compile(use_fallback: bool): +@pytest.mark.parametrize("mode", export_modes) +def test_export( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3: cue.Irreps, + layout: cue.IrrepsLayout, + internal_weights: bool, + use_fallback: bool, + mode: str, + tmp_path: str, +): if use_fallback is False and not torch.cuda.is_available(): pytest.skip("CUDA is not available") - - m = cuet.FullyConnectedTensorProduct( - irreps_in1=cue.Irreps("O3", "32x0e + 32x1o"), - irreps_in2=cue.Irreps("O3", "32x0e + 32x1o"), - irreps_out=cue.Irreps("O3", "32x0e + 32x1o"), - layout=cue.mul_ir, + dtype = torch.float32 + m1 = cuet.FullyConnectedTensorProduct( + irreps1, + irreps2, + irreps3, + shared_weights=True, + internal_weights=internal_weights, + layout=layout, device=device, + dtype=dtype, use_fallback=use_fallback, ) - m_compile = torch.compile(m, fullgraph=True) - input1 = torch.randn(100, m.irreps_in1.dim, device=device) - input2 = torch.randn(100, m.irreps_in2.dim, device=device) - m_compile(input1, input2) + x1 = torch.randn(32, irreps1.dim, dtype=dtype).to(device) + x2 = torch.randn(32, irreps2.dim, dtype=dtype).to(device) + + if internal_weights: + inputs = (x1, x2) + else: + weights = torch.randn(1, m1.weight_numel, device=device, dtype=dtype) + inputs = (x1, x2, weights) + + out1 = m1(*inputs) + + m1 = module_with_mode(mode, m1, inputs, dtype, tmp_path) + out2 = m1(*inputs) + torch.testing.assert_close(out1, out2) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 071cb2a..dd4389b 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -12,207 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import timeit import pytest import torch import torch._dynamo +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -torch._dynamo.config.cache_size_limit = 100 - - -try: - import cuequivariance_ops_torch.onnx # noqa: F401 - import onnx # noqa: F401 - import onnxruntime # noqa: F401 - import onnxscript # noqa: F401 - from cuequivariance_ops_torch.tensorrt import register_plugins - - ONNX_AVAILABLE = True -except Exception: - ONNX_AVAILABLE = False - - -try: - import torch_tensorrt - - TORCH_TRT_AVAILABLE = True -except Exception: - TORCH_TRT_AVAILABLE = False - - -def verify_onnx(module, onnx_module, inputs, dtype): - if dtype != torch.float32: - pytest.skip("onnxrt only checked for float32") - from onnxruntime import SessionOptions - from onnxruntime_extensions import get_library_path - from torch.onnx.verification import ( - VerificationOptions, - _compare_onnx_pytorch_model, - ) - - original_init = SessionOptions.__init__ - - def new_init(self): - original_init(self) - try: - self.register_custom_ops_library(get_library_path()) - except Exception: - pass - - SessionOptions.__init__ = new_init - _compare_onnx_pytorch_model( - module, onnx_module, tuple(inputs), None, None, VerificationOptions() - ) - SessionOptions.__init__ = original_init - torch.cuda.synchronize() - torch.cuda.empty_cache() - - -def verify_trt(module, onnx_module, inputs, dtype): - import tensorrt - from pkg_resources import parse_version - - if parse_version(tensorrt.__version__) < parse_version("10.3.0"): - pytest.skip("TRT < 10.3.0 is not supported!") - if dtype == torch.float64: - pytest.skip("TRT does not support float64") - - from onnxruntime import InferenceSession, SessionOptions - from onnxruntime_extensions import get_library_path - from polygraphy.backend.onnxrt import OnnxrtRunner - from polygraphy.backend.trt import ( - CreateConfig, - TrtRunner, - engine_from_network, - network_from_onnx_path, - ) - from polygraphy.comparator import Comparator, DataLoader - - register_plugins() - - network = network_from_onnx_path(onnx_module) - trt_engine = engine_from_network(network, config=CreateConfig()) - - if dtype != torch.float32: - pytest.skip("Comparator only supports float32") - - # Create runners for ONNX and TRT models - trt_runner = TrtRunner(trt_engine) - - options = SessionOptions() - options.register_custom_ops_library(get_library_path()) - onnx_runner = OnnxrtRunner(InferenceSession(onnx_module, sess_options=options)) - - results = Comparator.run([trt_runner, onnx_runner], data_loader=DataLoader()) - Comparator.compare_accuracy(results) - torch.cuda.synchronize() - torch.cuda.empty_cache() - - -def module_with_mode( - mode, - module, - inputs, - math_dtype, - tmp_path, - grad_modes=["eager", "compile", "jit", "export"], -): - if isinstance(inputs[0], list): - dtype = inputs[0][0].dtype - else: - dtype = inputs[0].dtype - if mode in ["trt", "torch_trt", "onnx", "onnx_dynamo", "export"]: - if not ONNX_AVAILABLE: - pytest.skip("ONNX not available!") - if dtype == torch.float64 or math_dtype == torch.float64: - pytest.skip("TRT/ORT do not support float64") - - with torch.set_grad_enabled(mode in grad_modes): - if mode == "compile": - import sys - - if sys.version_info.major == 3 and sys.version_info.minor >= 12: - pytest.skip("torch dynamo needs cpy <= 3.11") - module = torch.compile(module) - elif mode == "fx": - module = torch.fx.symbolic_trace(module) - elif mode == "jit": - module = torch.jit.trace(module, inputs) - fname = os.path.join(tmp_path, "test.ts") - torch.jit.save(module, fname) - module = torch.jit.load(fname) - elif mode == "export": - exp_program = torch.export.export(module, tuple(inputs)) - fname = os.path.join(tmp_path, "test.pt2") - torch.export.save(exp_program, fname) - del exp_program - module = torch.export.load(fname).module() - elif mode == "torch_trt": - if not TORCH_TRT_AVAILABLE: - pytest.skip("torch_tensorrt is not installed!") - register_plugins() - exp_program = torch_tensorrt.dynamo.trace(module, inputs) - module = torch_tensorrt.dynamo.compile( - exp_program, - inputs=inputs, - require_full_compilation=True, - min_block_size=1, - enabled_precisions={torch.float32, dtype}, - # dryrun=True - ) - elif mode == "onnx" or mode == "trt": - try: - onnx_path = os.path.join(tmp_path, "test.onnx") - torch.onnx.export( - module, tuple(inputs), onnx_path, opset_version=17, verbose=False - ) - if mode == "trt": - verify_trt(module, onnx_path, inputs, dtype) - else: - verify_onnx(module, onnx_path, inputs, dtype) - except ImportError: - pytest.skip("ONNX/TRT is not available") - - elif mode == "onnx_dynamo": - try: - from cuequivariance_ops_torch.onnx import ( - cuequivariance_ops_torch_onnx_registry, - ) - - export_options = torch.onnx.ExportOptions( - onnx_registry=cuequivariance_ops_torch_onnx_registry - ) - onnx_program = torch.onnx.dynamo_export( - module, *inputs, export_options=export_options - ) - onnx_path = os.path.join(tmp_path, "test.onnx") - onnx_program.save(onnx_path) - verify_onnx(module, onnx_path, inputs, dtype) - except ImportError: - pytest.skip("ONNX is not available") - elif mode == "eager": - pass - else: - raise ValueError(f"No such mode: {mode}") - - torch.cuda.synchronize() - torch.cuda.empty_cache() - - return module - - -def maybe_detach_and_to(tensor, *args, **kwargs): - if tensor is not None: - return tensor.clone().detach().to(*args, **kwargs) - return None - - device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") @@ -279,7 +91,6 @@ def test_performance_cuda_vs_fx( device=device, math_dtype=math_dtype, use_fallback=True, - optimize_fallback=True, ) inputs = [ @@ -345,7 +156,6 @@ def test_precision_cuda_vs_fx( device=device, math_dtype=torch.float64, use_fallback=True, - optimize_fallback=True, ) inputs = [x.to(torch.float64) for x in inputs] y1 = m(inputs).to(dtype) @@ -353,60 +163,13 @@ def test_precision_cuda_vs_fx( torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol) -@pytest.mark.parametrize("e", make_descriptors()) -@pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) -def test_compile( - e: cue.EquivariantTensorProduct, - dtype: torch.dtype, - math_dtype: torch.dtype, - atol: float, - rtol: float, -): - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, use_fallback=False, device=device, math_dtype=math_dtype - ) - inputs = [ - torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs - ] - res = m(inputs) - m_compile = torch.compile(m, fullgraph=True) - res_script = m_compile(inputs) - torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize("e", make_descriptors()) -@pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) -def test_script( - e: cue.EquivariantTensorProduct, - dtype: torch.dtype, - math_dtype: torch.dtype, - atol: float, - rtol: float, -): - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, use_fallback=False, device=device, math_dtype=math_dtype - ) - inputs = [ - torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs - ] - res = m(inputs) - m_script = torch.jit.script(m) - res_script = m_script(inputs) - torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) - - -# export_modes = ["onnx", "onnx_dynamo", "trt", "torch_trt", "jit"] -export_modes = ["trt", "onnx"] +export_modes = ["compile", "script", "jit"] +# "export" does not support the change of batch size @pytest.mark.parametrize("e", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) +@pytest.mark.parametrize("use_fallback", [True, False]) @pytest.mark.parametrize("mode", export_modes) def test_export( e: cue.EquivariantTensorProduct, @@ -415,18 +178,23 @@ def test_export( atol: float, rtol: float, mode: str, + use_fallback: bool, tmp_path, ): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, math_dtype=math_dtype, use_fallback=False, device=device + e, + layout=cue.mul_ir, + math_dtype=math_dtype, + use_fallback=use_fallback, + device=device, ) inputs = [ - torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs + torch.randn((512, inp.dim), device=device, dtype=dtype) for inp in e.inputs ] res = m(inputs) - m_script = module_with_mode(mode, m, [inputs], math_dtype, tmp_path) - res_script = m_script(inputs) + m = module_with_mode(mode, m, [inputs], math_dtype, tmp_path) + res_script = m(inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) diff --git a/cuequivariance_torch/tests/primitives/script_test.py b/cuequivariance_torch/tests/primitives/primitive_export_test.py similarity index 62% rename from cuequivariance_torch/tests/primitives/script_test.py rename to cuequivariance_torch/tests/primitives/primitive_export_test.py index 4706bff..72592a9 100644 --- a/cuequivariance_torch/tests/primitives/script_test.py +++ b/cuequivariance_torch/tests/primitives/primitive_export_test.py @@ -1,5 +1,8 @@ import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue from cuequivariance_torch.primitives.symmetric_tensor_product import ( @@ -14,8 +17,11 @@ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +export_modes = ["script", "export"] + -def test_script_symmetric_contraction(): +@pytest.mark.parametrize("mode", export_modes) +def test_script_symmetric_contraction(mode, tmp_path): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") @@ -28,13 +34,16 @@ def test_script_symmetric_contraction(): i0 = torch.zeros(batch, device=device, dtype=torch.int32) x1 = torch.randn(batch, ds[0].operands[1].size, device=device, dtype=torch.float32) - module = SymmetricTensorProduct(ds, device, torch.float32) - module = torch.jit.script(module) + m = SymmetricTensorProduct(ds, device, torch.float32) + inputs = (x0, i0, x1) + module = module_with_mode(mode, m, inputs, torch.float32, tmp_path) + out1 = m(*inputs) + out2 = module(*inputs) + torch.testing.assert_close(out1, out2) - assert module(x0, i0, x1).shape == (batch, ds[0].operands[-1].size) - -def test_script_fused_tp_3(): +@pytest.mark.parametrize("mode", export_modes) +def test_script_fused_tp_3(mode, tmp_path): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") @@ -49,14 +58,16 @@ def test_script_fused_tp_3(): batch = 12 x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) + inputs = [x0, x1] + m = FusedTensorProductOp3(d, (0, 1), device, torch.float32) + module = module_with_mode(mode, m, (inputs,), torch.float32, tmp_path) + out1 = m(inputs) + out2 = module(inputs) + torch.testing.assert_close(out1, out2) - module = FusedTensorProductOp3(d, (0, 1), device, torch.float32) - module = torch.jit.script(module) - - assert module([x0, x1]).shape == (batch, d.operands[2].size) - -def test_script_fused_tp_4(): +@pytest.mark.parametrize("mode", export_modes) +def test_script_fused_tp_4(mode, tmp_path): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") @@ -74,13 +85,16 @@ def test_script_fused_tp_4(): x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) x2 = torch.randn(batch, d.operands[2].size, device=device, dtype=torch.float32) - module = FusedTensorProductOp4(d, (0, 1, 2), device, torch.float32) - module = torch.jit.script(module) - - assert module([x0, x1, x2]).shape == (batch, d.operands[3].size) + inputs = [x0, x1, x2] + m = FusedTensorProductOp4(d, [0, 1, 2], device, torch.float32) + module = module_with_mode(mode, m, (inputs,), torch.float32, tmp_path) + out1 = m(inputs) + out2 = module(inputs) + torch.testing.assert_close(out1, out2) -def test_script_uniform_tp_3(): +@pytest.mark.parametrize("mode", export_modes) +def test_script_uniform_tp_3(mode, tmp_path): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") @@ -95,14 +109,17 @@ def test_script_uniform_tp_3(): batch = 12 x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) + inputs = [x0, x1] - module = TensorProductUniform3x1d(d, device, torch.float32) - module = torch.jit.script(module) + m = TensorProductUniform3x1d(d, device, torch.float32) + module = module_with_mode(mode, m, (inputs,), torch.float32, tmp_path) + out1 = m(inputs) + out2 = module(inputs) + torch.testing.assert_close(out1, out2) - assert module([x0, x1]).shape == (batch, d.operands[2].size) - -def test_script_uniform_tp_4(): +@pytest.mark.parametrize("mode", export_modes) +def test_script_uniform_tp_4(mode, tmp_path): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") @@ -118,8 +135,10 @@ def test_script_uniform_tp_4(): x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) x2 = torch.randn(batch, d.operands[2].size, device=device, dtype=torch.float32) + inputs = [x0, x1, x2] - module = TensorProductUniform4x1d(d, device, torch.float32) - module = torch.jit.script(module) - - assert module([x0, x1, x2]).shape == (batch, d.operands[3].size) + m = TensorProductUniform4x1d(d, device, torch.float32) + module = module_with_mode(mode, m, (inputs,), torch.float32, tmp_path) + out1 = m(inputs) + out2 = module(inputs) + torch.testing.assert_close(out1, out2) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 7858576..9662e85 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -14,6 +14,9 @@ # limitations under the License. import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance.segmented_tensor_product as stp @@ -25,13 +28,7 @@ def make_descriptors(): yield descriptors.symmetric_contraction( - cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [3] - ).ds - yield descriptors.symmetric_contraction( - cue.Irreps("O3", "0e + 1o + 2e"), cue.Irreps("O3", "0e + 1o"), [4] - ).ds - yield descriptors.symmetric_contraction( - cue.Irreps("SU2", "0 + 1/2"), cue.Irreps("SU2", "0 + 1/2"), [5] + cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] ).ds d1 = stp.SegmentedTensorProduct.from_subscripts(",,") @@ -60,22 +57,14 @@ def make_descriptors(): @pytest.mark.parametrize("ds", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings1) -@pytest.mark.parametrize("use_fallback", [False, True]) def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( - ds: list[stp.SegmentedTensorProduct], - dtype, - math_dtype, - tol: float, - use_fallback: bool, + ds: list[stp.SegmentedTensorProduct], dtype, math_dtype, tol: float ): - if use_fallback is False and not torch.cuda.is_available(): - pytest.skip("CUDA is not available") + use_fallback = not torch.cuda.is_available() m = cuet.IWeightedSymmetricTensorProduct( ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback ) - if use_fallback is False: - m = torch.jit.script(m) x0 = torch.randn((2, m.x0_size), device=device, dtype=dtype, requires_grad=True) i0 = torch.tensor([0, 1, 0], dtype=torch.int32, device=device) @@ -87,11 +76,7 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( out1 = m(x0, i0, x1) m = cuet.IWeightedSymmetricTensorProduct( - ds, - math_dtype=torch.float64, - device=device, - use_fallback=True, - optimize_fallback=True, + ds, math_dtype=torch.float64, device=device, use_fallback=True ) out2 = m(x0_, i0, x1_) @@ -146,6 +131,9 @@ def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: b # .to should have no effect for param in m.parameters(): assert False # no parameters + + m = m.to(device) + m = m.to(torch.float32) m = m.to(torch.float64) out2 = m(x0, i0, x1) @@ -153,3 +141,39 @@ def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: b assert out1.dtype == dtype assert out2.dtype == dtype assert (out1 == out2).all() + + +export_modes = ["compile", "script", "jit"] + + +@pytest.mark.parametrize("ds", make_descriptors()) +@pytest.mark.parametrize("mode", export_modes) +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_export( + ds: list[stp.SegmentedTensorProduct], + mode: str, + use_fallback: bool, + tmp_path, +): + if not use_fallback and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dtype = torch.float32 + math_dtype = torch.float32 + + if use_fallback is True and mode in ["trt"]: + pytest.skip(f"{mode} not supported for the fallback!") + + m = cuet.IWeightedSymmetricTensorProduct( + ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback + ) + x0 = torch.randn((2, m.x0_size), device=device, dtype=dtype, requires_grad=True) + i0 = torch.tensor([0, 1, 0], dtype=torch.int32, device=device) + x1 = torch.randn( + (i0.size(0), m.x1_size), device=device, dtype=dtype, requires_grad=True + ) + inputs = (x0, i0, x1) + out1 = m(*inputs) + m = module_with_mode(mode, m, inputs, torch.float32, tmp_path) + out2 = m(*inputs) + torch.testing.assert_close(out1, out2) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index d8c26ef..a54a530 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet from cuequivariance import descriptors @@ -62,7 +63,7 @@ def make_descriptors(): "u,,u", ",v,v", ]: - d = stp.SegmentedTensorProduct.from_subscripts(subscripts) + d = cue.SegmentedTensorProduct.from_subscripts(subscripts) for i in range(3): d.add_path( *[None] * d.num_operands, @@ -91,9 +92,9 @@ def make_descriptors(): @pytest.mark.parametrize("d", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings) -@pytest.mark.parametrize("use_fallback", [False, True]) +@pytest.mark.parametrize("use_fallback", [True, False]) def test_primitive_tensor_product_cuda_vs_fx( - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, dtype: torch.dtype, math_dtype: torch.dtype, tol: float, @@ -102,48 +103,76 @@ def test_primitive_tensor_product_cuda_vs_fx( if use_fallback is False and not torch.cuda.is_available(): pytest.skip("CUDA is not available") - for batches in itertools.product([(16,), (), (4, 1)], repeat=d.num_operands - 1): - inputs = [ - torch.randn( - batches[i] + (d.operands[i].size,), - device=device, - dtype=dtype, - requires_grad=True, - ) - for i in range(d.num_operands - 1) - ] - - m = cuet.TensorProduct( - d, device=device, math_dtype=math_dtype, use_fallback=use_fallback + inputs = [ + torch.randn( + (12, d.operands[i].size), + device=device, + dtype=dtype, + requires_grad=True, ) - if not use_fallback: - m = torch.jit.script(m) + for i in range(d.num_operands - 1) + ] - out1 = m(inputs) + m = cuet.TensorProduct( + d, + device=device, + math_dtype=math_dtype, + use_fallback=use_fallback, + ) - m = cuet.TensorProduct( - d, - device=device, - math_dtype=torch.float64, - use_fallback=True, - optimize_fallback=False, - ) - inputs_ = [inp.clone().to(torch.float64) for inp in inputs] - out2 = m(inputs_) + out1 = m(inputs) + + m = cuet.TensorProduct( + d, + device=device, + math_dtype=torch.float64, + use_fallback=True, + ) + + inputs_ = [inp.to(torch.float64) for inp in inputs] + out2 = m(inputs_) + + assert out1.shape[:-1] == (12,) + assert out1.dtype == dtype - assert out1.shape[:-1] == torch.broadcast_shapes(*batches) - assert out1.dtype == dtype + torch.testing.assert_close(out1, out2.to(dtype), atol=tol, rtol=tol) - torch.testing.assert_close(out1, out2.to(dtype), atol=tol, rtol=tol) + grad1 = torch.autograd.grad(out1.sum(), inputs, create_graph=True) + grad2 = torch.autograd.grad(out2.sum(), inputs_, create_graph=True) - grad1 = torch.autograd.grad(out1.sum(), inputs, create_graph=True) - grad2 = torch.autograd.grad(out2.sum(), inputs_, create_graph=True) + for g1, g2 in zip(grad1, grad2): + torch.testing.assert_close(g1, g2.to(dtype), atol=10 * tol, rtol=10 * tol) - for g1, g2 in zip(grad1, grad2): - torch.testing.assert_close(g1, g2.to(dtype), atol=10 * tol, rtol=10 * tol) + double_grad1 = torch.autograd.grad(sum(g.sum() for g in grad1), inputs) + double_grad2 = torch.autograd.grad(sum(g.sum() for g in grad2), inputs_) + + for g1, g2 in zip(double_grad1, double_grad2): + torch.testing.assert_close(g1, g2.to(dtype), atol=100 * tol, rtol=100 * tol) + + +export_modes = ["compile", "script", "jit"] + + +@pytest.mark.parametrize("d", make_descriptors()) +@pytest.mark.parametrize("mode", export_modes) +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_export(d: cue.SegmentedTensorProduct, mode, use_fallback, tmp_path): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + batch = 12 + inputs = [ + torch.randn(batch, ope.size, device=device, dtype=torch.float32) + for ope in d.operands[:-1] + ] - double_grad1 = torch.autograd.grad(sum(g.sum() for g in grad1), inputs) - double_grad2 = torch.autograd.grad(sum(g.sum() for g in grad2), inputs_) + if use_fallback is True and mode in ["trt"]: + pytest.skip(f"{mode} not supported for the fallback!") - for g1, g2 in zip(double_grad1, double_grad2): - torch.testing.assert_close(g1, g2.to(dtype), atol=100 * tol, rtol=100 * tol) + module = cuet.TensorProduct( + d, device=device, math_dtype=torch.float32, use_fallback=use_fallback + ) + out1 = module(inputs) + module = module_with_mode(mode, module, (inputs,), torch.float32, tmp_path) + out2 = module(inputs) + torch.testing.assert_close(out1, out2) diff --git a/cuequivariance_torch/tests/primitives/transpose_test.py b/cuequivariance_torch/tests/primitives/transpose_test.py index 31eb271..f1b32d7 100644 --- a/cuequivariance_torch/tests/primitives/transpose_test.py +++ b/cuequivariance_torch/tests/primitives/transpose_test.py @@ -14,6 +14,9 @@ # limitations under the License. import pytest import torch +from tests.utils import ( + module_with_mode, +) import cuequivariance_torch as cuet @@ -48,3 +51,24 @@ def test_transpose(use_fallback: bool, dtype: torch.dtype): m = cuet.TransposeSegments(segments, device, use_fallback=use_fallback) torch.testing.assert_close(m(x), xt) + + +export_modes = ["compile", "script", "jit"] + + +@pytest.mark.parametrize("mode", export_modes) +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_export(mode, use_fallback, tmp_path): + if not use_fallback and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dtype = torch.float32 + x = torch.tensor( + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 10, 11, 12, 13]], dtype=dtype, device=device + ) + segments = [(2, 3), (2, 2)] + m = cuet.TransposeSegments(segments, device, use_fallback=use_fallback) + out1 = m(x) + m = module_with_mode(mode, m, (x,), dtype, tmp_path) + out2 = m(x) + torch.testing.assert_close(out1, out2) diff --git a/cuequivariance_torch/tests/utils.py b/cuequivariance_torch/tests/utils.py new file mode 100644 index 0000000..5102e1e --- /dev/null +++ b/cuequivariance_torch/tests/utils.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +from typing import Sequence + +import pytest +import torch +import torch._dynamo + +torch._dynamo.config.cache_size_limit = 100 + +try: + import cuequivariance_ops_torch.onnx # noqa: F401 + import onnx # noqa: F401 + import onnxruntime # noqa: F401 + import onnxscript # noqa: F401 + from cuequivariance_ops_torch.tensorrt import register_plugins + + ONNX_AVAILABLE = True +except Exception: + ONNX_AVAILABLE = False + + +try: + import torch_tensorrt + + TORCH_TRT_AVAILABLE = True +except Exception: + TORCH_TRT_AVAILABLE = False + + +def verify_onnx(module, onnx_module, inputs, dtype): + if dtype != torch.float32: + pytest.skip("onnxrt only checked for float32") + from onnxruntime import SessionOptions + from onnxruntime_extensions import get_library_path + from torch.onnx.verification import ( + VerificationOptions, + _compare_onnx_pytorch_model, + ) + + original_init = SessionOptions.__init__ + + def new_init(self): + original_init(self) + try: + self.register_custom_ops_library(get_library_path()) + except Exception: + pass + + SessionOptions.__init__ = new_init + _compare_onnx_pytorch_model( + module, onnx_module, tuple(inputs), None, None, VerificationOptions() + ) + SessionOptions.__init__ = original_init + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def verify_trt(module, onnx_module, inputs, dtype): + import tensorrt + from pkg_resources import parse_version + + if parse_version(tensorrt.__version__) < parse_version("10.3.0"): + pytest.skip("TRT < 10.3.0 is not supported!") + if dtype == torch.float64: + pytest.skip("TRT does not support float64") + + from onnxruntime import InferenceSession, SessionOptions + from onnxruntime_extensions import get_library_path + from polygraphy.backend.onnxrt import OnnxrtRunner + from polygraphy.backend.trt import ( + CreateConfig, + TrtRunner, + engine_from_network, + network_from_onnx_path, + ) + from polygraphy.comparator import Comparator, DataLoader + + register_plugins() + + network = network_from_onnx_path(onnx_module) + trt_engine = engine_from_network(network, config=CreateConfig()) + + if dtype != torch.float32: + pytest.skip("Comparator only supports float32") + + # Create runners for ONNX and TRT models + trt_runner = TrtRunner(trt_engine) + + options = SessionOptions() + options.register_custom_ops_library(get_library_path()) + onnx_runner = OnnxrtRunner(InferenceSession(onnx_module, sess_options=options)) + + results = Comparator.run([trt_runner, onnx_runner], data_loader=DataLoader()) + Comparator.compare_accuracy(results) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def module_with_mode( + mode: str, + module: torch.nn.Module, + inputs: list[torch.Tensor] | list[list[torch.Tensor]], + math_dtype: torch.dtype, + tmp_path: str, + grad_modes: list[str] = ["eager", "compile", "jit", "export"], +) -> torch.nn.Module: + if isinstance(inputs[0], list): + dtype = inputs[0][0].dtype + else: + dtype = inputs[0].dtype + if mode in ["trt", "torch_trt", "onnx", "onnx_dynamo"]: + if not ONNX_AVAILABLE: + pytest.skip("ONNX not available!") + if dtype == torch.float64 or math_dtype == torch.float64: + pytest.skip("TRT/ORT do not support float64") + + with torch.set_grad_enabled(mode in grad_modes): + if mode == "compile": + module = torch.compile(module, fullgraph=True) + elif mode == "fx": + module = torch.fx.symbolic_trace(module) + elif mode == "script": + module = torch.jit.script(module) + fname = os.path.join(tmp_path, "test.ts") + torch.jit.save(module, fname) + module = torch.jit.load(fname) + elif mode == "jit": + module = torch.jit.trace(module, inputs) + fname = os.path.join(tmp_path, "test.ts") + torch.jit.save(module, fname) + module = torch.jit.load(fname) + elif mode == "export": + exp_program = torch.export.export(module, tuple(inputs)) + fname = os.path.join(tmp_path, "test.pt2") + torch.export.save(exp_program, fname) + del exp_program + module = torch.export.load(fname).module() + elif mode == "torch_trt": + if not TORCH_TRT_AVAILABLE: + pytest.skip("torch_tensorrt is not installed!") + register_plugins() + exp_program = torch.export.export(module, tuple(inputs)) + module = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + require_full_compilation=True, + min_block_size=1, + enabled_precisions={torch.float32, dtype}, + # dryrun=True + ) + elif mode == "onnx" or mode == "trt": + try: + onnx_path = os.path.join(tmp_path, "test.onnx") + torch.onnx.export( + module, tuple(inputs), onnx_path, opset_version=17, verbose=False + ) + if mode == "trt": + verify_trt(module, onnx_path, inputs, dtype) + else: + verify_onnx(module, onnx_path, inputs, dtype) + except ImportError: + pytest.skip("ONNX/TRT is not available") + + elif mode == "onnx_dynamo": + try: + from cuequivariance_ops_torch.onnx import ( + cuequivariance_ops_torch_onnx_registry, + ) + + export_options = torch.onnx.ExportOptions( + onnx_registry=cuequivariance_ops_torch_onnx_registry + ) + onnx_program = torch.onnx.dynamo_export( + module, *inputs, export_options=export_options + ) + onnx_path = os.path.join(tmp_path, "test.onnx") + onnx_program.save(onnx_path) + verify_onnx(module, onnx_path, inputs, dtype) + except ImportError: + pytest.skip("ONNX is not available") + elif mode == "eager": + pass + else: + raise ValueError(f"No such mode: {mode}") + + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + return module + + +def create_random_tensor_2d(batch_size, stride, requires_grad, dtype, is_shared): + data = torch.randn( + (stride,) if is_shared else (batch_size, stride), + dtype=dtype, + device="cuda", + ).requires_grad_(requires_grad) + + return data + + +def maybe_detach_and_to(tensor, *args, **kwargs): + if tensor is not None: + return tensor.clone().detach().to(*args, **kwargs) + return None + + +def run_fwd_test(module, x: Sequence): + with torch.no_grad(): + out = module(*x) + test_output = [maybe_detach_and_to(out, dtype=torch.float32)] + return test_output + + +def run_fwd_bwd_test(module, x: Sequence): + out = module(*x) + + loss = out.sum() + loss.backward() + + test_output = [maybe_detach_and_to(out, dtype=torch.float32)] + test_output.extend([maybe_detach_and_to(t.grad, dtype=torch.float32) for t in x]) + + return test_output + + +def run_bwd_bwd_test(module, x: Sequence): + test_outputs = [] + out = module(*x) + grads = torch.autograd.grad(out.pow(2).sum(), x, create_graph=True) + test_outputs.extend([maybe_detach_and_to(g, dtype=torch.float32) for g in grads]) + loss = sum([g.sum() for g in grads]) + loss.backward() + test_outputs.extend([maybe_detach_and_to(t.grad, dtype=torch.float32) for t in x]) + return test_outputs + + +def assert_close_modules(m_test, m_ref, inputs_test, procedure, tol_dict): + outs_test = procedure(m_test, inputs_test) + + inputs_ref = [ + x.clone() + .detach() + .to(device="cuda", dtype=torch.float32) + .requires_grad_(x.requires_grad) + for x in inputs_test + ] + outs_ref = procedure(m_ref, inputs_ref) + for out_test, out_ref in zip(outs_test, outs_ref): + torch.testing.assert_close(out_test, out_ref, **tol_dict) + + +tol_dict = { + # we compare against double for precision reasons + # hence FP64 and FP32 threshold are the same + (torch.float64, torch.float64): {"atol": 1e-9, "rtol": 1e-5}, + (torch.float32, torch.float64): {"atol": 1e-4, "rtol": 1e-5}, + (torch.float64, torch.float32): {"atol": 1e-4, "rtol": 1e-5}, + (torch.float32, torch.float32): {"atol": 1e-4, "rtol": 1e-5}, + (torch.bfloat16, torch.float32): {"atol": 4.0, "rtol": 1e-2}, + (torch.float16, torch.float32): {"atol": 0.25, "rtol": 1e-2}, +} diff --git a/docs/api/cuequivariance_torch.rst b/docs/api/cuequivariance_torch.rst index ff33bfa..2785cf6 100644 --- a/docs/api/cuequivariance_torch.rst +++ b/docs/api/cuequivariance_torch.rst @@ -41,12 +41,7 @@ Special Cases of Tensor Products Linear SymmetricContraction TransposeIrrepsLayout - -.. autosummary:: - :toctree: generated/ - :template: function_template.rst - - spherical_harmonics + SphericalHarmonics Euclidean Operations -------------------- diff --git a/docs/conf.py b/docs/conf.py index b00ec7e..94992cb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,8 +21,8 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import datetime -import nvidia_sphinx_theme +import nvidia_sphinx_theme # noqa current_year = datetime.datetime.now().year