Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get "invalid value (nan) encountered in jit" even when jit disabled globally #25701

Open
SUSYUSTC opened this issue Dec 31, 2024 · 7 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@SUSYUSTC
Copy link

Description

In my following code clearly there's no jit anywhere, but the error suggests that the issue comes from jit.

import jax
jax.disable_jit(True)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_debug_nans", True)
import jax.numpy as jnp
jnp.log(-10)

Error information:

...
...
Traceback (most recent call last):
  File "/home/jiace/shadow/TN/shadow_vmc/debug.py", line 6, in <module>
    jnp.log(-10)
FloatingPointError: invalid value (nan) encountered in jit(log). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.36
jaxlib: 0.4.36
numpy: 1.26.1
python: 3.11.5 | packaged by conda-forge | (main, Aug 27 2023, 03:34:09) [GCC 12.3.0]
device info: NVIDIA GeForce RTX 2060-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='jiace-XPS-8930', release='5.4.0-150-generic', version='#167~18.04.1-Ubuntu SMP Wed May 24 00:51:42 UTC 2023', machine='x86_64')

$ nvidia-smi
Mon Dec 30 22:29:29 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05 Driver Version: 525.147.05 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:01:00.0 Off | N/A |
| 32% 28C P2 15W / 160W | 131MiB / 6144MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2745 G /usr/lib/xorg/Xorg 16MiB |
| 0 N/A N/A 13172 C ...envs/py311/bin/python3.11 110MiB |
+-----------------------------------------------------------------------------+

@SUSYUSTC SUSYUSTC added the bug Something isn't working label Dec 31, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 31, 2024

Thanks for the report! Just to be clear, jax.disable_jit is a context manager, and does not affect the global configuration. If you want to disable JIT globally, you should do something like this:

jax.config.update('jax_disable_jit', True)

Even with this change, however, the same error message appears. This is due to the fact that jax_disable_jit is more of a debugging tool than something we expect users to configure as a matter of course, and so the error message is written assuming more typical configurations. We assume that the user advanced enough to use this setting is also advanced enough to understand the context of the error message.

The fix, if we were to do it, would be to add a specific check for jax_disable_jit and manually create an alternative error message in this location. In my view, this isn't worth the added complexity.

What do you think?

@SUSYUSTC
Copy link
Author

SUSYUSTC commented Jan 1, 2025

Thanks for your response! I guess my question is more related to the other side: the message seems to suggest that there's an optimized function that generates nan, and there's another de-optimized function that does not generate nan. However in this case log(-10) is always nan, so why it says that there's some de-optimized function that does not generate nan? Actually I have some complicated function in my realistic setting which has the same error (maybe related to autograd of svd which is natually unstable), and I hope that I can disable jit globally to remove the nan but I can't.

@mattjj
Copy link
Collaborator

mattjj commented Jan 1, 2025

Sorry, that error message is busted. #25519 will fix it, but for now the "de-optimized function doesn't generate a nan" erroneously happens every time.

@mattjj
Copy link
Collaborator

mattjj commented Jan 1, 2025

Here's what I get when I run on the #25519 branch:

Invalid nan value encountered in the output of a jax.jit function. Calling the de-optimized version.
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/25701.py", line 6, in <module>
    jnp.log(-10)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/ufuncs.py", line 489, in log
    return lax.log(*promote_args_inexact('log', x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 342, in log
    return log_p.bind(x)
           ^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in log

I think this is a decent error message, but the phrasing is still a bit confusing with disable_jit=True present. @emilyfertig let's think about disable_jit=True, either in #25519 or a follow-up.

@SUSYUSTC sorry for the confusion! We hope to land that PR soon. In the meantime, if you think debug_nans isn't giving you a useful error message where it should, you could try patching that branch.

@SUSYUSTC
Copy link
Author

SUSYUSTC commented Jan 1, 2025

Thanks a lot! That is indeed what I expected to have. Another question is related to nan appeared in autograd. Say I have a function whose value is valid but grad is not. Is it possible that jax can give me the traceback which tells me which exact line generates nan in a very complicated function? Is this feature currently available in this branch or even somewhere in the main branch? Here's an example:

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_debug_nans", True)
import jax.numpy as jnp


def func(x):
    y = x ** 2
    z = jnp.sqrt(y)
    return z


print(func(0.0))
print(jax.grad(func)(0.0))

Clearly the issue is that sqrt is not differentiable at 0.0, so I hope it could pin to the z = jnp.sqrt(y) line. However the error message is

...
Traceback (most recent call last):
  File "debug.py", line 15, in <module>
    print(jax.grad(func)(0.0))
          ^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in jit(mul).
...

@mattjj
Copy link
Collaborator

mattjj commented Jan 1, 2025

Here's the traceback on the branch:

0.0
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/25701_2.py", line 14, in <module>
    print(jax.grad(func)(0.0))
  File "/usr/local/google/home/mattjj/packages/jax/25701_2.py", line 8, in func
    y = x ** 2
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/array_methods.py", line 1049, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/array_methods.py", line 573, in deferring_binary_op
    return binary_op(*args)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/ufuncs.py", line 2645, in power
    return lax.integer_pow(x1, x2)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 405, in integer_pow
    return integer_pow_p.bind(x, y=y)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 2672, in _integer_pow_jvp
    return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 453, in mul
    return mul_p.bind(x, y)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in mul
When differentiating the code at the top of the callstack:
/usr/local/google/home/mattjj/packages/jax/25701_2.py:8:8 (func)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/25701_2.py", line 14, in <module>
    print(jax.grad(func)(0.0))
          ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 2815, in _mul_transpose
    return [_unbroadcast(x.aval, mul(ct, y)), None]
                                 ^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 453, in mul
    return mul_p.bind(x, y)
           ^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in mul
When differentiating the code at the top of the callstack:
/usr/local/google/home/mattjj/packages/jax/25701_2.py:8:8 (func)
--------------------

We'd like to improve this further since it's not as easy to read as we'd like, but there's interesting information there:

  • the first traceback shows you the primal code that was problematic, and is pointing to the differentiation rule for the x ** 2 line
  • the second traceback shows you the backward pass problem (and since it's the backward pass, it doesn't include any of your code in the stacktrace), namely the transpose rule _mul_transpose.

The reason the issue in this particular example shows up with the mul and not the sqrt is that, effectively, the sqrt VJP rule actually produces an inf cotangent, and the nan only arises when we do the multiply against x==0. when applying the VJP of the mul.

If we instead set jax.config.update('jax_debug_infs', True), we get:

0.0
Invalid nan value encountered in the output of a jax.jit function. Calling the de-optimized version.
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/25701_2.py", line 14, in <module>
    print(jax.grad(func)(0.0))
          ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/25701_2.py", line 9, in func
    z = jnp.sqrt(y)
        ^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/ufuncs.py", line 1146, in sqrt
    return lax.sqrt(*promote_args_inexact('sqrt', x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 409, in sqrt
    return sqrt_p.bind(x)
           ^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 2582, in <lambda>
    ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
                                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/lax.py", line 462, in div
    return div_p.bind(x, y)
           ^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (inf) encountered in div

That points to the divide-by-zero that happens in the JVP of sqrt, and in particular to "25701_2.py, line 9, in func", as we might expect.

WDYT?

@SUSYUSTC
Copy link
Author

SUSYUSTC commented Jan 1, 2025

Thanks a lot for the explanation! It is super clear. Actually I can get the same message in the main branch code but I didn't understand its meaning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants