Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected NaN when signing ±inf + 0j #25679

Open
lfaucheux opened this issue Dec 23, 2024 · 4 comments
Open

Unexpected NaN when signing ±inf + 0j #25679

lfaucheux opened this issue Dec 23, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@lfaucheux
Copy link

lfaucheux commented Dec 23, 2024

Description

I am getting unexpected nan sign with the following example.

>>> import jax
>>> jax.numpy.sign(x := float('inf') + 0j)
Array(nan+0.j, dtype=complex128, weak_type=True)
# import numpy
# >>> numpy.sign(x)
# (1+0j)

What jax/jaxlib version are you using?

jax v0.4.38, jaxlib v0.4.38 numpy v1.26.4

Which accelerator(s) are you using?

CPU

Additional system info?

Python 3.11, Windows

NVIDIA GPU info

No response

@lfaucheux lfaucheux added the bug Something isn't working label Dec 23, 2024
@lfaucheux lfaucheux changed the title Unexpected NaN sign when signing inf + 0j Unexpected NaN when signing inf + 0j Dec 24, 2024
@lfaucheux lfaucheux changed the title Unexpected NaN when signing inf + 0j Unexpected NaN when signing ±inf + 0j Dec 24, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 3, 2025

I find NumPy's output strange here, because the complex sign is defined as x / abs(x), and inf / inf is NaN.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 3, 2025

In any case, it seems NumPy has special handling for this case, so jax.numpy probably should match this.

@jakevdp jakevdp self-assigned this Jan 3, 2025
@lfaucheux
Copy link
Author

True. The confrontation between the result of numpy.sign(float('inf') + 0j) and the numpy's complex sign definition also left me dubious.

@pearu
Copy link
Collaborator

pearu commented Jan 4, 2025

The apparent nan part in complex division or multiplication with floats is a general property of libraries that don't implement mixed mode arithmetic. When one of the operands in complex expression is a float value, this is converted to complex value that introduces the imaginary part with zero value. This zero imaginary part together with inf values in other operand is the source of nan values in the complex division or multiplication results.

For example, the result of

>>> jnp.complex64(jnp.inf+1j) * 2
Array(inf+nanj, dtype=complex64)

forms as follows:

(inf + 1j) * 2
-> (inf + 1j) * (2 + 0j) 
-> (inf * 2 - 1 * 0) + (1 * 2 + inf * 0) * 1j
-> inf + nan * 1j

because inf * 0 is nan.

Similarly, the nan value will occur in division:

>>> jnp.complex64(jnp.inf+1j) / 2
Array(inf+nanj, dtype=complex64)

that is,

(inf + 1j) / 2
-> (inf + 1j) / (2 + 0j)
-> ... (this is an exercise to the reader to lay out the complex division formula here)
-> inf + nan * 1j 

For libraries that do implement mixed mode arithmetic, complex division/multiplication with floats would not involve nan values:

(inf + 1j) * 2 -> (inf * 2) + (1 * 2) * 1j -> inf + 2j
(inf + 1j) / 2 -> (inf / 2) + (1 / 2) * 1j -> inf + 0.5j

In practice, however, libraries typically implement mixed mode arithmetic or use special handling when the expected result would be different from the result of using non-mixed mode arithmetic. For instance, numpy casts floats to complex in complex operations, similar to JAX (or vice versa :) ):

>>> np.complex64(np.inf+1j) * 2
(inf+nanj)

but numpy.sign on complex inputs uses special handling (before numpy 2):

>>> print(np.sign.__doc__)
<snip>
For complex inputs, the `sign` function returns
``sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j``.
<snip>

or (starting from numpy 2):

>>> print(np.sign.__doc__)
<snip>
For complex inputs, the `sign` function returns ``x / abs(x)``, ...
<snip>

and

>>> np.sign(np.inf+1j)
np.complex128(1+0j)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants