From 8ee68bd03ff2095e8d0fb06b11108e1324c70fb7 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 22 Jan 2025 12:43:35 +0100 Subject: [PATCH] Refactor tensor_product to use jax.extend.core for Primitive (#68) --- .pre-commit-config.yaml | 2 +- .../primitives/equivariant_tensor_product.py | 18 +++++++++--------- .../primitives/tensor_product.py | 12 +++++++----- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62bcfec..c8c64b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.2 # Use the latest stable version of Black + rev: v0.9.1 hooks: - id: ruff args: ["--fix"] diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py index 7a9ec81..d5d61f5 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py @@ -87,16 +87,16 @@ def equivariant_tensor_product( for i, (x, rep) in enumerate(zip(inputs, e.inputs)): if isinstance(x, cuex.RepArray): - assert ( - x.rep(-1) == rep - ), f"Input {i} should have representation {rep}, got {x.rep(-1)}." + assert x.rep(-1) == rep, ( + f"Input {i} should have representation {rep}, got {x.rep(-1)}." + ) else: - assert ( - x.ndim >= 1 - ), f"Input {i} should have at least one dimension, got {x.ndim}." - assert ( - x.shape[-1] == rep.dim - ), f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}." + assert x.ndim >= 1, ( + f"Input {i} should have at least one dimension, got {x.ndim}." + ) + assert x.shape[-1] == rep.dim, ( + f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}." + ) if not rep.is_scalar(): raise ValueError( f"Input {i} should be a RepArray unless the input is scalar. Got {type(x)} for {rep}." diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index bc52ea6..cb81c35 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -16,9 +16,10 @@ from functools import partial import jax +import jax.core +import jax.extend import jax.lax import jax.numpy as jnp -from jax import core from jax.interpreters import ad, batching, mlir, xla from cuequivariance import segmented_tensor_product as stp @@ -160,7 +161,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array: ################################################################################ -tensor_product_p = core.Primitive("tensor_product") +tensor_product_p = jax.extend.core.Primitive("tensor_product") tensor_product_p.multiple_results = True @@ -249,12 +250,12 @@ def dispatch( def tensor_product_abstract_eval( - *inputs: core.ShapedArray, + *inputs: jax.core.ShapedArray, shapes: tuple[tuple[int, ...], ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, -) -> tuple[core.ShapedArray, ...]: +) -> tuple[jax.core.ShapedArray, ...]: # assert that all input/output are used assert exe.max_in_buffer + 1 == len(exe.in_buffers) == len(inputs) assert exe.max_out_buffer + 1 == len(exe.out_buffers) @@ -269,7 +270,7 @@ def tensor_product_abstract_eval( outputs = [None] * len(exe.out_buffers) for c in exe.computations: - out = core.ShapedArray( + out = jax.core.ShapedArray( shape=shapes[c.out_operand] + (d.operands[c.out_operand].size,), dtype=options["dtype_output"], ) @@ -291,6 +292,7 @@ def tensor_product_jvp( out_tangents = [ad.Zero(p.aval) for p in out_primals] jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents]) + del exe permutations: list[tuple[int, ...]] = d.symmetries() for multiplicator, exe in jvp.group_by_symmetries(permutations):