From 1b0af24ac2ba6e823025113b96886eb8603cc541 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 23 Jan 2025 14:33:39 -0800 Subject: [PATCH 1/7] Implement frontend to call backend JAX bindings --- .../tensor_product_execution.py | 16 +++ .../primitives/tensor_product.py | 11 +- .../primitives/tensor_product_ops_impl.py | 118 ++++++++++++++++++ 3 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py diff --git a/cuequivariance/cuequivariance/tensor_product_execution.py b/cuequivariance/cuequivariance/tensor_product_execution.py index 3564b4f..b8fcff8 100644 --- a/cuequivariance/cuequivariance/tensor_product_execution.py +++ b/cuequivariance/cuequivariance/tensor_product_execution.py @@ -165,6 +165,22 @@ def num_inputs_per_operand(self) -> tuple[int, ...]: def num_outputs_per_operand(self) -> tuple[int, ...]: return tuple(len(s) for s in self.out_buffers_per_operand) + def get_in_buffer_operands(self, buffer: int) -> set[int]: + return { + ope + for c in self.computations + for ope, b in enumerate(c) + if isinstance(b, InBuffer) and b == buffer + } + + def get_out_buffer_operands(self, buffer: int) -> set[int]: + return { + ope + for c in self.computations + for ope, b in enumerate(c) + if isinstance(b, OutBuffer) and b == buffer + } + def map_buffers( self, f_in: Optional[Callable[[int], int]], diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index 7dd2f37..7c77a3e 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -29,6 +29,9 @@ OutBuffer, TensorProductExecution, ) +from cuequivariance_jax.primitives.tensor_product_ops_impl import ( + tensor_product_ops_impl, +) from cuequivariance_jax.primitives.tensor_product_vanilla_impl import ( tensor_product_vanilla_impl, ) @@ -226,8 +229,12 @@ def dispatch( exe: TensorProductExecution, ) -> list[jax.Array]: if platform == "cuda" and use_custom_kernels: - # TODO: call custom kernels here - pass + try: + return tensor_product_ops_impl( + *inputs, output_shapes=output_shapes, d=d, exe=exe, **options + ) + except NotImplementedError as e: + logger.info(f"{e}. Falling back to JAX.") return tensor_product_vanilla_impl( *inputs, output_shapes=output_shapes, d=d, exe=exe, **options diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py new file mode 100644 index 0000000..11e77fe --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math + +import jax +import jax.numpy as jnp + +import cuequivariance as cue +from cuequivariance.tensor_product_execution import InBuffer + +logger = logging.getLogger(__name__) + + +def tensor_product_ops_impl( + *inputs: jax.Array, # input buffers + output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands + d: cue.SegmentedTensorProduct, + exe: cue.TensorProductExecution, + **options, +) -> tuple[jax.Array, ...]: # output buffers + assert exe.max_out_buffer + 1 == len(exe.out_buffers) + + if not d.all_same_segment_shape(): + raise NotImplementedError("Only supports operands with the same shape") + + try: + from cuequivariance_ops_jax import tensor_product_uniform_1d + except ImportError: + raise NotImplementedError("Cannot import cuequivariance_ops_jax") + + modes = d.subscripts.modes() + if len(modes) > 1: + raise NotImplementedError("cuequivariance_ops_jax only supports 1D modes") + + if len(modes) == 1: + dims: set[int] = d.get_dims(modes[0]) + if len(dims) != 1: + raise NotImplementedError( + "cuequivariance_ops_jax only supports uniform 1D modes" + ) + + batch_size = 1 + for shape in [input.shape[:-1] for input in inputs] + [ + shape for shape in output_shapes if shape is not None + ]: + n = math.prod(shape) + if n > 1: + if n != batch_size and batch_size != 1: + raise NotImplementedError( + "cuequivariance_ops_jax does not support broadcasting" + ) + batch_size = n + + reshaped_inputs = [] + for index, input in enumerate(inputs): + operands = { + (d.operands[op].size, d.operands[op].num_segments) + for op in exe.get_in_buffer_operands(index) + } + assert len(operands) == 1 + size, num_segments = operands.pop() + reshaped_inputs.append( + input.reshape( + (math.prod(input.shape[:-1]), num_segments, size // num_segments) + ) + ) + + output_operands = [] + outputs = [] + for index in exe.out_buffers: + operands = exe.get_out_buffer_operands(index) + assert len(operands) == 1 + ope = operands.pop() + size, num_segments = d.operands[ope].size, d.operands[ope].num_segments + + output_operands.append(ope) + outputs.append( + jnp.zeros( + (math.prod(output_shapes[ope]), num_segments, size // num_segments), + dtype=options["dtype_output"], + ) + ) + + logger.info("Executing tensor_product_uniform_1d") + + outputs = tensor_product_uniform_1d( + options["dtype_math"], + [ope.num_segments for ope in d.operands], + [path.indices for path in d.paths], + [path.coefficients.item() for path in d.paths], + reshaped_inputs, + outputs, + [ + tuple( + int(b) if isinstance(b, InBuffer) else -1 - int(b) for b in computation + ) + for computation in exe.computations + ], + ) + + outputs = [ + output.reshape(output_shapes[ope] + (-1,)) + for ope, output in zip(output_operands, outputs) + ] + return tuple(outputs) From fe279c547cee56b31fef84ce8d6bbb1976a6cc8b Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 24 Jan 2025 07:20:48 -0800 Subject: [PATCH 2/7] add logger info messages --- .../primitives/tensor_product.py | 4 +-- .../primitives/tensor_product_ops_impl.py | 31 ++++++++++++------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index 7c77a3e..d948fc0 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -233,8 +233,8 @@ def dispatch( return tensor_product_ops_impl( *inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) - except NotImplementedError as e: - logger.info(f"{e}. Falling back to JAX.") + except NotImplementedError: + pass return tensor_product_vanilla_impl( *inputs, output_shapes=output_shapes, d=d, exe=exe, **options diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py index 11e77fe..1636f1e 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py @@ -33,24 +33,28 @@ def tensor_product_ops_impl( ) -> tuple[jax.Array, ...]: # output buffers assert exe.max_out_buffer + 1 == len(exe.out_buffers) + detail_str = f"\n{d}\n{exe}".replace("\n", "\n | ") + if not d.all_same_segment_shape(): - raise NotImplementedError("Only supports operands with the same shape") + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() try: from cuequivariance_ops_jax import tensor_product_uniform_1d except ImportError: - raise NotImplementedError("Cannot import cuequivariance_ops_jax") + logger.info("🛶 can't import cuequivariance_ops_jax") + raise NotImplementedError() modes = d.subscripts.modes() if len(modes) > 1: - raise NotImplementedError("cuequivariance_ops_jax only supports 1D modes") + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() if len(modes) == 1: dims: set[int] = d.get_dims(modes[0]) if len(dims) != 1: - raise NotImplementedError( - "cuequivariance_ops_jax only supports uniform 1D modes" - ) + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() batch_size = 1 for shape in [input.shape[:-1] for input in inputs] + [ @@ -59,9 +63,8 @@ def tensor_product_ops_impl( n = math.prod(shape) if n > 1: if n != batch_size and batch_size != 1: - raise NotImplementedError( - "cuequivariance_ops_jax does not support broadcasting" - ) + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() batch_size = n reshaped_inputs = [] @@ -70,7 +73,9 @@ def tensor_product_ops_impl( (d.operands[op].size, d.operands[op].num_segments) for op in exe.get_in_buffer_operands(index) } - assert len(operands) == 1 + if len(operands) != 1: + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() size, num_segments = operands.pop() reshaped_inputs.append( input.reshape( @@ -82,7 +87,9 @@ def tensor_product_ops_impl( outputs = [] for index in exe.out_buffers: operands = exe.get_out_buffer_operands(index) - assert len(operands) == 1 + if len(operands) != 1: + logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str) + raise NotImplementedError() ope = operands.pop() size, num_segments = d.operands[ope].size, d.operands[ope].num_segments @@ -94,7 +101,7 @@ def tensor_product_ops_impl( ) ) - logger.info("Executing tensor_product_uniform_1d") + logger.info("🎉 use tensor_product_uniform_1d for" + detail_str) outputs = tensor_product_uniform_1d( options["dtype_math"], From b5c7be86098f39ac6f125828ba86cee13b71ea68 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 24 Jan 2025 08:03:33 -0800 Subject: [PATCH 3/7] rename --- .../cuequivariance/tensor_product_execution.py | 6 ++++-- .../primitives/tensor_product.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cuequivariance/cuequivariance/tensor_product_execution.py b/cuequivariance/cuequivariance/tensor_product_execution.py index b8fcff8..4a9c2d0 100644 --- a/cuequivariance/cuequivariance/tensor_product_execution.py +++ b/cuequivariance/cuequivariance/tensor_product_execution.py @@ -24,11 +24,13 @@ class Buffer(int): class InBuffer(Buffer): - pass + def __repr__(self): + return f"InBuffer({int(self)})" class OutBuffer(Buffer): - pass + def __repr__(self): + return f"OutBuffer({int(self)})" T = TypeVar("T") diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index d948fc0..4662538 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -22,7 +22,7 @@ import jax.numpy as jnp from jax.interpreters import ad, batching, mlir, xla -from cuequivariance import segmented_tensor_product as stp +import cuequivariance as cue from cuequivariance.tensor_product_execution import ( Computation, InBuffer, @@ -40,7 +40,7 @@ def tensor_product( - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, *inputs: jax.Array, dtype_output: jnp.dtype | None = None, dtype_math: jnp.dtype | None = None, @@ -190,7 +190,7 @@ def clean_inputs( def tensor_product_prim( *inputs: jax.Array, # input buffers output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array, ...]: # output buffers @@ -215,7 +215,7 @@ def tensor_product_impl( platform: str | None, *inputs: jax.Array, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array, ...]: @@ -225,7 +225,7 @@ def tensor_product_impl( def dispatch( inputs: list[jax.Array], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, ) -> list[jax.Array]: if platform == "cuda" and use_custom_kernels: @@ -261,7 +261,7 @@ def dispatch( def tensor_product_abstract_eval( *inputs: jax.core.ShapedArray, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.core.ShapedArray, ...]: @@ -292,7 +292,7 @@ def tensor_product_jvp( tangents: tuple[jax.Array | ad.Zero, ...], *, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]: @@ -325,7 +325,7 @@ def tensor_product_transpose( cotangents: tuple[jax.Array | ad.Zero, ...], *inputs: jax.Array | ad.UndefinedPrimal, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array | ad.Zero | None, ...]: @@ -375,7 +375,7 @@ def tensor_product_batching( batch_axes: tuple[int | None, ...], *, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]: From 761e1c546e9c13b84a40f1376868bb6d28392ba5 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 28 Jan 2025 03:18:04 -0800 Subject: [PATCH 4/7] Add optional 'name' parameter to tensor product functions and update usage --- .../primitives/equivariant_tensor_product.py | 5 ++++- .../primitives/symmetric_tensor_product.py | 8 +++++++- .../primitives/tensor_product.py | 14 +++++++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py index 2578e74..3dabe06 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py @@ -28,7 +28,8 @@ def equivariant_tensor_product( precision: jax.lax.Precision = jax.lax.Precision.HIGHEST, algorithm: str = "sliced", use_custom_primitive: bool = True, - use_custom_kernels: bool = False, + use_custom_kernels: bool | None = False, + name: str | None = None, ) -> cuex.RepArray: """Compute the equivariant tensor product of the input arrays. @@ -78,6 +79,7 @@ def equivariant_tensor_product( algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) if len(inputs) != e.num_inputs: @@ -113,6 +115,7 @@ def equivariant_tensor_product( algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) return cuex.RepArray(e.output, x) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/symmetric_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/symmetric_tensor_product.py index 088fc61..fad5d5d 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/symmetric_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/symmetric_tensor_product.py @@ -31,7 +31,8 @@ def symmetric_tensor_product( precision: jax.lax.Precision = jax.lax.Precision.HIGHEST, algorithm: str = "sliced", use_custom_primitive: bool = True, - use_custom_kernels: bool = False, + use_custom_kernels: bool | None = False, + name: str | None = None, ) -> jax.Array: """ Compute the sum of the STPs evaluated on the input (all input operands are the same). @@ -54,6 +55,9 @@ def symmetric_tensor_product( """ assert any(d.num_operands >= 2 for d in ds) + if name is None: + name = "symmetric_tensor_product" + # currying if len(inputs) == 0: @@ -67,6 +71,7 @@ def fn(*inputs) -> jax.Array: algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) return fn @@ -136,6 +141,7 @@ def fn(*inputs) -> jax.Array: algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name + f"_{n_in - n_un}", ) return output diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index 4662538..62dfe52 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -47,7 +47,8 @@ def tensor_product( precision: jax.lax.Precision = jax.lax.Precision.HIGHEST, algorithm: str = "sliced", use_custom_primitive: bool = True, - use_custom_kernels: bool = False, + use_custom_kernels: bool | None = False, + name: str | None = None, ) -> jax.Array: """ Compute the last operand of a `SegmentedTensorProduct`. @@ -88,6 +89,9 @@ def tensor_product( if isinstance(precision, str): precision = jax.lax.Precision[precision] + if name is None: + name = "tensor_product" + options = dict( dtype_output=dtype_output, dtype_math=dtype_math, @@ -95,6 +99,7 @@ def tensor_product( algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) if len(inputs) > d.num_operands - 1: @@ -143,6 +148,7 @@ def _partial(*remaining_inputs: jax.Array) -> jax.Array: algorithm=algorithm, use_custom_primitive=use_custom_primitive, use_custom_kernels=use_custom_kernels, + name=name, ) # inputs of shape (..., ope.size) with identical ndim @@ -304,6 +310,8 @@ def tensor_product_jvp( jvp = exe.jvp([not isinstance(t, ad.Zero) for t in tangents]) del exe + options["name"] = options.get("name", "tensor_product") + "->jvp" + permutations: list[tuple[int, ...]] = d.symmetries() for multiplicator, exe in jvp.group_by_symmetries(permutations): # tensor_product_prim can remove unused inputs @@ -345,6 +353,8 @@ def tensor_product_transpose( output_shapes[oid] = undefined_primal_shape output_shapes = tuple(output_shapes) + options["name"] = options.get("name", "tensor_product") + "->transpose" + tr = exe.transpose( [ad.is_undefined_primal(x) for x in inputs], [not isinstance(x, ad.Zero) for x in cotangents], @@ -400,6 +410,8 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array: assert new_output_shapes[oid] == expected new_output_shapes = tuple(new_output_shapes) + options["name"] = options.get("name", "tensor_product") + "->batching" + outputs = tensor_product_prim( *batched_inputs, output_shapes=new_output_shapes, From 06ffd6e398674c82850a84414d7afa5430c38830 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 29 Jan 2025 06:34:30 -0800 Subject: [PATCH 5/7] profile_and_select_implementation --- .../primitives/tensor_product.py | 115 +++++++++++++++++- 1 file changed, 109 insertions(+), 6 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index 62dfe52..78bed78 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -20,6 +20,8 @@ import jax.extend import jax.lax import jax.numpy as jnp +import numpy as np +from jax.experimental.mosaic.gpu import profiler from jax.interpreters import ad, batching, mlir, xla import cuequivariance as cue @@ -198,6 +200,7 @@ def tensor_product_prim( output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands d: cue.SegmentedTensorProduct, exe: TensorProductExecution, + use_custom_primitive: bool = True, **options, ) -> tuple[jax.Array, ...]: # output buffers if exe.is_trivial: @@ -207,7 +210,7 @@ def tensor_product_prim( unique_inputs, exe = clean_inputs(list(inputs), exe) - if options.pop("use_custom_primitive", True): + if use_custom_primitive: return tensor_product_p.bind( *unique_inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) @@ -217,30 +220,130 @@ def tensor_product_prim( ) +def produce_minimal_code( + *inputs: jax.Array, + output_shapes: tuple[tuple[int, ...] | None, ...], + d: cue.SegmentedTensorProduct, + exe: TensorProductExecution, + dtype_output: jnp.dtype, + dtype_math: jnp.dtype, + **unused_options, +) -> str: + def format_dtype(dtype: jnp.dtype) -> str: + dtype = jnp.dtype(dtype) + return f"jnp.{dtype.name}" + + mincode = """import jax.numpy as jnp +import cuequivariance as cue +from cuequivariance.tensor_product_execution import InBuffer, OutBuffer +from cuequivariance_jax.primitives.tensor_product import tensor_product_prim +""" + mincode += ( + "inputs = [" + + ", ".join([f"jnp.zeros({x.shape}, {format_dtype(x.dtype)})" for x in inputs]) + + "]\n" + ) + mincode += f"output_shapes = {output_shapes}\n" + mincode += f'd = cue.SegmentedTensorProduct.from_base64("{d.to_base64()}")\n' + mincode += f"exe = cue.TensorProductExecution({exe.computations})\n" + mincode += f"dtype_output = {format_dtype(dtype_output)}\n" + mincode += f"dtype_math = {format_dtype(dtype_math)}\n" + # tensor_product_prim( + # *inputs, + # output_shapes=output_shapes, + # d=d, + # exe=exe, + # dtype_output=dtype_output, + # dtype_math=dtype_math, + # use_custom_kernels=True, + # ) + mincode += "# " + ", ".join([f"{k}={v}" for k, v in unused_options.items()]) + return mincode + + +def profile_and_select_implementation( + name: str, impls: list[tuple[str, callable]], *inputs: jax.Array +): + # import time + # t0 = time.perf_counter() + + with jax.ensure_compile_time_eval(): + dummy_inputs = [np.random.normal(size=x.shape).astype(x.dtype) for x in inputs] + dummy_inputs = [jax.device_put(x) for x in dummy_inputs] + ref = None + # first_runtime: float | None = None + best: tuple[str, float, callable] | None = None + for impl_name, impl in impls: + try: + out, runtime = profiler.measure(impl, mode="cupti")(*dummy_inputs) + except NotImplementedError: + continue + else: + if ref is None: + ref = out + best = impl_name, runtime, impl + # first_runtime = runtime + else: + diff = max( + [ + np.max(np.abs(a - b)) + for a, b in zip(jax.tree.leaves(out), jax.tree.leaves(ref)) + ] + ) + if diff > 1e-3: + raise ValueError( + f"cuex.tensor_product: {name} implementation {impl_name} produced different results, diff={diff}" + ) + if runtime < best[1]: + best = impl_name, runtime, impl + assert best is not None + impl_name, runtime, impl = best + + # dt = time.perf_counter() - t0 + # speedup = first_runtime / runtime + # print( + # f"cuex.tensor_product: {name}: selected {impl_name} with runtime {runtime:.2f} ms, speedup {speedup:.2f}x, (profiled in {dt:.1f} s)" + # ) + + return impl(*inputs) + + def tensor_product_impl( platform: str | None, *inputs: jax.Array, output_shapes: tuple[tuple[int, ...] | None, ...], d: cue.SegmentedTensorProduct, exe: TensorProductExecution, + name: str = "tensor_product", + use_custom_kernels: bool | None = True, **options, ) -> tuple[jax.Array, ...]: assert exe.max_in_buffer + 1 == len(exe.in_buffers) == len(inputs) assert exe.max_out_buffer + 1 == len(exe.out_buffers) - use_custom_kernels = options.pop("use_custom_kernels", True) def dispatch( inputs: list[jax.Array], d: cue.SegmentedTensorProduct, exe: TensorProductExecution, ) -> list[jax.Array]: - if platform == "cuda" and use_custom_kernels: - try: + if platform == "cuda": + if use_custom_kernels is None: + kwargs = dict(output_shapes=output_shapes, d=d, exe=exe, **options) + outputs = profile_and_select_implementation( + name, + [ + ("vanilla", partial(tensor_product_vanilla_impl, **kwargs)), + ("ops", partial(tensor_product_ops_impl, **kwargs)), + ], + *inputs, + ) + # print(produce_minimal_code(*inputs, **kwargs)) + # print() + return outputs + if use_custom_kernels is True: return tensor_product_ops_impl( *inputs, output_shapes=output_shapes, d=d, exe=exe, **options ) - except NotImplementedError: - pass return tensor_product_vanilla_impl( *inputs, output_shapes=output_shapes, d=d, exe=exe, **options From a9d1114ef8ea1a22dfa26272acfef8aa81aa5f72 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 30 Jan 2025 12:33:00 -0800 Subject: [PATCH 6/7] Enhance RepArray indexing to support slicing in addition to jax.Array --- .../rep_array/jax_rep_array.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py index 4bc6a09..9cee2d6 100644 --- a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py @@ -223,13 +223,18 @@ def __getitem__(self, key: Any) -> RepArray: self.array[None], ) - # self[jnp.array([0, 1, 2])] - assert isinstance(key, jax.Array) assert 0 not in self.reps - return RepArray( - {k + key.ndim - 1: irreps for k, irreps in self.reps.items()}, - self.array[key], - ) + + # self[1:4] + if isinstance(key, slice): + return RepArray(self.reps, self.array[key]) + + # self[jnp.array([0, 1, 2])] + if isinstance(key, jax.Array): + return RepArray( + {k + key.ndim - 1: irreps for k, irreps in self.reps.items()}, + self.array[key], + ) @property def slice_by_mul(self) -> _MulIndexSliceHelper: From 5eac58878c345af13cb050f5d51bc307f7f0ef6c Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 30 Jan 2025 12:34:13 -0800 Subject: [PATCH 7/7] print dbg --- .../primitives/tensor_product.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index 78bed78..c4845d3 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -271,7 +271,7 @@ def profile_and_select_implementation( dummy_inputs = [np.random.normal(size=x.shape).astype(x.dtype) for x in inputs] dummy_inputs = [jax.device_put(x) for x in dummy_inputs] ref = None - # first_runtime: float | None = None + first_runtime: float | None = None best: tuple[str, float, callable] | None = None for impl_name, impl in impls: try: @@ -282,7 +282,7 @@ def profile_and_select_implementation( if ref is None: ref = out best = impl_name, runtime, impl - # first_runtime = runtime + first_runtime = runtime else: diff = max( [ @@ -300,10 +300,10 @@ def profile_and_select_implementation( impl_name, runtime, impl = best # dt = time.perf_counter() - t0 - # speedup = first_runtime / runtime - # print( - # f"cuex.tensor_product: {name}: selected {impl_name} with runtime {runtime:.2f} ms, speedup {speedup:.2f}x, (profiled in {dt:.1f} s)" - # ) + speedup = first_runtime / runtime + print( + f"{name:<50}: {impl_name:<10} with runtime {runtime:.2f} ms, speedup {speedup:.2f}x wrt {first_runtime:.2f} ms" + ) return impl(*inputs) @@ -326,9 +326,12 @@ def dispatch( d: cue.SegmentedTensorProduct, exe: TensorProductExecution, ) -> list[jax.Array]: + kwargs = dict(output_shapes=output_shapes, d=d, exe=exe, **options) + if platform == "cuda": + # print(produce_minimal_code(*inputs, **kwargs)) + # print() if use_custom_kernels is None: - kwargs = dict(output_shapes=output_shapes, d=d, exe=exe, **options) outputs = profile_and_select_implementation( name, [ @@ -337,17 +340,11 @@ def dispatch( ], *inputs, ) - # print(produce_minimal_code(*inputs, **kwargs)) - # print() return outputs if use_custom_kernels is True: - return tensor_product_ops_impl( - *inputs, output_shapes=output_shapes, d=d, exe=exe, **options - ) + return tensor_product_ops_impl(*inputs, **kwargs) - return tensor_product_vanilla_impl( - *inputs, output_shapes=output_shapes, d=d, exe=exe, **options - ) + return tensor_product_vanilla_impl(*inputs, **kwargs) outputs = [0] * len(exe.out_buffers) for partition, ex in exe.group_by_identical_buffers():