Skip to content

Commit 16bc463

Browse files
committed
Add ConvTranspose descriptor
1 parent 7bd7353 commit 16bc463

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,25 @@ def _padsDefault(node: gs.Node) -> Tuple[int, ...]:
371371
],
372372
)
373373

374+
convTransposeDesc = OperatorDescriptor(
375+
inputDescriptor = IoDesc(["data_in", "weight"], optional = "bias"),
376+
outputDescriptor = IoDesc("data_out"),
377+
attrDescriptors = [
378+
AttrDesc("auto_pad", AutoPad, default = AutoPad.NOTSET),
379+
AttrDesc("dilations", IntTupleUnpack, default = _dilationsDefault),
380+
AttrDesc("group", IntUnpack, default = 1),
381+
AttrDesc("kernel_shape", IntTupleUnpack, default = _kernelShapeDefault),
382+
# TODO: Add output_shape and output_padding default functions.
383+
# Docs:
384+
# - ONNX: https://onnx.ai/onnx/operators/onnx__ConvTranspose.html
385+
# - PyTorch: https://docs.pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
386+
# AttrDesc("output_shape", IntTupleUnpack, default = _outputShapeDefault),
387+
# AttrDesc("output_padding", IntTupleUnpack, default = _outputPaddingDefault),
388+
AttrDesc("pads", IntTupleUnpack, default = _padsDefault),
389+
AttrDesc("strides", IntTupleUnpack, default = _stridesDefault),
390+
],
391+
)
392+
374393

375394
class RequantizedOperatorDescriptor(OperatorDescriptor):
376395

@@ -750,6 +769,7 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
750769
"CLCA": clcaDesc,
751770
"Concat": concatDesc,
752771
"Conv": convDesc,
772+
"ConvTranspose": convTransposeDesc,
753773
"DebugPrint": debugPrintDesc,
754774
"Dequant": dequantDesc,
755775
"Div": divDesc,

0 commit comments

Comments
 (0)