Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions jax/_src/lax/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,13 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
rhs_dilation=rhs_dilation, precision=precision,
preferred_element_type=preferred_element_type)


def _conv_transpose_padding(k, s, padding):
"""Calculate before and after padding for a dim of transposed convolution.

Args:
k: int: kernel dimension.
s: int: dimension stride value.
padding: 'same' or 'valid' padding mode for original forward conv.
padding: tuple of ints or 'same' or 'valid' padding mode for original forward conv.

Returns:
2-tuple: ints: before and after padding for transposed convolution.
Expand All @@ -268,12 +267,15 @@ def _conv_transpose_padding(k, s, padding):
elif padding == 'VALID':
pad_len = k + s - 2 + max(k - s, 0)
pad_a = k - 1
elif isinstance(padding, tuple):
pads = tuple(k - p - 1 for p in padding)
pad_a = pads[0]
pad_len = sum(pads)
else:
raise ValueError('Padding mode must be `SAME` or `VALID`.')
raise ValueError(f"Invalid padding mode: {padding}")
pad_b = pad_len - pad_a
return pad_a, pad_b


def _flip_axes(x, axes):
"""Flip ndarray 'x' along each axis specified in axes tuple."""
for axis in axes:
Expand All @@ -297,9 +299,11 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
strides: sequence of `n` integers, sets fractional stride.
padding: 'SAME', 'VALID' will set as transpose of corresponding forward
conv, or a sequence of `n` integer 2-tuples describing before-and-after
padding for each `n` spatial dimension.
padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after
padding for each spatial dimension in the corresponding forward conv. This effectively adds
`dilation * (kernel_size - 1) - padding` zero padding to each side
of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding
and stride arguments.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
Expand Down Expand Up @@ -341,15 +345,12 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
k_shape = np.take(rhs.shape, dn.rhs_spec)
k_sdims = k_shape[2:]
# Calculate correct output shape given padding and strides.
pads: str | Sequence[tuple[int, int]]
if isinstance(padding, str) and padding in {'SAME', 'VALID'}:
if rhs_dilation is None:
rhs_dilation = (1,) * (rhs.ndim - 2)
effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation)
pads = [_conv_transpose_padding(k, s, padding)
for k,s in zip(effective_k_size, strides)]
else:
pads = padding
if rhs_dilation is None:
rhs_dilation = (1,) * (rhs.ndim - 2)
effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation)
replicated_padding = [padding] * len(strides) if isinstance(padding, str) else padding
pads = [_conv_transpose_padding(k, s, p)
for k,s,p in zip(effective_k_size, strides, replicated_padding)]
if transpose_kernel:
# flip spatial dims and swap input / output channel axes
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
Expand Down
11 changes: 9 additions & 2 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,9 @@ def _conv_transpose_via_grad(data, kernel, strides, padding,
for i in range(nspatial)]
elif padding == 'SAME':
o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)]
else:
o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0) - np.sum(p)
for i, p in enumerate(padding)]
o_shape = [in_shape[0], k_shape[1]] + o_sdims
out_spec_inv = [x[0] for x in
sorted(enumerate(dn.out_spec), key=lambda x: x[1])]
Expand Down Expand Up @@ -922,7 +925,9 @@ def _transpose_conv_kernel(data, kernel, dimension_numbers):
],
dtype=lax_test_util.float_dtypes,
strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)],
padding=["VALID", "SAME"],
padding=list(itertools.product(
itertools.product([0,1,2], [0,1,2]),
itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"],
dspec=[
("NHWC", "HWIO", "NHWC"),
],
Expand Down Expand Up @@ -962,7 +967,9 @@ def fun_via_grad(lhs, rhs):
],
dtype=lax_test_util.float_dtypes,
strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)],
padding=["VALID", "SAME"],
padding=list(itertools.product(
itertools.product([0,1,2], [0,1,2]),
itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"],
dspec=[
("NHWC", "HWIO", "NHWC"),
],
Expand Down