Skip to content

Commit

Permalink
Implement frontend to call backend JAX bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jan 23, 2025
1 parent bcc2847 commit 1b0af24
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 2 deletions.
16 changes: 16 additions & 0 deletions cuequivariance/cuequivariance/tensor_product_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,22 @@ def num_inputs_per_operand(self) -> tuple[int, ...]:
def num_outputs_per_operand(self) -> tuple[int, ...]:
return tuple(len(s) for s in self.out_buffers_per_operand)

def get_in_buffer_operands(self, buffer: int) -> set[int]:
return {
ope
for c in self.computations
for ope, b in enumerate(c)
if isinstance(b, InBuffer) and b == buffer
}

def get_out_buffer_operands(self, buffer: int) -> set[int]:
return {
ope
for c in self.computations
for ope, b in enumerate(c)
if isinstance(b, OutBuffer) and b == buffer
}

def map_buffers(
self,
f_in: Optional[Callable[[int], int]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
OutBuffer,
TensorProductExecution,
)
from cuequivariance_jax.primitives.tensor_product_ops_impl import (
tensor_product_ops_impl,
)
from cuequivariance_jax.primitives.tensor_product_vanilla_impl import (
tensor_product_vanilla_impl,
)
Expand Down Expand Up @@ -226,8 +229,12 @@ def dispatch(
exe: TensorProductExecution,
) -> list[jax.Array]:
if platform == "cuda" and use_custom_kernels:
# TODO: call custom kernels here
pass
try:
return tensor_product_ops_impl(
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
except NotImplementedError as e:
logger.info(f"{e}. Falling back to JAX.")

return tensor_product_vanilla_impl(
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math

import jax
import jax.numpy as jnp

import cuequivariance as cue
from cuequivariance.tensor_product_execution import InBuffer

logger = logging.getLogger(__name__)


def tensor_product_ops_impl(
*inputs: jax.Array, # input buffers
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: cue.SegmentedTensorProduct,
exe: cue.TensorProductExecution,
**options,
) -> tuple[jax.Array, ...]: # output buffers
assert exe.max_out_buffer + 1 == len(exe.out_buffers)

if not d.all_same_segment_shape():
raise NotImplementedError("Only supports operands with the same shape")

try:
from cuequivariance_ops_jax import tensor_product_uniform_1d
except ImportError:
raise NotImplementedError("Cannot import cuequivariance_ops_jax")

modes = d.subscripts.modes()
if len(modes) > 1:
raise NotImplementedError("cuequivariance_ops_jax only supports 1D modes")

if len(modes) == 1:
dims: set[int] = d.get_dims(modes[0])
if len(dims) != 1:
raise NotImplementedError(
"cuequivariance_ops_jax only supports uniform 1D modes"
)

batch_size = 1
for shape in [input.shape[:-1] for input in inputs] + [
shape for shape in output_shapes if shape is not None
]:
n = math.prod(shape)
if n > 1:
if n != batch_size and batch_size != 1:
raise NotImplementedError(
"cuequivariance_ops_jax does not support broadcasting"
)
batch_size = n

reshaped_inputs = []
for index, input in enumerate(inputs):
operands = {
(d.operands[op].size, d.operands[op].num_segments)
for op in exe.get_in_buffer_operands(index)
}
assert len(operands) == 1
size, num_segments = operands.pop()
reshaped_inputs.append(
input.reshape(
(math.prod(input.shape[:-1]), num_segments, size // num_segments)
)
)

output_operands = []
outputs = []
for index in exe.out_buffers:
operands = exe.get_out_buffer_operands(index)
assert len(operands) == 1
ope = operands.pop()
size, num_segments = d.operands[ope].size, d.operands[ope].num_segments

output_operands.append(ope)
outputs.append(
jnp.zeros(
(math.prod(output_shapes[ope]), num_segments, size // num_segments),
dtype=options["dtype_output"],
)
)

logger.info("Executing tensor_product_uniform_1d")

outputs = tensor_product_uniform_1d(
options["dtype_math"],
[ope.num_segments for ope in d.operands],
[path.indices for path in d.paths],
[path.coefficients.item() for path in d.paths],
reshaped_inputs,
outputs,
[
tuple(
int(b) if isinstance(b, InBuffer) else -1 - int(b) for b in computation
)
for computation in exe.computations
],
)

outputs = [
output.reshape(output_shapes[ope] + (-1,))
for ope, output in zip(output_operands, outputs)
]
return tuple(outputs)

0 comments on commit 1b0af24

Please sign in to comment.