Skip to content

Conversation

@SIVALANAGASHANKARNIVAS
Copy link

Implements gradient computation for numpy.take function.

  • Adds untake_along_axis primitive for scattering gradients back
  • Handles both axis=None (flattened) and specific axis cases
  • Uses numpy.add.at for proper gradient accumulation with repeated indices

Fixes #743

Changes Made

This PR adds VJP (Vector-Jacobian Product) support for numpy.take, enabling gradient computation through this function.

Implementation Details

  • Created untake_along_axis primitive that scatters gradients back to original array positions
  • The VJP handles both axis=None (flattened array) and specific axis cases
  • Uses numpy.add.at for proper gradient accumulation when indices are repeated

Testing

With this change, the following code now works:

import autograd.numpy as anp
import autograd as ag

x = anp.array([[1., 2., 3.], [4., 5., 6.]])
idx = anp.array([0, 2])

def foo(x):
    return anp.take(x, idx, axis=1).sum()

grad_fn = ag.grad(foo)
print(grad_fn(x))  # Now works!

Implements gradient computation for numpy.take function.

- Adds untake_along_axis primitive for scattering gradients back
- Handles both axis=None (flattened) and specific axis cases
- Uses numpy.add.at for proper gradient accumulation with repeated indices

Fixes HIPS#743
@agriyakhetarpal
Copy link
Collaborator

Thanks @SIVALANAGASHANKARNIVAS – could you please add a few tests? 🙏🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

support numpy.take

2 participants