Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] call backend JAX bindings #74

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
22 changes: 20 additions & 2 deletions cuequivariance/cuequivariance/tensor_product_execution.py
Original file line number Diff line number Diff line change
@@ -24,11 +24,13 @@ class Buffer(int):


class InBuffer(Buffer):
pass
def __repr__(self):
return f"InBuffer({int(self)})"


class OutBuffer(Buffer):
pass
def __repr__(self):
return f"OutBuffer({int(self)})"


T = TypeVar("T")
@@ -165,6 +167,22 @@ def num_inputs_per_operand(self) -> tuple[int, ...]:
def num_outputs_per_operand(self) -> tuple[int, ...]:
return tuple(len(s) for s in self.out_buffers_per_operand)

def get_in_buffer_operands(self, buffer: int) -> set[int]:
return {
ope
for c in self.computations
for ope, b in enumerate(c)
if isinstance(b, InBuffer) and b == buffer
}

def get_out_buffer_operands(self, buffer: int) -> set[int]:
return {
ope
for c in self.computations
for ope, b in enumerate(c)
if isinstance(b, OutBuffer) and b == buffer
}

def map_buffers(
self,
f_in: Optional[Callable[[int], int]],
Original file line number Diff line number Diff line change
@@ -28,7 +28,8 @@ def equivariant_tensor_product(
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST,
algorithm: str = "sliced",
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
use_custom_kernels: bool | None = False,
name: str | None = None,
) -> cuex.RepArray:
"""Compute the equivariant tensor product of the input arrays.

@@ -78,6 +79,7 @@ def equivariant_tensor_product(
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

if len(inputs) != e.num_inputs:
@@ -113,6 +115,7 @@ def equivariant_tensor_product(
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

return cuex.RepArray(e.output, x)
Original file line number Diff line number Diff line change
@@ -31,7 +31,8 @@ def symmetric_tensor_product(
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST,
algorithm: str = "sliced",
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
use_custom_kernels: bool | None = False,
name: str | None = None,
) -> jax.Array:
"""
Compute the sum of the STPs evaluated on the input (all input operands are the same).
@@ -54,6 +55,9 @@ def symmetric_tensor_product(
"""
assert any(d.num_operands >= 2 for d in ds)

if name is None:
name = "symmetric_tensor_product"

# currying
if len(inputs) == 0:

@@ -67,6 +71,7 @@ def fn(*inputs) -> jax.Array:
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

return fn
@@ -136,6 +141,7 @@ def fn(*inputs) -> jax.Array:
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name + f"_{n_in - n_un}",
)

return output
155 changes: 137 additions & 18 deletions cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py
Original file line number Diff line number Diff line change
@@ -20,15 +20,20 @@
import jax.extend
import jax.lax
import jax.numpy as jnp
import numpy as np
from jax.experimental.mosaic.gpu import profiler
from jax.interpreters import ad, batching, mlir, xla

from cuequivariance import segmented_tensor_product as stp
import cuequivariance as cue
from cuequivariance.tensor_product_execution import (
Computation,
InBuffer,
OutBuffer,
TensorProductExecution,
)
from cuequivariance_jax.primitives.tensor_product_ops_impl import (
tensor_product_ops_impl,
)
from cuequivariance_jax.primitives.tensor_product_vanilla_impl import (
tensor_product_vanilla_impl,
)
@@ -37,14 +42,15 @@


def tensor_product(
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
*inputs: jax.Array,
dtype_output: jnp.dtype | None = None,
dtype_math: jnp.dtype | None = None,
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST,
algorithm: str = "sliced",
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
use_custom_kernels: bool | None = False,
name: str | None = None,
) -> jax.Array:
"""
Compute the last operand of a `SegmentedTensorProduct`.
@@ -85,13 +91,17 @@ def tensor_product(
if isinstance(precision, str):
precision = jax.lax.Precision[precision]

if name is None:
name = "tensor_product"

options = dict(
dtype_output=dtype_output,
dtype_math=dtype_math,
precision=precision,
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

if len(inputs) > d.num_operands - 1:
@@ -140,6 +150,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array:
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

# inputs of shape (..., ope.size) with identical ndim
@@ -187,8 +198,9 @@ def clean_inputs(
def tensor_product_prim(
*inputs: jax.Array, # input buffers
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
use_custom_primitive: bool = True,
**options,
) -> tuple[jax.Array, ...]: # output buffers
if exe.is_trivial:
@@ -198,7 +210,7 @@ def tensor_product_prim(

unique_inputs, exe = clean_inputs(list(inputs), exe)

if options.pop("use_custom_primitive", True):
if use_custom_primitive:
return tensor_product_p.bind(
*unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
@@ -208,30 +220,131 @@ def tensor_product_prim(
)


def produce_minimal_code(
*inputs: jax.Array,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
dtype_output: jnp.dtype,
dtype_math: jnp.dtype,
**unused_options,
) -> str:
def format_dtype(dtype: jnp.dtype) -> str:
dtype = jnp.dtype(dtype)
return f"jnp.{dtype.name}"

mincode = """import jax.numpy as jnp
import cuequivariance as cue
from cuequivariance.tensor_product_execution import InBuffer, OutBuffer
from cuequivariance_jax.primitives.tensor_product import tensor_product_prim
"""
mincode += (
"inputs = ["
+ ", ".join([f"jnp.zeros({x.shape}, {format_dtype(x.dtype)})" for x in inputs])
+ "]\n"
)
mincode += f"output_shapes = {output_shapes}\n"
mincode += f'd = cue.SegmentedTensorProduct.from_base64("{d.to_base64()}")\n'
mincode += f"exe = cue.TensorProductExecution({exe.computations})\n"
mincode += f"dtype_output = {format_dtype(dtype_output)}\n"
mincode += f"dtype_math = {format_dtype(dtype_math)}\n"
# tensor_product_prim(
# *inputs,
# output_shapes=output_shapes,
# d=d,
# exe=exe,
# dtype_output=dtype_output,
# dtype_math=dtype_math,
# use_custom_kernels=True,
# )
mincode += "# " + ", ".join([f"{k}={v}" for k, v in unused_options.items()])
return mincode


def profile_and_select_implementation(
name: str, impls: list[tuple[str, callable]], *inputs: jax.Array
):
# import time
# t0 = time.perf_counter()

with jax.ensure_compile_time_eval():
dummy_inputs = [np.random.normal(size=x.shape).astype(x.dtype) for x in inputs]
dummy_inputs = [jax.device_put(x) for x in dummy_inputs]
ref = None
first_runtime: float | None = None
best: tuple[str, float, callable] | None = None
for impl_name, impl in impls:
try:
out, runtime = profiler.measure(impl, mode="cupti")(*dummy_inputs)
except NotImplementedError:
continue
else:
if ref is None:
ref = out
best = impl_name, runtime, impl
first_runtime = runtime
else:
diff = max(
[
np.max(np.abs(a - b))
for a, b in zip(jax.tree.leaves(out), jax.tree.leaves(ref))
]
)
if diff > 1e-3:
raise ValueError(
f"cuex.tensor_product: {name} implementation {impl_name} produced different results, diff={diff}"
)
if runtime < best[1]:
best = impl_name, runtime, impl
assert best is not None
impl_name, runtime, impl = best

# dt = time.perf_counter() - t0
speedup = first_runtime / runtime
print(
f"{name:<50}: {impl_name:<10} with runtime {runtime:.2f} ms, speedup {speedup:.2f}x wrt {first_runtime:.2f} ms"
)

return impl(*inputs)


def tensor_product_impl(
platform: str | None,
*inputs: jax.Array,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
name: str = "tensor_product",
use_custom_kernels: bool | None = True,
**options,
) -> tuple[jax.Array, ...]:
assert exe.max_in_buffer + 1 == len(exe.in_buffers) == len(inputs)
assert exe.max_out_buffer + 1 == len(exe.out_buffers)
use_custom_kernels = options.pop("use_custom_kernels", True)

def dispatch(
inputs: list[jax.Array],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
) -> list[jax.Array]:
if platform == "cuda" and use_custom_kernels:
# TODO: call custom kernels here
pass
kwargs = dict(output_shapes=output_shapes, d=d, exe=exe, **options)

if platform == "cuda":
# print(produce_minimal_code(*inputs, **kwargs))
# print()
if use_custom_kernels is None:
outputs = profile_and_select_implementation(
name,
[
("vanilla", partial(tensor_product_vanilla_impl, **kwargs)),
("ops", partial(tensor_product_ops_impl, **kwargs)),
],
*inputs,
)
return outputs
if use_custom_kernels is True:
return tensor_product_ops_impl(*inputs, **kwargs)

return tensor_product_vanilla_impl(
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
return tensor_product_vanilla_impl(*inputs, **kwargs)

outputs = [0] * len(exe.out_buffers)
for partition, ex in exe.group_by_identical_buffers():
@@ -254,7 +367,7 @@ def dispatch(
def tensor_product_abstract_eval(
*inputs: jax.core.ShapedArray,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.core.ShapedArray, ...]:
@@ -285,7 +398,7 @@ def tensor_product_jvp(
tangents: tuple[jax.Array | ad.Zero, ...],
*,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]:
@@ -297,6 +410,8 @@ def tensor_product_jvp(
jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents])
del exe

options["name"] = options.get("name", "tensor_product") + "->jvp"

permutations: list[tuple[int, ...]] = d.symmetries()
for multiplicator, exe in jvp.group_by_symmetries(permutations):
# tensor_product_prim can remove unused inputs
@@ -318,7 +433,7 @@ def tensor_product_transpose(
cotangents: tuple[jax.Array | ad.Zero, ...],
*inputs: jax.Array | ad.UndefinedPrimal,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.Array | ad.Zero | None, ...]:
@@ -338,6 +453,8 @@ def tensor_product_transpose(
output_shapes[oid] = undefined_primal_shape
output_shapes = tuple(output_shapes)

options["name"] = options.get("name", "tensor_product") + "->transpose"

tr = exe.transpose(
[ad.is_undefined_primal(x) for x in inputs],
[not isinstance(x, ad.Zero) for x in cotangents],
@@ -368,7 +485,7 @@ def tensor_product_batching(
batch_axes: tuple[int | None, ...],
*,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]:
@@ -393,6 +510,8 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array:
assert new_output_shapes[oid] == expected
new_output_shapes = tuple(new_output_shapes)

options["name"] = options.get("name", "tensor_product") + "->batching"

outputs = tensor_product_prim(
*batched_inputs,
output_shapes=new_output_shapes,
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math

import jax
import jax.numpy as jnp

import cuequivariance as cue
from cuequivariance.tensor_product_execution import InBuffer

logger = logging.getLogger(__name__)


def tensor_product_ops_impl(
*inputs: jax.Array, # input buffers
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: cue.SegmentedTensorProduct,
exe: cue.TensorProductExecution,
**options,
) -> tuple[jax.Array, ...]: # output buffers
assert exe.max_out_buffer + 1 == len(exe.out_buffers)

detail_str = f"\n{d}\n{exe}".replace("\n", "\n | ")

if not d.all_same_segment_shape():
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()

try:
from cuequivariance_ops_jax import tensor_product_uniform_1d
except ImportError:
logger.info("🛶 can't import cuequivariance_ops_jax")
raise NotImplementedError()

modes = d.subscripts.modes()
if len(modes) > 1:
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()

if len(modes) == 1:
dims: set[int] = d.get_dims(modes[0])
if len(dims) != 1:
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()

batch_size = 1
for shape in [input.shape[:-1] for input in inputs] + [
shape for shape in output_shapes if shape is not None
]:
n = math.prod(shape)
if n > 1:
if n != batch_size and batch_size != 1:
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()
batch_size = n

reshaped_inputs = []
for index, input in enumerate(inputs):
operands = {
(d.operands[op].size, d.operands[op].num_segments)
for op in exe.get_in_buffer_operands(index)
}
if len(operands) != 1:
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()
size, num_segments = operands.pop()
reshaped_inputs.append(
input.reshape(
(math.prod(input.shape[:-1]), num_segments, size // num_segments)
)
)

output_operands = []
outputs = []
for index in exe.out_buffers:
operands = exe.get_out_buffer_operands(index)
if len(operands) != 1:
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()
ope = operands.pop()
size, num_segments = d.operands[ope].size, d.operands[ope].num_segments

output_operands.append(ope)
outputs.append(
jnp.zeros(
(math.prod(output_shapes[ope]), num_segments, size // num_segments),
dtype=options["dtype_output"],
)
)

logger.info("🎉 use tensor_product_uniform_1d for" + detail_str)

outputs = tensor_product_uniform_1d(
options["dtype_math"],
[ope.num_segments for ope in d.operands],
[path.indices for path in d.paths],
[path.coefficients.item() for path in d.paths],
reshaped_inputs,
outputs,
[
tuple(
int(b) if isinstance(b, InBuffer) else -1 - int(b) for b in computation
)
for computation in exe.computations
],
)

outputs = [
output.reshape(output_shapes[ope] + (-1,))
for ope, output in zip(output_operands, outputs)
]
return tuple(outputs)
17 changes: 11 additions & 6 deletions cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py
Original file line number Diff line number Diff line change
@@ -223,13 +223,18 @@ def __getitem__(self, key: Any) -> RepArray:
self.array[None],
)

# self[jnp.array([0, 1, 2])]
assert isinstance(key, jax.Array)
assert 0 not in self.reps
return RepArray(
{k + key.ndim - 1: irreps for k, irreps in self.reps.items()},
self.array[key],
)

# self[1:4]
if isinstance(key, slice):
return RepArray(self.reps, self.array[key])

# self[jnp.array([0, 1, 2])]
if isinstance(key, jax.Array):
return RepArray(
{k + key.ndim - 1: irreps for k, irreps in self.reps.items()},
self.array[key],
)

@property
def slice_by_mul(self) -> _MulIndexSliceHelper: