Skip to content

Commit

Permalink
cue.IrrepsAndLayout, cue.EquivariantTensorProduct, cuex.RepArray (#46)
Browse files Browse the repository at this point in the history
* add cuex.RepArray

* add IrrepsAndLayout

* fix

* EquivariantTensorProduct use Rep instead of Irreps

* Use RepArray in cuex.equivariant_tensor_product

* rename folder

* cuex use IrrepsAndLayout

* fix

* cuet.EquivariantTensorProduct optional transpose

* IrrepsArray as alias

* fix

* remove is_simple

* fix

* fix docs and call rep.dim instead of rep.irreps.dim

* docs

* add sh to docs

* edit docs of cuex.tensor_product and cuex.equivariant_tensor_product

* back to match case

* move things in experimental and add minimal tests

* clean

* remove todo
  • Loading branch information
mariogeiger authored Dec 19, 2024
1 parent 88e6596 commit 163f5f1
Show file tree
Hide file tree
Showing 48 changed files with 1,655 additions and 1,204 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
### Added

- Partial support of `torch.jit.script` and `torch.compile`
- Added `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `IrrepsArray`).

### Changed

- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input.
- `cuex.IrrepsArray` is now an alias for `cuex.RepArray` and its `.irreps` attribute and `.segments` are not functions anymore but properties.

## Removed

- `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`.

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions cuequivariance/cuequivariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
IrrepsLayout,
mul_ir,
ir_mul,
IrrepsAndLayout,
get_layout_scope,
assume,
NumpyIrrepsArray,
Expand Down Expand Up @@ -71,6 +72,7 @@
"IrrepsLayout",
"mul_ir",
"ir_mul",
"IrrepsAndLayout",
"get_layout_scope",
"assume",
"NumpyIrrepsArray",
Expand Down
7 changes: 0 additions & 7 deletions cuequivariance/cuequivariance/descriptors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
yxy_rotation,
inversion,
)
from .escn import escn_tp, escn_tp_compact
from .spherical_harmonics_ import sympy_spherical_harmonics, spherical_harmonics
from .gatr import gatr_linear, gatr_geometric_product, gatr_outer_product

__all__ = [
"transpose",
Expand All @@ -49,11 +47,6 @@
"yx_rotation",
"yxy_rotation",
"inversion",
"escn_tp",
"escn_tp_compact",
"sympy_spherical_harmonics",
"spherical_harmonics",
"gatr_linear",
"gatr_geometric_product",
"gatr_outer_product",
]
37 changes: 28 additions & 9 deletions cuequivariance/cuequivariance/descriptors/irreps_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ def fully_connected_tensor_product(
d = d.normalize_paths_for_operand(-1)
return cue.EquivariantTensorProduct(
d,
[irreps1.new_scalars(d.operands[0].size), irreps1, irreps2, irreps3],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps1, cue.ir_mul),
cue.IrrepsAndLayout(irreps2, cue.ir_mul),
cue.IrrepsAndLayout(irreps3, cue.ir_mul),
],
)


Expand Down Expand Up @@ -131,8 +135,11 @@ def full_tensor_product(
d = d.normalize_paths_for_operand(-1)
return cue.EquivariantTensorProduct(
d,
[irreps1, irreps2, irreps3],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps1, cue.ir_mul),
cue.IrrepsAndLayout(irreps2, cue.ir_mul),
cue.IrrepsAndLayout(irreps3, cue.ir_mul),
],
)


Expand Down Expand Up @@ -193,8 +200,12 @@ def channelwise_tensor_product(
d = d.normalize_paths_for_operand(-1)
return cue.EquivariantTensorProduct(
d,
[irreps1.new_scalars(d.operands[0].size), irreps1, irreps2, irreps3],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps1, cue.ir_mul),
cue.IrrepsAndLayout(irreps2, cue.ir_mul),
cue.IrrepsAndLayout(irreps3, cue.ir_mul),
],
)


Expand Down Expand Up @@ -273,7 +284,12 @@ def elementwise_tensor_product(
irreps3 = cue.Irreps(G, irreps3)
d = d.normalize_paths_for_operand(-1)
return cue.EquivariantTensorProduct(
d, [irreps1, irreps2, irreps3], layout=cue.ir_mul
d,
[
cue.IrrepsAndLayout(irreps1, cue.ir_mul),
cue.IrrepsAndLayout(irreps2, cue.ir_mul),
cue.IrrepsAndLayout(irreps3, cue.ir_mul),
],
)


Expand Down Expand Up @@ -308,6 +324,9 @@ def linear(

return cue.EquivariantTensorProduct(
d,
[irreps_in.new_scalars(d.operands[0].size), irreps_in, irreps_out],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps_in, cue.ir_mul),
cue.IrrepsAndLayout(irreps_out, cue.ir_mul),
],
)
59 changes: 39 additions & 20 deletions cuequivariance/cuequivariance/descriptors/rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ def fixed_axis_angle_rotation(
)

d = d.flatten_coefficient_modes()
return cue.EquivariantTensorProduct(d, [irreps, irreps], layout=cue.ir_mul)
return cue.EquivariantTensorProduct(
d,
[
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
)


def yxy_rotation(
Expand Down Expand Up @@ -70,13 +76,12 @@ def yxy_rotation(
return cue.EquivariantTensorProduct(
cbaio,
[
irreps.new_scalars(cbaio.operands[0].size),
irreps.new_scalars(cbaio.operands[1].size),
irreps.new_scalars(cbaio.operands[2].size),
irreps,
irreps,
cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[1].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[2].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
layout=cue.ir_mul,
)


Expand All @@ -95,12 +100,11 @@ def xy_rotation(
return cue.EquivariantTensorProduct(
cbio,
[
irreps.new_scalars(cbio.operands[0].size),
irreps.new_scalars(cbio.operands[1].size),
irreps,
irreps,
cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
layout=cue.ir_mul,
)


Expand All @@ -119,12 +123,11 @@ def yx_rotation(
return cue.EquivariantTensorProduct(
cbio,
[
irreps.new_scalars(cbio.operands[0].size),
irreps.new_scalars(cbio.operands[1].size),
irreps,
irreps,
cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
layout=cue.ir_mul,
)


Expand Down Expand Up @@ -188,7 +191,12 @@ def y_rotation(

d = d.flatten_coefficient_modes()
return cue.EquivariantTensorProduct(
d, [irreps.new_scalars(d.operands[0].size), irreps, irreps], layout=cue.ir_mul
d,
[
cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
)


Expand All @@ -213,7 +221,12 @@ def x_rotation(
d = stp.dot(stp.dot(dy, dz90, (1, 1)), dz90, (1, 1))

return cue.EquivariantTensorProduct(
d, [irreps.new_scalars(d.operands[0].size), irreps, irreps], layout=cue.ir_mul
d,
[
cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
)


Expand All @@ -228,4 +241,10 @@ def inversion(irreps: cue.Irreps) -> cue.EquivariantTensorProduct:
assert np.allclose(H @ H, np.eye(ir.dim), atol=1e-6)
d.add_path(None, None, c=H, dims={"u": mul})
d = d.flatten_coefficient_modes()
return cue.EquivariantTensorProduct(d, [irreps, irreps], layout=cue.ir_mul)
return cue.EquivariantTensorProduct(
d,
[
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def spherical_harmonics(
indices = poly_degrees_to_path_indices(degrees)
d.add_path(*indices, i, c=coeff)

return cue.EquivariantTensorProduct([d], [ir_vec, ir], layout=layout)
return cue.EquivariantTensorProduct(
[d],
[
cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul),
cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul),
],
)


def poly_degrees_to_path_indices(degrees: tuple[int, ...]) -> tuple[int, ...]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def symmetric_contraction(
d = d.append_modes_to_all_operands("u", {"u": mul})
return cue.EquivariantTensorProduct(
[d],
[irreps_in.new_scalars(d.operands[0].size), mul * irreps_in, mul * irreps_out],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul),
cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul),
],
)
11 changes: 7 additions & 4 deletions cuequivariance/cuequivariance/descriptors/transposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cuequivariance as cue
from cuequivariance.equivariant_tensor_product import Operand


def transpose(
Expand All @@ -22,12 +21,16 @@ def transpose(
"""Transpose the irreps layout of a tensor."""
d = cue.SegmentedTensorProduct(
operands=[
cue.Operand(subscripts="ui" if source == cue.mul_ir else "iu"),
cue.Operand(subscripts="ui" if target == cue.mul_ir else "iu"),
cue.segmented_tensor_product.Operand(
subscripts="ui" if source == cue.mul_ir else "iu"
),
cue.segmented_tensor_product.Operand(
subscripts="ui" if target == cue.mul_ir else "iu"
),
]
)
for mul, ir in irreps:
d.add_path(None, None, c=1, dims={"u": mul, "i": ir.dim})
return cue.EquivariantTensorProduct(
d, [Operand(irreps, source), Operand(irreps, target)]
d, [cue.IrrepsAndLayout(irreps, source), cue.IrrepsAndLayout(irreps, target)]
)
Loading

0 comments on commit 163f5f1

Please sign in to comment.