Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add pad #69

Open
ev-br opened this issue Dec 19, 2024 · 1 comment · Fixed by #71 · May be fixed by #72
Open

ENH: add pad #69

ev-br opened this issue Dec 19, 2024 · 1 comment · Fixed by #71 · May be fixed by #72
Labels
enhancement New feature or request new function

Comments

@ev-br
Copy link
Contributor

ev-br commented Dec 19, 2024

np.pad is not in array API spec, and is unlikely to get there. Still is a sometimes useful function, which is available in some backends (numpy, cupy, jax.numpy) but not others (torch). Thus it'd be great to add it to array-api-extra.

Implementing the full set of mode keywords is somewhat tricky, but the most useful one, mode="constant" is easy to implement even with pytorch. A ready implementation is available in the scipy PR:
https://github.com/scipy/scipy/pull/22122/files#diff-351836adc98d076c1552d17a57c52e6aa8ca43760bae44bea9190e55b4769b7fR873

The torch implementation derives from torch._numpy, https://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045

The version from the scipy PR is under the fold.

def xp_pad(x, pad_width, mode='constant', *, xp, **kwargs):
    # xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse
    # http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045
    # for mode = 'constant'
    if mode != 'constant':
        raise NotImplementedError()

    value = kwargs.get("constant_values", 0)

    if is_array_api_strict(xp):
        np_x = np.asarray(x)
        padded = np.pad(np_x, pad_width, mode=mode, **kwargs)
        return xp.asarray(padded)
    elif is_torch(xp):
        pad_width = xp.asarray(pad_width)
        pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
        pad_width = xp.flip(pad_width, axis=(0,)).flatten()
        return xp.nn.functional.pad(x, tuple(pad_width), value=value)
    else:
        return xp.pad(x, pad_width, mode=mode, **kwargs)
@lucascolley lucascolley added the enhancement New feature or request label Dec 19, 2024
@lucascolley lucascolley changed the title add pad ENH: add pad Dec 19, 2024
@ev-br
Copy link
Contributor Author

ev-br commented Dec 26, 2024

Closing as completed feels a bit premature.

@lucascolley lucascolley reopened this Dec 26, 2024
@lucascolley lucascolley linked a pull request Dec 26, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new function
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants