JAX can create str-dtyped tracers under eval_shape
with numpy 2
#25707
Labels
bug
Something isn't working
eval_shape
with numpy 2
#25707
Description
MWE:
Expected behaviour: I'm expecting that string-dtyped tracers probably shouldn't be create-able, and the above should throw an error just like the equivalent
jax.jit
statement does.I'm assuming this happens because numpy now has support for first-class string dtypes.
(NB the above uses a
lambda x: print(x)
rather than justprint
so that we get the side-effecting print every time the above is ran, without caching.)System info (python version, jaxlib version, accelerator, etc.)
JAX version 0.4.38
Numpy version 2.2.0
The text was updated successfully, but these errors were encountered: