Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 66 additions & 20 deletions autograd/scipy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@

@primitive
def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
assert mode in ["valid", "full"], f"Mode {mode} not yet implemented"
assert mode in ["valid", "full", "same"], (
f"Mode {mode} undefined, it can be one of 'valid', 'full', and 'same'"
)
if axes is None:
axes = [list(range(A.ndim)), list(range(A.ndim))]
wrong_order = any([B.shape[ax_B] < A.shape[ax_A] for ax_A, ax_B in zip(*axes)])
if wrong_order:
if mode == "valid" and not all([B.shape[ax_B] <= A.shape[ax_A] for ax_A, ax_B in zip(*axes)]):
raise Exception("One array must be larger than the other along all convolved dimensions")
elif mode != "full" or B.size <= A.size: # Tie breaker
elif mode == "valid" or (mode == "full" and B.size <= A.size): # Tie breaker
i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore
i2 = i1 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore
i3 = i2 + len(axes[0])
Expand All @@ -27,8 +29,12 @@ def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
ignore_A + ignore_B + conv
)

if mode == "full":
B = pad_to_full(B, A, axes[::-1])
if mode == "same":
B, A = A, B
axes = axes[::-1]
dot_axes = dot_axes[::-1]
if mode != "valid":
B = pad(B, A, axes[::-1], mode=mode)
B_view_shape = list(B.shape)
B_view_strides = list(B.strides)
flipped_idxs = [slice(None)] * A.ndim
Expand All @@ -40,7 +46,16 @@ def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
B_view = as_strided(B, B_view_shape, B_view_strides)
A_view = A[tuple(flipped_idxs)]
all_axes = [list(axes[i]) + list(dot_axes[i]) for i in [0, 1]]
return einsum_tensordot(A_view, B_view, all_axes)
if mode == "same":
i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore
i2 = i1 + len(axes[0])
i3 = i2 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore
ignore_A = list(range(i1))
conv = list(range(i1, i2))
ignore_B = list(range(i2, i3))
return einsum_tensordot(B_view, A_view, all_axes[::-1]).transpose(ignore_A + ignore_B + conv)
else:
return einsum_tensordot(A_view, B_view, all_axes)


def einsum_tensordot(A, B, axes, reverse=False):
Expand All @@ -54,10 +69,15 @@ def einsum_tensordot(A, B, axes, reverse=False):
return npo.einsum(A, A_axnums, B, B_axnums)


def pad_to_full(A, B, axes):
def pad(A, B, axes, mode="full"):
A_pad = [(0, 0)] * A.ndim
for ax_A, ax_B in zip(*axes):
A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2
if mode == "full":
for ax_A, ax_B in zip(*axes):
A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2
elif mode == "same":
for ax_A, ax_B in zip(*axes):
right_bound = (B.shape[ax_B] - 1) // 2
A_pad[ax_A] = (B.shape[ax_B] - 1 - right_bound, right_bound)
return npo.pad(A, A_pad, mode="constant")


Expand Down Expand Up @@ -124,7 +144,9 @@ def flipped_idxs(ndim, axes):


def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"):
assert mode in ["valid", "full"], f"Grad for mode {mode} not yet implemented"
assert mode in ["valid", "full", "same"], (
f"Mode {mode} undefined, it can be one of 'valid', 'full', and 'same'"
)
axes, shapes = parse_axes(A.shape, B.shape, axes, dot_axes, mode)
if argnum == 0:
X, Y = A, B
Expand All @@ -139,22 +161,46 @@ def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"):

if mode == "full":
new_mode = "valid"
else:
elif mode == "valid":
if any([x_size > y_size for x_size, y_size in zip(shapes[_X_]["conv"], shapes[_Y_]["conv"])]):
new_mode = "full"
else:
new_mode = "valid"

def vjp(g):
result = convolve(
g,
Y[flipped_idxs(Y.ndim, axes[_Y_]["conv"])],
axes=[axes["out"]["conv"], axes[_Y_]["conv"]],
dot_axes=[axes["out"][ignore_Y], axes[_Y_]["ignore"]],
mode=new_mode,
)
new_order = npo.argsort(axes[_X_]["ignore"] + axes[_X_]["dot"] + axes[_X_]["conv"])
return np.transpose(result, new_order)
if mode == "same":

def vjp(g):
g_pad = [(0, 0)] * g.ndim
for ax, ax_B in zip(axes["out"]["conv"], axes["B"]["conv"]):
left_bound = (B.shape[ax_B] - 1) // 2
g_pad[ax] = (
left_bound,
B.shape[ax_B] - 1 - left_bound,
)
g = np.pad(g, g_pad, mode="constant")

result = convolve(
g,
Y[flipped_idxs(Y.ndim, axes[_Y_]["conv"])],
axes=[axes["out"]["conv"], axes[_Y_]["conv"]],
dot_axes=[axes["out"][ignore_Y], axes[_Y_]["ignore"]],
mode="valid",
)

new_order = np.argsort(axes[_X_]["ignore"] + axes[_X_]["dot"] + axes[_X_]["conv"])
return np.transpose(result, new_order)
else:

def vjp(g):
result = convolve(
g,
Y[flipped_idxs(Y.ndim, axes[_Y_]["conv"])],
axes=[axes["out"]["conv"], axes[_Y_]["conv"]],
dot_axes=[axes["out"][ignore_Y], axes[_Y_]["ignore"]],
mode=new_mode,
)
new_order = npo.argsort(axes[_X_]["ignore"] + axes[_X_]["dot"] + axes[_X_]["conv"])
return np.transpose(result, new_order)

return vjp

Expand Down
8 changes: 4 additions & 4 deletions tests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,20 +301,20 @@ def test_convolve_generalization():

def test_convolve():
combo_check(autograd.scipy.signal.convolve, [0, 1])(
[R(4), R(5), R(6)], [R(2), R(3), R(4)], mode=["full", "valid"]
[R(4), R(5), R(6)], [R(2), R(3), R(4)], mode=["full", "valid", "same"]
)

def test_convolve_2d():
combo_check(autograd.scipy.signal.convolve, [0, 1])(
[R(4, 3), R(5, 4), R(6, 7)], [R(2, 2), R(3, 2), R(4, 2), R(4, 1)], mode=["full", "valid"]
[R(4, 3), R(5, 4), R(6, 7)], [R(2, 2), R(3, 2), R(4, 2), R(4, 1)], mode=["full", "valid", "same"]
)

def test_convolve_ignore():
combo_check(autograd.scipy.signal.convolve, [0, 1])(
[R(4, 3)],
[R(3, 2)],
axes=[([0], [0]), ([1], [1]), ([0], [1]), ([1], [0]), ([0, 1], [0, 1]), ([1, 0], [1, 0])],
mode=["full", "valid"],
mode=["full", "valid", "same"],
)

def test_convolve_ignore_dot():
Expand All @@ -323,7 +323,7 @@ def test_convolve_ignore_dot():
[R(3, 2, 3)],
axes=[([1], [1])],
dot_axes=[([0], [2]), ([0], [0])],
mode=["full", "valid"],
mode=["full", "valid", "same"],
)

### Special ###
Expand Down
Loading