diff --git a/autograd/scipy/signal.py b/autograd/scipy/signal.py index 49b76194..61980bc9 100644 --- a/autograd/scipy/signal.py +++ b/autograd/scipy/signal.py @@ -9,14 +9,16 @@ @primitive def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"): - assert mode in ["valid", "full"], f"Mode {mode} not yet implemented" + assert mode in ["valid", "full", "same"], ( + f"Mode {mode} undefined, it can be one of 'valid', 'full', and 'same'" + ) if axes is None: axes = [list(range(A.ndim)), list(range(A.ndim))] wrong_order = any([B.shape[ax_B] < A.shape[ax_A] for ax_A, ax_B in zip(*axes)]) if wrong_order: if mode == "valid" and not all([B.shape[ax_B] <= A.shape[ax_A] for ax_A, ax_B in zip(*axes)]): raise Exception("One array must be larger than the other along all convolved dimensions") - elif mode != "full" or B.size <= A.size: # Tie breaker + elif mode == "valid" or (mode == "full" and B.size <= A.size): # Tie breaker i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore i2 = i1 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore i3 = i2 + len(axes[0]) @@ -27,8 +29,12 @@ def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"): ignore_A + ignore_B + conv ) - if mode == "full": - B = pad_to_full(B, A, axes[::-1]) + if mode == "same": + B, A = A, B + axes = axes[::-1] + dot_axes = dot_axes[::-1] + if mode != "valid": + B = pad(B, A, axes[::-1], mode=mode) B_view_shape = list(B.shape) B_view_strides = list(B.strides) flipped_idxs = [slice(None)] * A.ndim @@ -40,7 +46,16 @@ def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"): B_view = as_strided(B, B_view_shape, B_view_strides) A_view = A[tuple(flipped_idxs)] all_axes = [list(axes[i]) + list(dot_axes[i]) for i in [0, 1]] - return einsum_tensordot(A_view, B_view, all_axes) + if mode == "same": + i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore + i2 = i1 + len(axes[0]) + i3 = i2 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore + ignore_A = list(range(i1)) + conv = list(range(i1, i2)) + ignore_B = list(range(i2, i3)) + return einsum_tensordot(B_view, A_view, all_axes[::-1]).transpose(ignore_A + ignore_B + conv) + else: + return einsum_tensordot(A_view, B_view, all_axes) def einsum_tensordot(A, B, axes, reverse=False): @@ -54,10 +69,15 @@ def einsum_tensordot(A, B, axes, reverse=False): return npo.einsum(A, A_axnums, B, B_axnums) -def pad_to_full(A, B, axes): +def pad(A, B, axes, mode="full"): A_pad = [(0, 0)] * A.ndim - for ax_A, ax_B in zip(*axes): - A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2 + if mode == "full": + for ax_A, ax_B in zip(*axes): + A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2 + elif mode == "same": + for ax_A, ax_B in zip(*axes): + right_bound = (B.shape[ax_B] - 1) // 2 + A_pad[ax_A] = (B.shape[ax_B] - 1 - right_bound, right_bound) return npo.pad(A, A_pad, mode="constant") @@ -124,7 +144,9 @@ def flipped_idxs(ndim, axes): def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"): - assert mode in ["valid", "full"], f"Grad for mode {mode} not yet implemented" + assert mode in ["valid", "full", "same"], ( + f"Mode {mode} undefined, it can be one of 'valid', 'full', and 'same'" + ) axes, shapes = parse_axes(A.shape, B.shape, axes, dot_axes, mode) if argnum == 0: X, Y = A, B @@ -139,22 +161,46 @@ def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"): if mode == "full": new_mode = "valid" - else: + elif mode == "valid": if any([x_size > y_size for x_size, y_size in zip(shapes[_X_]["conv"], shapes[_Y_]["conv"])]): new_mode = "full" else: new_mode = "valid" - def vjp(g): - result = convolve( - g, - Y[flipped_idxs(Y.ndim, axes[_Y_]["conv"])], - axes=[axes["out"]["conv"], axes[_Y_]["conv"]], - dot_axes=[axes["out"][ignore_Y], axes[_Y_]["ignore"]], - mode=new_mode, - ) - new_order = npo.argsort(axes[_X_]["ignore"] + axes[_X_]["dot"] + axes[_X_]["conv"]) - return np.transpose(result, new_order) + if mode == "same": + + def vjp(g): + g_pad = [(0, 0)] * g.ndim + for ax, ax_B in zip(axes["out"]["conv"], axes["B"]["conv"]): + left_bound = (B.shape[ax_B] - 1) // 2 + g_pad[ax] = ( + left_bound, + B.shape[ax_B] - 1 - left_bound, + ) + g = np.pad(g, g_pad, mode="constant") + + result = convolve( + g, + Y[flipped_idxs(Y.ndim, axes[_Y_]["conv"])], + axes=[axes["out"]["conv"], axes[_Y_]["conv"]], + dot_axes=[axes["out"][ignore_Y], axes[_Y_]["ignore"]], + mode="valid", + ) + + new_order = np.argsort(axes[_X_]["ignore"] + axes[_X_]["dot"] + axes[_X_]["conv"]) + return np.transpose(result, new_order) + else: + + def vjp(g): + result = convolve( + g, + Y[flipped_idxs(Y.ndim, axes[_Y_]["conv"])], + axes=[axes["out"]["conv"], axes[_Y_]["conv"]], + dot_axes=[axes["out"][ignore_Y], axes[_Y_]["ignore"]], + mode=new_mode, + ) + new_order = npo.argsort(axes[_X_]["ignore"] + axes[_X_]["dot"] + axes[_X_]["conv"]) + return np.transpose(result, new_order) return vjp diff --git a/tests/test_scipy.py b/tests/test_scipy.py index ba563eff..b636b055 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -301,12 +301,12 @@ def test_convolve_generalization(): def test_convolve(): combo_check(autograd.scipy.signal.convolve, [0, 1])( - [R(4), R(5), R(6)], [R(2), R(3), R(4)], mode=["full", "valid"] + [R(4), R(5), R(6)], [R(2), R(3), R(4)], mode=["full", "valid", "same"] ) def test_convolve_2d(): combo_check(autograd.scipy.signal.convolve, [0, 1])( - [R(4, 3), R(5, 4), R(6, 7)], [R(2, 2), R(3, 2), R(4, 2), R(4, 1)], mode=["full", "valid"] + [R(4, 3), R(5, 4), R(6, 7)], [R(2, 2), R(3, 2), R(4, 2), R(4, 1)], mode=["full", "valid", "same"] ) def test_convolve_ignore(): @@ -314,7 +314,7 @@ def test_convolve_ignore(): [R(4, 3)], [R(3, 2)], axes=[([0], [0]), ([1], [1]), ([0], [1]), ([1], [0]), ([0, 1], [0, 1]), ([1, 0], [1, 0])], - mode=["full", "valid"], + mode=["full", "valid", "same"], ) def test_convolve_ignore_dot(): @@ -323,7 +323,7 @@ def test_convolve_ignore_dot(): [R(3, 2, 3)], axes=[([1], [1])], dot_axes=[([0], [2]), ([0], [0])], - mode=["full", "valid"], + mode=["full", "valid", "same"], ) ### Special ###