Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sh l=0 #49

Merged
merged 4 commits into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -195,7 +195,13 @@ def __init__(
use_fallback=use_fallback,
)

if any(d.num_operands != e.num_inputs + 1 for d in e.ds):
if (
len(e.ds) > 1
or any(d.num_operands != e.num_inputs + 1 for d in e.ds)
or any(
d.num_operands == 2 for d in e.ds
) # special case for Spherical Harmonics ls = [1]
):
if e.num_inputs == 1:
self.tp = SymmetricTPDispatcher(
cuet.SymmetricTensorProduct(
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from typing import Optional

import torch
@@ -48,20 +49,6 @@ def __init__(

self.descriptors = descriptors

if any(d.num_operands < 2 for d in descriptors):
d0 = next(d for d in descriptors if d.num_operands == 1)
descriptors = [d for d in descriptors if d.num_operands >= 2]
assert len(descriptors) + 1 == len(self.descriptors)
self.f0 = cuet.TensorProduct(
d0,
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)
else:
self.f0 = None

descriptors = [
stp.SegmentedTensorProduct(
operands=[stp.Operand.empty_segments(1)] + d.operands,
@@ -72,13 +59,10 @@ def __init__(
)
for d in descriptors
]
try:
d = next(d for d in descriptors if d.num_operands >= 1)
except StopIteration:
raise ValueError("At least one STP must have at least 2 operands.")
d_max = max(descriptors, key=lambda d: d.num_operands)

self.x0_size = d.operands[0].size
self.x1_size = d.operands[1].size
self.x0_size = d_max.operands[0].size
self.x1_size = d_max.operands[1].size if d_max.num_operands >= 3 else 1

self.f = cuet.IWeightedSymmetricTensorProduct(
descriptors,
@@ -103,14 +87,14 @@ def forward(self, x0: torch.Tensor) -> torch.Tensor:
The output tensor resulting from the indexed symmetric tensor product operation.
It will have the shape (batch, x1_size).
"""
out = self.f(
torch._assert(
x0.ndim == 2, f"Expected 2 dims (batch, x0_size), got shape {x0.shape}"
)
return self.f(
torch.ones((1, 1), dtype=x0.dtype, device=x0.device),
torch.zeros((x0.shape[0],), dtype=torch.int32, device=x0.device),
x0,
)
if self.f0 is not None:
out += self.f0([])
return out


class IWeightedSymmetricTensorProduct(torch.nn.Module):
@@ -145,31 +129,26 @@ def __init__(
_check_descriptors(descriptors)
self.descriptors = descriptors

d = next(d for d in descriptors if d.num_operands >= 3)
d = max(descriptors, key=lambda d: d.num_operands)
self.x0_size = d.operands[0].size
self.x1_size = d.operands[1].size
self.x1_size = d.operands[1].size if d.num_operands >= 3 else 1
self.x2_size = d.operands[-1].size

self.has_cuda = False

self.f = None

if use_fallback is None or use_fallback is False:
if use_fallback is False:
self.f = CUDAKernel(descriptors, device, math_dtype)
self.has_cuda = True
elif use_fallback is None:
try:
self.f = CUDAKernel(descriptors, device, math_dtype)
self.has_cuda = True
return
except NotImplementedError as e:
logger.info(f"Failed to initialize CUDA implementation: {e}")
except ImportError as e:
logger.warning(f"Failed to initialize CUDA implementation: {e}")

if use_fallback is False and not self.has_cuda:
raise RuntimeError(
"`use_fallback` is `False` and no CUDA kernel is available!"
)

if self.f is None:
if not self.has_cuda:
self.f = FallbackImpl(
descriptors,
device,
@@ -228,42 +207,40 @@ def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]):
if len(descriptors) == 0:
raise ValueError("stps must contain at least one STP.")

try:
d = next(d for d in descriptors if d.num_operands >= 3)
except StopIteration:
raise ValueError("At least one STP must have at least 3 operands.")

x0 = d.operands[0]
x1 = d.operands[1]
x2 = d.operands[-1]
d_max = max(descriptors, key=lambda d: d.num_operands)
assert d_max.num_operands >= 2 # at least x0 and x2

for d in descriptors:
if d.operands[0].size != x0.size:
if d.operands[0].size != d_max.operands[0].size:
raise ValueError("All STPs must have the same first operand (x0).")

if any(ope.size != x1.size for ope in d.operands[1:-1]):
if any(ope.size != d_max.operands[1].size for ope in d.operands[1:-1]):
raise ValueError("All STPs must have the operands[1:-1] identical (x1).")

if d.operands[-1].size != x2.size:
if d.operands[-1].size != d_max.operands[-1].size:
raise ValueError("All STPs must have the same last operand (x2, output).")


class CUDAKernel(torch.nn.Module):
def __init__(
self,
stps: list[stp.SegmentedTensorProduct],
ds: list[stp.SegmentedTensorProduct],
device: Optional[torch.device],
math_dtype: torch.dtype,
):
super().__init__()

max_degree = max(d.num_operands - 2 for d in stps)
max_degree = max(d.num_operands - 2 for d in ds)

if max_degree > 6:
raise NotImplementedError("Correlation > 6 is not implemented.")
if min(d.num_operands for d in stps) == 2:
raise NotImplementedError(
"Only STPs with at least 3 operands are supported."
)

if len({d.operands[0].num_segments for d in ds}) != 1:
raise ValueError("All STPs must have the same number of segments in x0.")
if len({ope.num_segments for d in ds for ope in d.operands[1:-1]}) > 1:
raise ValueError("All STPs must have the same number of segments in x1.")
if len({d.operands[-1].num_segments for d in ds}) != 1:
raise ValueError("All STPs must have the same number of segments in x2.")

def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct:
d = d.move_operand(0, -2)
@@ -276,44 +253,59 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct:
]
)
d = d.consolidate_modes()
if d.subscripts.modes() == []:
d = d.append_modes_to_all_operands("u", dict(u=1))

# ops.SymmetricTensorContraction will "symmetrize" for the derivatives so we can sort for the forward pass
d = d.sort_indices_for_identical_operands(range(0, d.num_operands - 2))

if len(set(ope.subscripts for ope in d.operands)) != 1:
if len(d.subscripts.modes()) != 1:
raise NotImplementedError("Different modes are not supported.")

m = d.subscripts.modes()[0]

if not all(ope.subscripts == m for ope in d.operands):
raise NotImplementedError("Different subscripts are not supported.")
return d

ds = [f(d) for d in stps]

if (
len(
set(
(
d.operands[0].num_segments,
d.operands[-2].num_segments,
d.operands[-1].num_segments,
)
for d in ds
)
)
!= 1
):
raise ValueError("All STPs must have the same number of segments.")
d = d.split_mode(m, math.gcd(*d.get_dims(m)))

return d

ds_ = [f(d) for d in ds]
import cuequivariance_ops_torch as ops

d_max = max(ds_, key=lambda d: d.num_operands)

path_segment_indices = sum((d.indices.tolist() for d in ds_), [])
path_coefficients = sum((d.stacked_coefficients.tolist() for d in ds_), [])
num_in_segments = (
d_max.operands[0].num_segments if d_max.num_operands >= 3 else 1
)
num_couplings = d_max.operands[-2].num_segments
num_out_segments = d_max.operands[-1].num_segments
correlation = max(1, max_degree)
math_dtype = math_dtype
logger.debug(f"""cuequivariance_ops_torch.SymmetricTensorContraction(
path_segment_indices={path_segment_indices},
path_coefficients={path_coefficients},
num_in_segments={num_in_segments},
num_couplings={num_couplings},
num_out_segments={num_out_segments},
correlation={correlation},
math_dtype={math_dtype},
)""")

self.f = ops.SymmetricTensorContraction(
sum((d.indices.tolist() for d in ds), []),
sum((d.stacked_coefficients.tolist() for d in ds), []),
ds[0].operands[0].num_segments,
ds[0].operands[-2].num_segments,
ds[0].operands[-1].num_segments,
max_degree,
math_dtype,
path_segment_indices=path_segment_indices,
path_coefficients=path_coefficients,
num_in_segments=num_in_segments,
num_couplings=num_couplings,
num_out_segments=num_out_segments,
correlation=correlation,
math_dtype=math_dtype,
).to(device=device)
self.u = ds[0].operands[0].size // ds[0].operands[0].num_segments
self.descriptors = ds
self.u = d_max.operands[0].size // d_max.operands[0].num_segments
self.descriptors = ds_

def forward(
self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor
@@ -324,14 +316,18 @@ def forward(
x_2[j_{n+1}] = val x_0[i_0][j_0] \prod_{k=1}^{n} x_1[j_k]

"""
torch._assert(x0.ndim == 2, f"Expected shape (num_x0, x0_size), got {x0.shape}")
torch._assert(x1.ndim == 2, f"Expected shape (batch, x1_size), got {x1.shape}")
torch._assert(i0.ndim == 1, f"Expected shape (batch,), got {i0.shape}")

i0 = i0.to(torch.int32)
x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u)
x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u)
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
logger.debug(
f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
)
out = self.f(x1, x0, i0)
out: torch.Tensor = self.f(x1, x0, i0)
out = out.reshape(out.shape[0], out.shape[1] * self.u)
return out

Original file line number Diff line number Diff line change
@@ -101,11 +101,14 @@ def __init__(
self.descriptor = descriptor
if math_dtype is None:
math_dtype = torch.get_default_dtype()
self.f = None

self.has_cuda = False
self.num_operands = descriptor.num_operands

if use_fallback is None or use_fallback is False:
if use_fallback is False:
self.f = _tensor_product_cuda(descriptor, device, math_dtype)
self.has_cuda = True
elif use_fallback is None:
try:
self.f = _tensor_product_cuda(descriptor, device, math_dtype)
self.has_cuda = True
@@ -120,12 +123,7 @@ def __init__(
"pip install cuequivariance-ops-torch-cu12"
)

if use_fallback is False and not self.has_cuda:
raise RuntimeError(
"`use_fallback` is `False` and no CUDA kernel is available!"
)

if self.f is None:
if not self.has_cuda:
if optimize_fallback is None:
optimize_fallback = False
warnings.warn(
15 changes: 10 additions & 5 deletions cuequivariance_torch/tests/operations/rotation_test.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

import cuequivariance as cue
@@ -47,11 +48,15 @@ def test_vector_to_euler_angles():
assert torch.allclose(A, B)


def test_inversion():
@pytest.mark.parametrize("use_fallback", [False, True])
def test_inversion(use_fallback: bool):
if use_fallback is False and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

irreps = cue.Irreps("O3", "2x1e + 1o")
torch.testing.assert_close(
cuet.Inversion(irreps, layout=cue.ir_mul)(
torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
),
torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0]),
cuet.Inversion(
irreps, layout=cue.ir_mul, device=device, use_fallback=use_fallback
)(torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], device=device)),
torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0]], device=device),
)
30 changes: 22 additions & 8 deletions cuequivariance_torch/tests/operations/spherical_harmonics_test.py
Original file line number Diff line number Diff line change
@@ -26,14 +26,18 @@
"dtype, tol",
[(torch.float64, 1e-6), (torch.float32, 1e-4)],
)
@pytest.mark.parametrize("ell", [1, 2, 3])
def test_spherical_harmonics(ell: int, dtype, tol):
@pytest.mark.parametrize("ell", [0, 1, 2, 3])
@pytest.mark.parametrize("use_fallback", [False, True])
def test_spherical_harmonics_equivariance(use_fallback: bool, ell: int, dtype, tol):
if use_fallback is False and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

vec = torch.randn(3, dtype=dtype, device=device)
axis = np.random.randn(3)
angle = np.random.rand()
scale = 1.3

yl = cuet.spherical_harmonics([ell], vec, False)
yl = cuet.spherical_harmonics([ell], vec, False, use_fallback=use_fallback)

R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype).to(device)
Rl = torch.from_numpy(cue.SO3(ell).rotation(axis, angle)).to(dtype).to(device)
@@ -44,9 +48,19 @@ def test_spherical_harmonics(ell: int, dtype, tol):
torch.testing.assert_close(yl1, yl2, rtol=tol, atol=tol)


def test_spherical_harmonics_full():
vec = torch.randn(3, device=device)
ls = [0, 1, 2, 3]
yl = cuet.spherical_harmonics(ls, vec, False)
data_types = [torch.float32, torch.float64]

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
data_types += [torch.float16, torch.bfloat16]


@pytest.mark.parametrize("dtype", data_types)
@pytest.mark.parametrize("ls", [[0], [1], [2], [0, 1], [0, 1, 2]])
@pytest.mark.parametrize("use_fallback", [False, True])
def test_spherical_harmonics_full(dtype, ls: list[int], use_fallback: bool):
if use_fallback is False and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

assert abs(yl[0] - 1.0) < 1e-6
vec = torch.randn(3, device=device, dtype=dtype)
yl = cuet.spherical_harmonics(ls, vec, False, use_fallback=use_fallback)
assert yl.shape[-1] == sum(2 * ell + 1 for ell in ls)