Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MLX #299

Open
gabrieldemarmiesse opened this issue Feb 19, 2025 · 4 comments · May be fixed by #301
Open

Support for MLX #299

gabrieldemarmiesse opened this issue Feb 19, 2025 · 4 comments · May be fixed by #301

Comments

@gabrieldemarmiesse
Copy link

gabrieldemarmiesse commented Feb 19, 2025

Hello,
I'm opening an issue to request support for MLX. Currently we get TypeError: type 'nanobind.nb_type_0' is not an acceptable base type. Hopefully it will gather interest from the community.

import mlx.core as mx
from beartype import beartype
from jaxtyping import jaxtyped, Int


@jaxtyped(typechecker=beartype)
def hello(a: Int[mx.array, "8"]):
    pass

hello(mx.zeros((8,), dtype=mx.int32))
Traceback (most recent call last):
  File "/Users/gabrieldemarmiesse/projects/V2LLMs/moshi_mlx/./moshi_mlx/a.py", line 7, in <module>
    def hello(a: Int[mx.array, "8"]):
                 ~~~^^^^^^^^^^^^^^^
  File "/Users/gabrieldemarmiesse/projects/V2LLMs/moshi_mlx/.venv/lib/python3.12/site-packages/jaxtyping/_array_types.py", line 678, in __getitem__
    out = _make_array(array_type, dim_str, cls)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gabrieldemarmiesse/projects/V2LLMs/moshi_mlx/.venv/lib/python3.12/site-packages/jaxtyping/_array_types.py", line 617, in _make_array
    else _make_metaclass(type(array_type))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gabrieldemarmiesse/projects/V2LLMs/moshi_mlx/.venv/lib/python3.12/site-packages/jaxtyping/_array_types.py", line 327, in _make_metaclass
    class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
TypeError: type 'nanobind.nb_type_0' is not an acceptable base type

Cheers!

@patrick-kidger
Copy link
Owner

Looks like this might be an MLX issue, but happy to take a PR here if it would help.

@gabrieldemarmiesse
Copy link
Author

Do you have a hint on how to get started with this? I'm noticing that jaxtyping is trying to subclass anything passed as the type hint: https://github.com/patrick-kidger/jaxtyping/blob/main/jaxtyping/_array_types.py#L620 but mx.array can't be subclassed, likely because it's coming from nanobind. Is there any way to skip this step without using Any? Using Any would allow numpy arrays to be passed for example, and it's not what we want.

@patrick-kidger
Copy link
Owner

FWIW I don't completely remember why I went to the effort of subclassing the provided type, instead of treating this as a completely independent type. Does removing that break any tests?

@gabrieldemarmiesse gabrieldemarmiesse linked a pull request Feb 21, 2025 that will close this issue
@gabrieldemarmiesse
Copy link
Author

@patrick-kidger #301 should fix this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants