diff --git a/autograd/numpy/linalg.py b/autograd/numpy/linalg.py index 10e1e96e4..884c496c7 100644 --- a/autograd/numpy/linalg.py +++ b/autograd/numpy/linalg.py @@ -74,10 +74,18 @@ def check_implemented(): unroll = lambda a: anp.rollaxis(anp.rollaxis(a, a.ndim-2, row_axis), a.ndim-1, col_axis) + # Used for returning zero gradient of zero norms + replace_zero = lambda x, val: anp.where(x, x, val) + # For manually set the second derivative of norm to zero, to match np.abs() + replace_zero_ans = lambda x, val: anp.where(expand(ans), x, val) + check_implemented() def vjp(g): if ord is None or ord == 2 or ord is 'fro': - return expand(g / ans) * x + # The gradient is 1 / ans * x + # FIXME: when x is complex vector, the gradient seems to be: + # 1 / ans * conj(x) + return expand(g / replace_zero(ans, 1.)) * replace_zero_ans(x, 0.) elif ord == 'nuc': dot = anp.dot if x.ndim == 2 else partial(anp.einsum, '...ij,...jk->...ik') x_rolled = roll(x) @@ -88,8 +96,14 @@ def vjp(g): g = expand(g) return g * uvt else: - # see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm - return expand(g / ans**(ord-1)) * x * anp.abs(x)**(ord-2) + # See https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm + # The gradient is 1 / ans**(ord-1) * abs(x)**(ord-1) * sign(x) + # Use `abs(x)**(ord-1) * sign(x)` instead of `abs(x)**(ord-2) * x` + # avoids NaN when x contains zero. + # FIXME: when x is complex vector, the gradient seems to be: + # 1 / ans**(ord-1) * abs(x)**(ord-1) * conj(x) / abs(x) + return expand(g / replace_zero(ans**(ord-1), 1.0)) \ + * anp.abs(x)**(ord-1) * anp.sign(x) return vjp defvjp(norm, grad_norm) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index a79ea6c0d..42b7b560e 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -152,6 +152,26 @@ def fun(x): return np.linalg.norm(x, ord='nuc', axis=(0, 1)) mat = npr.randn(D, D-1, D-2) check_grads(fun)(mat) +def test_vector_zero_norm(): + def helper(size, ord): + def fun(x): return np.linalg.norm(x, ord=ord) + vec = np.zeros(size) + check_grads(fun, order=1)(vec) + for ord in [1.1, 2, 3]: + for size in [1, 2]: + yield helper, size, ord + +def test_vector_mix_zero_norm_axis(): + def helper(axis, ord): + def fun(x): return np.linalg.norm(x, ord=ord, axis=axis) + size = (2,2) + vecs = np.zeros(size) + vecs[0,0] = npr.randn(1) + check_grads(fun, order=1)(vecs) + for axis in [0, 1]: + for ord in [1.1, 2]: + yield helper, axis, ord + def test_eigvalh_lower(): def fun(x): w, v = np.linalg.eigh(x)