Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched meshgrid or alternative #25696

Open
Qazalbash opened this issue Dec 30, 2024 · 1 comment
Open

Batched meshgrid or alternative #25696

Qazalbash opened this issue Dec 30, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@Qazalbash
Copy link

Qazalbash commented Dec 30, 2024

I am working with the meshgrid function, which works well for 1D arrays. However, extending it to higher-dimensional arrays is challenging, and I haven’t found suitable alternatives.

My use case involves the following shapes:

>>> x.shape  # Shape of the first input
(a1, a2, ..., p)
>>> y.shape  # Shape of the second input
(a1, a2, ..., q)
>>> jnp.meshgrid(x, y, axis=(-1, -1)).shape  # Desired functionality and output shape
(a1, a2, ..., p, q)

I currently use vmap over the batched axes, but this approach seems insufficient when batching spans multiple axes.

Is there an alternative or more efficient way to achieve this functionality? If not, could this behaviour be considered a potential feature enhancement for meshgrid?

Related discussions


Update

I currently use vmap over the batched axes, but this approach seems insufficient when batching spans multiple axes.

I reshaped the arrays to avoid multiple nested vmaps. My workaround code is,

from functools import partial
from typing import List

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike


def batched_meshgrid(
    *xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = "xy"
) -> List[jax.Array]:
    args = [jnp.asarray(x) for x in xi]
    shapes = [jnp.shape(x) for x in args]
    batch_shape = shapes[0][:-1]
    if batch_shape == ():
        return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)
    if not all(shape[:-1] == batch_shape for shape in shapes):
        raise ValueError(
            "batched_meshgrid expects all input arrays to have the "
            "same shape except for the last dimension."
        )
    last_dims = [shape[-1] for shape in shapes]
    if indexing == "xy" and len(args) >= 2:
        last_dims[0], last_dims[1] = last_dims[1], last_dims[0]
    nargs = len(args)
    shape_size = jnp.prod(jnp.asarray(batch_shape))
    args = [jnp.reshape(x, (shape_size, -1)) for x in args]
    output = jax.vmap(
        partial(jnp.meshgrid, copy=copy, sparse=sparse, indexing=indexing),
        in_axes=tuple(0 for _ in range(nargs)),
        out_axes=0,
    )(*args)
    batch_shape = list(batch_shape)
    return [jnp.reshape(x, batch_shape + last_dims) for x in output]


## Testing code

batch_shape = (99, 12)
x_batched = jnp.linspace(
    jnp.full(batch_shape, fill_value=5.0),
    jnp.full(batch_shape, fill_value=80.0),
    2,
    axis=-1,
)
y_batched = jnp.linspace(jnp.zeros(batch_shape), jnp.ones(batch_shape), 3, axis=-1)

print(x_batched.shape, y_batched.shape)

yy = batched_meshgrid(x_batched, y_batched, indexing="ij")
print([y.shape for y in yy])
(99, 12, 2) (99, 12, 3)
[(99, 12, 2, 3), (99, 12, 2, 3)]
@Qazalbash Qazalbash added the enhancement New feature or request label Dec 30, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 30, 2024

This looks like a job for jnp.vectorize, though unfortunately vectorize is only designed to work for functions with one output. You can work around this by vectorizing each output separately; for example:

def batched_meshgrid(x, y, *, indexing='xy'):
  signature = "(n),(m)->(m,n)" if indexing == 'xy' else "(n),(m)->(n,m)"
  f1 = jnp.vectorize(lambda x, y: jnp.meshgrid(x, y, indexing=indexing)[0], signature=signature)
  f2 = jnp.vectorize(lambda x, y: jnp.meshgrid(x, y, indexing=indexing)[1], signature=signature)
  return f1(x, y), f2(x, y)

Plugging your test cases into this gives the same output as with your approach. You could probably modify this approach to be more general if you wish.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants