diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 649cb2d91379..4af21c666d09 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -247,14 +247,13 @@ def conv_with_general_padding(lhs: Array, rhs: Array, rhs_dilation=rhs_dilation, precision=precision, preferred_element_type=preferred_element_type) - def _conv_transpose_padding(k, s, padding): """Calculate before and after padding for a dim of transposed convolution. Args: k: int: kernel dimension. s: int: dimension stride value. - padding: 'same' or 'valid' padding mode for original forward conv. + padding: tuple of ints or 'same' or 'valid' padding mode for original forward conv. Returns: 2-tuple: ints: before and after padding for transposed convolution. @@ -268,12 +267,15 @@ def _conv_transpose_padding(k, s, padding): elif padding == 'VALID': pad_len = k + s - 2 + max(k - s, 0) pad_a = k - 1 + elif isinstance(padding, tuple): + pads = tuple(k - p - 1 for p in padding) + pad_a = pads[0] + pad_len = sum(pads) else: - raise ValueError('Padding mode must be `SAME` or `VALID`.') + raise ValueError(f"Invalid padding mode: {padding}") pad_b = pad_len - pad_a return pad_a, pad_b - def _flip_axes(x, axes): """Flip ndarray 'x' along each axis specified in axes tuple.""" for axis in axes: @@ -297,9 +299,11 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], lhs: a rank `n+2` dimensional input array. rhs: a rank `n+2` dimensional array of kernel weights. strides: sequence of `n` integers, sets fractional stride. - padding: 'SAME', 'VALID' will set as transpose of corresponding forward - conv, or a sequence of `n` integer 2-tuples describing before-and-after - padding for each `n` spatial dimension. + padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after + padding for each spatial dimension in the corresponding forward conv. This effectively adds + `dilation * (kernel_size - 1) - padding` zero padding to each side + of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding + and stride arguments. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. @@ -341,15 +345,12 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], k_shape = np.take(rhs.shape, dn.rhs_spec) k_sdims = k_shape[2:] # Calculate correct output shape given padding and strides. - pads: str | Sequence[tuple[int, int]] - if isinstance(padding, str) and padding in {'SAME', 'VALID'}: - if rhs_dilation is None: - rhs_dilation = (1,) * (rhs.ndim - 2) - effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation) - pads = [_conv_transpose_padding(k, s, padding) - for k,s in zip(effective_k_size, strides)] - else: - pads = padding + if rhs_dilation is None: + rhs_dilation = (1,) * (rhs.ndim - 2) + effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation) + replicated_padding = [padding] * len(strides) if isinstance(padding, str) else padding + pads = [_conv_transpose_padding(k, s, p) + for k,s,p in zip(effective_k_size, strides, replicated_padding)] if transpose_kernel: # flip spatial dims and swap input / output channel axes rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) diff --git a/tests/lax_test.py b/tests/lax_test.py index 19f3bda81c1e..19d6cbbe0177 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -887,6 +887,9 @@ def _conv_transpose_via_grad(data, kernel, strides, padding, for i in range(nspatial)] elif padding == 'SAME': o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)] + else: + o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0) - np.sum(p) + for i, p in enumerate(padding)] o_shape = [in_shape[0], k_shape[1]] + o_sdims out_spec_inv = [x[0] for x in sorted(enumerate(dn.out_spec), key=lambda x: x[1])] @@ -922,7 +925,9 @@ def _transpose_conv_kernel(data, kernel, dimension_numbers): ], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)], - padding=["VALID", "SAME"], + padding=list(itertools.product( + itertools.product([0,1,2], [0,1,2]), + itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"], dspec=[ ("NHWC", "HWIO", "NHWC"), ], @@ -962,7 +967,9 @@ def fun_via_grad(lhs, rhs): ], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)], - padding=["VALID", "SAME"], + padding=list(itertools.product( + itertools.product([0,1,2], [0,1,2]), + itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"], dspec=[ ("NHWC", "HWIO", "NHWC"), ],