Skip to content

Commit 94309d5

Browse files
committed
Add flag for backwards compat
1 parent 64f1c97 commit 94309d5

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

jax/_src/lax/convolution.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
290290
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
291291
transpose_kernel: bool = False,
292292
precision: lax.PrecisionLike = None,
293-
preferred_element_type: DTypeLike | None = None) -> Array:
293+
preferred_element_type: DTypeLike | None = None,
294+
use_consistent_padding: bool = False) -> Array:
294295
"""Convenience wrapper for calculating the N-d convolution "transpose".
295296
296297
This function directly calculates a fractionally strided conv rather than
@@ -301,10 +302,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
301302
rhs: a rank `n+2` dimensional array of kernel weights.
302303
strides: sequence of `n` integers, sets fractional stride.
303304
padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after
304-
padding for each spatial dimension in the corresponding forward conv. This effectively adds
305+
padding for each spatial dimension. If `use_consistent_padding=True`, this is interpreted
306+
as the padding of the corresponding forward conv, which effectively adds
305307
`dilation * (kernel_size - 1) - padding` zero padding to each side
306308
of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding
307-
and stride arguments.
309+
and stride arguments. This is the behavior in PyTorch. If `use_consistent_padding=False`,
310+
the 'SAME' and 'VALID' strings are interpreted as the padding of the corresponding forward conv,
311+
but integer tuples are interpreted as padding for the transposed convolution.
308312
rhs_dilation: `None`, or a sequence of `n` integers, giving the
309313
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
310314
is also known as atrous convolution.
@@ -322,7 +326,10 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
322326
preferred_element_type: Optional. Either ``None``, which means the default
323327
accumulation type for the input types, or a datatype, indicating to
324328
accumulate results to and return a result with that datatype.
325-
329+
use_consistent_padding : In older versions of jax, the `padding` argument was interpreted differently
330+
depending on whether it was a string or a sequence of integers. Strings were interpreted as padding
331+
for the forward convolution, while integers were interpreted as padding for the transposed convolution.
332+
If `use_consistent_padding` is False, this inconsistent behavior is preserved for backwards compatibility.
326333
Returns:
327334
Transposed N-d convolution, with output padding following the conventions of
328335
keras.layers.Conv2DTranspose.
@@ -348,10 +355,14 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
348355
# Calculate correct output shape given padding and strides.
349356
if rhs_dilation is None:
350357
rhs_dilation = (1,) * (rhs.ndim - 2)
351-
effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation)
352-
replicated_padding = [padding] * len(strides) if isinstance(padding, str) else padding
353-
pads = [_conv_transpose_padding(k, s, p)
354-
for k,s,p in zip(effective_k_size, strides, replicated_padding)]
358+
pads: str | Sequence[tuple[int, int]]
359+
if use_consistent_padding or padding in {'SAME', 'VALID'}:
360+
effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation)
361+
replicated_padding = [padding] * len(strides) if isinstance(padding, str) else padding
362+
pads = [_conv_transpose_padding(k, s, p)
363+
for k,s,p in zip(effective_k_size, strides, replicated_padding)]
364+
else:
365+
pads = padding
355366
if transpose_kernel:
356367
# flip spatial dims and swap input / output channel axes
357368
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])

tests/lax_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,8 @@ def fun(lhs, rhs):
946946
return lax.conv_transpose(lhs, rhs, strides, padding,
947947
rhs_dilation=rhs_dilation,
948948
dimension_numbers=dspec,
949-
transpose_kernel=True)
949+
transpose_kernel=True,
950+
use_consistent_padding=True)
950951

951952
def fun_via_grad(lhs, rhs):
952953
return self._conv_transpose_via_grad(lhs, rhs, strides, padding,
@@ -986,7 +987,8 @@ def fun(lhs, rhs):
986987
return lax.conv_transpose(lhs, rhs, strides, padding,
987988
rhs_dilation=rhs_dilation,
988989
dimension_numbers=dspec,
989-
transpose_kernel=False)
990+
transpose_kernel=False,
991+
use_consistent_padding=True)
990992

991993
def fun_via_grad(lhs, rhs):
992994
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)

0 commit comments

Comments
 (0)