Skip to content

Commit

Permalink
Remove paxml references from README (#1271)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbhavani authored Feb 6, 2025
1 parent 8dd4054 commit 578140c
Showing 1 changed file with 2 additions and 51 deletions.
53 changes: 2 additions & 51 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/NVIDIA/JAX-Toolbox/blob/main/LICENSE.md)
[![Build](https://badgen.net/badge/build/check-status/blue)](#build-pipeline-status)

JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext), [Paxml](https://github.com/google/paxml), and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html).
JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext) and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html).

## Frameworks and Supported Models
We support and test the following JAX frameworks and model architectures. More details about each model and available containers can be found in their respective READMEs.

| Framework | Models | Use cases | Container |
| :--- | :---: | :---: | :---: |
| [maxtext](./rosetta/rosetta/projects/maxtext)| GPT, LLaMA, Gemma, Mistral, Mixtral | pretraining | `ghcr.io/nvidia/jax:maxtext` |
| [paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` |
| [t5x](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` |
| [t5x](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` |
| [big vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` |
Expand Down Expand Up @@ -204,54 +203,6 @@ We support and test the following JAX frameworks and model architectures. More d
</a>
</td>
</tr>
<tr>
<td>
<a href="https://github.com/NVIDIA/JAX-Toolbox/blob/main/.github/container/Dockerfile.pax">
<img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Upstream%20PAX%3D%7Bcore%2Cpaxml%2Cpraxis%7D">
</a>
</td>
<td>
<code>ghcr.io/nvidia/jax:upstream-pax</code>
</td>
<td>
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-upstream-pax-md">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-pax-build-amd64.json&logo=docker&label=amd64">
</a>
<br>
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-upstream-pax-md">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-pax-build-arm64.json&logo=docker&label=arm64">
</a>
</td>
<td>
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae#file-badge-upstream-pax-mgmn-test-json">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-upstream-pax-mgmn-test.json&logo=nvidia&label=A100%20distributed">
</a>
</td>
</tr>
<tr>
<td>
<a href="https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/Dockerfile.pax">
<img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Rosetta%20PAX%3D%7Bcore%2Cpaxml%2Cpraxis%7D">
</a>
</td>
<td>
<code>ghcr.io/nvidia/jax:pax</code>
</td>
<td>
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-pax-md">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-pax-amd64.json&logo=docker&label=amd64">
</a>
<br>
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-pax-md">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-pax-arm64.json&logo=docker&label=arm64">
</a>
</td>
<td>
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae#file-badge-rosetta-pax-mgmn-test-json">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-pax-mgmn-test.json&logo=nvidia&label=A100%20distributed">
</a>
</td>
</tr>
<tr>
<td>
<a href="https://github.com/NVIDIA/JAX-Toolbox/blob/main/.github/container/Dockerfile.maxtext">
Expand Down Expand Up @@ -317,7 +268,7 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb
| -------------------- | ----- | ----------- |
| `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. |

There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).
There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can also be tuned per workload. For example, each script includes a directory [xla_flags](./rosetta/rosetta/projects/maxtext/xla_flags).

For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page.

Expand Down

0 comments on commit 578140c

Please sign in to comment.