Skip to content

Commit

Permalink
Working torch.jit.script() and torch.compile() support (#44)
Browse files Browse the repository at this point in the history
* test and quick fix for zero batch

* trigger uniform 1d in test

* satisfy linter

Signed-off-by: Mario Geiger <[email protected]>

* from typing import

* determine math_dtype earlier

* warning with pip commands

* remove unused argument

* changelog

* list of inputs

* add Fixed subtite

* changelog

* add test for torch.jit.script

* fix

* remove keyword-only and import in the forward

* low lvl script tests

* TensorProduct working with script()

Signed-off-by: Boris Fomitchev <[email protected]>

* add 4 operands tests

* Unit tests run

Signed-off-by: Boris Fomitchev <[email protected]>

* Restoring debug logging

Signed-off-by: Boris Fomitchev <[email protected]>

* Parameterized script test

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixed transpose for script(), script_test successful

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixed input mutation

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixed tests

Signed-off-by: Boris Fomitchev <[email protected]>

* format with black

* format with black

* fix tests

* fix missing parenthesis

* fix tests: increase torch._dynamo.config.cache_size_limit

* fix docstring tests

* replace == by is

* clean use_fallback conditions

* fix

* fix

* Export test added, scripting fallback attempt

Signed-off-by: Boris Fomitchev <[email protected]>

* enable tests on cpu

* fix tests

* fix ruff

* fix

* fix docstring tests

* add -x to tests

* Working around torch_tensorrt bugs

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixing utils.py import

Signed-off-by: Boris Fomitchev <[email protected]>

* Adding utils.py

Signed-off-by: Boris Fomitchev <[email protected]>

* Style

Signed-off-by: Boris Fomitchev <[email protected]>

* import nvidia_sphinx_theme

* spherical harmonics module

* fix tests

* test SymmetricContraction export

* Fixed symmetric_contraction test

Signed-off-by: Boris Fomitchev <[email protected]>

* add device info

* fix sh

* fix

* skip

* torch._dynamo.config.cache_size_limit = 100

* fix test

* Script compatibility for fallback

Signed-off-by: Boris Fomitchev <[email protected]>

* style

Signed-off-by: Boris Fomitchev <[email protected]>

* Trying to make trace() work

Signed-off-by: Boris Fomitchev <[email protected]>

* Restoring integer cast

Signed-off-by: Boris Fomitchev <[email protected]>

* Skipping failing tests

Signed-off-by: Boris Fomitchev <[email protected]>

* disabling cast for fallback

Signed-off-by: Boris Fomitchev <[email protected]>

* optimize_fallback=use_fallback

* Fixing the reinterpret cast

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixing clone()

Signed-off-by: Boris Fomitchev <[email protected]>

* delete broadcast_shapes

* delete _reshape

* rename

* Using alternative disable type change fixture

Signed-off-by: Boris Fomitchev <[email protected]>

* Restored assert

Signed-off-by: Boris Fomitchev <[email protected]>

* try fix test

* simplify symmetric_tensor_product_test to make test run faster

* try to fix some tests

* Fixing disable_type_conv

Signed-off-by: Boris Fomitchev <[email protected]>

* try fix

* fix strange bug

* Script fixes for uniform

Signed-off-by: Boris Fomitchev <[email protected]>

* add test_script_tensor_product

* Moving all export tests, disabling torch_trt for now

Signed-off-by: Boris Fomitchev <[email protected]>

* more strict input shapes

* add back @pytest.mark.parametrize("mode", export_modes)

* fix

* Fixing noconv bug

Signed-off-by: Boris Fomitchev <[email protected]>

* Really fixing noconv

Signed-off-by: Boris Fomitchev <[email protected]>

* fix linear

* fix rotations

* fix tpfc

* fix tpcw

* less test

* remove unused mode in tensor_product_test

* typo

* disable export

* Reduced export test modes list

Signed-off-by: Boris Fomitchev <[email protected]>

* Added unit tests to operations and the rest of the primitives

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixing script() for non-internal weights

Signed-off-by: Boris Fomitchev <[email protected]>

* skip GPU tests when running on CPU

* Fix: if use_fallback is None and cuda is not available => use fallback

---------

Signed-off-by: Mario Geiger <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Co-authored-by: Mario Geiger <[email protected]>
Co-authored-by: Mario Geiger <[email protected]>
  • Loading branch information
3 people authored Jan 6, 2025
1 parent 163f5f1 commit dc2b096
Show file tree
Hide file tree
Showing 38 changed files with 994 additions and 685 deletions.
2 changes: 1 addition & 1 deletion cuequivariance/cuequivariance/irreps_array/irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import dataclasses
import re
from typing import NamedTuple, Union, Type, Any, Sequence, Callable, Optional
from typing import Any, Callable, NamedTuple, Optional, Sequence, Type, Union

import cuequivariance as cue

Expand Down
2 changes: 1 addition & 1 deletion cuequivariance/cuequivariance/irreps_array/misc_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import warnings
from typing import Generator, Optional, Union, Any
from typing import Any, Generator, Optional, Union

import cuequivariance as cue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import itertools
import logging
from math import prod
from typing import FrozenSet, List, Optional, Sequence, Tuple, Iterator, Union
from typing import FrozenSet, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion cuequivariance/cuequivariance/misc/sympy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from fractions import Fraction
import numpy as np

import numpy as np
import sympy


Expand Down
2 changes: 1 addition & 1 deletion cuequivariance/cuequivariance/representation/irrep_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np

from cuequivariance.misc.linalg import round_to_sqrt_rational
from cuequivariance.representation import Irrep, SU2
from cuequivariance.representation import SU2, Irrep


# This function is copied from https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irreps/so3_real.py
Expand Down
1 change: 1 addition & 0 deletions cuequivariance/cuequivariance/representation/irrep_su2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import itertools
import re
from dataclasses import dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import dataclasses
import math
from typing import Optional, Union, Sequence
from typing import Optional, Sequence, Union

from cuequivariance import segmented_tensor_product as stp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import math
import re
import zlib
from typing import Any, Optional, Union, Callable, Sequence
from typing import Any, Callable, Optional, Sequence, Union

import numpy as np
import opt_einsum
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance/tests/context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest

import cuequivariance as cue
import numpy as np


def test_rep_collection_context():
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance/tests/equivariant_tensor_products_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import pytest

import cuequivariance as cue
from cuequivariance import descriptors
import cuequivariance.segmented_tensor_product as stp
from cuequivariance import descriptors


def test_commutativity_squeeze_flatten():
Expand Down
4 changes: 2 additions & 2 deletions cuequivariance_torch/cuequivariance_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
vector_to_euler_angles,
Inversion,
)
from .operations.spherical_harmonics import spherical_harmonics
from .operations.spherical_harmonics import SphericalHarmonics

from cuequivariance_torch import layers

Expand All @@ -55,6 +55,6 @@
"Inversion",
"encode_rotation_angle",
"vector_to_euler_angles",
"spherical_harmonics",
"SphericalHarmonics",
"layers",
]
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class FullyConnectedTensorProductConv(nn.Module):
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.
Examples:
>>> in_irreps = cue.Irreps("O3", "4x0e + 4x1o")
Expand Down Expand Up @@ -106,7 +105,6 @@ def __init__(
mlp_activation: Union[nn.Module, Sequence[nn.Module], None] = nn.GELU(),
layout: cue.IrrepsLayout = None, # e3nn_compat_mode
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()

Expand All @@ -127,7 +125,6 @@ def __init__(
layout=self.layout,
shared_weights=False,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

self.batch_norm = (
Expand Down
11 changes: 3 additions & 8 deletions cuequivariance_torch/cuequivariance_torch/operations/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class Linear(torch.nn.Module):
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.
"""

def __init__(
Expand All @@ -52,7 +51,6 @@ def __init__(
dtype: Optional[torch.dtype] = None,
math_dtype: Optional[torch.dtype] = None,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
irreps_in, irreps_out = default_irreps(irreps_in, irreps_out)
Expand All @@ -77,7 +75,7 @@ def __init__(
if not self.shared_weights:
raise ValueError("Internal weights should be shared")
self.weight = torch.nn.Parameter(
torch.randn(self.weight_numel, device=device, dtype=dtype)
torch.randn(1, self.weight_numel, device=device, dtype=dtype)
)
else:
self.weight = None
Expand All @@ -90,7 +88,6 @@ def __init__(
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

def extra_repr(self) -> str:
Expand Down Expand Up @@ -122,9 +119,7 @@ def forward(

weight = self.weight

if self.shared_weights and weight.ndim != 1:
raise ValueError("Shared weights should be 1D tensor")
if not self.shared_weights and weight.ndim != 2:
raise ValueError("Weights should be 2D tensor")
if weight is None:
raise ValueError("Weights should not be None")

return self.f([weight, x])
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
device: Optional[torch.device] = None,
math_dtype: Optional[torch.dtype] = None,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
(irreps,) = default_irreps(irreps)
Expand All @@ -62,7 +61,6 @@ def __init__(
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

def forward(
Expand Down Expand Up @@ -158,8 +156,6 @@ class Inversion(torch.nn.Module):
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.
"""

def __init__(
Expand All @@ -172,7 +168,6 @@ def __init__(
device: Optional[torch.device] = None,
math_dtype: Optional[torch.dtype] = None,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
(irreps,) = default_irreps(irreps)
Expand All @@ -191,7 +186,6 @@ def __init__(
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,52 +15,57 @@
from typing import Optional

import torch
import torch.nn as nn

import cuequivariance as cue
import cuequivariance_torch as cuet
from cuequivariance import descriptors


def spherical_harmonics(
ls: list[int],
vectors: torch.Tensor,
normalize: bool = True,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
) -> torch.Tensor:
r"""Compute the spherical harmonics of the input vectors.
class SphericalHarmonics(nn.Module):
r"""Compute the spherical harmonics of the input vectors as a torch module."""

Args:
ls (list of int): List of spherical harmonic degrees.
vectors (torch.Tensor): Input vectors of shape (..., 3).
normalize (bool, optional): Whether to normalize the input vectors. Defaults to True.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
def __init__(
self,
ls: list[int],
normalize: bool = True,
device: Optional[torch.device] = None,
math_dtype: Optional[torch.dtype] = None,
use_fallback: Optional[bool] = None,
):
"""
Args:
ls (list of int): List of spherical harmonic degrees.
normalize (bool, optional): Whether to normalize the input vectors. Defaults to True.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
"""
super().__init__()
self.ls = ls if isinstance(ls, list) else [ls]
assert self.ls == sorted(set(self.ls))
self.normalize = normalize

optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.
self.f = cuet.EquivariantTensorProduct(
descriptors.spherical_harmonics(cue.SO3(1), self.ls),
layout=cue.ir_mul,
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
)

Returns:
torch.Tensor: The spherical harmonics of the input vectors of shape (..., dim)
where dim is the sum of 2*l+1 for l in ls.
"""
if isinstance(ls, int):
ls = [ls]
assert ls == sorted(set(ls))
assert vectors.shape[-1] == 3
def forward(self, vectors: torch.Tensor) -> torch.Tensor:
"""
Args:
vectors (torch.Tensor): Input vectors of shape (batch, 3).
if normalize:
vectors = torch.nn.functional.normalize(vectors, dim=-1)
Returns:
torch.Tensor: The spherical harmonics of the input vectors of shape (batch, dim),
where dim is the sum of 2*l+1 for l in ls.
"""
torch._assert(vectors.ndim == 2, "Input must have shape (batch, 3)")

x = vectors.reshape(-1, 3)
m = cuet.EquivariantTensorProduct(
descriptors.spherical_harmonics(cue.SO3(1), ls),
layout=cue.ir_mul,
device=x.device,
math_dtype=x.dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)
y = m([x])
y = y.reshape(vectors.shape[:-1] + (y.shape[-1],))
return y
if self.normalize:
vectors = torch.nn.functional.normalize(vectors, dim=1)

return self.f([vectors])
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,17 @@ class SymmetricContraction(torch.nn.Module):
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.
Examples:
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
>>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o")
>>> irreps_out = cue.Irreps("O3", "32x0e")
>>> layer = SymmetricContraction(irreps_in, irreps_out, contraction_degree=3, num_elements=5, layout=cue.ir_mul, dtype=torch.float32)
>>> layer = SymmetricContraction(irreps_in, irreps_out, contraction_degree=3, num_elements=5, layout=cue.ir_mul, dtype=torch.float32, device=device)
Now `layer` can be used as part of a PyTorch model.
The argument `original_mace` can be set to `True` to emulate the original MACE implementation.
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
>>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e")
>>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o")
>>> # OLD FUNCTION DEFINITION:
Expand Down Expand Up @@ -109,7 +108,6 @@ def __init__(
math_dtype: Optional[torch.dtype] = None,
original_mace: bool = False,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()

Expand Down Expand Up @@ -155,7 +153,6 @@ def __init__(
device=device,
math_dtype=math_dtype or dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

def extra_repr(self) -> str:
Expand All @@ -180,10 +177,6 @@ def forward(
Returns:
torch.Tensor: The output tensor. It has shape (batch, irreps_out.dim).
"""
torch._assert(
x.shape[-1] == self.irreps_in.dim,
f"Input tensor must have shape (..., {self.irreps_in.dim}), got {x.shape}",
)

if self.projection is not None:
weight = torch.einsum("zau,ab->zbu", self.weight, self.projection)
Expand Down
Loading

0 comments on commit dc2b096

Please sign in to comment.