@@ -248,14 +248,13 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
248248 rhs_dilation = rhs_dilation , precision = precision ,
249249 preferred_element_type = preferred_element_type )
250250
251-
252251def _conv_transpose_padding (k , s , padding ):
253252 """Calculate before and after padding for a dim of transposed convolution.
254253
255254 Args:
256255 k: int: kernel dimension.
257256 s: int: dimension stride value.
258- padding: 'same' or 'valid' padding mode for original forward conv.
257+ padding: tuple of ints or 'same' or 'valid' padding mode for original forward conv.
259258
260259 Returns:
261260 2-tuple: ints: before and after padding for transposed convolution.
@@ -269,12 +268,15 @@ def _conv_transpose_padding(k, s, padding):
269268 elif padding == 'VALID' :
270269 pad_len = k + s - 2 + max (k - s , 0 )
271270 pad_a = k - 1
271+ elif isinstance (padding , tuple ):
272+ pads = tuple (k - p - 1 for p in padding )
273+ pad_a = pads [0 ]
274+ pad_len = sum (pads )
272275 else :
273- raise ValueError ('Padding mode must be `SAME` or `VALID`.' )
276+ raise ValueError (f"Invalid padding mode: { padding } " )
274277 pad_b = pad_len - pad_a
275278 return pad_a , pad_b
276279
277-
278280def _flip_axes (x , axes ):
279281 """Flip ndarray 'x' along each axis specified in axes tuple."""
280282 for axis in axes :
@@ -298,9 +300,11 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
298300 lhs: a rank `n+2` dimensional input array.
299301 rhs: a rank `n+2` dimensional array of kernel weights.
300302 strides: sequence of `n` integers, sets fractional stride.
301- padding: 'SAME', 'VALID' will set as transpose of corresponding forward
302- conv, or a sequence of `n` integer 2-tuples describing before-and-after
303- padding for each `n` spatial dimension.
303+ padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after
304+ padding for each spatial dimension in the corresponding forward conv. This effectively adds
305+ `dilation * (kernel_size - 1) - padding` zero padding to each side
306+ of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding
307+ and stride arguments.
304308 rhs_dilation: `None`, or a sequence of `n` integers, giving the
305309 dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
306310 is also known as atrous convolution.
@@ -342,15 +346,12 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
342346 k_shape = np .take (rhs .shape , dn .rhs_spec )
343347 k_sdims = k_shape [2 :]
344348 # Calculate correct output shape given padding and strides.
345- pads : str | Sequence [tuple [int , int ]]
346- if isinstance (padding , str ) and padding in {'SAME' , 'VALID' }:
347- if rhs_dilation is None :
348- rhs_dilation = (1 ,) * (rhs .ndim - 2 )
349- effective_k_size = map (lambda k , r : core .dilate_dim (k , r ), k_sdims , rhs_dilation )
350- pads = [_conv_transpose_padding (k , s , padding )
351- for k ,s in zip (effective_k_size , strides )]
352- else :
353- pads = padding
349+ if rhs_dilation is None :
350+ rhs_dilation = (1 ,) * (rhs .ndim - 2 )
351+ effective_k_size = map (lambda k , r : core .dilate_dim (k , r ), k_sdims , rhs_dilation )
352+ replicated_padding = [padding ] * len (strides ) if isinstance (padding , str ) else padding
353+ pads = [_conv_transpose_padding (k , s , p )
354+ for k ,s ,p in zip (effective_k_size , strides , replicated_padding )]
354355 if transpose_kernel :
355356 # flip spatial dims and swap input / output channel axes
356357 rhs = _flip_axes (rhs , np .array (dn .rhs_spec )[2 :])
0 commit comments