@@ -55,11 +55,10 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
5555 # transpose input if needed, no need to record shapes on input
5656 for idx in input_indices :
5757 parent = node .inputs [idx ]
58- if node .inputs [idx ].is_const ():
59- # if input is a constant, transpose that one
60- if not parent .data_format :
61- val = parent .get_tensor_value (as_list = False )
62- parent .set_tensor_value (val .transpose (constants .NHWC_TO_NCHW ))
58+ if node .inputs [idx ].is_const () and len (ctx .find_output_consumers (node .input [1 ])) == 1 :
59+ # if input is a constant, transpose that one if we are the only consumer
60+ val = parent .get_tensor_value (as_list = False )
61+ parent .set_tensor_value (val .transpose (constants .NHWC_TO_NCHW ))
6362 else :
6463 # if input comes from a op, insert transpose op
6564 input_name = node .input [idx ]
@@ -70,33 +69,27 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
7069 if shape is not None :
7170 new_shape = spatial_map (shape , constants .NHWC_TO_NCHW )
7271 ctx .set_shape (transpose .output [0 ], new_shape )
73- parent .data_format = "NCHW"
7472
7573 # kernel must to be transposed
7674 if with_kernel :
7775 parent = node .inputs [1 ]
7876 need_transpose = True
7977 if node .inputs [1 ].is_const ():
8078 # kernel is const - transpose the const if we are the only consumer of const
81- # TODO: maybe we should make a copy of the const, or look at the other consumers
82- # if they'd want a transose as well.
8379 consumers = ctx .find_output_consumers (node .input [1 ])
8480 if len (consumers ) == 1 :
8581 val = parent .get_tensor_value (as_list = False )
8682 val = val .transpose (constants .HWCN_TO_NCHW )
8783 parent .set_tensor_value (val )
88- parent .data_format = "NCHW"
8984 need_transpose = False
9085
9186 if need_transpose :
9287 input_name = node .input [1 ]
9388 transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
9489 transpose .set_attr ("perm" , constants .HWCN_TO_NCHW )
9590 transpose .skip_conversion = True
96- ctx .copy_shape (input_name , transpose .output [0 ])
9791 new_shape = spatial_map (ctx .get_shape (input_name ), constants .HWCN_TO_NCHW )
9892 ctx .set_shape (transpose .output [0 ], new_shape )
99- parent .data_format = "NCHW"
10093
10194 # some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
10295 if new_kernel_shape :
@@ -129,7 +122,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
129122 ctx .set_shape (transpose .output [0 ], output_shape )
130123 # Transpose TF NHWC shape back to NCHW shape for current ONNX conv node output
131124 ctx .set_shape (output_name , spatial_map (output_shape , constants .NHWC_TO_NCHW ))
132- node .data_format = "NCHW"
125+ node .data_format = "NCHW"
133126
134127
135128def add_padding (ctx , node , kernel_shape , strides , dilations = None , spatial = 2 ):
0 commit comments