Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions autograd/numpy/numpy_jvps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
dot_0_adjoint, dot_1_adjoint)
dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
tensordot_adjoint_1)
from autograd.core import (defjvp, defjvps, def_linear_wrt_arg, defjvp_argnum,
def_multilinear, vspace)
from ..util import func
Expand Down Expand Up @@ -185,8 +186,11 @@ def fwd_grad_chooser(g, ans, gvs, vs, x, axis=None, keepdims=False):
def_multilinear(anp.tensordot)
def_multilinear(anp.outer)

def_multilinear(dot_0_adjoint)
def_multilinear(dot_1_adjoint)
def_multilinear(dot_adjoint_0)
def_multilinear(dot_adjoint_1)

def_multilinear(tensordot_adjoint_0)
def_multilinear(tensordot_adjoint_1)

def fwd_grad_concatenate_args(argnum, g, ans, gvs, vs, *axis_args, **kwargs):
result = []
Expand Down
141 changes: 86 additions & 55 deletions autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,21 +280,27 @@ def grad_inner(argnum, ans, vs, gvs, A, B):
axes = ([], [])
else:
axes = ([A.ndim - 1], [B.ndim - 1])
return grad_tensordot(argnum, ans, vs, gvs, A, B, axes=axes)
if argnum == 0:
return lambda G: tensordot_adjoint_0(B, G, axes, vs)
elif argnum == 1:
return lambda G: tensordot_adjoint_1(A, G, axes, vs)
defvjps(anp.inner, grad_inner, [0, 1])

def grad_matmul(argnum, ans, vs, gvs, A, B):
if anp.ndim(A) == 0 or anp.ndim(B) == 0:
raise ValueError("Scalar operands are not allowed, use '*' instead")
elif anp.ndim(A) == 1 or anp.ndim(B) == 1 or (anp.ndim(A) == 2 and anp.ndim(B) == 2):
axes = ([A.ndim - 1], [max(0, B.ndim - 2)])
return grad_tensordot(argnum, ans, vs, gvs, A, B, axes=axes)
if argnum == 0:
return lambda G: tensordot_adjoint_0(B, G, axes, vs)
elif argnum == 1:
return lambda G: tensordot_adjoint_1(A, G, axes, vs)
else:
return grad_einsum(argnum + 1, ans, vs, gvs, ("...ij,...jk->...ik", A, B), None)
defvjps(anp.matmul, grad_matmul, [0, 1])

@primitive
def dot_0_adjoint(B, G, A_vs):
def dot_adjoint_0(B, G, A_vs):
# The adjoint of the operator
# A |--> np.dot(A, B)
A_ndim, B_ndim = A_vs.ndim, onp.ndim(B)
Expand All @@ -305,7 +311,7 @@ def dot_0_adjoint(B, G, A_vs):
return onp.tensordot(G, onp.swapaxes(B, -1, -2), B_ndim - 1)

@primitive
def dot_1_adjoint(A, G, B_vs):
def dot_adjoint_1(A, G, B_vs):
# The adjoint of the operator
# B |--> np.dot(A, B)
A_ndim, B_ndim = onp.ndim(A), B_vs.ndim
Expand All @@ -318,60 +324,85 @@ def dot_1_adjoint(A, G, B_vs):
return swap(onp.tensordot(
G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)]))

defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_0_adjoint(B, g, vs))
defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_1_adjoint(A, g, vs), 1)
defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_adjoint_0(B, g, vs))
defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_adjoint_1(A, g, vs), 1)

defvjp(dot_0_adjoint, lambda ans, vs, gvs, B, g, A_vs: lambda A: dot_1_adjoint(A, g, vs))
defvjp(dot_0_adjoint, lambda ans, vs, gvs, B, g, *args: lambda A: anp.dot(A, B), 1)
defvjp(dot_adjoint_0, lambda ans, vs, gvs, B, g, A_vs: lambda A: dot_adjoint_1(A, g, vs))
defvjp(dot_adjoint_0, lambda ans, vs, gvs, B, g, *args: lambda A: anp.dot(A, B), 1)

defvjp(dot_1_adjoint, lambda ans, vs, gvs, A, g, B_vs: lambda B: dot_0_adjoint(B, g, vs))
defvjp(dot_1_adjoint, lambda ans, vs, gvs, A, g, *args: lambda B: anp.dot(A, B), 1)
defvjp(dot_adjoint_1, lambda ans, vs, gvs, A, g, B_vs: lambda B: dot_adjoint_0(B, g, vs))
defvjp(dot_adjoint_1, lambda ans, vs, gvs, A, g, *args: lambda B: anp.dot(A, B), 1)

def grad_tensordot(argnum, ans, vs, gvs, A, B, axes=2):
def vjp(g):
axes_ = axes
if anp.size(A) == anp.size(B) == 0:
return g * B if argnum == 0 else g * A

A_ndim = anp.ndim(A)
g_axes = onp.arange(anp.ndim(g))
if type(axes_) is int:
axes_ = max(axes_, 0)
if argnum == 0:
B_axes = onp.arange(anp.ndim(B))
return anp.tensordot(g, B, [g_axes[A_ndim-axes_:], B_axes[axes_:]])
else:
A_axes = onp.arange(A_ndim)
return anp.tensordot(A, g, [A_axes[:A_ndim-axes_], g_axes[:A_ndim-axes_]])
elif type(axes_[0]) is int:
B_ndim = anp.ndim(B)
axes_ = [axes_[0] % A_ndim, axes_[1] % B_ndim]
if argnum == 0:
B_axes = onp.arange(B_ndim)
return anp.tensordot(g, B, [g_axes[A_ndim-1:], onp.delete(B_axes, axes_[1])])
else:
A_axes = onp.arange(A_ndim)
return anp.tensordot(A, g, [onp.delete(A_axes, axes_[0]), g_axes[:A_ndim-1]])
else:
B_ndim = anp.ndim(B)
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes_[0]) % A_ndim,
onp.asarray(axes_[1]) % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
if argnum == 0:
out = anp.tensordot(g, B, [g_axes[len(other_axes[0]):], other_axes[1]])
perm = onp.argsort(onp.concatenate(
(other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
return anp.transpose(out, perm)
else:
out = anp.tensordot(A, g, [other_axes[0], g_axes[:len(other_axes[0])]])
perm = onp.argsort(onp.concatenate(
(summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
return anp.transpose(out, perm)
return vjp
defvjps(anp.tensordot, grad_tensordot, [0, 1])
@primitive
def tensordot_adjoint_0(B, G, axes, A_vs):
# The adjoint of the operator
# A |--> np.tensordot(A, B, axes)
if onp.ndim(B) == 0:
return G * B

A_ndim = A_vs.ndim
G_axes = onp.arange(onp.ndim(G))
if type(axes) is int:
axes = max(axes, 0)
B_axes = onp.arange(onp.ndim(B))
return onp.tensordot(G, B, [G_axes[A_ndim-axes:], B_axes[axes:]])
elif type(axes[0]) is int:
B_ndim = onp.ndim(B)
axes = [axes[0] % A_ndim, axes[1] % B_ndim]
B_axes = onp.arange(B_ndim)
return onp.tensordot(G, B, [G_axes[A_ndim-1:], onp.delete(B_axes, axes[1])])
else:
B_ndim = onp.ndim(B)
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes[0]) % A_ndim,
onp.asarray(axes[1]) % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
out = onp.tensordot(G, B, [G_axes[len(other_axes[0]):], other_axes[1]])
perm = onp.argsort(onp.concatenate(
(other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
return onp.transpose(out, perm)

@primitive
def tensordot_adjoint_1(A, G, axes, B_vs):
# The adjoint of the operator
# B |--> np.tensordot(A, B, axes)
if onp.ndim(A) == 0:
return G * A

A_ndim = onp.ndim(A)
G_axes = onp.arange(onp.ndim(G))
if type(axes) is int:
axes = max(axes, 0)
A_axes = onp.arange(A_ndim)
return onp.tensordot(A, G, [A_axes[:A_ndim-axes], G_axes[:A_ndim-axes]])
elif type(axes[0]) is int:
B_ndim = B_vs.ndim
axes = [axes[0] % A_ndim, axes[1] % B_ndim]
A_axes = onp.arange(A_ndim)
return onp.tensordot(A, G, [onp.delete(A_axes, axes[0]), G_axes[:A_ndim-1]])
else:
B_ndim = B_vs.ndim
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes[0]) % A_ndim,
onp.asarray(axes[1]) % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
out = onp.tensordot(A, G, [other_axes[0], G_axes[:len(other_axes[0])]])
perm = onp.argsort(onp.concatenate(
(summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
return onp.transpose(out, perm)

defvjp(anp.tensordot, lambda ans, vs, gvs, A, B, axes=2: lambda G: tensordot_adjoint_0(B, G, axes, vs))
defvjp(anp.tensordot, lambda ans, vs, gvs, A, B, axes=2: lambda G: tensordot_adjoint_1(A, G, axes, vs), 1)

defvjp(tensordot_adjoint_0, lambda ans, vs, gvs, B, G, axes, A_vs: lambda A: tensordot_adjoint_1(A, G, axes, vs))
defvjp(tensordot_adjoint_0, lambda ans, vs, gvs, B, G, axes, A_vs: lambda A: anp.tensordot(A, B, axes), 1)

defvjp(tensordot_adjoint_1, lambda ans, vs, gvs, A, G, axes, B_vs: lambda B: tensordot_adjoint_0(B, G, axes, vs))
defvjp(tensordot_adjoint_1, lambda ans, vs, gvs, A, G, axes, B_vs: lambda B: anp.tensordot(A, B, axes), 1)

defvjp(anp.outer, lambda ans, vs, gvs, a, b : lambda g: anp.dot(g, b.T))
defvjp(anp.outer, lambda ans, vs, gvs, a, b : lambda g: anp.dot(a.T, g), argnum=1)
Expand Down
43 changes: 40 additions & 3 deletions benchmarks/bench_numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
B = npr.randn(2, 3, 5, 4)
g = npr.randn(2, 3, 4, 2, 3, 4)

def time_dot():
np.dot(A, B)

def time_dot_0():
dot_0(A, B, g)

Expand All @@ -44,3 +41,43 @@ def time_dot_1_1():

def time_dot_1_2():
dot_1_2(A, B, g)

tensordot_0 = lambda A, B, G: make_vjp(np.tensordot)(A, B, 2)[0](G)
tensordot_1 = lambda A, B, G: make_vjp(np.tensordot, argnum=1)(A, B, 2)[0](G)

tensordot_0_0 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)
tensordot_0_1 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)
tensordot_0_2 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)

tensordot_1_0 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)
tensordot_1_1 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)
tensordot_1_2 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)

A = npr.randn(2, 3, 5, 4)
B = npr.randn(5, 4, 2, 3)
G = npr.randn(2, 3, 2, 3)

def time_tensordot_0():
tensordot_0(A, B, G)

def time_tensordot_1():
tensordot_1(A, B, G)

def time_tensordot_0_0():
tensordot_0_0(A, B, G)

def time_tensordot_0_1():
tensordot_0_1(A, B, G)

def time_tensordot_0_2():
tensordot_0_2(A, B, G)

def time_tensordot_1_0():
tensordot_1_0(A, B, G)

def time_tensordot_1_1():
tensordot_1_1(A, B, G)

def time_tensordot_1_2():
tensordot_1_2(A, B, G)

14 changes: 7 additions & 7 deletions tests/test_systematic.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,25 @@ def test_matmul(): combo_check(np.matmul, [0, 1])(
[R(3), R(2, 3), R(2, 2, 3)],
[R(3), R(3, 4), R(2, 3, 4)])
def test_matmul_broadcast(): combo_check(np.matmul, [0, 1])([R(1, 2, 2)], [R(3, 2, 1)])
def test_tensordot_1(): combo_check(np.tensordot, [0, 1])(
def test_tensordot_1(): combo_check(np.tensordot, [0, 1], order=3)(
[R(1, 3), R(2, 3, 2)],
[R(3), R(3, 1), R(3, 4, 2)],
axes=[ [(1,), (0,)] ])
def test_tensordot_2(): combo_check(np.tensordot, [0, 1])(
def test_tensordot_2(): combo_check(np.tensordot, [0, 1], order=3)(
[R(3), R(3, 1), R(3, 4, 2)],
[R(1, 3), R(2, 3, 2)],
axes=[ [(0,), (1,)] ])
def test_tensordot_3(): combo_check(np.tensordot, [0, 1])(
def test_tensordot_3(): combo_check(np.tensordot, [0, 1], order=3)(
[R(2, 3), R(2, 3, 4)],
[R(1, 2, 3), R(2, 2, 3, 4)],
axes=[ [(0, 1), (1, 2)] , [(1, 0), (2, 1)] ])
def test_tensordot_4(): combo_check(np.tensordot, [0, 1])(
def test_tensordot_4(): combo_check(np.tensordot, [0, 1], order=3)(
[R(2, 2), R(4, 2, 2)],
[R(2, 2), R(2, 2, 4)],
axes=[1, 2])
def test_tensordot_5(): combo_check(np.tensordot, [0, 1])([R(4)], [R()], axes=[0])
def test_tensordot_6(): combo_check(np.tensordot, [0, 1])([R(2,6)], [R(6,3)], axes=[[[-1], [0]]])
def test_tensordot_7(): combo_check(np.tensordot, [0, 1])([R(2,6)], [R(6,3)], axes=[[-1, 0]])
def test_tensordot_5(): combo_check(np.tensordot, [0, 1], order=3)([R(4)], [R()], axes=[0])
def test_tensordot_6(): combo_check(np.tensordot, [0, 1], order=3)([R(2,6)], [R(6,3)], axes=[[[-1], [0]]])
def test_tensordot_7(): combo_check(np.tensordot, [0, 1], order=3)([R(2,6)], [R(6,3)], axes=[[-1, 0]])

# Need custom tests because gradient is undefined when arguments are identical.
def test_maximum(): combo_check(np.maximum, [0, 1])(
Expand Down