Skip to content

Commit

Permalink
Refactor tensor_product to use jax.extend.core for Primitive (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger authored Jan 22, 2025
1 parent c433a86 commit 8ee68bd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2 # Use the latest stable version of Black
rev: v0.9.1
hooks:
- id: ruff
args: ["--fix"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ def equivariant_tensor_product(

for i, (x, rep) in enumerate(zip(inputs, e.inputs)):
if isinstance(x, cuex.RepArray):
assert (
x.rep(-1) == rep
), f"Input {i} should have representation {rep}, got {x.rep(-1)}."
assert x.rep(-1) == rep, (
f"Input {i} should have representation {rep}, got {x.rep(-1)}."
)
else:
assert (
x.ndim >= 1
), f"Input {i} should have at least one dimension, got {x.ndim}."
assert (
x.shape[-1] == rep.dim
), f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}."
assert x.ndim >= 1, (
f"Input {i} should have at least one dimension, got {x.ndim}."
)
assert x.shape[-1] == rep.dim, (
f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}."
)
if not rep.is_scalar():
raise ValueError(
f"Input {i} should be a RepArray unless the input is scalar. Got {type(x)} for {rep}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from functools import partial

import jax
import jax.core
import jax.extend
import jax.lax
import jax.numpy as jnp
from jax import core
from jax.interpreters import ad, batching, mlir, xla

from cuequivariance import segmented_tensor_product as stp
Expand Down Expand Up @@ -160,7 +161,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array:

################################################################################

tensor_product_p = core.Primitive("tensor_product")
tensor_product_p = jax.extend.core.Primitive("tensor_product")
tensor_product_p.multiple_results = True


Expand Down Expand Up @@ -249,12 +250,12 @@ def dispatch(


def tensor_product_abstract_eval(
*inputs: core.ShapedArray,
*inputs: jax.core.ShapedArray,
shapes: tuple[tuple[int, ...], ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[core.ShapedArray, ...]:
) -> tuple[jax.core.ShapedArray, ...]:
# assert that all input/output are used
assert exe.max_in_buffer + 1 == len(exe.in_buffers) == len(inputs)
assert exe.max_out_buffer + 1 == len(exe.out_buffers)
Expand All @@ -269,7 +270,7 @@ def tensor_product_abstract_eval(

outputs = [None] * len(exe.out_buffers)
for c in exe.computations:
out = core.ShapedArray(
out = jax.core.ShapedArray(
shape=shapes[c.out_operand] + (d.operands[c.out_operand].size,),
dtype=options["dtype_output"],
)
Expand All @@ -291,6 +292,7 @@ def tensor_product_jvp(
out_tangents = [ad.Zero(p.aval) for p in out_primals]

jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents])
del exe

permutations: list[tuple[int, ...]] = d.symmetries()
for multiplicator, exe in jvp.group_by_symmetries(permutations):
Expand Down

0 comments on commit 8ee68bd

Please sign in to comment.