From 1e9bb34f0eac9409be6bae0a95c0b225c3b41252 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Wed, 6 Sep 2017 11:08:56 +0100 Subject: [PATCH 1/3] Adjoint primitives for tensordot --- autograd/numpy/numpy_jvps.py | 10 ++- autograd/numpy/numpy_vjps.py | 141 +++++++++++++++++++++-------------- 2 files changed, 93 insertions(+), 58 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index cc3c0bdfc..0b10433d1 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -1,6 +1,7 @@ from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, - dot_0_adjoint, dot_1_adjoint) + dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, + tensordot_adjoint_1) from autograd.core import (defjvp, defjvps, def_linear_wrt_arg, defjvp_argnum, def_multilinear, vspace) from ..util import func @@ -185,8 +186,11 @@ def fwd_grad_chooser(g, ans, gvs, vs, x, axis=None, keepdims=False): def_multilinear(anp.tensordot) def_multilinear(anp.outer) -def_multilinear(dot_0_adjoint) -def_multilinear(dot_1_adjoint) +def_multilinear(dot_adjoint_0) +def_multilinear(dot_adjoint_1) + +def_multilinear(tensordot_adjoint_0) +def_multilinear(tensordot_adjoint_1) def fwd_grad_concatenate_args(argnum, g, ans, gvs, vs, *axis_args, **kwargs): result = [] diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index 070d32f6b..e78ac5624 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -280,7 +280,10 @@ def grad_inner(argnum, ans, vs, gvs, A, B): axes = ([], []) else: axes = ([A.ndim - 1], [B.ndim - 1]) - return grad_tensordot(argnum, ans, vs, gvs, A, B, axes=axes) + if argnum == 0: + return lambda G: tensordot_adjoint_0(B, G, axes, vs) + elif argnum == 1: + return lambda G: tensordot_adjoint_1(A, G, axes, vs) defvjps(anp.inner, grad_inner, [0, 1]) def grad_matmul(argnum, ans, vs, gvs, A, B): @@ -288,13 +291,16 @@ def grad_matmul(argnum, ans, vs, gvs, A, B): raise ValueError("Scalar operands are not allowed, use '*' instead") elif anp.ndim(A) == 1 or anp.ndim(B) == 1 or (anp.ndim(A) == 2 and anp.ndim(B) == 2): axes = ([A.ndim - 1], [max(0, B.ndim - 2)]) - return grad_tensordot(argnum, ans, vs, gvs, A, B, axes=axes) + if argnum == 0: + return lambda G: tensordot_adjoint_0(B, G, axes, vs) + elif argnum == 1: + return lambda G: tensordot_adjoint_1(A, G, axes, vs) else: return grad_einsum(argnum + 1, ans, vs, gvs, ("...ij,...jk->...ik", A, B), None) defvjps(anp.matmul, grad_matmul, [0, 1]) @primitive -def dot_0_adjoint(B, G, A_vs): +def dot_adjoint_0(B, G, A_vs): # The adjoint of the operator # A |--> np.dot(A, B) A_ndim, B_ndim = A_vs.ndim, onp.ndim(B) @@ -305,7 +311,7 @@ def dot_0_adjoint(B, G, A_vs): return onp.tensordot(G, onp.swapaxes(B, -1, -2), B_ndim - 1) @primitive -def dot_1_adjoint(A, G, B_vs): +def dot_adjoint_1(A, G, B_vs): # The adjoint of the operator # B |--> np.dot(A, B) A_ndim, B_ndim = onp.ndim(A), B_vs.ndim @@ -318,60 +324,85 @@ def dot_1_adjoint(A, G, B_vs): return swap(onp.tensordot( G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)])) -defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_0_adjoint(B, g, vs)) -defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_1_adjoint(A, g, vs), 1) +defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_adjoint_0(B, g, vs)) +defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_adjoint_1(A, g, vs), 1) -defvjp(dot_0_adjoint, lambda ans, vs, gvs, B, g, A_vs: lambda A: dot_1_adjoint(A, g, vs)) -defvjp(dot_0_adjoint, lambda ans, vs, gvs, B, g, *args: lambda A: anp.dot(A, B), 1) +defvjp(dot_adjoint_0, lambda ans, vs, gvs, B, g, A_vs: lambda A: dot_adjoint_1(A, g, vs)) +defvjp(dot_adjoint_0, lambda ans, vs, gvs, B, g, *args: lambda A: anp.dot(A, B), 1) -defvjp(dot_1_adjoint, lambda ans, vs, gvs, A, g, B_vs: lambda B: dot_0_adjoint(B, g, vs)) -defvjp(dot_1_adjoint, lambda ans, vs, gvs, A, g, *args: lambda B: anp.dot(A, B), 1) +defvjp(dot_adjoint_1, lambda ans, vs, gvs, A, g, B_vs: lambda B: dot_adjoint_0(B, g, vs)) +defvjp(dot_adjoint_1, lambda ans, vs, gvs, A, g, *args: lambda B: anp.dot(A, B), 1) -def grad_tensordot(argnum, ans, vs, gvs, A, B, axes=2): - def vjp(g): - axes_ = axes - if anp.size(A) == anp.size(B) == 0: - return g * B if argnum == 0 else g * A - - A_ndim = anp.ndim(A) - g_axes = onp.arange(anp.ndim(g)) - if type(axes_) is int: - axes_ = max(axes_, 0) - if argnum == 0: - B_axes = onp.arange(anp.ndim(B)) - return anp.tensordot(g, B, [g_axes[A_ndim-axes_:], B_axes[axes_:]]) - else: - A_axes = onp.arange(A_ndim) - return anp.tensordot(A, g, [A_axes[:A_ndim-axes_], g_axes[:A_ndim-axes_]]) - elif type(axes_[0]) is int: - B_ndim = anp.ndim(B) - axes_ = [axes_[0] % A_ndim, axes_[1] % B_ndim] - if argnum == 0: - B_axes = onp.arange(B_ndim) - return anp.tensordot(g, B, [g_axes[A_ndim-1:], onp.delete(B_axes, axes_[1])]) - else: - A_axes = onp.arange(A_ndim) - return anp.tensordot(A, g, [onp.delete(A_axes, axes_[0]), g_axes[:A_ndim-1]]) - else: - B_ndim = anp.ndim(B) - A_axes = onp.arange(A_ndim) - B_axes = onp.arange(B_ndim) - summed_axes = [onp.asarray(axes_[0]) % A_ndim, - onp.asarray(axes_[1]) % B_ndim] - other_axes = [onp.delete(A_axes, summed_axes[0]), - onp.delete(B_axes, summed_axes[1])] - if argnum == 0: - out = anp.tensordot(g, B, [g_axes[len(other_axes[0]):], other_axes[1]]) - perm = onp.argsort(onp.concatenate( - (other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])]))) - return anp.transpose(out, perm) - else: - out = anp.tensordot(A, g, [other_axes[0], g_axes[:len(other_axes[0])]]) - perm = onp.argsort(onp.concatenate( - (summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1]))) - return anp.transpose(out, perm) - return vjp -defvjps(anp.tensordot, grad_tensordot, [0, 1]) +@primitive +def tensordot_adjoint_0(B, G, axes, A_vs): + # The adjoint of the operator + # A |--> np.tensordot(A, B, axes) + if onp.ndim(B) == 0: + return G * B + + A_ndim = A_vs.ndim + G_axes = onp.arange(onp.ndim(G)) + if type(axes) is int: + axes = max(axes, 0) + B_axes = onp.arange(onp.ndim(B)) + return onp.tensordot(G, B, [G_axes[A_ndim-axes:], B_axes[axes:]]) + elif type(axes[0]) is int: + B_ndim = onp.ndim(B) + axes = [axes[0] % A_ndim, axes[1] % B_ndim] + B_axes = onp.arange(B_ndim) + return onp.tensordot(G, B, [G_axes[A_ndim-1:], onp.delete(B_axes, axes[1])]) + else: + B_ndim = onp.ndim(B) + A_axes = onp.arange(A_ndim) + B_axes = onp.arange(B_ndim) + summed_axes = [onp.asarray(axes[0]) % A_ndim, + onp.asarray(axes[1]) % B_ndim] + other_axes = [onp.delete(A_axes, summed_axes[0]), + onp.delete(B_axes, summed_axes[1])] + out = onp.tensordot(G, B, [G_axes[len(other_axes[0]):], other_axes[1]]) + perm = onp.argsort(onp.concatenate( + (other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])]))) + return onp.transpose(out, perm) + +@primitive +def tensordot_adjoint_1(A, G, axes, B_vs): + # The adjoint of the operator + # B |--> np.tensordot(A, B, axes) + if onp.ndim(A) == 0: + return G * A + + A_ndim = onp.ndim(A) + G_axes = onp.arange(onp.ndim(G)) + if type(axes) is int: + axes = max(axes, 0) + A_axes = onp.arange(A_ndim) + return onp.tensordot(A, G, [A_axes[:A_ndim-axes], G_axes[:A_ndim-axes]]) + elif type(axes[0]) is int: + B_ndim = B_vs.ndim + axes = [axes[0] % A_ndim, axes[1] % B_ndim] + A_axes = onp.arange(A_ndim) + return onp.tensordot(A, G, [onp.delete(A_axes, axes[0]), G_axes[:A_ndim-1]]) + else: + B_ndim = B_vs.ndim + A_axes = onp.arange(A_ndim) + B_axes = onp.arange(B_ndim) + summed_axes = [onp.asarray(axes[0]) % A_ndim, + onp.asarray(axes[1]) % B_ndim] + other_axes = [onp.delete(A_axes, summed_axes[0]), + onp.delete(B_axes, summed_axes[1])] + out = onp.tensordot(A, G, [other_axes[0], G_axes[:len(other_axes[0])]]) + perm = onp.argsort(onp.concatenate( + (summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1]))) + return onp.transpose(out, perm) + +defvjp(anp.tensordot, lambda ans, vs, gvs, A, B, axes=2: lambda G: tensordot_adjoint_0(B, G, axes, vs)) +defvjp(anp.tensordot, lambda ans, vs, gvs, A, B, axes=2: lambda G: tensordot_adjoint_1(A, G, axes, vs), 1) + +defvjp(tensordot_adjoint_0, lambda ans, vs, gvs, B, G, axes, A_vs: lambda A: tensordot_adjoint_1(A, G, axes, vs)) +defvjp(tensordot_adjoint_0, lambda ans, vs, gvs, B, G, axes, A_vs: lambda A: anp.tensordot(A, B, axes), 1) + +defvjp(tensordot_adjoint_1, lambda ans, vs, gvs, A, G, axes, B_vs: lambda B: tensordot_adjoint_0(B, G, axes, vs)) +defvjp(tensordot_adjoint_1, lambda ans, vs, gvs, A, G, axes, B_vs: lambda B: anp.tensordot(A, B, axes), 1) defvjp(anp.outer, lambda ans, vs, gvs, a, b : lambda g: anp.dot(g, b.T)) defvjp(anp.outer, lambda ans, vs, gvs, a, b : lambda g: anp.dot(a.T, g), argnum=1) From d5b45ef56757290f08d1bd4acff151b91257675c Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Wed, 6 Sep 2017 11:33:48 +0100 Subject: [PATCH 2/3] Test tensordot up to order=3 --- tests/test_systematic.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_systematic.py b/tests/test_systematic.py index ee67edf6f..518233fa0 100644 --- a/tests/test_systematic.py +++ b/tests/test_systematic.py @@ -111,25 +111,25 @@ def test_matmul(): combo_check(np.matmul, [0, 1])( [R(3), R(2, 3), R(2, 2, 3)], [R(3), R(3, 4), R(2, 3, 4)]) def test_matmul_broadcast(): combo_check(np.matmul, [0, 1])([R(1, 2, 2)], [R(3, 2, 1)]) -def test_tensordot_1(): combo_check(np.tensordot, [0, 1])( +def test_tensordot_1(): combo_check(np.tensordot, [0, 1], order=3)( [R(1, 3), R(2, 3, 2)], [R(3), R(3, 1), R(3, 4, 2)], axes=[ [(1,), (0,)] ]) -def test_tensordot_2(): combo_check(np.tensordot, [0, 1])( +def test_tensordot_2(): combo_check(np.tensordot, [0, 1], order=3)( [R(3), R(3, 1), R(3, 4, 2)], [R(1, 3), R(2, 3, 2)], axes=[ [(0,), (1,)] ]) -def test_tensordot_3(): combo_check(np.tensordot, [0, 1])( +def test_tensordot_3(): combo_check(np.tensordot, [0, 1], order=3)( [R(2, 3), R(2, 3, 4)], [R(1, 2, 3), R(2, 2, 3, 4)], axes=[ [(0, 1), (1, 2)] , [(1, 0), (2, 1)] ]) -def test_tensordot_4(): combo_check(np.tensordot, [0, 1])( +def test_tensordot_4(): combo_check(np.tensordot, [0, 1], order=3)( [R(2, 2), R(4, 2, 2)], [R(2, 2), R(2, 2, 4)], axes=[1, 2]) -def test_tensordot_5(): combo_check(np.tensordot, [0, 1])([R(4)], [R()], axes=[0]) -def test_tensordot_6(): combo_check(np.tensordot, [0, 1])([R(2,6)], [R(6,3)], axes=[[[-1], [0]]]) -def test_tensordot_7(): combo_check(np.tensordot, [0, 1])([R(2,6)], [R(6,3)], axes=[[-1, 0]]) +def test_tensordot_5(): combo_check(np.tensordot, [0, 1], order=3)([R(4)], [R()], axes=[0]) +def test_tensordot_6(): combo_check(np.tensordot, [0, 1], order=3)([R(2,6)], [R(6,3)], axes=[[[-1], [0]]]) +def test_tensordot_7(): combo_check(np.tensordot, [0, 1], order=3)([R(2,6)], [R(6,3)], axes=[[-1, 0]]) # Need custom tests because gradient is undefined when arguments are identical. def test_maximum(): combo_check(np.maximum, [0, 1])( From 82068d0e6c916daf8421db7375c22ade6a3d30d6 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Wed, 6 Sep 2017 11:45:08 +0100 Subject: [PATCH 3/3] Add tensordot benchmarks --- benchmarks/bench_numpy_vjps.py | 43 +++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_numpy_vjps.py b/benchmarks/bench_numpy_vjps.py index 1921ef374..1647bbcc5 100644 --- a/benchmarks/bench_numpy_vjps.py +++ b/benchmarks/bench_numpy_vjps.py @@ -18,9 +18,6 @@ B = npr.randn(2, 3, 5, 4) g = npr.randn(2, 3, 4, 2, 3, 4) -def time_dot(): - np.dot(A, B) - def time_dot_0(): dot_0(A, B, g) @@ -44,3 +41,43 @@ def time_dot_1_1(): def time_dot_1_2(): dot_1_2(A, B, g) + +tensordot_0 = lambda A, B, G: make_vjp(np.tensordot)(A, B, 2)[0](G) +tensordot_1 = lambda A, B, G: make_vjp(np.tensordot, argnum=1)(A, B, 2)[0](G) + +tensordot_0_0 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A) +tensordot_0_1 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A) +tensordot_0_2 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A) + +tensordot_1_0 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B) +tensordot_1_1 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B) +tensordot_1_2 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B) + +A = npr.randn(2, 3, 5, 4) +B = npr.randn(5, 4, 2, 3) +G = npr.randn(2, 3, 2, 3) + +def time_tensordot_0(): + tensordot_0(A, B, G) + +def time_tensordot_1(): + tensordot_1(A, B, G) + +def time_tensordot_0_0(): + tensordot_0_0(A, B, G) + +def time_tensordot_0_1(): + tensordot_0_1(A, B, G) + +def time_tensordot_0_2(): + tensordot_0_2(A, B, G) + +def time_tensordot_1_0(): + tensordot_1_0(A, B, G) + +def time_tensordot_1_1(): + tensordot_1_1(A, B, G) + +def time_tensordot_1_2(): + tensordot_1_2(A, B, G) +