From b5c7be86098f39ac6f125828ba86cee13b71ea68 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 24 Jan 2025 08:03:33 -0800 Subject: [PATCH] rename --- .../cuequivariance/tensor_product_execution.py | 6 ++++-- .../primitives/tensor_product.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cuequivariance/cuequivariance/tensor_product_execution.py b/cuequivariance/cuequivariance/tensor_product_execution.py index b8fcff8..4a9c2d0 100644 --- a/cuequivariance/cuequivariance/tensor_product_execution.py +++ b/cuequivariance/cuequivariance/tensor_product_execution.py @@ -24,11 +24,13 @@ class Buffer(int): class InBuffer(Buffer): - pass + def __repr__(self): + return f"InBuffer({int(self)})" class OutBuffer(Buffer): - pass + def __repr__(self): + return f"OutBuffer({int(self)})" T = TypeVar("T") diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index d948fc0..4662538 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -22,7 +22,7 @@ import jax.numpy as jnp from jax.interpreters import ad, batching, mlir, xla -from cuequivariance import segmented_tensor_product as stp +import cuequivariance as cue from cuequivariance.tensor_product_execution import ( Computation, InBuffer, @@ -40,7 +40,7 @@ def tensor_product( - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, *inputs: jax.Array, dtype_output: jnp.dtype | None = None, dtype_math: jnp.dtype | None = None, @@ -190,7 +190,7 @@ def clean_inputs( def tensor_product_prim( *inputs: jax.Array, # input buffers output_shapes: tuple[tuple[int, ...] | None, ...], # shapes of the operands - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array, ...]: # output buffers @@ -215,7 +215,7 @@ def tensor_product_impl( platform: str | None, *inputs: jax.Array, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array, ...]: @@ -225,7 +225,7 @@ def tensor_product_impl( def dispatch( inputs: list[jax.Array], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, ) -> list[jax.Array]: if platform == "cuda" and use_custom_kernels: @@ -261,7 +261,7 @@ def dispatch( def tensor_product_abstract_eval( *inputs: jax.core.ShapedArray, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.core.ShapedArray, ...]: @@ -292,7 +292,7 @@ def tensor_product_jvp( tangents: tuple[jax.Array | ad.Zero, ...], *, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]: @@ -325,7 +325,7 @@ def tensor_product_transpose( cotangents: tuple[jax.Array | ad.Zero, ...], *inputs: jax.Array | ad.UndefinedPrimal, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[jax.Array | ad.Zero | None, ...]: @@ -375,7 +375,7 @@ def tensor_product_batching( batch_axes: tuple[int | None, ...], *, output_shapes: tuple[tuple[int, ...] | None, ...], - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, exe: TensorProductExecution, **options, ) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]: