diff --git a/cuequivariance/cuequivariance/tensor_product_execution.py b/cuequivariance/cuequivariance/tensor_product_execution.py index 3564b4f..4a9c2d0 100644 --- a/cuequivariance/cuequivariance/tensor_product_execution.py +++ b/cuequivariance/cuequivariance/tensor_product_execution.py @@ -24,11 +24,13 @@ class Buffer(int): class InBuffer(Buffer): - pass + def __repr__(self): + return f"InBuffer({int(self)})" class OutBuffer(Buffer): - pass + def __repr__(self): + return f"OutBuffer({int(self)})" T = TypeVar("T") @@ -165,6 +167,22 @@ def num_inputs_per_operand(self) -> tuple[int, ...]: def num_outputs_per_operand(self) -> tuple[int, ...]: return tuple(len(s) for s in self.out_buffers_per_operand) + def get_in_buffer_operands(self, buffer: int) -> set[int]: + return { + ope + for c in self.computations + for ope, b in enumerate(c) + if isinstance(b, InBuffer) and b == buffer + } + + def get_out_buffer_operands(self, buffer: int) -> set[int]: + return { + ope + for c in self.computations + for ope, b in enumerate(c) + if isinstance(b, OutBuffer) and b == buffer + } + def map_buffers( self, f_in: Optional[Callable[[int], int]], diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py index 2578e74..3dabe06 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py @@ -28,7 +28,8 @@ def equivariant_tensor_product( precision: jax.lax.Precision = jax.lax.Precision.HIGHEST, algorithm: str = "sliced", use_custom_primitive: bool = True, - use_custom_kernels: bool = False, + use_custom_kernels: bool | None = False, + name: str | None = None, ) -> cuex.RepArray: """Compute the equivariant tensor product of the input arrays. @@ -78,6 +79,7 @@ def equivariant_tensor_product( algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) if len(inputs) != e.num_inputs: @@ -113,6 +115,7 @@ def equivariant_tensor_product( algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) return cuex.RepArray(e.output, x) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/symmetric_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/symmetric_tensor_product.py index 088fc61..fad5d5d 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/symmetric_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/symmetric_tensor_product.py @@ -31,7 +31,8 @@ def symmetric_tensor_product( precision: jax.lax.Precision = jax.lax.Precision.HIGHEST, algorithm: str = "sliced", use_custom_primitive: bool = True, - use_custom_kernels: bool = False, + use_custom_kernels: bool | None = False, + name: str | None = None, ) -> jax.Array: """ Compute the sum of the STPs evaluated on the input (all input operands are the same). @@ -54,6 +55,9 @@ def symmetric_tensor_product( """ assert any(d.num_operands >= 2 for d in ds) + if name is None: + name = "symmetric_tensor_product" + # currying if len(inputs) == 0: @@ -67,6 +71,7 @@ def fn(*inputs) -> jax.Array: algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) return fn @@ -136,6 +141,7 @@ def fn(*inputs) -> jax.Array: algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name + f"_{n_in - n_un}", ) return output diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index 7dd2f37..c4845d3 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -20,15 +20,20 @@ import jax.extend import jax.lax import jax.numpy as jnp +import numpy as np +from jax.experimental.mosaic.gpu import profiler from jax.interpreters import ad, batching, mlir, xla -from cuequivariance import segmented_tensor_product as stp +import cuequivariance as cue from cuequivariance.tensor_product_execution import ( Computation, InBuffer, OutBuffer, TensorProductExecution, ) +from cuequivariance_jax.primitives.tensor_product_ops_impl import ( + tensor_product_ops_impl, +) from cuequivariance_jax.primitives.tensor_product_vanilla_impl import ( tensor_product_vanilla_impl, ) @@ -37,14 +42,15 @@ def tensor_product( - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, *inputs: jax.Array, dtype_output: jnp.dtype | None = None, dtype_math: jnp.dtype | None = None, precision: jax.lax.Precision = jax.lax.Precision.HIGHEST, algorithm: str = "sliced", use_custom_primitive: bool = True, - use_custom_kernels: bool = False, + use_custom_kernels: bool | None = False, + name: str | None = None, ) -> jax.Array: """ Compute the last operand of a `SegmentedTensorProduct`. @@ -85,6 +91,9 @@ def tensor_product( if isinstance(precision, str): precision = jax.lax.Precision[precision] + if name is None: + name = "tensor_product" + options = dict( dtype_output=dtype_output, dtype_math=dtype_math, @@ -92,6 +101,7 @@ def tensor_product( algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) if len(inputs) > d.num_operands - 1: @@ -140,6 +150,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array: algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) # inputs of shape (..., ope.size) with identical ndim @@ -187,8 +198,9 @@ def clean_inputs( def tensor_product_prim( *inputs: jax.Array, # input buffers output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, + use_custom_primitive: bool = True, **options, ) -> tuple[jax.Array, ...]: # output buffers if exe.is_trivial: @@ -198,7 +210,7 @@ def tensor_product_prim( unique_inputs, exe = clean_inputs(list(inputs), exe) - if options.pop("use_custom_primitive", True): + if use_custom_primitive: return tensor_product_p.bind( *unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) @@ -208,30 +220,131 @@ def tensor_product_prim( ) +def produce_minimal_code( + *inputs: jax.Array, + output_shapes: tuple[tuple[int, ...] | None, ...], + d: cue.SegmentedTensorProduct, + exe: TensorProductExecution, + dtype_output: jnp.dtype, + dtype_math: jnp.dtype, + **unused_options, +) -> str: + def format_dtype(dtype: jnp.dtype) -> str: + dtype = jnp.dtype(dtype) + return f"jnp.{dtype.name}" + + mincode = """import jax.numpy as jnp +import cuequivariance as cue +from cuequivariance.tensor_product_execution import InBuffer, OutBuffer +from cuequivariance_jax.primitives.tensor_product import tensor_product_prim +""" + mincode += ( + "inputs = [" + + ", ".join([f"jnp.zeros({x.shape}, {format_dtype(x.dtype)})" for x in inputs]) + + "]\n" + ) + mincode += f"output_shapes = {output_shapes}\n" + mincode += f'd = cue.SegmentedTensorProduct.from_base64("{d.to_base64()}")\n' + mincode += f"exe = cue.TensorProductExecution({exe.computations})\n" + mincode += f"dtype_output = {format_dtype(dtype_output)}\n" + mincode += f"dtype_math = {format_dtype(dtype_math)}\n" + # tensor_product_prim( + # *inputs, + # output_shapes=output_shapes, + # d=d, + # exe=exe, + # dtype_output=dtype_output, + # dtype_math=dtype_math, + # use_custom_kernels=True, + # ) + mincode += "# " + ", ".join([f"{k}={v}" for k, v in unused_options.items()]) + return mincode + + +def profile_and_select_implementation( + name: str, impls: list[tuple[str, callable]], *inputs: jax.Array +): + # import time + # t0 = time.perf_counter() + + with jax.ensure_compile_time_eval(): + dummy_inputs = [np.random.normal(size=x.shape).astype(x.dtype) for x in inputs] + dummy_inputs = [jax.device_put(x) for x in dummy_inputs] + ref = None + first_runtime: float | None = None + best: tuple[str, float, callable] | None = None + for impl_name, impl in impls: + try: + out, runtime = profiler.measure(impl, mode="cupti")(*dummy_inputs) + except NotImplementedError: + continue + else: + if ref is None: + ref = out + best = impl_name, runtime, impl + first_runtime = runtime + else: + diff = max( + [ + np.max(np.abs(a - b)) + for a, b in zip(jax.tree.leaves(out), jax.tree.leaves(ref)) + ] + ) + if diff > 1e-3: + raise ValueError( + f"cuex.tensor_product: {name} implementation {impl_name} produced different results, diff={diff}" + ) + if runtime < best[1]: + best = impl_name, runtime, impl + assert best is not None + impl_name, runtime, impl = best + + # dt = time.perf_counter() - t0 + speedup = first_runtime / runtime + print( + f"{name:<50}: {impl_name:<10} with runtime {runtime:.2f} ms, speedup {speedup:.2f}x wrt {first_runtime:.2f} ms" + ) + + return impl(*inputs) + + def tensor_product_impl( platform: str | None, *inputs: jax.Array, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, + name: str = "tensor_product", + use_custom_kernels: bool | None = True, **options, ) -> tuple[jax.Array, ...]: assert exe.max_in_buffer + 1 == len(exe.in_buffers) == len(inputs) assert exe.max_out_buffer + 1 == len(exe.out_buffers) - use_custom_kernels = options.pop("use_custom_kernels", True) def dispatch( inputs: list[jax.Array], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, ) -> list[jax.Array]: - if platform == "cuda" and use_custom_kernels: - # TODO: call custom kernels here - pass + kwargs = dict(output_shapes=output_shapes, d=d, exe=exe, **options) + + if platform == "cuda": + # print(produce_minimal_code(*inputs, **kwargs)) + # print() + if use_custom_kernels is None: + outputs = profile_and_select_implementation( + name, + [ + ("vanilla", partial(tensor_product_vanilla_impl, **kwargs)), + ("ops", partial(tensor_product_ops_impl, **kwargs)), + ], + *inputs, + ) + return outputs + if use_custom_kernels is True: + return tensor_product_ops_impl(*inputs, **kwargs) - return tensor_product_vanilla_impl( - *inputs, output_shapes=output_shapes, d=d, exe=exe, **options - ) + return tensor_product_vanilla_impl(*inputs, **kwargs) outputs = [0] * len(exe.out_buffers) for partition, ex in exe.group_by_identical_buffers(): @@ -254,7 +367,7 @@ def dispatch( def tensor_product_abstract_eval( *inputs: jax.core.ShapedArray, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.core.ShapedArray, ...]: @@ -285,7 +398,7 @@ def tensor_product_jvp( tangents: tuple[jax.Array | ad.Zero, ...], *, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]: @@ -297,6 +410,8 @@ def tensor_product_jvp( jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents]) del exe + options["name"] = options.get("name", "tensor_product") + "->jvp" + permutations: list[tuple[int, ...]] = d.symmetries() for multiplicator, exe in jvp.group_by_symmetries(permutations): # tensor_product_prim can remove unused inputs @@ -318,7 +433,7 @@ def tensor_product_transpose( cotangents: tuple[jax.Array | ad.Zero, ...], *inputs: jax.Array | ad.UndefinedPrimal, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array | ad.Zero | None, ...]: @@ -338,6 +453,8 @@ def tensor_product_transpose( output_shapes[oid] = undefined_primal_shape output_shapes = tuple(output_shapes) + options["name"] = options.get("name", "tensor_product") + "->transpose" + tr = exe.transpose( [ad.is_undefined_primal(x) for x in inputs], [not isinstance(x, ad.Zero) for x in cotangents], @@ -368,7 +485,7 @@ def tensor_product_batching( batch_axes: tuple[int | None, ...], *, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]: @@ -393,6 +510,8 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array: assert new_output_shapes[oid] == expected new_output_shapes = tuple(new_output_shapes) + options["name"] = options.get("name", "tensor_product") + "->batching" + outputs = tensor_product_prim( *batched_inputs, output_shapes=new_output_shapes, diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py new file mode 100644 index 0000000..1636f1e --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math + +import jax +import jax.numpy as jnp + +import cuequivariance as cue +from cuequivariance.tensor_product_execution import InBuffer + +logger = logging.getLogger(__name__) + + +def tensor_product_ops_impl( + *inputs: jax.Array, # input buffers + output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands + d: cue.SegmentedTensorProduct, + exe: cue.TensorProductExecution, + **options, +) -> tuple[jax.Array, ...]: # output buffers + assert exe.max_out_buffer + 1 == len(exe.out_buffers) + + detail_str = f"\n{d}\n{exe}".replace("\n", "\n | ") + + if not d.all_same_segment_shape(): + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() + + try: + from cuequivariance_ops_jax import tensor_product_uniform_1d + except ImportError: + logger.info("🛶 can't import cuequivariance_ops_jax") + raise NotImplementedError() + + modes = d.subscripts.modes() + if len(modes) > 1: + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() + + if len(modes) == 1: + dims: set[int] = d.get_dims(modes[0]) + if len(dims) != 1: + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() + + batch_size = 1 + for shape in [input.shape[:-1] for input in inputs] + [ + shape for shape in output_shapes if shape is not None + ]: + n = math.prod(shape) + if n > 1: + if n != batch_size and batch_size != 1: + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() + batch_size = n + + reshaped_inputs = [] + for index, input in enumerate(inputs): + operands = { + (d.operands[op].size, d.operands[op].num_segments) + for op in exe.get_in_buffer_operands(index) + } + if len(operands) != 1: + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() + size, num_segments = operands.pop() + reshaped_inputs.append( + input.reshape( + (math.prod(input.shape[:-1]), num_segments, size // num_segments) + ) + ) + + output_operands = [] + outputs = [] + for index in exe.out_buffers: + operands = exe.get_out_buffer_operands(index) + if len(operands) != 1: + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() + ope = operands.pop() + size, num_segments = d.operands[ope].size, d.operands[ope].num_segments + + output_operands.append(ope) + outputs.append( + jnp.zeros( + (math.prod(output_shapes[ope]), num_segments, size // num_segments), + dtype=options["dtype_output"], + ) + ) + + logger.info("🎉 use tensor_product_uniform_1d for" + detail_str) + + outputs = tensor_product_uniform_1d( + options["dtype_math"], + [ope.num_segments for ope in d.operands], + [path.indices for path in d.paths], + [path.coefficients.item() for path in d.paths], + reshaped_inputs, + outputs, + [ + tuple( + int(b) if isinstance(b, InBuffer) else -1 - int(b) for b in computation + ) + for computation in exe.computations + ], + ) + + outputs = [ + output.reshape(output_shapes[ope] + (-1,)) + for ope, output in zip(output_operands, outputs) + ] + return tuple(outputs) diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py index 4bc6a09..9cee2d6 100644 --- a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py @@ -223,13 +223,18 @@ def __getitem__(self, key: Any) -> RepArray: self.array[None], ) - # self[jnp.array([0, 1, 2])] - assert isinstance(key, jax.Array) assert 0 not in self.reps - return RepArray( - {k + key.ndim - 1: irreps for k, irreps in self.reps.items()}, - self.array[key], - ) + + # self[1:4] + if isinstance(key, slice): + return RepArray(self.reps, self.array[key]) + + # self[jnp.array([0, 1, 2])] + if isinstance(key, jax.Array): + return RepArray( + {k + key.ndim - 1: irreps for k, irreps in self.reps.items()}, + self.array[key], + ) @property def slice_by_mul(self) -> _MulIndexSliceHelper: