Skip to content

Commit

Permalink
add logger info messages
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jan 24, 2025
1 parent 69cf172 commit fe279c5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def dispatch(
return tensor_product_ops_impl(
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
)
except NotImplementedError as e:
logger.info(f"{e}. Falling back to JAX.")
except NotImplementedError:
pass

return tensor_product_vanilla_impl(
*inputs, output_shapes=output_shapes, d=d, exe=exe, **options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,28 @@ def tensor_product_ops_impl(
) -> tuple[jax.Array, ...]: # output buffers
assert exe.max_out_buffer + 1 == len(exe.out_buffers)

detail_str = f"\n{d}\n{exe}".replace("\n", "\n | ")

if not d.all_same_segment_shape():
raise NotImplementedError("Only supports operands with the same shape")
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()

try:
from cuequivariance_ops_jax import tensor_product_uniform_1d
except ImportError:
raise NotImplementedError("Cannot import cuequivariance_ops_jax")
logger.info("🛶 can't import cuequivariance_ops_jax")
raise NotImplementedError()

modes = d.subscripts.modes()
if len(modes) > 1:
raise NotImplementedError("cuequivariance_ops_jax only supports 1D modes")
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()

if len(modes) == 1:
dims: set[int] = d.get_dims(modes[0])
if len(dims) != 1:
raise NotImplementedError(
"cuequivariance_ops_jax only supports uniform 1D modes"
)
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()

batch_size = 1
for shape in [input.shape[:-1] for input in inputs] + [
Expand All @@ -59,9 +63,8 @@ def tensor_product_ops_impl(
n = math.prod(shape)
if n > 1:
if n != batch_size and batch_size != 1:
raise NotImplementedError(
"cuequivariance_ops_jax does not support broadcasting"
)
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()
batch_size = n

reshaped_inputs = []
Expand All @@ -70,7 +73,9 @@ def tensor_product_ops_impl(
(d.operands[op].size, d.operands[op].num_segments)
for op in exe.get_in_buffer_operands(index)
}
assert len(operands) == 1
if len(operands) != 1:
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()
size, num_segments = operands.pop()
reshaped_inputs.append(
input.reshape(
Expand All @@ -82,7 +87,9 @@ def tensor_product_ops_impl(
outputs = []
for index in exe.out_buffers:
operands = exe.get_out_buffer_operands(index)
assert len(operands) == 1
if len(operands) != 1:
logger.info("🛶 can't use tensor_product_uniform_1d for" + detail_str)
raise NotImplementedError()
ope = operands.pop()
size, num_segments = d.operands[ope].size, d.operands[ope].num_segments

Expand All @@ -94,7 +101,7 @@ def tensor_product_ops_impl(
)
)

logger.info("Executing tensor_product_uniform_1d")
logger.info("🎉 use tensor_product_uniform_1d for" + detail_str)

outputs = tensor_product_uniform_1d(
options["dtype_math"],
Expand Down

0 comments on commit fe279c5

Please sign in to comment.