|
9 | 9 | import onnx_graphsurgeon as gs |
10 | 10 |
|
11 | 11 | from Deeploy.DeeployTypes import AttrDesc, IoDesc, OperatorDescriptor, VariadicIoDesc |
| 12 | +from Deeploy.Logging import DEFAULT_LOGGER as log |
12 | 13 |
|
13 | 14 |
|
14 | 15 | def IntUnpack(value: Any) -> int: |
@@ -499,13 +500,24 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool: |
499 | 500 | class SqueezeDescriptor(OperatorDescriptor): |
500 | 501 |
|
501 | 502 | def canonicalize(self, node: gs.Node, opset: int) -> bool: |
502 | | - if opset >= 13: |
503 | | - assert len(node.inputs) == 2, f"Expected 2 inputs but received {len(node.inputs)}" |
| 503 | + if len(node.inputs) == 2: |
504 | 504 | axes = node.inputs[1] |
505 | | - assert isinstance(axes, |
506 | | - gs.Constant), f"Expected axes to be a constant but received axes of type {type(axes)}" |
| 505 | + assert isinstance(axes, gs.Constant), \ |
| 506 | + f"Expected axes to be a constant but received axes of type {type(axes)}" |
507 | 507 | node.attrs["axes"] = axes.values |
508 | 508 | axes.outputs.clear() |
| 509 | + |
| 510 | + if opset >= 13 and len(node.inputs) != 2: |
| 511 | + log.warning( |
| 512 | + "Squeeze operation expects 2 inputs for opset >= 13. " |
| 513 | + f"Received node {node.name} with {len(node.inputs)} input{'s' if len(node.inputs) > 1 else ''} and opset {opset}" |
| 514 | + ) |
| 515 | + elif opset < 13 and len(node.inputs) != 1: |
| 516 | + log.warning( |
| 517 | + "Squeeze operation expects 1 input for opset < 13. " |
| 518 | + f"Received node {node.name} with {len(node.inputs)} input{'s' if len(node.inputs) > 1 else ''} and opset {opset}" |
| 519 | + ) |
| 520 | + |
509 | 521 | return super().canonicalize(node, opset) |
510 | 522 |
|
511 | 523 |
|
|
0 commit comments