Skip to content

Commit fcfeec2

Browse files
committed
Fixes #32267
1 parent 47d933c commit fcfeec2

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

jax/_src/lax/convolution.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
247247
rhs_dilation=rhs_dilation, precision=precision,
248248
preferred_element_type=preferred_element_type)
249249

250-
251250
def _conv_transpose_padding(k, s, padding):
252251
"""Calculate before and after padding for a dim of transposed convolution.
253252
@@ -268,12 +267,15 @@ def _conv_transpose_padding(k, s, padding):
268267
elif padding == 'VALID':
269268
pad_len = k + s - 2 + max(k - s, 0)
270269
pad_a = k - 1
270+
elif isinstance(padding, tuple):
271+
pads = tuple(k - p - 1 for p in padding)
272+
pad_a = pads[0]
273+
pad_len = sum(pads)
271274
else:
272-
raise ValueError('Padding mode must be `SAME` or `VALID`.')
275+
raise ValueError(f"Invalid padding mode: {padding}")
273276
pad_b = pad_len - pad_a
274277
return pad_a, pad_b
275278

276-
277279
def _flip_axes(x, axes):
278280
"""Flip ndarray 'x' along each axis specified in axes tuple."""
279281
for axis in axes:
@@ -297,9 +299,11 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
297299
lhs: a rank `n+2` dimensional input array.
298300
rhs: a rank `n+2` dimensional array of kernel weights.
299301
strides: sequence of `n` integers, sets fractional stride.
300-
padding: 'SAME', 'VALID' will set as transpose of corresponding forward
301-
conv, or a sequence of `n` integer 2-tuples describing before-and-after
302-
padding for each `n` spatial dimension.
302+
padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after
303+
padding for each spatial dimension in the corresponding forward conv. This effectively adds
304+
`dilation * (kernel_size - 1) - padding` zero padding to each side
305+
of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding
306+
and stride arguments.
303307
rhs_dilation: `None`, or a sequence of `n` integers, giving the
304308
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
305309
is also known as atrous convolution.
@@ -342,14 +346,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
342346
k_sdims = k_shape[2:]
343347
# Calculate correct output shape given padding and strides.
344348
pads: str | Sequence[tuple[int, int]]
345-
if isinstance(padding, str) and padding in {'SAME', 'VALID'}:
346-
if rhs_dilation is None:
347-
rhs_dilation = (1,) * (rhs.ndim - 2)
348-
effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation)
349-
pads = [_conv_transpose_padding(k, s, padding)
350-
for k,s in zip(effective_k_size, strides)]
351-
else:
352-
pads = padding
349+
if rhs_dilation is None:
350+
rhs_dilation = (1,) * (rhs.ndim - 2)
351+
effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation)
352+
if isinstance(padding, str):
353+
padding = [padding] * len(strides)
354+
pads = [_conv_transpose_padding(k, s, p)
355+
for k,s,p in zip(effective_k_size, strides, padding)]
353356
if transpose_kernel:
354357
# flip spatial dims and swap input / output channel axes
355358
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])

tests/lax_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,9 @@ def _conv_transpose_via_grad(data, kernel, strides, padding,
887887
for i in range(nspatial)]
888888
elif padding == 'SAME':
889889
o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)]
890+
else:
891+
o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0) - np.sum(p)
892+
for i, p in enumerate(padding)]
890893
o_shape = [in_shape[0], k_shape[1]] + o_sdims
891894
out_spec_inv = [x[0] for x in
892895
sorted(enumerate(dn.out_spec), key=lambda x: x[1])]
@@ -922,7 +925,9 @@ def _transpose_conv_kernel(data, kernel, dimension_numbers):
922925
],
923926
dtype=lax_test_util.float_dtypes,
924927
strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)],
925-
padding=["VALID", "SAME"],
928+
padding=list(itertools.product(
929+
itertools.product([0,1,2], [0,1,2]),
930+
itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"],
926931
dspec=[
927932
("NHWC", "HWIO", "NHWC"),
928933
],
@@ -962,7 +967,9 @@ def fun_via_grad(lhs, rhs):
962967
],
963968
dtype=lax_test_util.float_dtypes,
964969
strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)],
965-
padding=["VALID", "SAME"],
970+
padding=list(itertools.product(
971+
itertools.product([0,1,2], [0,1,2]),
972+
itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"],
966973
dspec=[
967974
("NHWC", "HWIO", "NHWC"),
968975
],

0 commit comments

Comments
 (0)