Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX fix a (rare) bug related to vmap of tensor_product #73

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 45 additions & 35 deletions cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,17 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array:
jnp.reshape(input, (1,) * (len(output_shape) + 1 - input.ndim) + input.shape)
for input in inputs
]
shapes = tuple(input.shape[:-1] for input in inputs) + (output_shape,)
output_shapes = tuple(None for _ in inputs) + (output_shape,)
exe = TensorProductExecution(
[
Computation(
[InBuffer(oid) for oid in range(d.num_operands - 1)] + [OutBuffer(0)]
)
]
)
(output,) = tensor_product_prim(*inputs, shapes=shapes, d=d, exe=exe, **options)
(output,) = tensor_product_prim(
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
return output


Expand Down Expand Up @@ -184,7 +186,7 @@ def clean_inputs(

def tensor_product_prim(
*inputs: jax.Array, # input buffers
shapes: tuple[tuple[int, ...], ...], # shapes of the operands
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -198,18 +200,18 @@ def tensor_product_prim(

if options.pop("use_custom_primitive", True):
return tensor_product_p.bind(
*unique_inputs, shapes=shapes, d=d, exe=exe, **options
*unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
else:
return tensor_product_vanilla_impl(
*unique_inputs, shapes=shapes, d=d, exe=exe, **options
*unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)


def tensor_product_impl(
platform: str | None,
*inputs: jax.Array,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -228,7 +230,7 @@ def dispatch(
pass

return tensor_product_vanilla_impl(
*inputs, shapes=shapes, d=d, exe=exe, **options
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)

outputs = [0] * len(exe.out_buffers)
Expand All @@ -251,7 +253,7 @@ def dispatch(

def tensor_product_abstract_eval(
*inputs: jax.core.ShapedArray,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -262,16 +264,15 @@ def tensor_product_abstract_eval(

for c in exe.computations:
for oid, x in zip(c.in_operands, c.map_inputs(inputs)):
expected_shape = shapes[oid] + (d.operands[oid].size,)
if x.shape != expected_shape:
if x.shape[-1] != d.operands[oid].size:
raise ValueError(
f"cuex.tensor_product: expected input to have shape {expected_shape}, got {x.shape}"
f"cuex.tensor_product: expected input to have size {d.operands[oid].size}, got {x.shape[-1]}"
)

outputs = [None] * len(exe.out_buffers)
for c in exe.computations:
out = jax.core.ShapedArray(
shape=shapes[c.out_operand] + (d.operands[c.out_operand].size,),
shape=output_shapes[c.out_operand] + (d.operands[c.out_operand].size,),
dtype=options["dtype_output"],
)
assert outputs[c.out_buffer] is None or outputs[c.out_buffer] == out
Expand All @@ -283,12 +284,14 @@ def tensor_product_jvp(
primals: tuple[jax.Array, ...],
tangents: tuple[jax.Array | ad.Zero, ...],
*,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]:
out_primals = tensor_product_prim(*primals, shapes=shapes, d=d, exe=exe, **options)
out_primals = tensor_product_prim(
*primals, output_shapes=output_shapes, d=d, exe=exe, **options
)
out_tangents = [ad.Zero(p.aval) for p in out_primals]

jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents])
Expand All @@ -300,7 +303,7 @@ def tensor_product_jvp(
tmp = tensor_product_prim(
*primals,
*[t for t in tangents if not isinstance(t, ad.Zero)],
shapes=shapes,
output_shapes=output_shapes,
d=multiplicator * d,
exe=exe.map_buffers(None, lambda b: exe.out_buffers.index(b)),
**options,
Expand All @@ -314,19 +317,35 @@ def tensor_product_jvp(
def tensor_product_transpose(
cotangents: tuple[jax.Array | ad.Zero, ...],
*inputs: jax.Array | ad.UndefinedPrimal,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.Array | ad.Zero | None, ...]:
# The cotangents replace the outputs as inputs
# The undefined primal inputs become outputs
del output_shapes
output_shapes = [None] * d.num_operands
for comp in exe.computations:
for oid, x in zip(comp.in_operands, comp.map_inputs(inputs)):
if ad.is_undefined_primal(x):
undefined_primal_shape = x.aval.shape[:-1]
# if the following assert fails, we need to change the internal API of the primitive
assert (
output_shapes[oid] is None
or output_shapes[oid] == undefined_primal_shape
)
output_shapes[oid] = undefined_primal_shape
output_shapes = tuple(output_shapes)

tr = exe.transpose(
[ad.is_undefined_primal(x) for x in inputs],
[not isinstance(x, ad.Zero) for x in cotangents],
)
tmp = tensor_product_prim(
*[x for x in inputs if not ad.is_undefined_primal(x)],
*[x for x in cotangents if not isinstance(x, ad.Zero)],
shapes=shapes,
output_shapes=output_shapes,
d=d,
exe=tr.map_buffers(None, lambda b: tr.out_buffers.index(b)),
**options,
Expand All @@ -348,7 +367,7 @@ def tensor_product_batching(
batched_inputs: tuple[jax.Array, ...],
batch_axes: tuple[int | None, ...],
*,
shapes: tuple[tuple[int, ...], ...],
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -359,33 +378,24 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array:
else:
return jnp.moveaxis(input, axis, 0)

assert len(batched_inputs) == len(batch_axes)
batched_inputs = [
prepare(input, axis) for input, axis in zip(batched_inputs, batch_axes)
]
new_dim = max(input.shape[0] for input in batched_inputs)

new_shapes = [None] * d.num_operands
new_output_shapes = [None] * d.num_operands
for comp in exe.computations:
# inputs
for oid, input in zip(comp.in_operands, comp.map_inputs(batched_inputs)):
expected = input.shape[:-1]
if new_shapes[oid] is None:
new_shapes[oid] = expected
assert new_shapes[oid] == expected

# output
oid = comp.out_operand
expected = (new_dim,) + shapes[oid]
if new_shapes[oid] is None:
new_shapes[oid] = expected
assert new_shapes[oid] == expected

new_shapes = tuple(new_shapes)
assert all(s is not None for s in new_shapes)
expected = (new_dim,) + output_shapes[oid]
if new_output_shapes[oid] is None:
new_output_shapes[oid] = expected
assert new_output_shapes[oid] == expected
new_output_shapes = tuple(new_output_shapes)

outputs = tensor_product_prim(
*batched_inputs,
shapes=new_shapes,
output_shapes=new_output_shapes,
d=d,
exe=exe,
**options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def tensor_product_vanilla_impl(
*inputs: jax.Array, # input buffers
shapes: tuple[tuple[int, ...], ...], # shapes of the operands
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: stp.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
Expand All @@ -39,18 +39,20 @@ def tensor_product_vanilla_impl(
outputs = [0] * len(exe.out_buffers)

for c in exe.computations:
shape = output_shapes[c.out_operand]
assert shape is not None
out = sum_cat_list_list(
d.operands[c.out_operand],
tp_list_list(
*c.map_inputs(inputs),
shape=shapes[c.out_operand],
shape=shape,
d=d.move_operand_last(c.out_operand),
**options,
),
shapes[c.out_operand],
shape,
options["dtype_output"],
)
assert out.shape == shapes[c.out_operand] + (d.operands[c.out_operand].size,)
assert out.shape == shape + (d.operands[c.out_operand].size,)
outputs[c.out_buffer] += out

return tuple(outputs)
Expand Down
11 changes: 11 additions & 0 deletions cuequivariance_jax/tests/primitives/tensor_product_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,14 @@ def f(w, x):
return jnp.sum(a) + jnp.sum(b)

jax.jit(jax.grad(f, 0))(w, x)


def test_multiple_operand_shape_bug():
# This was causing an issue in the past.
# Before, it was not possible to have an input
# with a different shape than the output of the same operand.
def h(x):
d = cue.descriptors.spherical_harmonics(cue.SO3(1), [2]).d
return cuex.tensor_product(d, x, x)

assert jax.jacobian(h)(jnp.array([1.0, 0.0, 0.0])).shape == (5, 3)
Loading