Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] call backend JAX bindings #74

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
22 changes: 20 additions & 2 deletions cuequivariance/cuequivariance/tensor_product_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ class Buffer(int):


class InBuffer(Buffer):
pass
def __repr__(self):
return f"InBuffer({int(self)})"


class OutBuffer(Buffer):
pass
def __repr__(self):
return f"OutBuffer({int(self)})"


T = TypeVar("T")
Expand Down Expand Up @@ -165,6 +167,22 @@ def num_inputs_per_operand(self) -> tuple[int, ...]:
def num_outputs_per_operand(self) -> tuple[int, ...]:
return tuple(len(s) for s in self.out_buffers_per_operand)

def get_in_buffer_operands(self, buffer: int) -> set[int]:
return {
ope
for c in self.computations
for ope, b in enumerate(c)
if isinstance(b, InBuffer) and b == buffer
}

def get_out_buffer_operands(self, buffer: int) -> set[int]:
return {
ope
for c in self.computations
for ope, b in enumerate(c)
if isinstance(b, OutBuffer) and b == buffer
}

def map_buffers(
self,
f_in: Optional[Callable[[int], int]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def equivariant_tensor_product(
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST,
algorithm: str = "sliced",
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
use_custom_kernels: bool | None = False,
name: str | None = None,
) -> cuex.RepArray:
"""Compute the equivariant tensor product of the input arrays.

Expand Down Expand Up @@ -78,6 +79,7 @@ def equivariant_tensor_product(
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

if len(inputs) != e.num_inputs:
Expand Down Expand Up @@ -113,6 +115,7 @@ def equivariant_tensor_product(
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

return cuex.RepArray(e.output, x)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def symmetric_tensor_product(
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST,
algorithm: str = "sliced",
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
use_custom_kernels: bool | None = False,
name: str | None = None,
) -> jax.Array:
"""
Compute the sum of the STPs evaluated on the input (all input operands are the same).
Expand All @@ -54,6 +55,9 @@ def symmetric_tensor_product(
"""
assert any(d.num_operands >= 2 for d in ds)

if name is None:
name = "symmetric_tensor_product"

# currying
if len(inputs) == 0:

Expand All @@ -67,6 +71,7 @@ def fn(*inputs) -> jax.Array:
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

return fn
Expand Down Expand Up @@ -136,6 +141,7 @@ def fn(*inputs) -> jax.Array:
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name + f"_{n_in - n_un}",
)

return output
155 changes: 137 additions & 18 deletions cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@
import jax.extend
import jax.lax
import jax.numpy as jnp
import numpy as np
from jax.experimental.mosaic.gpu import profiler
from jax.interpreters import ad, batching, mlir, xla

from cuequivariance import segmented_tensor_product as stp
import cuequivariance as cue
from cuequivariance.tensor_product_execution import (
Computation,
InBuffer,
OutBuffer,
TensorProductExecution,
)
from cuequivariance_jax.primitives.tensor_product_ops_impl import (
tensor_product_ops_impl,
)
from cuequivariance_jax.primitives.tensor_product_vanilla_impl import (
tensor_product_vanilla_impl,
)
Expand All @@ -37,14 +42,15 @@


def tensor_product(
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
*inputs: jax.Array,
dtype_output: jnp.dtype | None = None,
dtype_math: jnp.dtype | None = None,
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST,
algorithm: str = "sliced",
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
use_custom_kernels: bool | None = False,
name: str | None = None,
) -> jax.Array:
"""
Compute the last operand of a `SegmentedTensorProduct`.
Expand Down Expand Up @@ -85,13 +91,17 @@ def tensor_product(
if isinstance(precision, str):
precision = jax.lax.Precision[precision]

if name is None:
name = "tensor_product"

options = dict(
dtype_output=dtype_output,
dtype_math=dtype_math,
precision=precision,
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

if len(inputs) > d.num_operands - 1:
Expand Down Expand Up @@ -140,6 +150,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array:
algorithm=algorithm,
use_custom_primitive=use_custom_primitive,
use_custom_kernels=use_custom_kernels,
name=name,
)

# inputs of shape (..., ope.size) with identical ndim
Expand Down Expand Up @@ -187,8 +198,9 @@ def clean_inputs(
def tensor_product_prim(
*inputs: jax.Array, # input buffers
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
use_custom_primitive: bool = True,
**options,
) -> tuple[jax.Array, ...]: # output buffers
if exe.is_trivial:
Expand All @@ -198,7 +210,7 @@ def tensor_product_prim(

unique_inputs, exe = clean_inputs(list(inputs), exe)

if options.pop("use_custom_primitive", True):
if use_custom_primitive:
return tensor_product_p.bind(
*unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
Expand All @@ -208,30 +220,131 @@ def tensor_product_prim(
)


def produce_minimal_code(
*inputs: jax.Array,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
dtype_output: jnp.dtype,
dtype_math: jnp.dtype,
**unused_options,
) -> str:
def format_dtype(dtype: jnp.dtype) -> str:
dtype = jnp.dtype(dtype)
return f"jnp.{dtype.name}"

mincode = """import jax.numpy as jnp
import cuequivariance as cue
from cuequivariance.tensor_product_execution import InBuffer, OutBuffer
from cuequivariance_jax.primitives.tensor_product import tensor_product_prim
"""
mincode += (
"inputs = ["
+ ", ".join([f"jnp.zeros({x.shape}, {format_dtype(x.dtype)})" for x in inputs])
+ "]\n"
)
mincode += f"output_shapes = {output_shapes}\n"
mincode += f'd = cue.SegmentedTensorProduct.from_base64("{d.to_base64()}")\n'
mincode += f"exe = cue.TensorProductExecution({exe.computations})\n"
mincode += f"dtype_output = {format_dtype(dtype_output)}\n"
mincode += f"dtype_math = {format_dtype(dtype_math)}\n"
# tensor_product_prim(
# *inputs,
# output_shapes=output_shapes,
# d=d,
# exe=exe,
# dtype_output=dtype_output,
# dtype_math=dtype_math,
# use_custom_kernels=True,
# )
mincode += "# " + ", ".join([f"{k}={v}" for k, v in unused_options.items()])
return mincode


def profile_and_select_implementation(
name: str, impls: list[tuple[str, callable]], *inputs: jax.Array
):
# import time
# t0 = time.perf_counter()

with jax.ensure_compile_time_eval():
dummy_inputs = [np.random.normal(size=x.shape).astype(x.dtype) for x in inputs]
dummy_inputs = [jax.device_put(x) for x in dummy_inputs]
ref = None
first_runtime: float | None = None
best: tuple[str, float, callable] | None = None
for impl_name, impl in impls:
try:
out, runtime = profiler.measure(impl, mode="cupti")(*dummy_inputs)
except NotImplementedError:
continue
else:
if ref is None:
ref = out
best = impl_name, runtime, impl
first_runtime = runtime
else:
diff = max(
[
np.max(np.abs(a - b))
for a, b in zip(jax.tree.leaves(out), jax.tree.leaves(ref))
]
)
if diff > 1e-3:
raise ValueError(
f"cuex.tensor_product: {name} implementation {impl_name} produced different results, diff={diff}"
)
if runtime < best[1]:
best = impl_name, runtime, impl
assert best is not None
impl_name, runtime, impl = best

# dt = time.perf_counter() - t0
speedup = first_runtime / runtime
print(
f"{name:<50}: {impl_name:<10} with runtime {runtime:.2f} ms, speedup {speedup:.2f}x wrt {first_runtime:.2f} ms"
)

return impl(*inputs)


def tensor_product_impl(
platform: str | None,
*inputs: jax.Array,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
name: str = "tensor_product",
use_custom_kernels: bool | None = True,
**options,
) -> tuple[jax.Array, ...]:
assert exe.max_in_buffer + 1 == len(exe.in_buffers) == len(inputs)
assert exe.max_out_buffer + 1 == len(exe.out_buffers)
use_custom_kernels = options.pop("use_custom_kernels", True)

def dispatch(
inputs: list[jax.Array],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
) -> list[jax.Array]:
if platform == "cuda" and use_custom_kernels:
# TODO: call custom kernels here
pass
kwargs = dict(output_shapes=output_shapes, d=d, exe=exe, **options)

if platform == "cuda":
# print(produce_minimal_code(*inputs, **kwargs))
# print()
if use_custom_kernels is None:
outputs = profile_and_select_implementation(
name,
[
("vanilla", partial(tensor_product_vanilla_impl, **kwargs)),
("ops", partial(tensor_product_ops_impl, **kwargs)),
],
*inputs,
)
return outputs
if use_custom_kernels is True:
return tensor_product_ops_impl(*inputs, **kwargs)

return tensor_product_vanilla_impl(
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
return tensor_product_vanilla_impl(*inputs, **kwargs)

outputs = [0] * len(exe.out_buffers)
for partition, ex in exe.group_by_identical_buffers():
Expand All @@ -254,7 +367,7 @@ def dispatch(
def tensor_product_abstract_eval(
*inputs: jax.core.ShapedArray,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.core.ShapedArray, ...]:
Expand Down Expand Up @@ -285,7 +398,7 @@ def tensor_product_jvp(
tangents: tuple[jax.Array | ad.Zero, ...],
*,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]:
Expand All @@ -297,6 +410,8 @@ def tensor_product_jvp(
jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents])
del exe

options["name"] = options.get("name", "tensor_product") + "->jvp"

permutations: list[tuple[int, ...]] = d.symmetries()
for multiplicator, exe in jvp.group_by_symmetries(permutations):
# tensor_product_prim can remove unused inputs
Expand All @@ -318,7 +433,7 @@ def tensor_product_transpose(
cotangents: tuple[jax.Array | ad.Zero, ...],
*inputs: jax.Array | ad.UndefinedPrimal,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.Array | ad.Zero | None, ...]:
Expand All @@ -338,6 +453,8 @@ def tensor_product_transpose(
output_shapes[oid] = undefined_primal_shape
output_shapes = tuple(output_shapes)

options["name"] = options.get("name", "tensor_product") + "->transpose"

tr = exe.transpose(
[ad.is_undefined_primal(x) for x in inputs],
[not isinstance(x, ad.Zero) for x in cotangents],
Expand Down Expand Up @@ -368,7 +485,7 @@ def tensor_product_batching(
batch_axes: tuple[int | None, ...],
*,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]:
Expand All @@ -393,6 +510,8 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array:
assert new_output_shapes[oid] == expected
new_output_shapes = tuple(new_output_shapes)

options["name"] = options.get("name", "tensor_product") + "->batching"

outputs = tensor_product_prim(
*batched_inputs,
output_shapes=new_output_shapes,
Expand Down
Loading