Skip to content

Commit

Permalink
tensor_product_p: shapes -> output_shapes (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger authored Jan 23, 2025
1 parent 8ee68bd commit bcc2847
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 39 deletions.
80 changes: 45 additions & 35 deletions cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,17 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array:
jnp.reshape(input, (1,) * (len(output_shape) + 1 - input.ndim) + input.shape)
for input in inputs
]
shapes = tuple(input.shape[:-1] for input in inputs) + (output_shape,)
output_shapes = tuple(None for _ in inputs) + (output_shape,)
exe = TensorProductExecution(
[
Computation(
[InBuffer(oid) for oid in range(d.num_operands - 1)] + [OutBuffer(0)]
)
]
)
(output,) = tensor_product_prim(*inputs, shapes=shapes, d=d, exe=exe, **options)
(output,) = tensor_product_prim(
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
return output


Expand Down Expand Up @@ -184,7 +186,7 @@ def clean_inputs(

def tensor_product_prim(
*inputs: jax.Array, # input buffers
shapes: tuple[tuple[int, ...], ...], # shapes of the operands
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -198,18 +200,18 @@ def tensor_product_prim(

if options.pop("use_custom_primitive", True):
return tensor_product_p.bind(
*unique_inputs, shapes=shapes, d=d, exe=exe, **options
*unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
else:
return tensor_product_vanilla_impl(
*unique_inputs, shapes=shapes, d=d, exe=exe, **options
*unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)


def tensor_product_impl(
platform: str | None,
*inputs: jax.Array,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -228,7 +230,7 @@ def dispatch(
pass

return tensor_product_vanilla_impl(
*inputs, shapes=shapes, d=d, exe=exe, **options
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)

outputs = [0] * len(exe.out_buffers)
Expand All @@ -251,7 +253,7 @@ def dispatch(

def tensor_product_abstract_eval(
*inputs: jax.core.ShapedArray,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -262,16 +264,15 @@ def tensor_product_abstract_eval(

for c in exe.computations:
for oid, x in zip(c.in_operands, c.map_inputs(inputs)):
expected_shape = shapes[oid] + (d.operands[oid].size,)
if x.shape != expected_shape:
if x.shape[-1] != d.operands[oid].size:
raise ValueError(
f"cuex.tensor_product: expected input to have shape {expected_shape}, got {x.shape}"
f"cuex.tensor_product: expected input to have size {d.operands[oid].size}, got {x.shape[-1]}"
)

outputs = [None] * len(exe.out_buffers)
for c in exe.computations:
out = jax.core.ShapedArray(
shape=shapes[c.out_operand] + (d.operands[c.out_operand].size,),
shape=output_shapes[c.out_operand] + (d.operands[c.out_operand].size,),
dtype=options["dtype_output"],
)
assert outputs[c.out_buffer] is None or outputs[c.out_buffer] == out
Expand All @@ -283,12 +284,14 @@ def tensor_product_jvp(
primals: tuple[jax.Array, ...],
tangents: tuple[jax.Array | ad.Zero, ...],
*,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]:
out_primals = tensor_product_prim(*primals, shapes=shapes, d=d, exe=exe, **options)
out_primals = tensor_product_prim(
*primals, output_shapes=output_shapes, d=d, exe=exe, **options
)
out_tangents = [ad.Zero(p.aval) for p in out_primals]

jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents])
Expand All @@ -300,7 +303,7 @@ def tensor_product_jvp(
tmp = tensor_product_prim(
*primals,
*[t for t in tangents if not isinstance(t, ad.Zero)],
shapes=shapes,
output_shapes=output_shapes,
d=multiplicator * d,
exe=exe.map_buffers(None, lambda b: exe.out_buffers.index(b)),
**options,
Expand All @@ -314,19 +317,35 @@ def tensor_product_jvp(
def tensor_product_transpose(
cotangents: tuple[jax.Array | ad.Zero, ...],
*inputs: jax.Array | ad.UndefinedPrimal,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.Array | ad.Zero | None, ...]:
# The cotangents replace the outputs as inputs
# The undefined primal inputs become outputs
del output_shapes
output_shapes = [None] * d.num_operands
for comp in exe.computations:
for oid, x in zip(comp.in_operands, comp.map_inputs(inputs)):
if ad.is_undefined_primal(x):
undefined_primal_shape = x.aval.shape[:-1]
# if the following assert fails, we need to change the internal API of the primitive
assert (
output_shapes[oid] is None
or output_shapes[oid] == undefined_primal_shape
)
output_shapes[oid] = undefined_primal_shape
output_shapes = tuple(output_shapes)

tr = exe.transpose(
[ad.is_undefined_primal(x) for x in inputs],
[not isinstance(x, ad.Zero) for x in cotangents],
)
tmp = tensor_product_prim(
*[x for x in inputs if not ad.is_undefined_primal(x)],
*[x for x in cotangents if not isinstance(x, ad.Zero)],
shapes=shapes,
output_shapes=output_shapes,
d=d,
exe=tr.map_buffers(None, lambda b: tr.out_buffers.index(b)),
**options,
Expand All @@ -348,7 +367,7 @@ def tensor_product_batching(
batched_inputs: tuple[jax.Array, ...],
batch_axes: tuple[int | None, ...],
*,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -359,33 +378,24 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array:
else:
return jnp.moveaxis(input, axis, 0)

assert len(batched_inputs) == len(batch_axes)
batched_inputs = [
prepare(input, axis) for input, axis in zip(batched_inputs, batch_axes)
]
new_dim = max(input.shape[0] for input in batched_inputs)

new_shapes = [None] * d.num_operands
new_output_shapes = [None] * d.num_operands
for comp in exe.computations:
# inputs
for oid, input in zip(comp.in_operands, comp.map_inputs(batched_inputs)):
expected = input.shape[:-1]
if new_shapes[oid] is None:
new_shapes[oid] = expected
assert new_shapes[oid] == expected

# output
oid = comp.out_operand
expected = (new_dim,) + shapes[oid]
if new_shapes[oid] is None:
new_shapes[oid] = expected
assert new_shapes[oid] == expected

new_shapes = tuple(new_shapes)
assert all(s is not None for s in new_shapes)
expected = (new_dim,) + output_shapes[oid]
if new_output_shapes[oid] is None:
new_output_shapes[oid] = expected
assert new_output_shapes[oid] == expected
new_output_shapes = tuple(new_output_shapes)

outputs = tensor_product_prim(
*batched_inputs,
shapes=new_shapes,
output_shapes=new_output_shapes,
d=d,
exe=exe,
**options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def tensor_product_vanilla_impl(
*inputs: jax.Array, # input buffers
shapes: tuple[tuple[int, ...], ...], # shapes of the operands
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -39,18 +39,20 @@ def tensor_product_vanilla_impl(
outputs = [0] * len(exe.out_buffers)

for c in exe.computations:
shape = output_shapes[c.out_operand]
assert shape is not None
out = sum_cat_list_list(
d.operands[c.out_operand],
tp_list_list(
*c.map_inputs(inputs),
shape=shapes[c.out_operand],
shape=shape,
d=d.move_operand_last(c.out_operand),
**options,
),
shapes[c.out_operand],
shape,
options["dtype_output"],
)
assert out.shape == shapes[c.out_operand] + (d.operands[c.out_operand].size,)
assert out.shape == shape + (d.operands[c.out_operand].size,)
outputs[c.out_buffer] += out

return tuple(outputs)
Expand Down
11 changes: 11 additions & 0 deletions cuequivariance_jax/tests/primitives/tensor_product_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,14 @@ def f(w, x):
return jnp.sum(a) + jnp.sum(b)

jax.jit(jax.grad(f, 0))(w, x)


def test_multiple_operand_shape_bug():
# This was causing an issue in the past.
# Before, it was not possible to have an input
# with a different shape than the output of the same operand.
def h(x):
d = cue.descriptors.spherical_harmonics(cue.SO3(1), [2]).d
return cuex.tensor_product(d, x, x)

assert jax.jacobian(h)(jnp.array([1.0, 0.0, 0.0])).shape == (5, 3)

0 comments on commit bcc2847

Please sign in to comment.