From fb157e224baa59596d5eac0304632e395d2b7c61 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 10:40:51 +0100 Subject: [PATCH 01/13] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d1d80a6..f8e517bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ - `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input. - `cuex.IrrepsArray` is now an alias for `cuex.RepArray` and its `.irreps` attribute and `.segments` are not functions anymore but properties. -## Removed +### Removed - `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`. From 1c7ee41a72130c6342207c989dac0e8e2725f9e1 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 10:51:07 +0100 Subject: [PATCH 02/13] Update CHANGELOG.md --- CHANGELOG.md | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8e517bb..c6ba8646 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,22 +1,21 @@ ## Latest Changes -### Added - -- Partial support of `torch.jit.script` and `torch.compile` -- Added `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `IrrepsArray`). - -### Changed +### Breaking Changes -- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input. -- `cuex.IrrepsArray` is now an alias for `cuex.RepArray` and its `.irreps` attribute and `.segments` are not functions anymore but properties. +- `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` are no more variadic functions. They now require a list of `torch.Tensor` as input. +- `cuex.IrrepsArray` is an alias for `cuex.RepArray`. +- `cuex.RepArray.irreps` and `cuex.RepArray.segments` are not functions anymore. They are now properties. +- `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`. -### Removed +### Added -- `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`. +- Support of `torch.jit.script` and `torch.compile`. Known issue: the export in c++ is not working. +- Add `cue.IrrepsAndLayout`: A simple class that inherits from `cue.Rep` and contains a `cue.Irreps` and a `cue.IrrepsLayout`. +- Add `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `cuex.IrrepsArray`). ### Fixed -- Add support for empty batch dimension in `cuequivariance-torch`. +- Add support for empty batch dimension in `cuet` (`cuequivariance_torch`). ## 0.1.0 (2024-11-18) From 4d3fe52d3dcf35eabfdbabd5fdd5857ff3bafe21 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 10:54:55 +0100 Subject: [PATCH 03/13] add SphericalHarmonics to changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c6ba8646..9384e08d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - `cuex.IrrepsArray` is an alias for `cuex.RepArray`. - `cuex.RepArray.irreps` and `cuex.RepArray.segments` are not functions anymore. They are now properties. - `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`. +- The function `cuet.spherical_harmonics` is replaced by the Torch Module `cuet.SphericalHarmonics`. This was done to allow the use of `torch.jit.script` and `torch.compile`. ### Added From bea98e852c7d2424c94c49524a546f3b73356173 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 11:06:49 +0100 Subject: [PATCH 04/13] Improve documentation formatting for RepArray class --- .../cuequivariance_jax/rep_array/jax_rep_array.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py index 11a4b4f9..f39be44f 100644 --- a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py @@ -14,11 +14,11 @@ @dataclass(frozen=True, init=False, repr=False) class RepArray: """ - A `jax.Array` decorated with a dict of `Rep` for the axes transforming under a group representation. + A :class:`jax.Array ` decorated with a dict of :class:`cue.Rep ` for the axes transforming under a group representation. Example: - You can create a `RepArray` by specifying the `Rep` for each axis: + You can create a :class:`cuex.RepArray ` by specifying the :class:`cue.Rep ` for each axis: >>> cuex.RepArray({0: cue.SO3(1), 1: cue.SO3(1)}, jnp.eye(3)) {0: 1, 1: 1} @@ -26,7 +26,7 @@ class RepArray: [0. 1. 0.] [0. 0. 1.]] - By default, arguments that are not `Rep` will be automatically converted into `IrrepsAndLayout`: + By default, arguments that are not :class:`cue.Rep ` will be automatically converted into :class:`cue.IrrepsAndLayout `: >>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray({0: "1", 1: "2"}, jnp.ones((3, 5))) @@ -40,7 +40,7 @@ class RepArray: .. rubric:: IrrepsArray - An ``IrrepsArray`` is just a special case of a ``RepArray`` where the last axis is a `IrrepsAndLayout`: + An ``IrrepsArray`` is just a special case of a ``RepArray`` where the last axis is a :class:`cue.IrrepsAndLayout `: >>> x = cuex.RepArray( ... cue.Irreps("SO3", "2x0"), jnp.zeros((3, 2)), cue.ir_mul From f68ee03b374fbd13c2bc75828192116a359c51e9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 13:23:29 +0100 Subject: [PATCH 05/13] Refactor GitHub Actions workflows to separate dependency installation and Sphinx build steps --- .github/workflows/doc_build.yml | 4 +++- .github/workflows/doc_build_deploy.yml | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/doc_build.yml b/.github/workflows/doc_build.yml index 59340525..cde5375b 100644 --- a/.github/workflows/doc_build.yml +++ b/.github/workflows/doc_build.yml @@ -14,7 +14,7 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.12" - - name: Build sphinx + - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install pytest @@ -22,4 +22,6 @@ jobs: python -m pip install ./cuequivariance_jax python -m pip install ./cuequivariance_torch pip install -r docs/requirements.txt + - name: Build sphinx + run: | sphinx-build -b html docs docs/public diff --git a/.github/workflows/doc_build_deploy.yml b/.github/workflows/doc_build_deploy.yml index 63a071ef..af70fe13 100644 --- a/.github/workflows/doc_build_deploy.yml +++ b/.github/workflows/doc_build_deploy.yml @@ -35,7 +35,7 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.12" - - name: Build sphinx + - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install pytest @@ -43,6 +43,8 @@ jobs: python -m pip install ./cuequivariance_jax python -m pip install ./cuequivariance_torch pip install -r docs/requirements.txt + - name: Build sphinx + run: | sphinx-build -b html docs docs/public - name: Upload artifact uses: actions/upload-pages-artifact@v3 From 8065075c0d5818646c4ca9b9ee7006014b35c3b8 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 13:25:22 +0100 Subject: [PATCH 06/13] Remove python-gitlab from documentation requirements --- docs/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index e11214dc..aac8c51d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,6 +7,5 @@ myst-parser ipykernel matplotlib jupyter-sphinx -python-gitlab e3nn flax \ No newline at end of file From dd239f1f839b2cf534ed380b59cb30149e3e89c6 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 13:37:24 +0100 Subject: [PATCH 07/13] Update docstrings to use class reference format for EquivariantTensorProduct --- cuequivariance/cuequivariance/descriptors/irreps_tp.py | 10 +++++----- .../cuequivariance/descriptors/spherical_harmonics_.py | 2 +- .../descriptors/symmetric_contractions.py | 2 +- .../primitives/equivariant_tensor_product.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cuequivariance/cuequivariance/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/descriptors/irreps_tp.py index a6d7fc4a..9df26e90 100644 --- a/cuequivariance/cuequivariance/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/descriptors/irreps_tp.py @@ -39,7 +39,7 @@ def fully_connected_tensor_product( irreps3 (Irreps): Irreps of the output. Returns: - EquivariantTensorProduct: Descriptor of the fully connected tensor product. + :class:`cue.EquivariantTensorProduct `: Descriptor of the fully connected tensor product. Examples: >>> cue.descriptors.fully_connected_tensor_product( @@ -102,7 +102,7 @@ def full_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - EquivariantTensorProduct: Descriptor of the full tensor product. + :class:`cue.EquivariantTensorProduct `: Descriptor of the full tensor product. """ G = irreps1.irrep_class @@ -163,7 +163,7 @@ def channelwise_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - EquivariantTensorProduct: Descriptor of the channelwise tensor product. + :class:`cue.EquivariantTensorProduct `: Descriptor of the channelwise tensor product. """ G = irreps1.irrep_class @@ -254,7 +254,7 @@ def elementwise_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - EquivariantTensorProduct: Descriptor of the elementwise tensor product. + :class:`cue.EquivariantTensorProduct `: Descriptor of the elementwise tensor product. """ G = irreps1.irrep_class @@ -306,7 +306,7 @@ def linear( irreps_out (Irreps): Irreps of the output. Returns: - EquivariantTensorProduct: Descriptor of the linear transformation. + :class:`cue.EquivariantTensorProduct `: Descriptor of the linear transformation. """ d = stp.SegmentedTensorProduct.from_subscripts("uv_iu_iv") for mul, ir in irreps_in: diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py index af7beb63..899fa49a 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py @@ -33,7 +33,7 @@ def spherical_harmonics( layout (IrrepsLayout, optional): layout of the output. Defaults to ``cue.ir_mul``. Returns: - EquivariantTensorProduct: The descriptor. + :class:`cue.EquivariantTensorProduct `: The descriptor. Examples: >>> spherical_harmonics(cue.SO3(1), [0, 1, 2]) diff --git a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py index 48f3fab9..d830cc63 100644 --- a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py @@ -34,7 +34,7 @@ def symmetric_contraction( degree (int): The degree of the symmetric contraction. Returns: - EquivariantTensorProduct: + :class:`cue.EquivariantTensorProduct `: The descriptor of the symmetric contraction. The operands are the weights, the input degree times and the output. diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py index b93ffdc7..ac89d813 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py @@ -33,7 +33,7 @@ def equivariant_tensor_product( """Compute the equivariant tensor product of the input arrays. Args: - e (EquivariantTensorProduct): The equivariant tensor product descriptor. + e (:class:`cue.EquivariantTensorProduct `): The equivariant tensor product descriptor. *inputs (RepArray or jax.Array): The input arrays. dtype_output (jnp.dtype, optional): The data type for the output array. Defaults to None. dtype_math (jnp.dtype, optional): The data type for computational operations. Defaults to None. From 2eaa444c1d642679578303f5deccb5de4e74b2e9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 23:54:33 +0100 Subject: [PATCH 08/13] Update CHANGELOG to reflect breaking changes in TensorProduct and EquivariantTensorProduct input requirements --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9384e08d..2eba490f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Breaking Changes +- `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` now require inputs to be of shape `(batch_size, dim)` or `(1, dim)`. Inputs of dimension `(dim,)` are no more allowed. - `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` are no more variadic functions. They now require a list of `torch.Tensor` as input. - `cuex.IrrepsArray` is an alias for `cuex.RepArray`. - `cuex.RepArray.irreps` and `cuex.RepArray.segments` are not functions anymore. They are now properties. From bffb2274f4365708031f6dd6fe1275acd1b6cbd5 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 9 Jan 2025 00:03:28 +0100 Subject: [PATCH 09/13] Enhance input validation for EquivariantTensorProduct and TensorProduct classes to ensure correct tensor dimensions and types --- .../primitives/equivariant_tensor_product.py | 25 +++++++++++++------ .../primitives/tensor_product.py | 20 ++++++++++++++- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 2d5610b6..846f6ee1 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -251,13 +251,24 @@ def forward( """ If ``indices`` is not None, the first input is indexed by ``indices``. """ - - # assert len(inputs) == len(self.etp.inputs) - for a, dim in zip(inputs, self.operands_dims): - torch._assert( - a.shape[-1] == dim, - f"Expected last dimension of input to be {dim}, got {a.shape[-1]}", - ) + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if not isinstance(inputs, (list, tuple)): + raise ValueError( + "inputs should be a list of tensors followed by optional indices" + ) + if len(inputs) != self.etp.num_inputs: + raise ValueError( + f"Expected {self.etp.num_inputs} inputs, got {len(inputs)}" + ) + for oid, input in enumerate(inputs): + torch._assert( + input.ndim == 2, + f"input {oid} should have ndim=2", + ) + torch._assert( + input.shape[1] == self.operands_dims[oid], + f"input {oid} should have shape (batch, {self.operands_dims[oid]}), got {input.shape}", + ) # Transpose inputs inputs = self.transpose_in(inputs) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index c0d4bf07..7b7b1a17 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -91,6 +91,8 @@ def __init__( if not self.has_cuda: self.f = _tensor_product_fx(descriptor, device, math_dtype, True) + self.operands_dims = [ope.size for ope in descriptor.operands] + @torch.jit.ignore def __repr__(self): has_cuda_kernel = ( @@ -113,8 +115,24 @@ def forward(self, inputs: List[torch.Tensor]): The output tensor resulting from the tensor product. It has a shape of (batch, last_operand_size), where `last_operand_size` is the size of the last operand in the descriptor. - """ + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if not isinstance(inputs, (list, tuple)): + raise ValueError("inputs should be a list of tensors") + if len(inputs) != self.num_operands - 1: + raise ValueError( + f"Expected {self.num_operands - 1} input tensors, got {len(inputs)}" + ) + for oid, input in enumerate(inputs): + torch._assert( + input.ndim == 2, + f"input {oid} should have ndim=2", + ) + torch._assert( + input.shape[1] == self.operands_dims[oid], + f"input {oid} should have shape (batch, {self.operands_dims[oid]}), got {input.shape}", + ) + return self.f(inputs) From d3dc7da7301f600fce7a7820b65228ad6934da3f Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 9 Jan 2025 00:17:01 +0100 Subject: [PATCH 10/13] Refactor input checks in tensor product classes to include tracing condition --- .../primitives/equivariant_tensor_product.py | 6 +++- .../primitives/symmetric_tensor_product.py | 6 +++- .../primitives/tensor_product.py | 36 +++++++++++++++---- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 846f6ee1..4723e2f2 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -251,7 +251,11 @@ def forward( """ If ``indices`` is not None, the first input is indexed by ``indices``. """ - if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and not torch.compiler.is_compiling() + ): if not isinstance(inputs, (list, tuple)): raise ValueError( "inputs should be a list of tensors followed by optional indices" diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index f56eb37a..d83a9039 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -320,7 +320,11 @@ def forward( i0 = i0.to(torch.int32) x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u) x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u) - if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and not torch.compiler.is_compiling() + ): logger.debug( f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 7b7b1a17..e92e64ac 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -116,7 +116,11 @@ def forward(self, inputs: List[torch.Tensor]): It has a shape of (batch, last_operand_size), where `last_operand_size` is the size of the last operand in the descriptor. """ - if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and not torch.compiler.is_compiling() + ): if not isinstance(inputs, (list, tuple)): raise ValueError("inputs should be a list of tensors") if len(inputs) != self.num_operands - 1: @@ -374,7 +378,11 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu self.descriptor = descriptor def forward(self, args: List[torch.Tensor]): - if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and not torch.compiler.is_compiling() + ): for oid, arg in enumerate(args): torch._assert( arg.ndim == 2, @@ -506,7 +514,11 @@ def __repr__(self) -> str: def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1 = self._perm(inputs[0], inputs[1]) - if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and not torch.compiler.is_compiling() + ): logger.debug( f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" ) @@ -567,7 +579,11 @@ def __repr__(self) -> str: def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1, x2 = self._perm(inputs[0], inputs[1], inputs[2]) - if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and not torch.compiler.is_compiling() + ): logger.debug( f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" ) @@ -626,7 +642,11 @@ def __repr__(self): def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1 = inputs - if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and not torch.compiler.is_compiling() + ): logger.debug( f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" ) @@ -647,7 +667,11 @@ def __repr__(self): def forward(self, inputs: List[torch.Tensor]): x0, x1, x2 = inputs - if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and not torch.compiler.is_compiling() + ): logger.debug( f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" ) From bbe7a340140ccdd209a66075ec2346d0f6d1960f Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 15 Jan 2025 18:38:33 +0100 Subject: [PATCH 11/13] Update CHANGELOG.md --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d742fe62..7164ed7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,6 @@ ### Breaking Changes - `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` now require inputs to be of shape `(batch_size, dim)` or `(1, dim)`. Inputs of dimension `(dim,)` are no more allowed. -- `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` are no more variadic functions. They now require a list of `torch.Tensor` as input. - `cuex.IrrepsArray` is an alias for `cuex.RepArray`. - `cuex.RepArray.irreps` and `cuex.RepArray.segments` are not functions anymore. They are now properties. - `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`. @@ -13,7 +12,7 @@ ### Added -- Add an experimental support for `torch.jit.script` and `torch.compile`. Known issue: the export in c++ is not working. +- Add an experimental support for `torch.compile`. Known issue: the export in c++ is not working. - Add `cue.IrrepsAndLayout`: A simple class that inherits from `cue.Rep` and contains a `cue.Irreps` and a `cue.IrrepsLayout`. - Add `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `cuex.IrrepsArray`). From b0b287d9d3aaef0fd28e9a917b9766e8304f670e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 16 Jan 2025 18:04:03 +0100 Subject: [PATCH 12/13] Add missing changes --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2eba490f..1c259e12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Breaking Changes +- Minimal python version is now 3.10 in all packages. - `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` now require inputs to be of shape `(batch_size, dim)` or `(1, dim)`. Inputs of dimension `(dim,)` are no more allowed. - `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` are no more variadic functions. They now require a list of `torch.Tensor` as input. - `cuex.IrrepsArray` is an alias for `cuex.RepArray`. @@ -18,6 +19,12 @@ ### Fixed - Add support for empty batch dimension in `cuet` (`cuequivariance_torch`). +- Move `README.md` and `LICENSE` into the source distribution. +- Fix `cue.SegmentedTensorProduct.flop_cost` for the special case of 1 operand. + +### Improved + +- No more special case for degree 0 in `cuet.SymmetricTensorProduct`. ## 0.1.0 (2024-11-18) From 255eaae7183be5e98367785d5b30d6247cd888b5 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 16 Jan 2025 18:08:53 +0100 Subject: [PATCH 13/13] Add section for latest changes in CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85b64f28..b45538b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +## Latest Changes + ## 0.2.0 (2025-01-24) ### Breaking Changes