Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
mariogeiger committed Jan 24, 2025
1 parent fe279c5 commit b5c7be8
Showing 2 changed files with 13 additions and 11 deletions.
6 changes: 4 additions & 2 deletions cuequivariance/cuequivariance/tensor_product_execution.py
Original file line number Diff line number Diff line change
@@ -24,11 +24,13 @@ class Buffer(int):


class InBuffer(Buffer):
pass
def __repr__(self):
return f"InBuffer({int(self)})"


class OutBuffer(Buffer):
pass
def __repr__(self):
return f"OutBuffer({int(self)})"


T = TypeVar("T")
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
import jax.numpy as jnp
from jax.interpreters import ad, batching, mlir, xla

from cuequivariance import segmented_tensor_product as stp
import cuequivariance as cue
from cuequivariance.tensor_product_execution import (
Computation,
InBuffer,
@@ -40,7 +40,7 @@


def tensor_product(
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
*inputs: jax.Array,
dtype_output: jnp.dtype | None = None,
dtype_math: jnp.dtype | None = None,
@@ -190,7 +190,7 @@ def clean_inputs(
def tensor_product_prim(
*inputs: jax.Array, # input buffers
output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.Array, ...]: # output buffers
@@ -215,7 +215,7 @@ def tensor_product_impl(
platform: str | None,
*inputs: jax.Array,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.Array, ...]:
@@ -225,7 +225,7 @@ def tensor_product_impl(

def dispatch(
inputs: list[jax.Array],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
) -> list[jax.Array]:
if platform == "cuda" and use_custom_kernels:
@@ -261,7 +261,7 @@ def dispatch(
def tensor_product_abstract_eval(
*inputs: jax.core.ShapedArray,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.core.ShapedArray, ...]:
@@ -292,7 +292,7 @@ def tensor_product_jvp(
tangents: tuple[jax.Array | ad.Zero, ...],
*,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]:
@@ -325,7 +325,7 @@ def tensor_product_transpose(
cotangents: tuple[jax.Array | ad.Zero, ...],
*inputs: jax.Array | ad.UndefinedPrimal,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[jax.Array | ad.Zero | None, ...]:
@@ -375,7 +375,7 @@ def tensor_product_batching(
batch_axes: tuple[int | None, ...],
*,
output_shapes: tuple[tuple[int, ...] | None, ...],
d: stp.SegmentedTensorProduct,
d: cue.SegmentedTensorProduct,
exe: TensorProductExecution,
**options,
) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]:

0 comments on commit b5c7be8

Please sign in to comment.