Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions autograd/numpy/numpy_jvps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
tensordot_adjoint_1)
tensordot_adjoint_1, broadcast, _broadcast,
unbroadcast, _unbroadcast)
from autograd.core import (defjvp, defjvps, def_linear_wrt_arg, defjvp_argnum,
def_multilinear, vspace)
from ..util import func
Expand All @@ -22,10 +23,10 @@
def_linear_wrt_arg(anp.true_divide)

# ----- Binary ufuncs -----
defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans))
defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans), argnum=1)
defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans))
defjvp(anp.subtract, lambda g, ans, x, y : broadcast(-g, ans), argnum=1)
defjvp(anp.add, lambda g, ans, x, y : broadcast(g, anp.metadata(ans)))
defjvp(anp.add, lambda g, ans, x, y : broadcast(g, anp.metadata(ans)), argnum=1)
defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, anp.metadata(ans)))
defjvp(anp.subtract, lambda g, ans, x, y : broadcast(-g, anp.metadata(ans)), argnum=1)
defjvp(anp.divide, lambda g, ans, x, y : - g * x / y**2, argnum=1)
defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(x, ans, y))
defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(y, ans, x), argnum=1)
Expand All @@ -40,8 +41,8 @@
defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(x-ans))
defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(y-ans), argnum=1)
defjvp(anp.true_divide,lambda g, ans, x, y : - g * x / y**2, argnum=1)
defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, ans))
defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, ans))
defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, anp.metadata(ans)))
defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, anp.metadata(ans)))
defjvp(anp.mod, lambda g, ans, x, y : -g * anp.floor(x/y), argnum=1)
defjvp(anp.remainder, lambda g, ans, x, y : -g * anp.floor(x/y), argnum=1)
defjvp(anp.power, lambda g, ans, x, y : g * y * x ** anp.where(y, y - 1, 1.))
Expand Down Expand Up @@ -220,14 +221,5 @@ def jvp(g, ans, *arys):

def_multilinear(anp.einsum)

# TODO(mattjj): can we call np.broadcast_to or a related function instead?
def broadcast(x, target):
target_shape, target_ndim, target_dtype, target_iscomplex = anp.metadata(target)
while anp.ndim(x) < target_ndim:
x = anp.expand_dims(x, 0)
for axis, size in enumerate(anp.shape(x)):
if size == 1:
x = anp.repeat(x, target_shape[axis], axis=axis)
if target_iscomplex and not anp.iscomplexobj(x):
x = x + 0j # TODO(mattjj): this might promote the dtype
return x
def_linear_wrt_arg(_broadcast)
def_linear_wrt_arg(_unbroadcast)
60 changes: 48 additions & 12 deletions autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as onp
from numpy.core.einsumfunc import _parse_einsum_input
from ..util import func
from autograd.tracer import primitive, getval
from autograd.tracer import primitive, notrace_primitive, getval
from autograd.vspace import vspace
from autograd.core import defvjp, defvjps, defvjp_is_zero, defvjp_argnum, SparseObject
from . import numpy_wrapper as anp
Expand Down Expand Up @@ -519,11 +519,11 @@ def vjp(g):
defvjp_argnum(anp.einsum, grad_einsum)

defvjp(anp.diagonal,
lambda ans, A, offset=0, axis1=0, axis2=1 :
lambda g: anp.make_diagonal(g, offset, axis1, axis2))
lambda ans, A, offset=0, axis1=0, axis2=1 :
lambda g: anp.make_diagonal(g, offset, axis1, axis2))
defvjp(anp.make_diagonal,
lambda ans, D, offset=0, axis1=0, axis2=1 :
lambda g: anp.diagonal(g, offset, axis1, axis2))
lambda ans, D, offset=0, axis1=0, axis2=1 :
lambda g: anp.diagonal(g, offset, axis1, axis2))

def match_complex(target, x):
target_iscomplex = anp.iscomplexobj(target)
Expand All @@ -535,17 +535,53 @@ def match_complex(target, x):
else:
return x

@notrace_primitive
def _needs_broadcast(x, target_meta):
target_shape, _, _, target_iscomplex = target_meta
return (onp.shape(x) != target_shape
or (target_iscomplex != onp.iscomplexobj(x)))

def broadcast(x, target_meta, broadcast_idx=0):
if _needs_broadcast(x, target_meta):
return _broadcast(x, target_meta, broadcast_idx)
return x

@primitive
def _broadcast(x, target_meta, broadcast_idx=0):
target_shape, _, _, target_iscomplex = target_meta
x = onp.broadcast_to(x, target_shape)
if target_iscomplex and not onp.iscomplexobj(x):
x = x + 0j # TODO(mattjj): this might promote the dtype
return x

def grad_broadcast(ans, x, target_meta, broadcast_idx=0):
meta = anp.metadata(x)
return lambda g: _unbroadcast(g, meta, broadcast_idx)
defvjp(_broadcast, grad_broadcast)

def unbroadcast(x, target_meta, broadcast_idx=0):
target_shape, target_ndim, dtype, target_iscomplex = target_meta
while anp.ndim(x) > target_ndim:
x = anp.sum(x, axis=broadcast_idx)
if _needs_broadcast(x, target_meta):
return _unbroadcast(x, target_meta, broadcast_idx)
return x

@primitive
def _unbroadcast(x, target_meta, broadcast_idx=0):
target_shape, target_ndim, _, target_iscomplex = target_meta
x_shape = onp.shape(x)
while onp.ndim(x) > target_ndim:
x = onp.sum(x, axis=broadcast_idx)
Copy link
Collaborator

@j-towns j-towns Sep 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we should replace the above two lines with:

x = onp.sum(x, axis=range(broadcast_idx, broadcast_idx + onp.ndim(x) - target_ndim))

or similar. Am I right that only calling sum once might lead to better performance, basically because only one output array has to be allocated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I briefly tried something like that (though if broadcast_idx is -1, which is the only nonzero use case I noticed in the code, then I think we want something different) and it didn't seem to make a speed difference, so I dropped it. Now is a good time to make sure it's performant, though!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing a few timings it looks like there is a benefit for small arrays but it's not massive:

In [15]: a = np.ones((5, 5, 5))

In [16]: %timeit np.sum(a, axis=(0, 1))
5.38 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [17]: %timeit x = np.sum(a, axis=0); x = np.sum(x, axis=0)
8.62 µs ± 124 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

and for slightly bigger arrays it's the other way round (maybe I've made some mistake?):

In [18]: a = np.ones((50, 50, 50))

In [19]: %timeit np.sum(a, axis=(0, 1))
118 µs ± 930 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [20]: %timeit x = np.sum(a, axis=0); x = np.sum(x, axis=0)
81.6 µs ± 1.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, I got similar timings. That seems weird for the bigger arrays...

for axis, size in enumerate(target_shape):
if size == 1:
x = anp.sum(x, axis=axis, keepdims=True)
if anp.iscomplexobj(x) and not target_iscomplex:
x = anp.real(x)
if size == 1: # TODO(mattjj): bug here w/ passing through scalars?
x = onp.sum(x, axis=axis, keepdims=True)
Copy link
Collaborator

@j-towns j-towns Sep 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do a similar thing for this one.

if onp.iscomplexobj(x) and not target_iscomplex:
x = onp.real(x)
return x

def grad_unbroadcast(ans, x, target_meta, broadcast_idx=0):
meta = anp.metadata(x)
return lambda g: _broadcast(g, meta, broadcast_idx)
defvjp(_unbroadcast, grad_unbroadcast)

def unbroadcast_f(target, f):
target_meta = anp.metadata(target)
return lambda g: unbroadcast(f(g), target_meta)
Expand Down