You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue tracks improvements to debug_nans/debug_infs. See the description of #25519 for some examples of status quo behavior.
NaNs in the forward pass of a shard_map function currently only report "Invalid value in a sharded computation," and the line where the shard_map function was called. The traceback should extend to the lax/jnp op within the shard_map function that produced the NaN.
Avoid re-running functions with side effects and collectives.
Consider stopping the traceback at the jax.numpy boundary instead of the lax primitive that NaNed in a jnp function, or using the JAX_TRACEBACK_FILTERING flag to switch this behavior.
Improve the error messages for NaNs in the backward pass. Currently they report a stacktrace of line numbers, but it would be nice to have a regular Python stacktrace with code highlighted.
For pmap, the Python dispatch path reports "Invalid value in parallel computation" and the traceback stops at the call to the pmap function. It should extend to the arithmetic op where the NaN occurred, like it does for the C++ dispatch path.
The text was updated successfully, but these errors were encountered:
This issue tracks improvements to debug_nans/debug_infs. See the description of #25519 for some examples of status quo behavior.
jax.numpy
boundary instead of the lax primitive that NaNed in a jnp function, or using theJAX_TRACEBACK_FILTERING
flag to switch this behavior.The text was updated successfully, but these errors were encountered: