Skip to content

Commit d3dc7da

Browse files
committed
Refactor input checks in tensor product classes to include tracing condition
1 parent bffb227 commit d3dc7da

3 files changed

Lines changed: 40 additions & 8 deletions

File tree

cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,11 @@ def forward(
251251
"""
252252
If ``indices`` is not None, the first input is indexed by ``indices``.
253253
"""
254-
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
254+
if (
255+
not torch.jit.is_scripting()
256+
and not torch.jit.is_tracing()
257+
and not torch.compiler.is_compiling()
258+
):
255259
if not isinstance(inputs, (list, tuple)):
256260
raise ValueError(
257261
"inputs should be a list of tensors followed by optional indices"

cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,11 @@ def forward(
320320
i0 = i0.to(torch.int32)
321321
x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u)
322322
x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u)
323-
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
323+
if (
324+
not torch.jit.is_scripting()
325+
and not torch.jit.is_tracing()
326+
and not torch.compiler.is_compiling()
327+
):
324328
logger.debug(
325329
f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
326330
)

cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ def forward(self, inputs: List[torch.Tensor]):
116116
It has a shape of (batch, last_operand_size), where
117117
`last_operand_size` is the size of the last operand in the descriptor.
118118
"""
119-
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
119+
if (
120+
not torch.jit.is_scripting()
121+
and not torch.jit.is_tracing()
122+
and not torch.compiler.is_compiling()
123+
):
120124
if not isinstance(inputs, (list, tuple)):
121125
raise ValueError("inputs should be a list of tensors")
122126
if len(inputs) != self.num_operands - 1:
@@ -374,7 +378,11 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu
374378
self.descriptor = descriptor
375379

376380
def forward(self, args: List[torch.Tensor]):
377-
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
381+
if (
382+
not torch.jit.is_scripting()
383+
and not torch.jit.is_tracing()
384+
and not torch.compiler.is_compiling()
385+
):
378386
for oid, arg in enumerate(args):
379387
torch._assert(
380388
arg.ndim == 2,
@@ -506,7 +514,11 @@ def __repr__(self) -> str:
506514
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
507515
x0, x1 = self._perm(inputs[0], inputs[1])
508516

509-
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
517+
if (
518+
not torch.jit.is_scripting()
519+
and not torch.jit.is_tracing()
520+
and not torch.compiler.is_compiling()
521+
):
510522
logger.debug(
511523
f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}"
512524
)
@@ -567,7 +579,11 @@ def __repr__(self) -> str:
567579
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
568580
x0, x1, x2 = self._perm(inputs[0], inputs[1], inputs[2])
569581

570-
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
582+
if (
583+
not torch.jit.is_scripting()
584+
and not torch.jit.is_tracing()
585+
and not torch.compiler.is_compiling()
586+
):
571587
logger.debug(
572588
f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}"
573589
)
@@ -626,7 +642,11 @@ def __repr__(self):
626642
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
627643
x0, x1 = inputs
628644

629-
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
645+
if (
646+
not torch.jit.is_scripting()
647+
and not torch.jit.is_tracing()
648+
and not torch.compiler.is_compiling()
649+
):
630650
logger.debug(
631651
f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}"
632652
)
@@ -647,7 +667,11 @@ def __repr__(self):
647667
def forward(self, inputs: List[torch.Tensor]):
648668
x0, x1, x2 = inputs
649669

650-
if not torch.jit.is_scripting() and not torch.compiler.is_compiling():
670+
if (
671+
not torch.jit.is_scripting()
672+
and not torch.jit.is_tracing()
673+
and not torch.compiler.is_compiling()
674+
):
651675
logger.debug(
652676
f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}"
653677
)

0 commit comments

Comments
 (0)