Skip to content

Commit 1e460b4

Browse files
committed
Fixes #32267
1 parent 15e9723 commit 1e460b4

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

jax/_src/lax/convolution.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,13 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
248248
rhs_dilation=rhs_dilation, precision=precision,
249249
preferred_element_type=preferred_element_type)
250250

251-
252251
def _conv_transpose_padding(k, s, padding):
253252
"""Calculate before and after padding for a dim of transposed convolution.
254253
255254
Args:
256255
k: int: kernel dimension.
257256
s: int: dimension stride value.
258-
padding: 'same' or 'valid' padding mode for original forward conv.
257+
padding: tuple of ints or 'same' or 'valid' padding mode for original forward conv.
259258
260259
Returns:
261260
2-tuple: ints: before and after padding for transposed convolution.
@@ -269,12 +268,15 @@ def _conv_transpose_padding(k, s, padding):
269268
elif padding == 'VALID':
270269
pad_len = k + s - 2 + max(k - s, 0)
271270
pad_a = k - 1
271+
elif isinstance(padding, tuple):
272+
pads = tuple(k - p - 1 for p in padding)
273+
pad_a = pads[0]
274+
pad_len = sum(pads)
272275
else:
273-
raise ValueError('Padding mode must be `SAME` or `VALID`.')
276+
raise ValueError(f"Invalid padding mode: {padding}")
274277
pad_b = pad_len - pad_a
275278
return pad_a, pad_b
276279

277-
278280
def _flip_axes(x, axes):
279281
"""Flip ndarray 'x' along each axis specified in axes tuple."""
280282
for axis in axes:
@@ -298,9 +300,11 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
298300
lhs: a rank `n+2` dimensional input array.
299301
rhs: a rank `n+2` dimensional array of kernel weights.
300302
strides: sequence of `n` integers, sets fractional stride.
301-
padding: 'SAME', 'VALID' will set as transpose of corresponding forward
302-
conv, or a sequence of `n` integer 2-tuples describing before-and-after
303-
padding for each `n` spatial dimension.
303+
padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after
304+
padding for each spatial dimension in the corresponding forward conv. This effectively adds
305+
`dilation * (kernel_size - 1) - padding` zero padding to each side
306+
of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding
307+
and stride arguments.
304308
rhs_dilation: `None`, or a sequence of `n` integers, giving the
305309
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
306310
is also known as atrous convolution.
@@ -342,15 +346,12 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
342346
k_shape = np.take(rhs.shape, dn.rhs_spec)
343347
k_sdims = k_shape[2:]
344348
# Calculate correct output shape given padding and strides.
345-
pads: str | Sequence[tuple[int, int]]
346-
if isinstance(padding, str) and padding in {'SAME', 'VALID'}:
347-
if rhs_dilation is None:
348-
rhs_dilation = (1,) * (rhs.ndim - 2)
349-
effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation)
350-
pads = [_conv_transpose_padding(k, s, padding)
351-
for k,s in zip(effective_k_size, strides)]
352-
else:
353-
pads = padding
349+
if rhs_dilation is None:
350+
rhs_dilation = (1,) * (rhs.ndim - 2)
351+
effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation)
352+
replicated_padding = [padding] * len(strides) if isinstance(padding, str) else padding
353+
pads = [_conv_transpose_padding(k, s, p)
354+
for k,s,p in zip(effective_k_size, strides, replicated_padding)]
354355
if transpose_kernel:
355356
# flip spatial dims and swap input / output channel axes
356357
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])

tests/lax_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,9 @@ def _conv_transpose_via_grad(data, kernel, strides, padding,
887887
for i in range(nspatial)]
888888
elif padding == 'SAME':
889889
o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)]
890+
else:
891+
o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0) - np.sum(p)
892+
for i, p in enumerate(padding)]
890893
o_shape = [in_shape[0], k_shape[1]] + o_sdims
891894
out_spec_inv = [x[0] for x in
892895
sorted(enumerate(dn.out_spec), key=lambda x: x[1])]
@@ -922,7 +925,9 @@ def _transpose_conv_kernel(data, kernel, dimension_numbers):
922925
],
923926
dtype=lax_test_util.float_dtypes,
924927
strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)],
925-
padding=["VALID", "SAME"],
928+
padding=list(itertools.product(
929+
itertools.product([0,1,2], [0,1,2]),
930+
itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"],
926931
dspec=[
927932
("NHWC", "HWIO", "NHWC"),
928933
],
@@ -962,7 +967,9 @@ def fun_via_grad(lhs, rhs):
962967
],
963968
dtype=lax_test_util.float_dtypes,
964969
strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)],
965-
padding=["VALID", "SAME"],
970+
padding=list(itertools.product(
971+
itertools.product([0,1,2], [0,1,2]),
972+
itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"],
966973
dspec=[
967974
("NHWC", "HWIO", "NHWC"),
968975
],

0 commit comments

Comments
 (0)