diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index cb81c35..7dd2f37 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -147,7 +147,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array: jnp.reshape(input, (1,) * (len(output_shape) + 1 - input.ndim) + input.shape) for input in inputs ] - shapes = tuple(input.shape[:-1] for input in inputs) + (output_shape,) + output_shapes = tuple(None for _ in inputs) + (output_shape,) exe = TensorProductExecution( [ Computation( @@ -155,7 +155,9 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array: ) ] ) - (output,) = tensor_product_prim(*inputs, shapes=shapes, d=d, exe=exe, **options) + (output,) = tensor_product_prim( + *inputs, output_shapes=output_shapes, d=d, exe=exe, **options + ) return output @@ -184,7 +186,7 @@ def clean_inputs( def tensor_product_prim( *inputs: jax.Array, # input buffers - shapes: tuple[tuple[int, ...], ...], # shapes of the operands + output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -198,18 +200,18 @@ def tensor_product_prim( if options.pop("use_custom_primitive", True): return tensor_product_p.bind( - *unique_inputs, shapes=shapes, d=d, exe=exe, **options + *unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) else: return tensor_product_vanilla_impl( - *unique_inputs, shapes=shapes, d=d, exe=exe, **options + *unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) def tensor_product_impl( platform: str | None, *inputs: jax.Array, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -228,7 +230,7 @@ def dispatch( pass return tensor_product_vanilla_impl( - *inputs, shapes=shapes, d=d, exe=exe, **options + *inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) outputs = [0] * len(exe.out_buffers) @@ -251,7 +253,7 @@ def dispatch( def tensor_product_abstract_eval( *inputs: jax.core.ShapedArray, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -262,16 +264,15 @@ def tensor_product_abstract_eval( for c in exe.computations: for oid, x in zip(c.in_operands, c.map_inputs(inputs)): - expected_shape = shapes[oid] + (d.operands[oid].size,) - if x.shape != expected_shape: + if x.shape[-1] != d.operands[oid].size: raise ValueError( - f"cuex.tensor_product: expected input to have shape {expected_shape}, got {x.shape}" + f"cuex.tensor_product: expected input to have size {d.operands[oid].size}, got {x.shape[-1]}" ) outputs = [None] * len(exe.out_buffers) for c in exe.computations: out = jax.core.ShapedArray( - shape=shapes[c.out_operand] + (d.operands[c.out_operand].size,), + shape=output_shapes[c.out_operand] + (d.operands[c.out_operand].size,), dtype=options["dtype_output"], ) assert outputs[c.out_buffer] is None or outputs[c.out_buffer] == out @@ -283,12 +284,14 @@ def tensor_product_jvp( primals: tuple[jax.Array, ...], tangents: tuple[jax.Array | ad.Zero, ...], *, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]: - out_primals = tensor_product_prim(*primals, shapes=shapes, d=d, exe=exe, **options) + out_primals = tensor_product_prim( + *primals, output_shapes=output_shapes, d=d, exe=exe, **options + ) out_tangents = [ad.Zero(p.aval) for p in out_primals] jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents]) @@ -300,7 +303,7 @@ def tensor_product_jvp( tmp = tensor_product_prim( *primals, *[t for t in tangents if not isinstance(t, ad.Zero)], - shapes=shapes, + output_shapes=output_shapes, d=multiplicator * d, exe=exe.map_buffers(None, lambda b: exe.out_buffers.index(b)), **options, @@ -314,11 +317,27 @@ def tensor_product_jvp( def tensor_product_transpose( cotangents: tuple[jax.Array | ad.Zero, ...], *inputs: jax.Array | ad.UndefinedPrimal, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array | ad.Zero | None, ...]: + # The cotangents replace the outputs as inputs + # The undefined primal inputs become outputs + del output_shapes + output_shapes = [None] * d.num_operands + for comp in exe.computations: + for oid, x in zip(comp.in_operands, comp.map_inputs(inputs)): + if ad.is_undefined_primal(x): + undefined_primal_shape = x.aval.shape[:-1] + # if the following assert fails, we need to change the internal API of the primitive + assert ( + output_shapes[oid] is None + or output_shapes[oid] == undefined_primal_shape + ) + output_shapes[oid] = undefined_primal_shape + output_shapes = tuple(output_shapes) + tr = exe.transpose( [ad.is_undefined_primal(x) for x in inputs], [not isinstance(x, ad.Zero) for x in cotangents], @@ -326,7 +345,7 @@ def tensor_product_transpose( tmp = tensor_product_prim( *[x for x in inputs if not ad.is_undefined_primal(x)], *[x for x in cotangents if not isinstance(x, ad.Zero)], - shapes=shapes, + output_shapes=output_shapes, d=d, exe=tr.map_buffers(None, lambda b: tr.out_buffers.index(b)), **options, @@ -348,7 +367,7 @@ def tensor_product_batching( batched_inputs: tuple[jax.Array, ...], batch_axes: tuple[int | None, ...], *, - shapes: tuple[tuple[int, ...], ...], + output_shapes: tuple[tuple[int, ...] | None, ...], d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -359,33 +378,24 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array: else: return jnp.moveaxis(input, axis, 0) + assert len(batched_inputs) == len(batch_axes) batched_inputs = [ prepare(input, axis) for input, axis in zip(batched_inputs, batch_axes) ] new_dim = max(input.shape[0] for input in batched_inputs) - new_shapes = [None] * d.num_operands + new_output_shapes = [None] * d.num_operands for comp in exe.computations: - # inputs - for oid, input in zip(comp.in_operands, comp.map_inputs(batched_inputs)): - expected = input.shape[:-1] - if new_shapes[oid] is None: - new_shapes[oid] = expected - assert new_shapes[oid] == expected - - # output oid = comp.out_operand - expected = (new_dim,) + shapes[oid] - if new_shapes[oid] is None: - new_shapes[oid] = expected - assert new_shapes[oid] == expected - - new_shapes = tuple(new_shapes) - assert all(s is not None for s in new_shapes) + expected = (new_dim,) + output_shapes[oid] + if new_output_shapes[oid] is None: + new_output_shapes[oid] = expected + assert new_output_shapes[oid] == expected + new_output_shapes = tuple(new_output_shapes) outputs = tensor_product_prim( *batched_inputs, - shapes=new_shapes, + output_shapes=new_output_shapes, d=d, exe=exe, **options, diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py index f09a86d..7d94de4 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py @@ -29,7 +29,7 @@ def tensor_product_vanilla_impl( *inputs: jax.Array, # input buffers - shapes: tuple[tuple[int, ...], ...], # shapes of the operands + output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands d: stp.SegmentedTensorProduct, exe: TensorProductExecution, **options, @@ -39,18 +39,20 @@ def tensor_product_vanilla_impl( outputs = [0] * len(exe.out_buffers) for c in exe.computations: + shape = output_shapes[c.out_operand] + assert shape is not None out = sum_cat_list_list( d.operands[c.out_operand], tp_list_list( *c.map_inputs(inputs), - shape=shapes[c.out_operand], + shape=shape, d=d.move_operand_last(c.out_operand), **options, ), - shapes[c.out_operand], + shape, options["dtype_output"], ) - assert out.shape == shapes[c.out_operand] + (d.operands[c.out_operand].size,) + assert out.shape == shape + (d.operands[c.out_operand].size,) outputs[c.out_buffer] += out return tuple(outputs) diff --git a/cuequivariance_jax/tests/primitives/tensor_product_test.py b/cuequivariance_jax/tests/primitives/tensor_product_test.py index 8cbbaa1..f1db5d7 100644 --- a/cuequivariance_jax/tests/primitives/tensor_product_test.py +++ b/cuequivariance_jax/tests/primitives/tensor_product_test.py @@ -178,3 +178,14 @@ def f(w, x): return jnp.sum(a) + jnp.sum(b) jax.jit(jax.grad(f, 0))(w, x) + + +def test_multiple_operand_shape_bug(): + # This was causing an issue in the past. + # Before, it was not possible to have an input + # with a different shape than the output of the same operand. + def h(x): + d = cue.descriptors.spherical_harmonics(cue.SO3(1), [2]).d + return cuex.tensor_product(d, x, x) + + assert jax.jacobian(h)(jnp.array([1.0, 0.0, 0.0])).shape == (5, 3)