You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/ecosystem_overview/architectural.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -6,7 +6,7 @@ If this trend accelerates, all high-level frameworks as they exist today risk be
6
6
7
7
For TPUs to provide a clear path to this level of performance, the ecosystem must expose an API layer that is closer to the hardware, enabling the development of these highly specialized kernels. As this report will detail, the JAX stack is designed to solve this by offering a continuum of abstraction (See Figure 2), from the automated, high-level optimizations of the XLA compiler to the fine-grained, manual control of the Pallas kernel-authoring library.
Copy file name to clipboardExpand all lines: docs/source/ecosystem_overview/comparative.md
+7-5Lines changed: 7 additions & 5 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -4,13 +4,13 @@ The modern Machine Learning landscape offers many excellent, mature toolchains.
4
4
5
5
While many frameworks offer a wide array of features, the JAX AI Stack provides specific, powerful differentiators in key areas of the development lifecycle:
6
6
7
-
***A Simpler, More Powerful Developer Experience:** The "chainable gradient transformation paradigm" of [**Optax**](#optax:-composable-gradient-processing-and-optimization-strategies) allows for more powerful and flexible optimization strategies that are declared once, rather than imperatively managed in the training loop.1 At the system level, the "simpler single controller interface" of **Pathways** abstracts away the complexity of multi-pod, multi-slice training, a significant simplification for researchers.
8
-
***Engineered for "Hero-Scale" Resilience:** The JAX stack is designed for extreme-scale training. **Orbax** provides "hero-scale training resilience" features like emergency and multi-tier checkpointing. This is complemented by **Grain**, which offers "full support for reproducibility with deterministic global shuffles and checkpointable data loaders". The ability to atomically checkpoint the data pipeline state (Grain) with the model state (Orbax) is a critical capability for guaranteeing reproducibility in long-running jobs.
9
-
***A Complete, End-to-End Ecosystem:** The stack provides a cohesive, end-to-end solution. Developers can use [**MaxText**](https://maxtext.readthedocs.io/en/latest/) as a SOTA reference for training, [**Tunix**](https://tunix.readthedocs.io/en/latest/) for alignment, and follow a clear, dual-path to production with **vLLM** (for vLLM compatibility) and **NSE** (for native JAX performance).
7
+
***A Simpler, More Powerful Developer Experience:** The "chainable gradient transformation paradigm" of [**Optax**](#optax:composable) allows for more powerful and flexible optimization strategies that are declared once, rather than imperatively managed in the training loop.1 At the system level, the "simpler single controller interface" of **Pathways** abstracts away the complexity of multi-pod, multi-slice training, a significant simplification for researchers.
8
+
***Engineered for "Hero-Scale" Resilience:** The JAX stack is designed for extreme-scale training. **Orbax** provides "hero-scale training resilience" features like emergency and multi-tier checkpointing. This is complemented by **Grain**, which offers "full support for reproducibility with deterministic global shuffles and checkpointable data loaders". The ability to atomically checkpoint the data pipeline state (Grain) with the model state (Orbax) is a critical capability for guaranteeing reproducibility in long-running jobs.
9
+
***A Complete, End-to-End Ecosystem:** The stack provides a cohesive, end-to-end solution. Developers can use [**MaxText**](https://maxtext.readthedocs.io/en/latest/) as a SOTA reference for training, [**Tunix**](https://tunix.readthedocs.io/en/latest/) for alignment, and follow a clear, dual-path to production with **vLLM-TPU** (for vLLM compatibility) and **NSE** (for native JAX performance).
10
10
11
11
While many stacks are vastly similar from a high-level software standpoint, the deciding factor often comes down to **Performance/TCO**, which is where the co-design of JAX and TPUs provides a distinct advantage. This Performance/TCO benefit is a direct result of the "vertical integration across software and TPU hardware". The ability of the **XLA** compiler to fuse operations specifically for the TPU architecture, or for the **XProf** profiler to leverage hardware hooks for \<1% overhead profiling, are tangible benefits of this deep integration.
12
12
13
-
For organizations adopting this stack, the "full featured nature" of the JAX AI Stack minimizes the cost of migration. For customers employing popular open model architectures, a shift from other frameworks to [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion) is often a matter of setting up config files. Furthermore, the stack's ability to ingest popular checkpoint formats like safetensors allows existing checkpoints to be migrated over without needing costly re-training.
13
+
For organizations adopting this stack, the "full featured nature" of the JAX AI Stack minimizes the cost of migration. For customers employing popular open model architectures, a shift from other frameworks to [MaxText](#foundational-model-maxtext-and) is often a matter of setting up config files. Furthermore, the stack's ability to ingest popular checkpoint formats like safetensors allows existing checkpoints to be migrated over without needing costly re-training.
14
14
15
15
The table below provides a mapping of the components provided by the JAX AI stack and their equivalents in other frameworks or libraries.
16
16
@@ -28,6 +28,8 @@ The table below provides a mapping of the components provided by the JAX AI stac
[^8]: Some of the equivalents here are not true 1:1 comparisons because other frameworks draw API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently.
Copy file name to clipboardExpand all lines: docs/source/ecosystem_overview/conclusion.md
+1-17Lines changed: 1 addition & 17 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -6,21 +6,5 @@ The JAX AI stack offers a compelling and robust solution for training and deploy
6
6
7
7
By building on battle-tested internal systems, the stack has evolved to provide inherent reliability and scalability, enabling users to confidently develop and deploy even the largest models. Its modular and composable design, rooted in the JAX ecosystem philosophy, grants users unparalleled freedom and control, allowing them to tailor the stack to their specific needs without the constraints of a monolithic framework.
8
8
9
-
With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax-/-tensorstore---large-scale-distributed-checkpointing), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion), vLLM, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production.
10
-
11
-
[^1]: Included in the [jax-ai-stack Python package](https://docs.jaxstack.ai/en/latest/install.html)
12
-
13
-
[^2]: Included in the [jax-ai-stack Python package](https://docs.jaxstack.ai/en/latest/install.html)
14
-
15
-
[^3]: Image diffusion models are a typical example of this and can commonly be divided logically into a separately trained prompt encoder and a diffusion backbone.
16
-
17
-
[^4]: We say effectively free since there could be other bottlenecks such as the DMA engines, HBM bandwidth contention etc. that still incur a performance penalty.
18
-
19
-
[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup.
20
-
21
-
[^6]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens.
22
-
23
-
[^7]: This is a well established paradigm and has precedent in the CPU world, where compiled code forms the bulk of the program with developers dropping down to intrinsics or inline assembly to optimize performance critical sections.
24
-
25
-
[^8]: Some of the equivalents here are not true 1:1 comparisons because other frameworks draw API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently.
9
+
With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax:tensorstore), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundational-model-maxtext-and), vLLM-TPU, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production.
0 commit comments