diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index 8ae357a6..f07e31be 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -973,3 +973,37 @@ def pad_vjp(ans, array, pad_width, mode, **kwargs): defvjp(anp.pad, pad_vjp) + +# ----- VJP for numpy.take ----- + + +@primitive +def untake_along_axis(x, indices, shape, axis): + """Inverse of take along axis - scatters values back to original positions.""" + if axis is None: + # When axis is None, take flattens the array + result = onp.zeros(shape, dtype=x.dtype).ravel() + onp.add.at(result, indices, x) + return result.reshape(shape) + else: + # Handle negative axis + if axis < 0: + axis = len(shape) + axis + result = onp.zeros(shape, dtype=x.dtype) + # Create index tuple for add.at + idx = [slice(None)] * len(shape) + idx[axis] = indices + onp.add.at(result, tuple(idx), x) + return result + + +def grad_take(ans, a, indices, axis=None, out=None, mode="raise"): + shape = anp.shape(a) + + def vjp(g): + return untake_along_axis(g, indices, shape, axis) + + return vjp + + +defvjp(anp.take, grad_take)