@@ -371,6 +371,25 @@ def _padsDefault(node: gs.Node) -> Tuple[int, ...]:
371371 ],
372372)
373373
374+ convTransposeDesc = OperatorDescriptor (
375+ inputDescriptor = IoDesc (["data_in" , "weight" ], optional = "bias" ),
376+ outputDescriptor = IoDesc ("data_out" ),
377+ attrDescriptors = [
378+ AttrDesc ("auto_pad" , AutoPad , default = AutoPad .NOTSET ),
379+ AttrDesc ("dilations" , IntTupleUnpack , default = _dilationsDefault ),
380+ AttrDesc ("group" , IntUnpack , default = 1 ),
381+ AttrDesc ("kernel_shape" , IntTupleUnpack , default = _kernelShapeDefault ),
382+ # TODO: Add output_shape and output_padding default functions.
383+ # Docs:
384+ # - ONNX: https://onnx.ai/onnx/operators/onnx__ConvTranspose.html
385+ # - PyTorch: https://docs.pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
386+ # AttrDesc("output_shape", IntTupleUnpack, default = _outputShapeDefault),
387+ # AttrDesc("output_padding", IntTupleUnpack, default = _outputPaddingDefault),
388+ AttrDesc ("pads" , IntTupleUnpack , default = _padsDefault ),
389+ AttrDesc ("strides" , IntTupleUnpack , default = _stridesDefault ),
390+ ],
391+ )
392+
374393
375394class RequantizedOperatorDescriptor (OperatorDescriptor ):
376395
@@ -750,6 +769,7 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
750769 "CLCA" : clcaDesc ,
751770 "Concat" : concatDesc ,
752771 "Conv" : convDesc ,
772+ "ConvTranspose" : convTransposeDesc ,
753773 "DebugPrint" : debugPrintDesc ,
754774 "Dequant" : dequantDesc ,
755775 "Div" : divDesc ,
0 commit comments