@@ -116,7 +116,11 @@ def forward(self, inputs: List[torch.Tensor]):
116116 It has a shape of (batch, last_operand_size), where
117117 `last_operand_size` is the size of the last operand in the descriptor.
118118 """
119- if not torch .jit .is_scripting () and not torch .compiler .is_compiling ():
119+ if (
120+ not torch .jit .is_scripting ()
121+ and not torch .jit .is_tracing ()
122+ and not torch .compiler .is_compiling ()
123+ ):
120124 if not isinstance (inputs , (list , tuple )):
121125 raise ValueError ("inputs should be a list of tensors" )
122126 if len (inputs ) != self .num_operands - 1 :
@@ -374,7 +378,11 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu
374378 self .descriptor = descriptor
375379
376380 def forward (self , args : List [torch .Tensor ]):
377- if not torch .jit .is_scripting () and not torch .compiler .is_compiling ():
381+ if (
382+ not torch .jit .is_scripting ()
383+ and not torch .jit .is_tracing ()
384+ and not torch .compiler .is_compiling ()
385+ ):
378386 for oid , arg in enumerate (args ):
379387 torch ._assert (
380388 arg .ndim == 2 ,
@@ -506,7 +514,11 @@ def __repr__(self) -> str:
506514 def forward (self , inputs : List [torch .Tensor ]) -> torch .Tensor :
507515 x0 , x1 = self ._perm (inputs [0 ], inputs [1 ])
508516
509- if not torch .jit .is_scripting () and not torch .compiler .is_compiling ():
517+ if (
518+ not torch .jit .is_scripting ()
519+ and not torch .jit .is_tracing ()
520+ and not torch .compiler .is_compiling ()
521+ ):
510522 logger .debug (
511523 f"Calling FusedTensorProductOp3: { self .descriptor } , input shapes: { x0 .shape } , { x1 .shape } "
512524 )
@@ -567,7 +579,11 @@ def __repr__(self) -> str:
567579 def forward (self , inputs : List [torch .Tensor ]) -> torch .Tensor :
568580 x0 , x1 , x2 = self ._perm (inputs [0 ], inputs [1 ], inputs [2 ])
569581
570- if not torch .jit .is_scripting () and not torch .compiler .is_compiling ():
582+ if (
583+ not torch .jit .is_scripting ()
584+ and not torch .jit .is_tracing ()
585+ and not torch .compiler .is_compiling ()
586+ ):
571587 logger .debug (
572588 f"Calling FusedTensorProductOp4: { self .descriptor } , input shapes: { x0 .shape } , { x1 .shape } , { x2 .shape } "
573589 )
@@ -626,7 +642,11 @@ def __repr__(self):
626642 def forward (self , inputs : List [torch .Tensor ]) -> torch .Tensor :
627643 x0 , x1 = inputs
628644
629- if not torch .jit .is_scripting () and not torch .compiler .is_compiling ():
645+ if (
646+ not torch .jit .is_scripting ()
647+ and not torch .jit .is_tracing ()
648+ and not torch .compiler .is_compiling ()
649+ ):
630650 logger .debug (
631651 f"Calling TensorProductUniform3x1d: { self .descriptor } , input shapes: { x0 .shape } , { x1 .shape } "
632652 )
@@ -647,7 +667,11 @@ def __repr__(self):
647667 def forward (self , inputs : List [torch .Tensor ]):
648668 x0 , x1 , x2 = inputs
649669
650- if not torch .jit .is_scripting () and not torch .compiler .is_compiling ():
670+ if (
671+ not torch .jit .is_scripting ()
672+ and not torch .jit .is_tracing ()
673+ and not torch .compiler .is_compiling ()
674+ ):
651675 logger .debug (
652676 f"Calling TensorProductUniform4x1d: { self .descriptor } , input shapes: { x0 .shape } , { x1 .shape } , { x2 .shape } "
653677 )
0 commit comments