@@ -247,7 +247,6 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
247
247
rhs_dilation = rhs_dilation , precision = precision ,
248
248
preferred_element_type = preferred_element_type )
249
249
250
-
251
250
def _conv_transpose_padding (k , s , padding ):
252
251
"""Calculate before and after padding for a dim of transposed convolution.
253
252
@@ -268,12 +267,15 @@ def _conv_transpose_padding(k, s, padding):
268
267
elif padding == 'VALID' :
269
268
pad_len = k + s - 2 + max (k - s , 0 )
270
269
pad_a = k - 1
270
+ elif isinstance (padding , tuple ):
271
+ pads = tuple (k - p - 1 for p in padding )
272
+ pad_a = pads [0 ]
273
+ pad_len = sum (pads )
271
274
else :
272
- raise ValueError ('Padding mode must be `SAME` or `VALID`.' )
275
+ raise ValueError (f"Invalid padding mode: { padding } " )
273
276
pad_b = pad_len - pad_a
274
277
return pad_a , pad_b
275
278
276
-
277
279
def _flip_axes (x , axes ):
278
280
"""Flip ndarray 'x' along each axis specified in axes tuple."""
279
281
for axis in axes :
@@ -297,9 +299,11 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
297
299
lhs: a rank `n+2` dimensional input array.
298
300
rhs: a rank `n+2` dimensional array of kernel weights.
299
301
strides: sequence of `n` integers, sets fractional stride.
300
- padding: 'SAME', 'VALID' will set as transpose of corresponding forward
301
- conv, or a sequence of `n` integer 2-tuples describing before-and-after
302
- padding for each `n` spatial dimension.
302
+ padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after
303
+ padding for each spatial dimension in the corresponding forward conv. This effectively adds
304
+ `dilation * (kernel_size - 1) - padding` zero padding to each side
305
+ of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding
306
+ and stride arguments.
303
307
rhs_dilation: `None`, or a sequence of `n` integers, giving the
304
308
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
305
309
is also known as atrous convolution.
@@ -342,14 +346,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
342
346
k_sdims = k_shape [2 :]
343
347
# Calculate correct output shape given padding and strides.
344
348
pads : str | Sequence [tuple [int , int ]]
345
- if isinstance (padding , str ) and padding in {'SAME' , 'VALID' }:
346
- if rhs_dilation is None :
347
- rhs_dilation = (1 ,) * (rhs .ndim - 2 )
348
- effective_k_size = map (lambda k , r : core .dilate_dim (k , r ), k_sdims , rhs_dilation )
349
- pads = [_conv_transpose_padding (k , s , padding )
350
- for k ,s in zip (effective_k_size , strides )]
351
- else :
352
- pads = padding
349
+ if rhs_dilation is None :
350
+ rhs_dilation = (1 ,) * (rhs .ndim - 2 )
351
+ effective_k_size = map (lambda k , r : core .dilate_dim (k , r ), k_sdims , rhs_dilation )
352
+ if isinstance (padding , str ):
353
+ padding = [padding ] * len (strides )
354
+ pads = [_conv_transpose_padding (k , s , p )
355
+ for k ,s ,p in zip (effective_k_size , strides , padding )]
353
356
if transpose_kernel :
354
357
# flip spatial dims and swap input / output channel axes
355
358
rhs = _flip_axes (rhs , np .array (dn .rhs_spec )[2 :])
0 commit comments