diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 105e95050..25c7035ec 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -1,7 +1,8 @@ from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, - tensordot_adjoint_1) + tensordot_adjoint_1, broadcast, _broadcast, + unbroadcast, _unbroadcast) from autograd.core import (defjvp, defjvps, def_linear_wrt_arg, defjvp_argnum, def_multilinear, vspace) from ..util import func @@ -22,10 +23,10 @@ def_linear_wrt_arg(anp.true_divide) # ----- Binary ufuncs ----- -defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans)) -defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans), argnum=1) -defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans)) -defjvp(anp.subtract, lambda g, ans, x, y : broadcast(-g, ans), argnum=1) +defjvp(anp.add, lambda g, ans, x, y : broadcast(g, anp.metadata(ans))) +defjvp(anp.add, lambda g, ans, x, y : broadcast(g, anp.metadata(ans)), argnum=1) +defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, anp.metadata(ans))) +defjvp(anp.subtract, lambda g, ans, x, y : broadcast(-g, anp.metadata(ans)), argnum=1) defjvp(anp.divide, lambda g, ans, x, y : - g * x / y**2, argnum=1) defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(x, ans, y)) defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(y, ans, x), argnum=1) @@ -40,8 +41,8 @@ defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(x-ans)) defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(y-ans), argnum=1) defjvp(anp.true_divide,lambda g, ans, x, y : - g * x / y**2, argnum=1) -defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, ans)) -defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, ans)) +defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, anp.metadata(ans))) +defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, anp.metadata(ans))) defjvp(anp.mod, lambda g, ans, x, y : -g * anp.floor(x/y), argnum=1) defjvp(anp.remainder, lambda g, ans, x, y : -g * anp.floor(x/y), argnum=1) defjvp(anp.power, lambda g, ans, x, y : g * y * x ** anp.where(y, y - 1, 1.)) @@ -220,14 +221,5 @@ def jvp(g, ans, *arys): def_multilinear(anp.einsum) -# TODO(mattjj): can we call np.broadcast_to or a related function instead? -def broadcast(x, target): - target_shape, target_ndim, target_dtype, target_iscomplex = anp.metadata(target) - while anp.ndim(x) < target_ndim: - x = anp.expand_dims(x, 0) - for axis, size in enumerate(anp.shape(x)): - if size == 1: - x = anp.repeat(x, target_shape[axis], axis=axis) - if target_iscomplex and not anp.iscomplexobj(x): - x = x + 0j # TODO(mattjj): this might promote the dtype - return x +def_linear_wrt_arg(_broadcast) +def_linear_wrt_arg(_unbroadcast) diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index ab6517f35..3e01f787c 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -3,7 +3,7 @@ import numpy as onp from numpy.core.einsumfunc import _parse_einsum_input from ..util import func -from autograd.tracer import primitive, getval +from autograd.tracer import primitive, notrace_primitive, getval from autograd.vspace import vspace from autograd.core import defvjp, defvjps, defvjp_is_zero, defvjp_argnum, SparseObject from . import numpy_wrapper as anp @@ -519,11 +519,11 @@ def vjp(g): defvjp_argnum(anp.einsum, grad_einsum) defvjp(anp.diagonal, - lambda ans, A, offset=0, axis1=0, axis2=1 : - lambda g: anp.make_diagonal(g, offset, axis1, axis2)) + lambda ans, A, offset=0, axis1=0, axis2=1 : + lambda g: anp.make_diagonal(g, offset, axis1, axis2)) defvjp(anp.make_diagonal, - lambda ans, D, offset=0, axis1=0, axis2=1 : - lambda g: anp.diagonal(g, offset, axis1, axis2)) + lambda ans, D, offset=0, axis1=0, axis2=1 : + lambda g: anp.diagonal(g, offset, axis1, axis2)) def match_complex(target, x): target_iscomplex = anp.iscomplexobj(target) @@ -535,17 +535,53 @@ def match_complex(target, x): else: return x +@notrace_primitive +def _needs_broadcast(x, target_meta): + target_shape, _, _, target_iscomplex = target_meta + return (onp.shape(x) != target_shape + or (target_iscomplex != onp.iscomplexobj(x))) + +def broadcast(x, target_meta, broadcast_idx=0): + if _needs_broadcast(x, target_meta): + return _broadcast(x, target_meta, broadcast_idx) + return x + +@primitive +def _broadcast(x, target_meta, broadcast_idx=0): + target_shape, _, _, target_iscomplex = target_meta + x = onp.broadcast_to(x, target_shape) + if target_iscomplex and not onp.iscomplexobj(x): + x = x + 0j # TODO(mattjj): this might promote the dtype + return x + +def grad_broadcast(ans, x, target_meta, broadcast_idx=0): + meta = anp.metadata(x) + return lambda g: _unbroadcast(g, meta, broadcast_idx) +defvjp(_broadcast, grad_broadcast) + def unbroadcast(x, target_meta, broadcast_idx=0): - target_shape, target_ndim, dtype, target_iscomplex = target_meta - while anp.ndim(x) > target_ndim: - x = anp.sum(x, axis=broadcast_idx) + if _needs_broadcast(x, target_meta): + return _unbroadcast(x, target_meta, broadcast_idx) + return x + +@primitive +def _unbroadcast(x, target_meta, broadcast_idx=0): + target_shape, target_ndim, _, target_iscomplex = target_meta + x_shape = onp.shape(x) + while onp.ndim(x) > target_ndim: + x = onp.sum(x, axis=broadcast_idx) for axis, size in enumerate(target_shape): - if size == 1: - x = anp.sum(x, axis=axis, keepdims=True) - if anp.iscomplexobj(x) and not target_iscomplex: - x = anp.real(x) + if size == 1: # TODO(mattjj): bug here w/ passing through scalars? + x = onp.sum(x, axis=axis, keepdims=True) + if onp.iscomplexobj(x) and not target_iscomplex: + x = onp.real(x) return x +def grad_unbroadcast(ans, x, target_meta, broadcast_idx=0): + meta = anp.metadata(x) + return lambda g: _broadcast(g, meta, broadcast_idx) +defvjp(_unbroadcast, grad_unbroadcast) + def unbroadcast_f(target, f): target_meta = anp.metadata(target) return lambda g: unbroadcast(f(g), target_meta)