Skip to content

Commit

Permalink
improve tuto on stp and printing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jan 24, 2025
1 parent 25e588e commit 6805733
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
# limitations under the License.


def format_set(s: set[int]) -> str:
if len(s) == 0:
return ""
if len(s) == 1:
return str(next(iter(s)))
return "{" + ", ".join(str(i) for i in sorted(s)) + "}"


def format_dimensions_dict(dims: dict[str, set[int]]) -> str:
return " ".join(
f"{m}={next(iter(dd))}" if len(dd) == 1 else f"{m}={dd}"
for m, dd in sorted(dims.items())
)
return " ".join(f"{m}={format_set(dd)}" for m, dd in sorted(dims.items()))
Original file line number Diff line number Diff line change
Expand Up @@ -266,47 +266,62 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str:
>>> d = d.flatten_coefficient_modes()
>>> print(d.to_text())
uvw,u,v,w sizes=320,16,16,16 num_segments=5,4,4,4 num_paths=16 u=4 v=4 w=4
operand #0 subscripts=uvw
| u: [4] * 5
| v: [4] * 5
| w: [4] * 5
operand #1 subscripts=u
| u: [4] * 4
operand #2 subscripts=v
| v: [4] * 4
operand #3 subscripts=w
| w: [4] * 4
Flop cost: 0->1344 1->2368 2->2368 3->2368
Memory cost from 368 to 1216
Memory cost: 368
Path indices: 0 0 0 0, 1 0 1 1, 1 0 2 2, 1 0 3 3, 2 1 0 1, 2 2 0 2, ...
Path coefficients: [0.17...]
"""
out = f"{self}"
dims = self.get_dimensions_dict()
for oid, operand in enumerate(self.operands):
out += f"\noperand #{oid} subscripts={operand.subscripts}"
for i, ch in enumerate(operand.subscripts):
if len(dims[ch]) == 1:
continue

out += (
f"\noperand #{oid} subscripts={operand.subscripts} {ch}: ["
+ ", ".join(str(s[i]) for s in operand.segments)
+ "]"
)
out += f"\n | {ch}: [{operand.segments[0][i]}] * {len(operand.segments)}"
else:
out += (
f"\n | {ch}: ["
+ ", ".join(str(s[i]) for s in operand.segments)
+ "]"
)

out += f"\nFlop cost: {' '.join(f'{oid}->{self.flop_cost(oid)}' for oid in range(self.num_operands))}"
out += f"\nMemory cost from {self.memory_cost('global')} to {self.memory_cost('sequential')}"
out += f"\nMemory cost: {self.memory_cost('global')}"

out += "\nPath indices: " + ", ".join(
" ".join(str(i) for i in row) for row in self.indices
)
if len(self.paths) > 0:
out += "\nPath indices: " + ", ".join(
" ".join(str(i) for i in row) for row in self.indices
)

if coefficient_formatter is not None:
formatter = {"float": coefficient_formatter}
if all(len(dims[ch]) == 1 for ch in self.coefficient_subscripts):
out += "\nPath coefficients: " + np.array2string(
self.stacked_coefficients,
max_line_width=np.inf,
formatter=formatter,
threshold=np.inf,
)
else:
out += "\nPath coefficients:\n" + "\n".join(
np.array2string(
path.coefficients, formatter=formatter, threshold=np.inf
if coefficient_formatter is not None:
formatter = {"float": coefficient_formatter}
if all(len(dims[ch]) == 1 for ch in self.coefficient_subscripts):
out += "\nPath coefficients: " + np.array2string(
self.stacked_coefficients,
max_line_width=np.inf,
formatter=formatter,
threshold=np.inf,
)
for path in self.paths
)
else:
out += "\nPath coefficients:\n" + "\n".join(
np.array2string(
path.coefficients, formatter=formatter, threshold=np.inf
)
for path in self.paths
)
else:
out += "\nNo paths."

return out

def to_dict(self, extended: bool = False) -> dict[str, Any]:
Expand Down Expand Up @@ -372,10 +387,10 @@ def to_base64(self, extended: bool = False) -> str:

def get_dimensions_dict(self) -> dict[str, set[int]]:
"""Get the dimensions of the tensor product."""
dims: dict[str, set[int]] = dict()
dims: dict[str, set[int]] = {ch: set() for ch in self.subscripts.modes()}
for operand in self.operands:
for m, dd in operand.get_dimensions_dict().items():
dims.setdefault(m, set()).update(dd)
dims[m].update(dd)
# Note: no need to go through the coefficients since must be contracted with the operands
return dims

Expand Down Expand Up @@ -687,13 +702,13 @@ def add_path(
>>> d.add_path(0, None, None, c=np.ones((10, 10)))
1
>>> d
uv,ui,vj+ij sizes=4,26,30 num_segments=1,2,2 num_paths=2 i={10, 3} j={10, 5} u=2 v=2
uv,ui,vj+ij sizes=4,26,30 num_segments=1,2,2 num_paths=2 i={3, 10} j={5, 10} u=2 v=2
When the dimensions of the modes cannot be inferred, we can provide them:
>>> d.add_path(None, None, None, c=np.ones((2, 2)), dims={"u": 2, "v": 2})
2
>>> d
uv,ui,vj+ij sizes=8,30,34 num_segments=2,3,3 num_paths=3 i={2, 10, 3} j={10, 2, 5} u=2 v=2
uv,ui,vj+ij sizes=8,30,34 num_segments=2,3,3 num_paths=3 i={2, 3, 10} j={2, 5, 10} u=2 v=2
"""
return self.insert_path(len(self.paths), *segments, c=c, dims=dims)

Expand Down Expand Up @@ -776,7 +791,7 @@ def canonicalize_subscripts(self) -> SegmentedTensorProduct:
Examples:
>>> d = cue.SegmentedTensorProduct.from_subscripts("ab,ax,by+xy")
>>> d.canonicalize_subscripts()
uv,ui,vj+ij sizes=0,0,0 num_segments=0,0,0 num_paths=0
uv,ui,vj+ij sizes=0,0,0 num_segments=0,0,0 num_paths=0 i= j= u= v=
This is useful to identify equivalent descriptors.
"""
Expand Down
16 changes: 11 additions & 5 deletions cuequivariance/tests/segmented_tensor_product/descriptor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

def test_user_friendly():
d = stp.SegmentedTensorProduct.from_subscripts("ia_jb_kab+ijk")
assert str(d) == "ia,jb,kab+ijk sizes=0,0,0 num_segments=0,0,0 num_paths=0"
assert (
str(d)
== "ia,jb,kab+ijk sizes=0,0,0 num_segments=0,0,0 num_paths=0 a= b= i= j= k="
)

with pytest.raises(ValueError):
d.add_path(0, 0, 0, c=np.ones((2, 2, 3))) # need to add segments first
Expand Down Expand Up @@ -233,11 +236,14 @@ def test_to_text():
assert (
text
== """u,u,u sizes=50,64,50 num_segments=4,5,4 num_paths=29 u={12, 14}
operand #0 subscripts=u u: [12, 12, 12, 14]
operand #1 subscripts=u u: [12, 12, 12, 14, 14]
operand #2 subscripts=u u: [12, 12, 12, 14]
operand #0 subscripts=u
| u: [12, 12, 12, 14]
operand #1 subscripts=u
| u: [12, 12, 12, 14, 14]
operand #2 subscripts=u
| u: [12, 12, 12, 14]
Flop cost: 0->704 1->704 2->704
Memory cost from 164 to 1056
Memory cost: 164
Path indices: 0 0 0, 0 0 1, 0 0 2, 0 1 0, 0 1 1, 0 1 2, 0 2 0, 0 2 1, 0 2 2, 1 0 0, 1 0 1, 1 0 2, 1 1 0, 1 1 1, 1 1 2, 1 2 0, 1 2 1, 1 2 2, 2 0 0, 2 0 1, 2 0 2, 2 1 0, 2 1 1, 2 1 2, 2 2 0, 2 2 1, 2 2 2, 3 3 3, 3 4 3
Path coefficients: [1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0]"""
)
Expand Down
92 changes: 75 additions & 17 deletions docs/tutorials/stp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,100 @@
Segmented Tensor Product
========================

In this example, we will show how to create a custom tensor product descriptor and execute it.
First we need to import the necessary modules.
In this example, we are showing how to create a custom tensor product descriptor and execute it.
First, we need to import the necessary modules.

.. jupyter-execute::

import itertools
import numpy as np
import torch
import jax
import jax.numpy as jnp

import cuequivariance as cue
import cuequivariance.segmented_tensor_product as stp
import cuequivariance_torch as cuet # to execute the tensor product with PyTorch
import cuequivariance_jax as cuex # to execute the tensor product with JAX
import cuequivariance_torch as cuet # to execute the tensor product with PyTorch
import cuequivariance_jax as cuex # to execute the tensor product with JAX

.. currentmodule:: cuequivariance
Basic Tools
-----------

Now, we will create a custom tensor product descriptor that represents the tensor product of the two representations. See :ref:`tuto_irreps` for more information on irreps.
Creating a tensor product descriptor using the :class:`cue.SegmentedTensorProduct <cuequivariance.SegmentedTensorProduct>` class.

.. jupyter-execute::

d = cue.SegmentedTensorProduct.from_subscripts("a,ia,ja+ij")
print(d.to_text())

This descriptor has 3 operands.

.. jupyter-execute::

d.num_operands

Its coefficients have indices "ij".

.. jupyter-execute::

d.coefficient_subscripts

Adding segments to the two operands.

.. jupyter-execute::

d.add_segment(0, (200,))
d.add_segments(1, [(3, 100), (5, 200)])
d.add_segments(2, [(1, 200), (1, 100)])
print(d.to_text())

Observing that "j" is always set to 1, squeezing it.

.. jupyter-execute::

d = d.squeeze_modes("j")
print(d.to_text())

Adding paths between the segments.

.. jupyter-execute::

d.add_path(0, 1, 0, c=np.array([1.0, 2.0, 0.0, 0.0, 0.0]))
print(d.to_text())

Flattening the index "i" of the coefficients.

.. jupyter-execute::

d = d.flatten_modes("i")
# d = d.flatten_coefficient_modes()
print(d.to_text())

Equivalently, :meth:`flatten_coefficient_modes <cuequivariance.SegmentedTensorProduct.flatten_coefficient_modes>` can be used.



Equivariant Linear Layer
------------------------

Now, we are creating a custom tensor product descriptor that represents the tensor product of the two representations. See :ref:`tuto_irreps` for more information on irreps.

.. jupyter-execute::

irreps1 = cue.Irreps("O3", "32x0e + 32x1o")
irreps2 = cue.Irreps("O3", "16x0e + 48x1o")

The tensor product descriptor is created step by step. First, we create an empty descriptor given its subscripts.
The tensor product descriptor is created step by step. First, we are creating an empty descriptor given its subscripts.
In the case of the linear layer, we have 3 operands: the weight, the input, and the output.
The subscripts of this tensor product are "uv,iu,iv" where "uv" represents the modes of the weight, "iu" represents the modes of the input, and "iv" represents the modes of the output.

.. jupyter-execute::

d = stp.SegmentedTensorProduct.from_subscripts("uv,iu,iv")
d = cue.SegmentedTensorProduct.from_subscripts("uv,iu,iv")
d

Each operand of the tensor product descriptor has a list of segments.
We can add segments to the descriptor using the :meth:`add_segment <cuequivariance.SegmentedTensorProduct.add_segment>` method.
We can add the segments of the input and output representations to the descriptor.
We are adding the segments of the input and output representations to the descriptor.

.. jupyter-execute::

Expand All @@ -62,7 +120,7 @@ We can add the segments of the input and output representations to the descripto

d

Now we can enumerate all the possible pairs of irreps and add weight segements and paths between them when the irreps are the same.
Enumerating all the possible pairs of irreps and adding weight segements and paths between them when the irreps are the same.

.. jupyter-execute::

Expand All @@ -74,23 +132,23 @@ Now we can enumerate all the possible pairs of irreps and add weight segements a

d

We can see the two paths we added:
Printing the paths.

.. jupyter-execute::

d.paths


Finally, we can normalize the paths for the last operand such that the output is normalized to variance 1.
Normalizing the paths for the last operand such that the output is normalized to variance 1.

.. jupyter-execute::

d = d.normalize_paths_for_operand(-1)
d.paths

As we can see, the paths coefficients has been normalized.
As we can see, the paths coefficients have been normalized.

Now we can create a tensor product from the descriptor and execute it. In PyTorch, we can use the :class:`cuet.TensorProduct <cuequivariance_torch.TensorProduct>` class.
Now we are creating a tensor product from the descriptor and executing it. In PyTorch, we can use the :class:`cuet.TensorProduct <cuequivariance_torch.TensorProduct>` class.

.. jupyter-execute::

Expand All @@ -105,7 +163,7 @@ In JAX, we can use the :func:`cuex.tensor_product <cuequivariance_jax.tensor_pro
linear_jax = cuex.tensor_product(d)
linear_jax

Now we can execute the linear layer with random input and weight tensors.
Now we are executing the linear layer with random input and weight tensors.

.. jupyter-execute::

Expand All @@ -116,7 +174,7 @@ Now we can execute the linear layer with random input and weight tensors.

assert x2.shape == (3000, irreps2.dim)

Now we can verify that the output is well normalized.
Now we are verifying that the output is well normalized.

.. jupyter-execute::

Expand Down

0 comments on commit 6805733

Please sign in to comment.