Skip to content

Commit d865898

Browse files
committed
Relax opset check on squeeze operations to a warning
1 parent 16bc463 commit d865898

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import onnx_graphsurgeon as gs
1010

1111
from Deeploy.DeeployTypes import AttrDesc, IoDesc, OperatorDescriptor, VariadicIoDesc
12+
from Deeploy.Logging import DEFAULT_LOGGER as log
1213

1314

1415
def IntUnpack(value: Any) -> int:
@@ -499,13 +500,24 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
499500
class SqueezeDescriptor(OperatorDescriptor):
500501

501502
def canonicalize(self, node: gs.Node, opset: int) -> bool:
502-
if opset >= 13:
503-
assert len(node.inputs) == 2, f"Expected 2 inputs but received {len(node.inputs)}"
503+
if len(node.inputs) == 2:
504504
axes = node.inputs[1]
505-
assert isinstance(axes,
506-
gs.Constant), f"Expected axes to be a constant but received axes of type {type(axes)}"
505+
assert isinstance(axes, gs.Constant), \
506+
f"Expected axes to be a constant but received axes of type {type(axes)}"
507507
node.attrs["axes"] = axes.values
508508
axes.outputs.clear()
509+
510+
if opset >= 13 and len(node.inputs) != 2:
511+
log.warning(
512+
"Squeeze operation expects 2 inputs for opset >= 13. "
513+
f"Received node {node.name} with {len(node.inputs)} input{'s' if len(node.inputs) > 1 else ''} and opset {opset}"
514+
)
515+
elif opset < 13 and len(node.inputs) != 1:
516+
log.warning(
517+
"Squeeze operation expects 1 input for opset < 13. "
518+
f"Received node {node.name} with {len(node.inputs)} input{'s' if len(node.inputs) > 1 else ''} and opset {opset}"
519+
)
520+
509521
return super().canonicalize(node, opset)
510522

511523

0 commit comments

Comments
 (0)