@@ -290,7 +290,8 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
290290 dimension_numbers : ConvGeneralDilatedDimensionNumbers = None ,
291291 transpose_kernel : bool = False ,
292292 precision : lax .PrecisionLike = None ,
293- preferred_element_type : DTypeLike | None = None ) -> Array :
293+ preferred_element_type : DTypeLike | None = None ,
294+ use_consistent_padding : bool = False ) -> Array :
294295 """Convenience wrapper for calculating the N-d convolution "transpose".
295296
296297 This function directly calculates a fractionally strided conv rather than
@@ -301,10 +302,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
301302 rhs: a rank `n+2` dimensional array of kernel weights.
302303 strides: sequence of `n` integers, sets fractional stride.
303304 padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after
304- padding for each spatial dimension in the corresponding forward conv. This effectively adds
305+ padding for each spatial dimension. If `use_consistent_padding=True`, this is interpreted
306+ as the padding of the corresponding forward conv, which effectively adds
305307 `dilation * (kernel_size - 1) - padding` zero padding to each side
306308 of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding
307- and stride arguments.
309+ and stride arguments. This is the behavior in PyTorch. If `use_consistent_padding=False`,
310+ the 'SAME' and 'VALID' strings are interpreted as the padding of the corresponding forward conv,
311+ but integer tuples are interpreted as padding for the transposed convolution.
308312 rhs_dilation: `None`, or a sequence of `n` integers, giving the
309313 dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
310314 is also known as atrous convolution.
@@ -322,7 +326,10 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
322326 preferred_element_type: Optional. Either ``None``, which means the default
323327 accumulation type for the input types, or a datatype, indicating to
324328 accumulate results to and return a result with that datatype.
325-
329+ use_consistent_padding : In older versions of jax, the `padding` argument was interpreted differently
330+ depending on whether it was a string or a sequence of integers. Strings were interpreted as padding
331+ for the forward convolution, while integers were interpreted as padding for the transposed convolution.
332+ If `use_consistent_padding` is False, this inconsistent behavior is preserved for backwards compatibility.
326333 Returns:
327334 Transposed N-d convolution, with output padding following the conventions of
328335 keras.layers.Conv2DTranspose.
@@ -348,10 +355,14 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
348355 # Calculate correct output shape given padding and strides.
349356 if rhs_dilation is None :
350357 rhs_dilation = (1 ,) * (rhs .ndim - 2 )
351- effective_k_size = map (lambda k , r : core .dilate_dim (k , r ), k_sdims , rhs_dilation )
352- replicated_padding = [padding ] * len (strides ) if isinstance (padding , str ) else padding
353- pads = [_conv_transpose_padding (k , s , p )
354- for k ,s ,p in zip (effective_k_size , strides , replicated_padding )]
358+ pads : str | Sequence [tuple [int , int ]]
359+ if use_consistent_padding or padding in {'SAME' , 'VALID' }:
360+ effective_k_size = map (lambda k , r : core .dilate_dim (k , r ), k_sdims , rhs_dilation )
361+ replicated_padding = [padding ] * len (strides ) if isinstance (padding , str ) else padding
362+ pads = [_conv_transpose_padding (k , s , p )
363+ for k ,s ,p in zip (effective_k_size , strides , replicated_padding )]
364+ else :
365+ pads = padding
355366 if transpose_kernel :
356367 # flip spatial dims and swap input / output channel axes
357368 rhs = _flip_axes (rhs , np .array (dn .rhs_spec )[2 :])
0 commit comments