Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: new canonicalize DType function? #151

Closed
NeilGirdhar opened this issue Feb 28, 2025 · 10 comments
Closed

ENH: new canonicalize DType function? #151

NeilGirdhar opened this issue Feb 28, 2025 · 10 comments
Labels
enhancement New feature or request new function

Comments

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Feb 28, 2025

Jax has canonicalize_dtype and PyTorch also has a notion of default types.

Can we provide canonicalize_dtype for all libraries?

Something like:

def canonicalize_dtype(xp: Namespace, dtype: DType | type[complex]) -> DType:
    if is_jax_namespace(xp):
        from jax.dtypes import canonicalize_dtype
        return canonicalize_dtype(dtype)  # Suppresses warning.
    return xp.empty((), dtype=dtype).dtype
@NeilGirdhar NeilGirdhar changed the title Canonicalize dtype? Canonicalize DType? Feb 28, 2025
@lucascolley
Copy link
Member

@lucascolley lucascolley changed the title Canonicalize DType? ENH: new canonicalize DType function? Feb 28, 2025
@lucascolley lucascolley added enhancement New feature or request new function labels Feb 28, 2025
@lucascolley
Copy link
Member

do you have some example code motivating this inclusion?

@NeilGirdhar
Copy link
Contributor Author

@lucascolley
Copy link
Member

https://github.com/data-apis/array-api/pull/848/files is in the same general area.

@NeilGirdhar
Copy link
Contributor Author

I'm not sure I understand how that solves the above problem?

@lucascolley
Copy link
Member

You're effectively asking for the default float dtype, right? So xp.astype(xp.empty(()), 'real floating').dtype should achieve the same thing as canonicalize_dtype(float)?

@NeilGirdhar
Copy link
Contributor Author


[utM] In [1]: import jax.numpy as jnp

[utM] In [2]: x = jnp.zeros(())
WARNING:2025-02-28 12:12:44,915:jax._src.xla_bridge:966: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

[utM] In [3]: from array_api_compat import *

[utM] In [4]: xp = array_namespace(x)

[utM] In [5]: xp.astype(xp.empty(()), 'real floating').dtype
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 xp.astype(xp.empty(()), 'real floating').dtype

File ~/src/tjax/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py:5632, in astype(x, dtype, copy, device)
   5630 if dtype is None:
   5631   dtype = dtypes.canonicalize_dtype(dtypes.float_)
-> 5632 dtypes.check_user_dtype_supported(dtype, "astype")
   5633 if issubdtype(x_arr.dtype, np.complexfloating):
   5634   if dtypes.isdtype(dtype, ("integral", "real floating")):

File ~/src/tjax/.venv/lib/python3.13/site-packages/jax/_src/dtypes.py:902, in check_user_dtype_supported(dtype, fun_name)
    898   warnings.warn("Passing an array as a dtype argument is deprecated; "
    899                 "instead of dtype=arr use dtype=arr.dtype.",
    900                 category=DeprecationWarning, stacklevel=3)
    901   return  # no further check needed, as array dtypes have already been validated.
--> 902 if issubdtype(dtype, extended):
    903   return
    904 # Avoid using `dtype in [...]` because of numpy dtype equality overloading.

File ~/src/tjax/.venv/lib/python3.13/site-packages/jax/_src/dtypes.py:431, in issubdtype(a, b)
    414 """Returns True if first argument is a typecode lower/equal in type hierarchy.
    415 
    416 This is like :func:`numpy.issubdtype`, but can handle dtype extensions such as
    417 :obj:`jax.dtypes.bfloat16` and `jax.dtypes.prng_key`.
    418 """
    419 # Main departures from np.issubdtype are:
    420 # - "extended" dtypes (like prng key types) are not normal numpy dtypes, so we
    421 #   need to handle them specifically. However, their scalar types do conform to
   (...)    428 # unhashable (e.g. custom objects with a dtype attribute). The following check is
    429 # fast and covers the majority of calls to this function within JAX library code.
    430 return _issubdtype_cached(
--> 431   a if isinstance(a, _types_for_issubdtype) else np.dtype(a),  # type: ignore[arg-type]
    432   b if isinstance(b, _types_for_issubdtype) else np.dtype(b),  # type: ignore[arg-type]
    433 )

TypeError: data type 'real floating' not understood

@lucascolley
Copy link
Member

TypeError: data type 'real floating' not understood

That's what data-apis/array-api#848 is changing!

@NeilGirdhar
Copy link
Contributor Author

Oh, my mistake! I guess if you're okay with this, you can just do xp.empty((), dtype=dtype).dtype. I'll close since this is a one-liner, although it's a bit abstruse.

@lucascolley
Copy link
Member

I'd probably be open to adding an alias if you would like to reopen in the future! The difficulty is probably just deciding exactly which shorthands should be in the API, and what the best name would be.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new function
Projects
None yet
Development

No branches or pull requests

2 participants