Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX can create str-dtyped tracers under eval_shape with numpy 2 #25707

Open
patrick-kidger opened this issue Jan 2, 2025 · 1 comment
Open
Assignees
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

Description

MWE:

jax.eval_shape(lambda x: print(x), np.str_('hi'))
# Traced<ShapedArray(str64[])>with<DynamicJaxprTrace>

Expected behaviour: I'm expecting that string-dtyped tracers probably shouldn't be create-able, and the above should throw an error just like the equivalent jax.jit statement does.

I'm assuming this happens because numpy now has support for first-class string dtypes.

(NB the above uses a lambda x: print(x) rather than just print so that we get the side-effecting print every time the above is ran, without caching.)

System info (python version, jaxlib version, accelerator, etc.)

JAX version 0.4.38
Numpy version 2.2.0

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 2, 2025

Thanks for the report! I can reproduce this in JAX v0.4.38, but on HEAD I get the expected error:

In [1]: import numpy as np

In [2]: import jax

In [3]: jax.eval_shape(lambda x: print(x), np.str_('hi'))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 jax.eval_shape(lambda x: print(x), np.str_('hi'))

    [... skipping hidden 5 frame]

File ~/github/google/jax/jax/_src/pjit.py:760, in _infer_input_type(fun, dbg, explicit_args)
    757 except TypeError:
    758   arg_description = (f"path {dbg.arg_names[i]}" if dbg  # type: ignore
    759                      else f"flattened argument number {i}")  # type: ignore
--> 760   raise TypeError(
    761     f"Error interpreting argument to {fun} as an abstract array."
    762     f" The problematic value is of type {type(x)} and was passed to"  # type: ignore
    763     f" the function at {arg_description}.\n"
    764     "This typically means that a jit-wrapped function was called with a non-array"
    765     " argument, and this argument was not marked as static using the"
    766     " static_argnums or static_argnames parameters of jax.jit."
    767   ) from None
    768 if config.mutable_array_checks.value:
    769   _check_no_aliased_ref_args(dbg, avals, explicit_args)

TypeError: Error interpreting argument to <function <lambda> at 0x10367f4c0> as an abstract array. The problematic value is of type <class 'numpy.str_'> and was passed to the function at path x.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

I suspect this was probably fixed in the course of refactoring the abstract value utilities after the 0.4.38 release.

@jakevdp jakevdp self-assigned this Jan 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants