Skip to content

Commit

Permalink
Update CHANGELOG.md and fix few documentation details (#57)
Browse files Browse the repository at this point in the history
* Update CHANGELOG.md

* Update CHANGELOG.md

* add SphericalHarmonics to changelog

* Improve documentation formatting for RepArray class

* Refactor GitHub Actions workflows to separate dependency installation and Sphinx build steps

* Remove python-gitlab from documentation requirements

* Update docstrings to use class reference format for EquivariantTensorProduct

* Update CHANGELOG to reflect breaking changes in TensorProduct and EquivariantTensorProduct input requirements

* Enhance input validation for EquivariantTensorProduct and TensorProduct classes to ensure correct tensor dimensions and types

* Refactor input checks in tensor product classes to include tracing condition
  • Loading branch information
mariogeiger authored Jan 9, 2025
1 parent 31b3c48 commit b5e44e2
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 40 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/doc_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: "3.12"
- name: Build sphinx
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install ./cuequivariance
python -m pip install ./cuequivariance_jax
python -m pip install ./cuequivariance_torch
pip install -r docs/requirements.txt
- name: Build sphinx
run: |
sphinx-build -b html docs docs/public
4 changes: 3 additions & 1 deletion .github/workflows/doc_build_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: "3.12"
- name: Build sphinx
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install ./cuequivariance
python -m pip install ./cuequivariance_jax
python -m pip install ./cuequivariance_torch
pip install -r docs/requirements.txt
- name: Build sphinx
run: |
sphinx-build -b html docs docs/public
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
Expand Down
23 changes: 12 additions & 11 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
## Latest Changes

### Added

- Partial support of `torch.jit.script` and `torch.compile`
- Added `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `IrrepsArray`).

### Changed
### Breaking Changes

- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input.
- `cuex.IrrepsArray` is now an alias for `cuex.RepArray` and its `.irreps` attribute and `.segments` are not functions anymore but properties.
- `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` now require inputs to be of shape `(batch_size, dim)` or `(1, dim)`. Inputs of dimension `(dim,)` are no more allowed.
- `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` are no more variadic functions. They now require a list of `torch.Tensor` as input.
- `cuex.IrrepsArray` is an alias for `cuex.RepArray`.
- `cuex.RepArray.irreps` and `cuex.RepArray.segments` are not functions anymore. They are now properties.
- `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`.
- The function `cuet.spherical_harmonics` is replaced by the Torch Module `cuet.SphericalHarmonics`. This was done to allow the use of `torch.jit.script` and `torch.compile`.

## Removed
### Added

- `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`.
- Support of `torch.jit.script` and `torch.compile`. Known issue: the export in c++ is not working.
- Add `cue.IrrepsAndLayout`: A simple class that inherits from `cue.Rep` and contains a `cue.Irreps` and a `cue.IrrepsLayout`.
- Add `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `cuex.IrrepsArray`).

### Fixed

- Add support for empty batch dimension in `cuequivariance-torch`.
- Add support for empty batch dimension in `cuet` (`cuequivariance_torch`).

## 0.1.0 (2024-11-18)

Expand Down
10 changes: 5 additions & 5 deletions cuequivariance/cuequivariance/descriptors/irreps_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def fully_connected_tensor_product(
irreps3 (Irreps): Irreps of the output.
Returns:
EquivariantTensorProduct: Descriptor of the fully connected tensor product.
:class:`cue.EquivariantTensorProduct <cuequivariance.EquivariantTensorProduct>`: Descriptor of the fully connected tensor product.
Examples:
>>> cue.descriptors.fully_connected_tensor_product(
Expand Down Expand Up @@ -102,7 +102,7 @@ def full_tensor_product(
irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider.
Returns:
EquivariantTensorProduct: Descriptor of the full tensor product.
:class:`cue.EquivariantTensorProduct <cuequivariance.EquivariantTensorProduct>`: Descriptor of the full tensor product.
"""
G = irreps1.irrep_class

Expand Down Expand Up @@ -163,7 +163,7 @@ def channelwise_tensor_product(
irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider.
Returns:
EquivariantTensorProduct: Descriptor of the channelwise tensor product.
:class:`cue.EquivariantTensorProduct <cuequivariance.EquivariantTensorProduct>`: Descriptor of the channelwise tensor product.
"""
G = irreps1.irrep_class

Expand Down Expand Up @@ -254,7 +254,7 @@ def elementwise_tensor_product(
irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider.
Returns:
EquivariantTensorProduct: Descriptor of the elementwise tensor product.
:class:`cue.EquivariantTensorProduct <cuequivariance.EquivariantTensorProduct>`: Descriptor of the elementwise tensor product.
"""
G = irreps1.irrep_class

Expand Down Expand Up @@ -306,7 +306,7 @@ def linear(
irreps_out (Irreps): Irreps of the output.
Returns:
EquivariantTensorProduct: Descriptor of the linear transformation.
:class:`cue.EquivariantTensorProduct <cuequivariance.EquivariantTensorProduct>`: Descriptor of the linear transformation.
"""
d = stp.SegmentedTensorProduct.from_subscripts("uv_iu_iv")
for mul, ir in irreps_in:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def spherical_harmonics(
layout (IrrepsLayout, optional): layout of the output. Defaults to ``cue.ir_mul``.
Returns:
EquivariantTensorProduct: The descriptor.
:class:`cue.EquivariantTensorProduct <cuequivariance.EquivariantTensorProduct>`: The descriptor.
Examples:
>>> spherical_harmonics(cue.SO3(1), [0, 1, 2])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def symmetric_contraction(
degree (int): The degree of the symmetric contraction.
Returns:
EquivariantTensorProduct:
:class:`cue.EquivariantTensorProduct <cuequivariance.EquivariantTensorProduct>`:
The descriptor of the symmetric contraction.
The operands are the weights, the input degree times and the output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def equivariant_tensor_product(
"""Compute the equivariant tensor product of the input arrays.
Args:
e (EquivariantTensorProduct): The equivariant tensor product descriptor.
e (:class:`cue.EquivariantTensorProduct <cuequivariance.EquivariantTensorProduct>`): The equivariant tensor product descriptor.
*inputs (RepArray or jax.Array): The input arrays.
dtype_output (jnp.dtype, optional): The data type for the output array. Defaults to None.
dtype_math (jnp.dtype, optional): The data type for computational operations. Defaults to None.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
@dataclass(frozen=True, init=False, repr=False)
class RepArray:
"""
A `jax.Array` decorated with a dict of `Rep` for the axes transforming under a group representation.
A :class:`jax.Array <jax.Array>` decorated with a dict of :class:`cue.Rep <cuequivariance.Rep>` for the axes transforming under a group representation.
Example:
You can create a `RepArray` by specifying the `Rep` for each axis:
You can create a :class:`cuex.RepArray <cuequivariance_jax.RepArray>` by specifying the :class:`cue.Rep <cuequivariance.Rep>` for each axis:
>>> cuex.RepArray({0: cue.SO3(1), 1: cue.SO3(1)}, jnp.eye(3))
{0: 1, 1: 1}
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
By default, arguments that are not `Rep` will be automatically converted into `IrrepsAndLayout`:
By default, arguments that are not :class:`cue.Rep <cuequivariance.Rep>` will be automatically converted into :class:`cue.IrrepsAndLayout <cuequivariance.IrrepsAndLayout>`:
>>> with cue.assume(cue.SO3, cue.ir_mul):
... x = cuex.RepArray({0: "1", 1: "2"}, jnp.ones((3, 5)))
Expand All @@ -40,7 +40,7 @@ class RepArray:
.. rubric:: IrrepsArray
An ``IrrepsArray`` is just a special case of a ``RepArray`` where the last axis is a `IrrepsAndLayout`:
An ``IrrepsArray`` is just a special case of a ``RepArray`` where the last axis is a :class:`cue.IrrepsAndLayout <cuequivariance.IrrepsAndLayout>`:
>>> x = cuex.RepArray(
... cue.Irreps("SO3", "2x0"), jnp.zeros((3, 2)), cue.ir_mul
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,28 @@ def forward(
"""
If ``indices`` is not None, the first input is indexed by ``indices``.
"""

# assert len(inputs) == len(self.etp.inputs)
for a, dim in zip(inputs, self.operands_dims):
torch._assert(
a.shape[-1] == dim,
f"Expected last dimension of input to be {dim}, got {a.shape[-1]}",
)
if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
if not isinstance(inputs, (list, tuple)):
raise ValueError(
"inputs should be a list of tensors followed by optional indices"
)
if len(inputs) != self.etp.num_inputs:
raise ValueError(
f"Expected {self.etp.num_inputs} inputs, got {len(inputs)}"
)
for oid, input in enumerate(inputs):
torch._assert(
input.ndim == 2,
f"input {oid} should have ndim=2",
)
torch._assert(
input.shape[1] == self.operands_dims[oid],
f"input {oid} should have shape (batch, {self.operands_dims[oid]}), got {input.shape}",
)

# Transpose inputs
inputs = self.transpose_in(inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,11 @@ def forward(
i0 = i0.to(torch.int32)
x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u)
x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u)
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
logger.debug(
f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(
if not self.has_cuda:
self.f = _tensor_product_fx(descriptor, device, math_dtype, True)

self.operands_dims = [ope.size for ope in descriptor.operands]

@torch.jit.ignore
def __repr__(self):
has_cuda_kernel = (
Expand All @@ -113,8 +115,28 @@ def forward(self, inputs: List[torch.Tensor]):
The output tensor resulting from the tensor product.
It has a shape of (batch, last_operand_size), where
`last_operand_size` is the size of the last operand in the descriptor.
"""
if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
if not isinstance(inputs, (list, tuple)):
raise ValueError("inputs should be a list of tensors")
if len(inputs) != self.num_operands - 1:
raise ValueError(
f"Expected {self.num_operands - 1} input tensors, got {len(inputs)}"
)
for oid, input in enumerate(inputs):
torch._assert(
input.ndim == 2,
f"input {oid} should have ndim=2",
)
torch._assert(
input.shape[1] == self.operands_dims[oid],
f"input {oid} should have shape (batch, {self.operands_dims[oid]}), got {input.shape}",
)

return self.f(inputs)


Expand Down Expand Up @@ -356,7 +378,11 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu
self.descriptor = descriptor

def forward(self, args: List[torch.Tensor]):
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
for oid, arg in enumerate(args):
torch._assert(
arg.ndim == 2,
Expand Down Expand Up @@ -488,7 +514,11 @@ def __repr__(self) -> str:
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
x0, x1 = self._perm(inputs[0], inputs[1])

if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
logger.debug(
f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}"
)
Expand Down Expand Up @@ -549,7 +579,11 @@ def __repr__(self) -> str:
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
x0, x1, x2 = self._perm(inputs[0], inputs[1], inputs[2])

if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
logger.debug(
f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}"
)
Expand Down Expand Up @@ -608,7 +642,11 @@ def __repr__(self):
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
x0, x1 = inputs

if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
logger.debug(
f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}"
)
Expand All @@ -629,7 +667,11 @@ def __repr__(self):
def forward(self, inputs: List[torch.Tensor]):
x0, x1, x2 = inputs

if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
if (
not torch.jit.is_scripting()
and not torch.jit.is_tracing()
and not torch.compiler.is_compiling()
):
logger.debug(
f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}"
)
Expand Down
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@ myst-parser
ipykernel
matplotlib
jupyter-sphinx
python-gitlab
e3nn
flax

0 comments on commit b5e44e2

Please sign in to comment.