Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexed Weights Linear #78

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cuequivariance_torch/cuequivariance_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
importlib.resources.files(__package__).joinpath("VERSION").read_text().strip()
)

from .primitives.tensor_product import TensorProduct, _Wrapper
from .primitives.tensor_product import TensorProduct, _Wrapper, _BatchLinear
from .primitives.symmetric_tensor_product import (
SymmetricTensorProduct,
IWeightedSymmetricTensorProduct,
_SymmetricTensorProduct,
_IWeightedSymmetricTensorProduct,
)
from .primitives.transpose import TransposeSegments, TransposeIrrepsLayout

Expand All @@ -42,8 +42,10 @@

__all__ = [
"TensorProduct",
"SymmetricTensorProduct",
"IWeightedSymmetricTensorProduct",
"_Wrapper",
"_BatchLinear",
"_SymmetricTensorProduct",
"_IWeightedSymmetricTensorProduct",
"TransposeSegments",
"TransposeIrrepsLayout",
"EquivariantTensorProduct",
Expand All @@ -57,5 +59,4 @@
"vector_to_euler_angles",
"SphericalHarmonics",
"layers",
"_Wrapper",
]
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
device=device,
math_dtype=math_dtype or dtype,
use_fallback=use_fallback,
index_first_input=True,
)

def extra_repr(self) -> str:
Expand All @@ -161,11 +162,7 @@ def extra_repr(self) -> str:
f", weight_shape={self.weight_shape}"
)

def forward(
self,
x: torch.Tensor,
indices: torch.Tensor,
) -> torch.Tensor:
def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
Perform the forward pass of the symmetric contraction operation.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class EquivariantTensorProduct(torch.nn.Module):

Args:
e (cuequivariance.EquivariantTensorProduct): Equivariant tensor product.
index_first_input (bool, optional): If `True`, the first input tensor is indexed by `indices`. Default is `False`.
layout (IrrepsLayout): layout for inputs and output.
layout_in (IrrepsLayout): layout for inputs.
layout_out (IrrepsLayout): layout for output.
Expand All @@ -136,6 +137,7 @@ class EquivariantTensorProduct(torch.nn.Module):

You can optionally index the first input tensor:

>>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device, index_first_input=True)
>>> w = torch.ones(3, e.inputs[0].dim, device=device)
>>> indices = torch.randint(3, (17,))
>>> tp(w, x1, x2, indices=indices)
Expand All @@ -146,6 +148,7 @@ def __init__(
self,
e: cue.EquivariantTensorProduct,
*,
index_first_input: bool = False,
layout: Optional[cue.IrrepsLayout] = None,
layout_in: Optional[
Union[cue.IrrepsLayout, tuple[Optional[cue.IrrepsLayout], ...]]
Expand Down Expand Up @@ -209,8 +212,12 @@ def __init__(
) # special case for Spherical Harmonics ls = [1]
):
if e.num_inputs == 1:
if index_first_input:
raise NotImplementedError(
"Indexing the first input is not supported for a single input"
)
self.tp = SymmetricTPDispatcher(
cuet.SymmetricTensorProduct(
cuet._SymmetricTensorProduct(
e.ds,
device=device,
math_dtype=math_dtype,
Expand All @@ -219,7 +226,7 @@ def __init__(
)
elif e.num_inputs == 2:
self.tp = IWeightedSymmetricTPDispatcher(
cuet.IWeightedSymmetricTensorProduct(
cuet._IWeightedSymmetricTensorProduct(
e.ds,
device=device,
math_dtype=math_dtype,
Expand All @@ -229,14 +236,35 @@ def __init__(
else:
raise NotImplementedError("This should not happen")
else:
tp = cuet.TensorProduct(
e.ds[0],
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
)
self.tp = TPDispatcher(tp, tp.descriptor)
assert len(e.ds) == 1

tp = None
if (
index_first_input
and e.d.subscripts.canonicalize() in ["uv,u,v", "uv,v,u"]
and use_fallback is not False
):
try:
tp = cuet._BatchLinear(
e.ds[0], device=device, math_dtype=math_dtype
)
except NotImplementedError:
pass

if tp is None:
tp = TPDispatcher(
cuet.TensorProduct(
e.ds[0],
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
),
e.ds[0],
)

self.tp = tp

self.index_first_input = index_first_input
self.operands_dims = [op.dim for op in e.operands]

def extra_repr(self) -> str:
Expand All @@ -253,6 +281,10 @@ def forward(
"""
If ``indices`` is not None, the first input is indexed by ``indices``.
"""
torch._assert(
indices is None or self.index_first_input,
"indices can only be used with index_first_input=True",
)

if x3 is not None and x2 is not None and x1 is not None:
inputs = [x0, x1, x2, x3]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logger = logging.getLogger(__name__)


class SymmetricTensorProduct(torch.nn.Module):
class _SymmetricTensorProduct(torch.nn.Module):
"""
PyTorch module

Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
self.x0_size = d_max.operands[0].size
self.x1_size = d_max.operands[1].size if d_max.num_operands >= 3 else 1

self.f = cuet.IWeightedSymmetricTensorProduct(
self.f = cuet._IWeightedSymmetricTensorProduct(
descriptors,
device=device,
math_dtype=math_dtype,
Expand Down Expand Up @@ -93,7 +93,7 @@ def forward(self, x0: torch.Tensor) -> torch.Tensor:
)


class IWeightedSymmetricTensorProduct(torch.nn.Module):
class _IWeightedSymmetricTensorProduct(torch.nn.Module):
"""
PyTorch module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,3 +689,83 @@ def _permutation_module(permutation: Tuple[int, ...]):
inputs = [graph.placeholder(f"input_{i}") for i in range(len(permutation))]
graph.output([inputs[i] for i in permutation])
return torch.fx.GraphModule(dict(), graph, class_name="perm")


class _BatchLinear(torch.nn.Module):
def __init__(
self,
descriptor: stp.SegmentedTensorProduct,
*,
device: Optional[torch.device],
math_dtype: torch.dtype,
):
super().__init__()
try:
import cuequivariance_ops_torch as ops
except ImportError:
raise NotImplementedError()

if not torch.cuda.is_available():
raise NotImplementedError()

if descriptor.num_operands != 3:
raise NotImplementedError()

self.descriptor = descriptor
self.x0_size = descriptor.operands[0].size
self.x1_size = descriptor.operands[1].size

descriptor = descriptor.canonicalize_subscripts()
if descriptor.subscripts == "uv,u,v":
descriptor = descriptor.permute_operands([1, 0, 2])
self._perm = _permutation_module([1, 0])
elif descriptor.subscripts == "u,vu,v":
raise NotImplementedError()
descriptor = descriptor.permute_operands([1, 0, 2])
self._perm = _permutation_module([0, 1])
elif descriptor.subscripts == "u,uv,v":
raise NotImplementedError()
self._perm = _permutation_module([0, 1])
elif descriptor.subscripts == "uv,v,u":
self._perm = _permutation_module([1, 0])
else:
raise NotImplementedError()

descriptor = descriptor.canonicalize_subscripts()

assert descriptor.subscripts in ["u,uv,v", "uv,v,u"]
assert descriptor.coefficient_subscripts == ""

self._f = ops.BatchLinear(
operand_segment_modes=[ope.subscripts for ope in descriptor.operands],
operand_segment_offsets=[
[s.start for s in ope.segment_slices()] for ope in descriptor.operands
],
operand_segment_shapes=[ope.segments for ope in descriptor.operands],
path_indices=[path.indices for path in descriptor.paths],
path_coefficients=[path.coefficients.item() for path in descriptor.paths],
math_dtype=math_dtype,
).to(device=device)

def forward(
self, inputs: List[torch.Tensor], indices: torch.Tensor
) -> torch.Tensor:
[x0, x1] = inputs
torch._assert(x0.shape[1] == self.x0_size, "input 0 has wrong size")
torch._assert(x1.shape[1] == self.x1_size, "input 1 has wrong size")

if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
logger.debug(
f"Calling BatchedLinear: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {indices.shape}"
)

torch._assert(x0.ndim == 2, "input should be dim=2")
torch._assert(x1.ndim == 2, "input should be dim=2")
torch._assert(indices.ndim == 1, "indices should be (batch,)")

x0, x1 = self._perm(x0, x1)
return self._f(x0, x1, indices)
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx(
):
use_fallback = not torch.cuda.is_available()

m = cuet.IWeightedSymmetricTensorProduct(
m = cuet._IWeightedSymmetricTensorProduct(
ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback
)

Expand All @@ -76,7 +76,7 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx(
x1_ = x1.clone().to(torch.float64)

out1 = m(x0, i0, x1)
m = cuet.IWeightedSymmetricTensorProduct(
m = cuet._IWeightedSymmetricTensorProduct(
ds, math_dtype=torch.float64, device=device, use_fallback=True
)
out2 = m(x0_, i0, x1_)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: b
ds = descriptors.symmetric_contraction(
cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3]
).ds
m = cuet.IWeightedSymmetricTensorProduct(
m = cuet._IWeightedSymmetricTensorProduct(
ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback
)
x0 = torch.randn((20, m.x0_size), dtype=dtype, device=device)
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_export(
if use_fallback is True and mode in ["trt"]:
pytest.skip(f"{mode} not supported for the fallback!")

m = cuet.IWeightedSymmetricTensorProduct(
m = cuet._IWeightedSymmetricTensorProduct(
ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback
)
x0 = torch.randn((2, m.x0_size), device=device, dtype=dtype, requires_grad=True)
Expand Down