Skip to content

Commit

Permalink
fix sh l=0 (#49)
Browse files Browse the repository at this point in the history
* fix sh l=1

* fix

* skip

* skip
  • Loading branch information
mariogeiger authored Dec 19, 2024
1 parent f8b2971 commit 88e6596
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ def __init__(
use_fallback=use_fallback,
)

if any(d.num_operands != e.num_inputs + 1 for d in e.ds):
if (
len(e.ds) > 1
or any(d.num_operands != e.num_inputs + 1 for d in e.ds)
or any(
d.num_operands == 2 for d in e.ds
) # special case for Spherical Harmonics ls = [1]
):
if e.num_inputs == 1:
self.tp = SymmetricTPDispatcher(
cuet.SymmetricTensorProduct(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from typing import Optional

import torch
Expand Down Expand Up @@ -48,20 +49,6 @@ def __init__(

self.descriptors = descriptors

if any(d.num_operands < 2 for d in descriptors):
d0 = next(d for d in descriptors if d.num_operands == 1)
descriptors = [d for d in descriptors if d.num_operands >= 2]
assert len(descriptors) + 1 == len(self.descriptors)
self.f0 = cuet.TensorProduct(
d0,
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)
else:
self.f0 = None

descriptors = [
stp.SegmentedTensorProduct(
operands=[stp.Operand.empty_segments(1)] + d.operands,
Expand All @@ -72,13 +59,10 @@ def __init__(
)
for d in descriptors
]
try:
d = next(d for d in descriptors if d.num_operands >= 1)
except StopIteration:
raise ValueError("At least one STP must have at least 2 operands.")
d_max = max(descriptors, key=lambda d: d.num_operands)

self.x0_size = d.operands[0].size
self.x1_size = d.operands[1].size
self.x0_size = d_max.operands[0].size
self.x1_size = d_max.operands[1].size if d_max.num_operands >= 3 else 1

self.f = cuet.IWeightedSymmetricTensorProduct(
descriptors,
Expand All @@ -103,14 +87,14 @@ def forward(self, x0: torch.Tensor) -> torch.Tensor:
The output tensor resulting from the indexed symmetric tensor product operation.
It will have the shape (batch, x1_size).
"""
out = self.f(
torch._assert(
x0.ndim == 2, f"Expected 2 dims (batch, x0_size), got shape {x0.shape}"
)
return self.f(
torch.ones((1, 1), dtype=x0.dtype, device=x0.device),
torch.zeros((x0.shape[0],), dtype=torch.int32, device=x0.device),
x0,
)
if self.f0 is not None:
out += self.f0([])
return out


class IWeightedSymmetricTensorProduct(torch.nn.Module):
Expand Down Expand Up @@ -145,31 +129,26 @@ def __init__(
_check_descriptors(descriptors)
self.descriptors = descriptors

d = next(d for d in descriptors if d.num_operands >= 3)
d = max(descriptors, key=lambda d: d.num_operands)
self.x0_size = d.operands[0].size
self.x1_size = d.operands[1].size
self.x1_size = d.operands[1].size if d.num_operands >= 3 else 1
self.x2_size = d.operands[-1].size

self.has_cuda = False

self.f = None

if use_fallback is None or use_fallback is False:
if use_fallback is False:
self.f = CUDAKernel(descriptors, device, math_dtype)
self.has_cuda = True
elif use_fallback is None:
try:
self.f = CUDAKernel(descriptors, device, math_dtype)
self.has_cuda = True
return
except NotImplementedError as e:
logger.info(f"Failed to initialize CUDA implementation: {e}")
except ImportError as e:
logger.warning(f"Failed to initialize CUDA implementation: {e}")

if use_fallback is False and not self.has_cuda:
raise RuntimeError(
"`use_fallback` is `False` and no CUDA kernel is available!"
)

if self.f is None:
if not self.has_cuda:
self.f = FallbackImpl(
descriptors,
device,
Expand Down Expand Up @@ -228,42 +207,40 @@ def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]):
if len(descriptors) == 0:
raise ValueError("stps must contain at least one STP.")

try:
d = next(d for d in descriptors if d.num_operands >= 3)
except StopIteration:
raise ValueError("At least one STP must have at least 3 operands.")

x0 = d.operands[0]
x1 = d.operands[1]
x2 = d.operands[-1]
d_max = max(descriptors, key=lambda d: d.num_operands)
assert d_max.num_operands >= 2 # at least x0 and x2

for d in descriptors:
if d.operands[0].size != x0.size:
if d.operands[0].size != d_max.operands[0].size:
raise ValueError("All STPs must have the same first operand (x0).")

if any(ope.size != x1.size for ope in d.operands[1:-1]):
if any(ope.size != d_max.operands[1].size for ope in d.operands[1:-1]):
raise ValueError("All STPs must have the operands[1:-1] identical (x1).")

if d.operands[-1].size != x2.size:
if d.operands[-1].size != d_max.operands[-1].size:
raise ValueError("All STPs must have the same last operand (x2, output).")


class CUDAKernel(torch.nn.Module):
def __init__(
self,
stps: list[stp.SegmentedTensorProduct],
ds: list[stp.SegmentedTensorProduct],
device: Optional[torch.device],
math_dtype: torch.dtype,
):
super().__init__()

max_degree = max(d.num_operands - 2 for d in stps)
max_degree = max(d.num_operands - 2 for d in ds)

if max_degree > 6:
raise NotImplementedError("Correlation > 6 is not implemented.")
if min(d.num_operands for d in stps) == 2:
raise NotImplementedError(
"Only STPs with at least 3 operands are supported."
)

if len({d.operands[0].num_segments for d in ds}) != 1:
raise ValueError("All STPs must have the same number of segments in x0.")
if len({ope.num_segments for d in ds for ope in d.operands[1:-1]}) > 1:
raise ValueError("All STPs must have the same number of segments in x1.")
if len({d.operands[-1].num_segments for d in ds}) != 1:
raise ValueError("All STPs must have the same number of segments in x2.")

def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct:
d = d.move_operand(0, -2)
Expand All @@ -276,44 +253,59 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct:
]
)
d = d.consolidate_modes()
if d.subscripts.modes() == []:
d = d.append_modes_to_all_operands("u", dict(u=1))

# ops.SymmetricTensorContraction will "symmetrize" for the derivatives so we can sort for the forward pass
d = d.sort_indices_for_identical_operands(range(0, d.num_operands - 2))

if len(set(ope.subscripts for ope in d.operands)) != 1:
if len(d.subscripts.modes()) != 1:
raise NotImplementedError("Different modes are not supported.")

m = d.subscripts.modes()[0]

if not all(ope.subscripts == m for ope in d.operands):
raise NotImplementedError("Different subscripts are not supported.")
return d

ds = [f(d) for d in stps]

if (
len(
set(
(
d.operands[0].num_segments,
d.operands[-2].num_segments,
d.operands[-1].num_segments,
)
for d in ds
)
)
!= 1
):
raise ValueError("All STPs must have the same number of segments.")
d = d.split_mode(m, math.gcd(*d.get_dims(m)))

return d

ds_ = [f(d) for d in ds]
import cuequivariance_ops_torch as ops

d_max = max(ds_, key=lambda d: d.num_operands)

path_segment_indices = sum((d.indices.tolist() for d in ds_), [])
path_coefficients = sum((d.stacked_coefficients.tolist() for d in ds_), [])
num_in_segments = (
d_max.operands[0].num_segments if d_max.num_operands >= 3 else 1
)
num_couplings = d_max.operands[-2].num_segments
num_out_segments = d_max.operands[-1].num_segments
correlation = max(1, max_degree)
math_dtype = math_dtype
logger.debug(f"""cuequivariance_ops_torch.SymmetricTensorContraction(
path_segment_indices={path_segment_indices},
path_coefficients={path_coefficients},
num_in_segments={num_in_segments},
num_couplings={num_couplings},
num_out_segments={num_out_segments},
correlation={correlation},
math_dtype={math_dtype},
)""")

self.f = ops.SymmetricTensorContraction(
sum((d.indices.tolist() for d in ds), []),
sum((d.stacked_coefficients.tolist() for d in ds), []),
ds[0].operands[0].num_segments,
ds[0].operands[-2].num_segments,
ds[0].operands[-1].num_segments,
max_degree,
math_dtype,
path_segment_indices=path_segment_indices,
path_coefficients=path_coefficients,
num_in_segments=num_in_segments,
num_couplings=num_couplings,
num_out_segments=num_out_segments,
correlation=correlation,
math_dtype=math_dtype,
).to(device=device)
self.u = ds[0].operands[0].size // ds[0].operands[0].num_segments
self.descriptors = ds
self.u = d_max.operands[0].size // d_max.operands[0].num_segments
self.descriptors = ds_

def forward(
self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor
Expand All @@ -324,14 +316,18 @@ def forward(
x_2[j_{n+1}] = val x_0[i_0][j_0] \prod_{k=1}^{n} x_1[j_k]
"""
torch._assert(x0.ndim == 2, f"Expected shape (num_x0, x0_size), got {x0.shape}")
torch._assert(x1.ndim == 2, f"Expected shape (batch, x1_size), got {x1.shape}")
torch._assert(i0.ndim == 1, f"Expected shape (batch,), got {i0.shape}")

i0 = i0.to(torch.int32)
x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u)
x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u)
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
logger.debug(
f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
)
out = self.f(x1, x0, i0)
out: torch.Tensor = self.f(x1, x0, i0)
out = out.reshape(out.shape[0], out.shape[1] * self.u)
return out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,14 @@ def __init__(
self.descriptor = descriptor
if math_dtype is None:
math_dtype = torch.get_default_dtype()
self.f = None

self.has_cuda = False
self.num_operands = descriptor.num_operands

if use_fallback is None or use_fallback is False:
if use_fallback is False:
self.f = _tensor_product_cuda(descriptor, device, math_dtype)
self.has_cuda = True
elif use_fallback is None:
try:
self.f = _tensor_product_cuda(descriptor, device, math_dtype)
self.has_cuda = True
Expand All @@ -120,12 +123,7 @@ def __init__(
"pip install cuequivariance-ops-torch-cu12"
)

if use_fallback is False and not self.has_cuda:
raise RuntimeError(
"`use_fallback` is `False` and no CUDA kernel is available!"
)

if self.f is None:
if not self.has_cuda:
if optimize_fallback is None:
optimize_fallback = False
warnings.warn(
Expand Down
15 changes: 10 additions & 5 deletions cuequivariance_torch/tests/operations/rotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

import cuequivariance as cue
Expand Down Expand Up @@ -47,11 +48,15 @@ def test_vector_to_euler_angles():
assert torch.allclose(A, B)


def test_inversion():
@pytest.mark.parametrize("use_fallback", [False, True])
def test_inversion(use_fallback: bool):
if use_fallback is False and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

irreps = cue.Irreps("O3", "2x1e + 1o")
torch.testing.assert_close(
cuet.Inversion(irreps, layout=cue.ir_mul)(
torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
),
torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0]),
cuet.Inversion(
irreps, layout=cue.ir_mul, device=device, use_fallback=use_fallback
)(torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], device=device)),
torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0]], device=device),
)
30 changes: 22 additions & 8 deletions cuequivariance_torch/tests/operations/spherical_harmonics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@
"dtype, tol",
[(torch.float64, 1e-6), (torch.float32, 1e-4)],
)
@pytest.mark.parametrize("ell", [1, 2, 3])
def test_spherical_harmonics(ell: int, dtype, tol):
@pytest.mark.parametrize("ell", [0, 1, 2, 3])
@pytest.mark.parametrize("use_fallback", [False, True])
def test_spherical_harmonics_equivariance(use_fallback: bool, ell: int, dtype, tol):
if use_fallback is False and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

vec = torch.randn(3, dtype=dtype, device=device)
axis = np.random.randn(3)
angle = np.random.rand()
scale = 1.3

yl = cuet.spherical_harmonics([ell], vec, False)
yl = cuet.spherical_harmonics([ell], vec, False, use_fallback=use_fallback)

R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype).to(device)
Rl = torch.from_numpy(cue.SO3(ell).rotation(axis, angle)).to(dtype).to(device)
Expand All @@ -44,9 +48,19 @@ def test_spherical_harmonics(ell: int, dtype, tol):
torch.testing.assert_close(yl1, yl2, rtol=tol, atol=tol)


def test_spherical_harmonics_full():
vec = torch.randn(3, device=device)
ls = [0, 1, 2, 3]
yl = cuet.spherical_harmonics(ls, vec, False)
data_types = [torch.float32, torch.float64]

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
data_types += [torch.float16, torch.bfloat16]


@pytest.mark.parametrize("dtype", data_types)
@pytest.mark.parametrize("ls", [[0], [1], [2], [0, 1], [0, 1, 2]])
@pytest.mark.parametrize("use_fallback", [False, True])
def test_spherical_harmonics_full(dtype, ls: list[int], use_fallback: bool):
if use_fallback is False and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

assert abs(yl[0] - 1.0) < 1e-6
vec = torch.randn(3, device=device, dtype=dtype)
yl = cuet.spherical_harmonics(ls, vec, False, use_fallback=use_fallback)
assert yl.shape[-1] == sum(2 * ell + 1 for ell in ls)

0 comments on commit 88e6596

Please sign in to comment.