Skip to content

Commit f769b7a

Browse files
committed
MAINT: first pass cleanup the conversion artefacts
1 parent e81b19c commit f769b7a

File tree

13 files changed

+175
-171
lines changed

13 files changed

+175
-171
lines changed

docs/source/_static/images/JAX_ecosystem.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/source/_static/images/Tunix_diagram.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/source/_static/images/async_checkpointing.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/source/_static/images/programming_TPUS.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/source/_static/images/serving_orbax_nse.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/source/ecosystem_overview/architectural.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ If this trend accelerates, all high-level frameworks as they exist today risk be
66

77
For TPUs to provide a clear path to this level of performance, the ecosystem must expose an API layer that is closer to the hardware, enabling the development of these highly specialized kernels. As this report will detail, the JAX stack is designed to solve this by offering a continuum of abstraction (See Figure 2), from the automated, high-level optimizations of the XLA compiler to the fine-grained, manual control of the Pallas kernel-authoring library.
88

9-
![][image3]
9+
![](../_static/images/programming_TPUS.svg)
1010

1111
**Figure 2: The JAX continuum of abstraction**
1212

docs/source/ecosystem_overview/comparative.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ The modern Machine Learning landscape offers many excellent, mature toolchains.
44

55
While many frameworks offer a wide array of features, the JAX AI Stack provides specific, powerful differentiators in key areas of the development lifecycle:
66

7-
* **A Simpler, More Powerful Developer Experience:** The "chainable gradient transformation paradigm" of [**Optax**](#optax:-composable-gradient-processing-and-optimization-strategies) allows for more powerful and flexible optimization strategies that are declared once, rather than imperatively managed in the training loop.1 At the system level, the "simpler single controller interface" of **Pathways** abstracts away the complexity of multi-pod, multi-slice training, a significant simplification for researchers.
8-
* **Engineered for "Hero-Scale" Resilience:** The JAX stack is designed for extreme-scale training. **Orbax** provides "hero-scale training resilience" features like emergency and multi-tier checkpointing. This is complemented by **Grain**, which offers "full support for reproducibility with deterministic global shuffles and checkpointable data loaders". The ability to atomically checkpoint the data pipeline state (Grain) with the model state (Orbax) is a critical capability for guaranteeing reproducibility in long-running jobs.
9-
* **A Complete, End-to-End Ecosystem:** The stack provides a cohesive, end-to-end solution. Developers can use [**MaxText**](https://maxtext.readthedocs.io/en/latest/) as a SOTA reference for training, [**Tunix**](https://tunix.readthedocs.io/en/latest/) for alignment, and follow a clear, dual-path to production with **vLLM** (for vLLM compatibility) and **NSE** (for native JAX performance).
7+
* **A Simpler, More Powerful Developer Experience:** The "chainable gradient transformation paradigm" of [**Optax**](#optax:composable) allows for more powerful and flexible optimization strategies that are declared once, rather than imperatively managed in the training loop.1 At the system level, the "simpler single controller interface" of **Pathways** abstracts away the complexity of multi-pod, multi-slice training, a significant simplification for researchers.
8+
* **Engineered for "Hero-Scale" Resilience:** The JAX stack is designed for extreme-scale training. **Orbax** provides "hero-scale training resilience" features like emergency and multi-tier checkpointing. This is complemented by **Grain**, which offers "full support for reproducibility with deterministic global shuffles and checkpointable data loaders". The ability to atomically checkpoint the data pipeline state (Grain) with the model state (Orbax) is a critical capability for guaranteeing reproducibility in long-running jobs.
9+
* **A Complete, End-to-End Ecosystem:** The stack provides a cohesive, end-to-end solution. Developers can use [**MaxText**](https://maxtext.readthedocs.io/en/latest/) as a SOTA reference for training, [**Tunix**](https://tunix.readthedocs.io/en/latest/) for alignment, and follow a clear, dual-path to production with **vLLM-TPU** (for vLLM compatibility) and **NSE** (for native JAX performance).
1010

1111
While many stacks are vastly similar from a high-level software standpoint, the deciding factor often comes down to **Performance/TCO**, which is where the co-design of JAX and TPUs provides a distinct advantage. This Performance/TCO benefit is a direct result of the "vertical integration across software and TPU hardware". The ability of the **XLA** compiler to fuse operations specifically for the TPU architecture, or for the **XProf** profiler to leverage hardware hooks for \<1% overhead profiling, are tangible benefits of this deep integration.
1212

13-
For organizations adopting this stack, the "full featured nature" of the JAX AI Stack minimizes the cost of migration. For customers employing popular open model architectures, a shift from other frameworks to [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion) is often a matter of setting up config files. Furthermore, the stack's ability to ingest popular checkpoint formats like safetensors allows existing checkpoints to be migrated over without needing costly re-training.
13+
For organizations adopting this stack, the "full featured nature" of the JAX AI Stack minimizes the cost of migration. For customers employing popular open model architectures, a shift from other frameworks to [MaxText](#foundational-model-maxtext-and) is often a matter of setting up config files. Furthermore, the stack's ability to ingest popular checkpoint formats like safetensors allows existing checkpoints to be migrated over without needing costly re-training.
1414

1515
The table below provides a mapping of the components provided by the JAX AI stack and their equivalents in other frameworks or libraries.
1616

@@ -28,6 +28,8 @@ The table below provides a mapping of the components provided by the JAX AI stac
2828
| Post training / tuning | Tunix | VERL, NeMoRL |
2929
| Profiling | XProf | PyTorch profiler, NSight systems, NSight Compute |
3030
| Foundation model Training | MaxText, MaxDiffusion | NeMo-Megatron, DeepSpeed, TorchTitan |
31-
| LLM inference | vLLM | vLLM, SGLang |
31+
| LLM inference | vLLM-TPU | vLLM, SGLang |
3232
| Non-LLM Inference | NSE | Triton Inference Server, RayServe |
3333

34+
35+
[^8]: Some of the equivalents here are not true 1:1 comparisons because other frameworks draw API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently.

docs/source/ecosystem_overview/conclusion.md

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,5 @@ The JAX AI stack offers a compelling and robust solution for training and deploy
66

77
By building on battle-tested internal systems, the stack has evolved to provide inherent reliability and scalability, enabling users to confidently develop and deploy even the largest models. Its modular and composable design, rooted in the JAX ecosystem philosophy, grants users unparalleled freedom and control, allowing them to tailor the stack to their specific needs without the constraints of a monolithic framework.
88

9-
With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax-/-tensorstore---large-scale-distributed-checkpointing), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion), vLLM, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production.
10-
11-
[^1]: Included in the [jax-ai-stack Python package](https://docs.jaxstack.ai/en/latest/install.html)
12-
13-
[^2]: Included in the [jax-ai-stack Python package](https://docs.jaxstack.ai/en/latest/install.html)
14-
15-
[^3]: Image diffusion models are a typical example of this and can commonly be divided logically into a separately trained prompt encoder and a diffusion backbone.
16-
17-
[^4]: We say effectively free since there could be other bottlenecks such as the DMA engines, HBM bandwidth contention etc. that still incur a performance penalty.
18-
19-
[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup.
20-
21-
[^6]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens.
22-
23-
[^7]: This is a well established paradigm and has precedent in the CPU world, where compiled code forms the bulk of the program with developers dropping down to intrinsics or inline assembly to optimize performance critical sections.
24-
25-
[^8]: Some of the equivalents here are not true 1:1 comparisons because other frameworks draw API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently.
9+
With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax:tensorstore), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundational-model-maxtext-and), vLLM-TPU, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production.
2610

0 commit comments

Comments
 (0)