Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected NaN gradient of jnp.abs at ±inf + 0j #25681

Open
lfaucheux opened this issue Dec 24, 2024 · 0 comments
Open

Unexpected NaN gradient of jnp.abs at ±inf + 0j #25681

lfaucheux opened this issue Dec 24, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@lfaucheux
Copy link

lfaucheux commented Dec 24, 2024

Description

I am getting unexpected nan gradient with the following example.

>>> import jax
>>> (g := jax.grad(jax.numpy.abs))(float('inf') + 0j)
Array(nan+nanj, dtype=complex128)
# >>> g(float('inf'))
# Array(1., dtype=float64, weak_type=True)

What jax/jaxlib version are you using?

jax v0.4.38, jaxlib v0.4.38 numpy v1.26.4

Which accelerator(s) are you using?

CPU

Additional system info?

Python 3.11, Windows

NVIDIA GPU info

No response

@lfaucheux lfaucheux added the bug Something isn't working label Dec 24, 2024
@lfaucheux lfaucheux changed the title Unexpected NaN gradient of jnp.abs at inf + 0j Unexpected NaN gradient of jnp.abs at ±inf + 0j Dec 24, 2024
@dfm dfm self-assigned this Dec 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants