Skip to content

feat: Add GPU support to experimental JAX inference framework #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

tohaowu
Copy link

@tohaowu tohaowu commented May 21, 2025

This commit introduces GPU support for the JAX-based experimental inference framework located in experimental/jax.

Key changes include:

  • Modified experimental/jax/requirements.txt to use jax[cuda-pip] allowing JAX to utilize NVIDIA GPUs.
  • Refined experimental/jax/inference/parallel/mesh.py to correctly handle GPU devices during mesh creation, ensuring robust platform detection alongside existing TPU support.
  • Verified that experimental/jax/inference/runtime/offline_inference.py correctly uses jax.devices() and is compatible with the new GPU handling in the mesh creation logic.
  • Updated experimental/jax/README.md to include instructions for setting up JAX with GPU support and to reflect that NVIDIA GPUs are now a supported backend.
  • Added a new test script experimental/jax/inference/entrypoint/run_gpu_test.py and instructions for you to verify GPU functionality with a small number of prompts.

These changes allow you, if you have compatible NVIDIA GPUs and CUDA setups, to run the experimental JAX inference framework, expanding its usability beyond TPUs.

This commit introduces GPU support for the JAX-based experimental inference framework located in `experimental/jax`.

Key changes include:

- Modified `experimental/jax/requirements.txt` to use `jax[cuda-pip]` allowing JAX to utilize NVIDIA GPUs.
- Refined `experimental/jax/inference/parallel/mesh.py` to correctly handle GPU devices during mesh creation, ensuring robust platform detection alongside existing TPU support.
- Verified that `experimental/jax/inference/runtime/offline_inference.py` correctly uses `jax.devices()` and is compatible with the new GPU handling in the mesh creation logic.
- Updated `experimental/jax/README.md` to include instructions for setting up JAX with GPU support and to reflect that NVIDIA GPUs are now a supported backend.
- Added a new test script `experimental/jax/inference/entrypoint/run_gpu_test.py` and instructions for you to verify GPU functionality with a small number of prompts.

These changes allow you, if you have compatible NVIDIA GPUs and CUDA setups, to run the experimental JAX inference framework, expanding its usability beyond TPUs.
@tohaowu tohaowu requested a review from vipannalla as a code owner May 21, 2025 21:12
@tohaowu tohaowu closed this May 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant