You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am working with the meshgrid function, which works well for 1D arrays. However, extending it to higher-dimensional arrays is challenging, and I haven’t found suitable alternatives.
My use case involves the following shapes:
>>> x.shape # Shape of the first input
(a1, a2, ..., p)
>>> y.shape # Shape of the second input
(a1, a2, ..., q)
>>> jnp.meshgrid(x, y, axis=(-1, -1)).shape # Desired functionality and output shape
(a1, a2, ..., p, q)
I currently use vmap over the batched axes, but this approach seems insufficient when batching spans multiple axes.
Is there an alternative or more efficient way to achieve this functionality? If not, could this behaviour be considered a potential feature enhancement for meshgrid?
This looks like a job for jnp.vectorize, though unfortunately vectorize is only designed to work for functions with one output. You can work around this by vectorizing each output separately; for example:
defbatched_meshgrid(x, y, *, indexing='xy'):
signature="(n),(m)->(m,n)"ifindexing=='xy'else"(n),(m)->(n,m)"f1=jnp.vectorize(lambdax, y: jnp.meshgrid(x, y, indexing=indexing)[0], signature=signature)
f2=jnp.vectorize(lambdax, y: jnp.meshgrid(x, y, indexing=indexing)[1], signature=signature)
returnf1(x, y), f2(x, y)
Plugging your test cases into this gives the same output as with your approach. You could probably modify this approach to be more general if you wish.
I am working with the
meshgrid
function, which works well for 1D arrays. However, extending it to higher-dimensional arrays is challenging, and I haven’t found suitable alternatives.My use case involves the following shapes:
I currently use
vmap
over the batched axes, but this approach seems insufficient when batching spans multiple axes.Is there an alternative or more efficient way to achieve this functionality? If not, could this behaviour be considered a potential feature enhancement for
meshgrid
?Related discussions
Update
I reshaped the arrays to avoid multiple nested
vmaps
. My workaround code is,The text was updated successfully, but these errors were encountered: