diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 16701ff..123c044 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -195,7 +195,13 @@ def __init__( use_fallback=use_fallback, ) - if any(d.num_operands != e.num_inputs + 1 for d in e.ds): + if ( + len(e.ds) > 1 + or any(d.num_operands != e.num_inputs + 1 for d in e.ds) + or any( + d.num_operands == 2 for d in e.ds + ) # special case for Spherical Harmonics ls = [1] + ): if e.num_inputs == 1: self.tp = SymmetricTPDispatcher( cuet.SymmetricTensorProduct( diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index dc61230..0c9821f 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math from typing import Optional import torch @@ -48,20 +49,6 @@ def __init__( self.descriptors = descriptors - if any(d.num_operands < 2 for d in descriptors): - d0 = next(d for d in descriptors if d.num_operands == 1) - descriptors = [d for d in descriptors if d.num_operands >= 2] - assert len(descriptors) + 1 == len(self.descriptors) - self.f0 = cuet.TensorProduct( - d0, - device=device, - math_dtype=math_dtype, - use_fallback=use_fallback, - optimize_fallback=optimize_fallback, - ) - else: - self.f0 = None - descriptors = [ stp.SegmentedTensorProduct( operands=[stp.Operand.empty_segments(1)] + d.operands, @@ -72,13 +59,10 @@ def __init__( ) for d in descriptors ] - try: - d = next(d for d in descriptors if d.num_operands >= 1) - except StopIteration: - raise ValueError("At least one STP must have at least 2 operands.") + d_max = max(descriptors, key=lambda d: d.num_operands) - self.x0_size = d.operands[0].size - self.x1_size = d.operands[1].size + self.x0_size = d_max.operands[0].size + self.x1_size = d_max.operands[1].size if d_max.num_operands >= 3 else 1 self.f = cuet.IWeightedSymmetricTensorProduct( descriptors, @@ -103,14 +87,14 @@ def forward(self, x0: torch.Tensor) -> torch.Tensor: The output tensor resulting from the indexed symmetric tensor product operation. It will have the shape (batch, x1_size). """ - out = self.f( + torch._assert( + x0.ndim == 2, f"Expected 2 dims (batch, x0_size), got shape {x0.shape}" + ) + return self.f( torch.ones((1, 1), dtype=x0.dtype, device=x0.device), torch.zeros((x0.shape[0],), dtype=torch.int32, device=x0.device), x0, ) - if self.f0 is not None: - out += self.f0([]) - return out class IWeightedSymmetricTensorProduct(torch.nn.Module): @@ -145,31 +129,26 @@ def __init__( _check_descriptors(descriptors) self.descriptors = descriptors - d = next(d for d in descriptors if d.num_operands >= 3) + d = max(descriptors, key=lambda d: d.num_operands) self.x0_size = d.operands[0].size - self.x1_size = d.operands[1].size + self.x1_size = d.operands[1].size if d.num_operands >= 3 else 1 self.x2_size = d.operands[-1].size self.has_cuda = False - self.f = None - - if use_fallback is None or use_fallback is False: + if use_fallback is False: + self.f = CUDAKernel(descriptors, device, math_dtype) + self.has_cuda = True + elif use_fallback is None: try: self.f = CUDAKernel(descriptors, device, math_dtype) self.has_cuda = True - return except NotImplementedError as e: logger.info(f"Failed to initialize CUDA implementation: {e}") except ImportError as e: logger.warning(f"Failed to initialize CUDA implementation: {e}") - if use_fallback is False and not self.has_cuda: - raise RuntimeError( - "`use_fallback` is `False` and no CUDA kernel is available!" - ) - - if self.f is None: + if not self.has_cuda: self.f = FallbackImpl( descriptors, device, @@ -228,42 +207,40 @@ def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): if len(descriptors) == 0: raise ValueError("stps must contain at least one STP.") - try: - d = next(d for d in descriptors if d.num_operands >= 3) - except StopIteration: - raise ValueError("At least one STP must have at least 3 operands.") - - x0 = d.operands[0] - x1 = d.operands[1] - x2 = d.operands[-1] + d_max = max(descriptors, key=lambda d: d.num_operands) + assert d_max.num_operands >= 2 # at least x0 and x2 for d in descriptors: - if d.operands[0].size != x0.size: + if d.operands[0].size != d_max.operands[0].size: raise ValueError("All STPs must have the same first operand (x0).") - if any(ope.size != x1.size for ope in d.operands[1:-1]): + if any(ope.size != d_max.operands[1].size for ope in d.operands[1:-1]): raise ValueError("All STPs must have the operands[1:-1] identical (x1).") - if d.operands[-1].size != x2.size: + if d.operands[-1].size != d_max.operands[-1].size: raise ValueError("All STPs must have the same last operand (x2, output).") class CUDAKernel(torch.nn.Module): def __init__( self, - stps: list[stp.SegmentedTensorProduct], + ds: list[stp.SegmentedTensorProduct], device: Optional[torch.device], math_dtype: torch.dtype, ): super().__init__() - max_degree = max(d.num_operands - 2 for d in stps) + max_degree = max(d.num_operands - 2 for d in ds) + if max_degree > 6: raise NotImplementedError("Correlation > 6 is not implemented.") - if min(d.num_operands for d in stps) == 2: - raise NotImplementedError( - "Only STPs with at least 3 operands are supported." - ) + + if len({d.operands[0].num_segments for d in ds}) != 1: + raise ValueError("All STPs must have the same number of segments in x0.") + if len({ope.num_segments for d in ds for ope in d.operands[1:-1]}) > 1: + raise ValueError("All STPs must have the same number of segments in x1.") + if len({d.operands[-1].num_segments for d in ds}) != 1: + raise ValueError("All STPs must have the same number of segments in x2.") def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: d = d.move_operand(0, -2) @@ -276,44 +253,59 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: ] ) d = d.consolidate_modes() + if d.subscripts.modes() == []: + d = d.append_modes_to_all_operands("u", dict(u=1)) # ops.SymmetricTensorContraction will "symmetrize" for the derivatives so we can sort for the forward pass d = d.sort_indices_for_identical_operands(range(0, d.num_operands - 2)) - if len(set(ope.subscripts for ope in d.operands)) != 1: + if len(d.subscripts.modes()) != 1: + raise NotImplementedError("Different modes are not supported.") + + m = d.subscripts.modes()[0] + + if not all(ope.subscripts == m for ope in d.operands): raise NotImplementedError("Different subscripts are not supported.") - return d - ds = [f(d) for d in stps] - - if ( - len( - set( - ( - d.operands[0].num_segments, - d.operands[-2].num_segments, - d.operands[-1].num_segments, - ) - for d in ds - ) - ) - != 1 - ): - raise ValueError("All STPs must have the same number of segments.") + d = d.split_mode(m, math.gcd(*d.get_dims(m))) + return d + + ds_ = [f(d) for d in ds] import cuequivariance_ops_torch as ops + d_max = max(ds_, key=lambda d: d.num_operands) + + path_segment_indices = sum((d.indices.tolist() for d in ds_), []) + path_coefficients = sum((d.stacked_coefficients.tolist() for d in ds_), []) + num_in_segments = ( + d_max.operands[0].num_segments if d_max.num_operands >= 3 else 1 + ) + num_couplings = d_max.operands[-2].num_segments + num_out_segments = d_max.operands[-1].num_segments + correlation = max(1, max_degree) + math_dtype = math_dtype + logger.debug(f"""cuequivariance_ops_torch.SymmetricTensorContraction( + path_segment_indices={path_segment_indices}, + path_coefficients={path_coefficients}, + num_in_segments={num_in_segments}, + num_couplings={num_couplings}, + num_out_segments={num_out_segments}, + correlation={correlation}, + math_dtype={math_dtype}, + )""") + self.f = ops.SymmetricTensorContraction( - sum((d.indices.tolist() for d in ds), []), - sum((d.stacked_coefficients.tolist() for d in ds), []), - ds[0].operands[0].num_segments, - ds[0].operands[-2].num_segments, - ds[0].operands[-1].num_segments, - max_degree, - math_dtype, + path_segment_indices=path_segment_indices, + path_coefficients=path_coefficients, + num_in_segments=num_in_segments, + num_couplings=num_couplings, + num_out_segments=num_out_segments, + correlation=correlation, + math_dtype=math_dtype, ).to(device=device) - self.u = ds[0].operands[0].size // ds[0].operands[0].num_segments - self.descriptors = ds + self.u = d_max.operands[0].size // d_max.operands[0].num_segments + self.descriptors = ds_ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor @@ -324,6 +316,10 @@ def forward( x_2[j_{n+1}] = val x_0[i_0][j_0] \prod_{k=1}^{n} x_1[j_k] """ + torch._assert(x0.ndim == 2, f"Expected shape (num_x0, x0_size), got {x0.shape}") + torch._assert(x1.ndim == 2, f"Expected shape (batch, x1_size), got {x1.shape}") + torch._assert(i0.ndim == 1, f"Expected shape (batch,), got {i0.shape}") + i0 = i0.to(torch.int32) x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u) x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u) @@ -331,7 +327,7 @@ def forward( logger.debug( f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" ) - out = self.f(x1, x0, i0) + out: torch.Tensor = self.f(x1, x0, i0) out = out.reshape(out.shape[0], out.shape[1] * self.u) return out diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index b436161..91af5b0 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -101,11 +101,14 @@ def __init__( self.descriptor = descriptor if math_dtype is None: math_dtype = torch.get_default_dtype() - self.f = None + self.has_cuda = False self.num_operands = descriptor.num_operands - if use_fallback is None or use_fallback is False: + if use_fallback is False: + self.f = _tensor_product_cuda(descriptor, device, math_dtype) + self.has_cuda = True + elif use_fallback is None: try: self.f = _tensor_product_cuda(descriptor, device, math_dtype) self.has_cuda = True @@ -120,12 +123,7 @@ def __init__( "pip install cuequivariance-ops-torch-cu12" ) - if use_fallback is False and not self.has_cuda: - raise RuntimeError( - "`use_fallback` is `False` and no CUDA kernel is available!" - ) - - if self.f is None: + if not self.has_cuda: if optimize_fallback is None: optimize_fallback = False warnings.warn( diff --git a/cuequivariance_torch/tests/operations/rotation_test.py b/cuequivariance_torch/tests/operations/rotation_test.py index 86d0230..cc70899 100644 --- a/cuequivariance_torch/tests/operations/rotation_test.py +++ b/cuequivariance_torch/tests/operations/rotation_test.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch import cuequivariance as cue @@ -47,11 +48,15 @@ def test_vector_to_euler_angles(): assert torch.allclose(A, B) -def test_inversion(): +@pytest.mark.parametrize("use_fallback", [False, True]) +def test_inversion(use_fallback: bool): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + irreps = cue.Irreps("O3", "2x1e + 1o") torch.testing.assert_close( - cuet.Inversion(irreps, layout=cue.ir_mul)( - torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - ), - torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0]), + cuet.Inversion( + irreps, layout=cue.ir_mul, device=device, use_fallback=use_fallback + )(torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], device=device)), + torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0]], device=device), ) diff --git a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py index 955ee87..2b8db35 100644 --- a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py +++ b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py @@ -26,14 +26,18 @@ "dtype, tol", [(torch.float64, 1e-6), (torch.float32, 1e-4)], ) -@pytest.mark.parametrize("ell", [1, 2, 3]) -def test_spherical_harmonics(ell: int, dtype, tol): +@pytest.mark.parametrize("ell", [0, 1, 2, 3]) +@pytest.mark.parametrize("use_fallback", [False, True]) +def test_spherical_harmonics_equivariance(use_fallback: bool, ell: int, dtype, tol): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + vec = torch.randn(3, dtype=dtype, device=device) axis = np.random.randn(3) angle = np.random.rand() scale = 1.3 - yl = cuet.spherical_harmonics([ell], vec, False) + yl = cuet.spherical_harmonics([ell], vec, False, use_fallback=use_fallback) R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype).to(device) Rl = torch.from_numpy(cue.SO3(ell).rotation(axis, angle)).to(dtype).to(device) @@ -44,9 +48,19 @@ def test_spherical_harmonics(ell: int, dtype, tol): torch.testing.assert_close(yl1, yl2, rtol=tol, atol=tol) -def test_spherical_harmonics_full(): - vec = torch.randn(3, device=device) - ls = [0, 1, 2, 3] - yl = cuet.spherical_harmonics(ls, vec, False) +data_types = [torch.float32, torch.float64] + +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: + data_types += [torch.float16, torch.bfloat16] + + +@pytest.mark.parametrize("dtype", data_types) +@pytest.mark.parametrize("ls", [[0], [1], [2], [0, 1], [0, 1, 2]]) +@pytest.mark.parametrize("use_fallback", [False, True]) +def test_spherical_harmonics_full(dtype, ls: list[int], use_fallback: bool): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") - assert abs(yl[0] - 1.0) < 1e-6 + vec = torch.randn(3, device=device, dtype=dtype) + yl = cuet.spherical_harmonics(ls, vec, False, use_fallback=use_fallback) + assert yl.shape[-1] == sum(2 * ell + 1 for ell in ls)