You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Conclusion: A Durable, Production-Ready Platform for the Future of AI
1
+
# Conclusion: A Durable, Production-Ready Platform for the Future of AI
2
2
3
3
The data provided in the table above draws to a rather simple conclusion \- these stacks have their own strengths and weaknesses in a small number of areas but overall are vastly similar from the software standpoint. Both stacks provide out of the box turnkey solutions for pre-training, post-training adaptation and deployment of foundational models.
4
4
@@ -7,4 +7,3 @@ The JAX AI stack offers a compelling and robust solution for training and deploy
7
7
By building on battle-tested internal systems, the stack has evolved to provide inherent reliability and scalability, enabling users to confidently develop and deploy even the largest models. Its modular and composable design, rooted in the JAX ecosystem philosophy, grants users unparalleled freedom and control, allowing them to tailor the stack to their specific needs without the constraints of a monolithic framework.
8
8
9
9
With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax:tensorstore), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundational-model-maxtext-and), vLLM-TPU, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production.
Copy file name to clipboardExpand all lines: docs/source/ecosystem_overview/core.md
+7-7Lines changed: 7 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -82,14 +82,14 @@ As it can be seen in the example above, setting up an optimizer with a custom le
82
82
83
83
[Orbax](https://orbax.readthedocs.io/en/latest/) is a checkpointing library for JAX designed for any scale, from single-device to large-scale distributed training. It aims to unify fragmented checkpointing implementations and deliver critical performance features, such as asynchronous and multi-tier checkpointing, to a wider audience. Orbax enables the resilience required for massive training jobs and provides a flexible format for publishing checkpoints.
84
84
85
-
Unlike generalized checkpoint/restore systems that snapshot the entire system state, ML checkpointing with Orbax selectively persists only the information essential for resuming training—model weights, optimizer state, and data loader state. This targeted approach minimizes accelerator downtime. Orbax achieves this by overlapping I/O operations with computation, a critical feature for large workloads. The time accelerators are halted is thus reduced to the duration of the device-to-host data transfer, which can be further overlapped with the next training step, making checkpointing nearly free from a performance perspective.
85
+
Unlike generalized checkpoint/restore systems that snapshot the entire system state, ML checkpointing with Orbax selectively persists only the information essential for resuming training—model weights, optimizer state, and data loader state. This targeted approach minimizes accelerator downtime. Orbax achieves this by overlapping I/O operations with computation, a critical feature for large workloads. The time accelerators are halted is thus reduced to the duration of the device-to-host data transfer, which can be further overlapped with the next training step, making checkpointing nearly free from a performance perspective.
86
86
At its core, Orbax uses [TensorStore](https://google.github.io/tensorstore/) for efficient, parallel reading and writing of array data. The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) abstracts this complexity, offering a user-friendly interface for handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html), which are the standard representation of models in JAX.
87
87
88
88
### Key Strengths:
89
89
90
-
*[Widespread Adoption](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): With millions of monthly downloads, Orbax serves as a common medium for sharing ML artifacts.
91
-
* Easy to Use: Orbax abstracts away the complexities of distributed checkpointing, including asynchronous saving, atomicity, and filesystem details.
92
-
* Flexible: While offering simple APIs for common use cases, Orbax allows for customization to handle specialized requirements.
90
+
*[Widespread Adoption](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): With millions of monthly downloads, Orbax serves as a common medium for sharing ML artifacts.
91
+
* Easy to Use: Orbax abstracts away the complexities of distributed checkpointing, including asynchronous saving, atomicity, and filesystem details.
92
+
* Flexible: While offering simple APIs for common use cases, Orbax allows for customization to handle specialized requirements.
93
93
* Performant and Scalable: Features like asynchronous checkpointing, an efficient storage format ([OCDBT](https://orbax.readthedocs.io/en/latest/guides/checkpoint/optimized_checkpointing.html)), and intelligent data loading strategies ensure that Orbax scales to training runs involving tens of thousands of nodes.
94
94
95
95
@@ -115,11 +115,11 @@ Out of the box, [Grain](https://google-grain.readthedocs.io/en/latest/) supports
115
115
### Key Strengths
116
116
117
117
***Deterministic data feeding:** Colocating the data worker with the accelerator and coupling it with a stable global shuffle and [checkpointable iterators](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#checkpointing) allows the model state and data pipeline state to be checkpointed together in a consistent snapshot using [Orbax](https://orbax.readthedocs.io/en/latest/), enhancing the determinism of the training process.
118
-
***Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline.
119
-
***Extensible support for multiple formats and backends:** An extensible [data sources](https://google-grain.readthedocs.io/en/latest/tutorials/data_sources/index.html) API supports popular storage formats and backends and allows users to easily add support for new formats and backends.
118
+
***Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline.
119
+
***Extensible support for multiple formats and backends:** An extensible [data sources](https://google-grain.readthedocs.io/en/latest/tutorials/data_sources/index.html) API supports popular storage formats and backends and allows users to easily add support for new formats and backends.
120
120
***Powerful debugging interface:** Data pipeline [visualization tools](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_debugging_tutorial.html) and a debug mode allow users to introspect, debug and optimize the performance of their data pipelines.
121
121
122
122
123
-
[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup.
123
+
[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup.
124
124
125
125
[^6]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens.
Copy file name to clipboardExpand all lines: docs/source/ecosystem_overview/extended.md
+13-13Lines changed: 13 additions & 13 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -15,13 +15,13 @@ Our compiler-first design is a fundamental architectural choice that creates a d
15
15
16
16
#### Design
17
17
18
-
XLA works by Just-In-Time (JIT) compiling the computation graphs that JAX generates during its tracing process (e.g., when a function is decorated with `@jax.jit`).
18
+
XLA works by Just-In-Time (JIT) compiling the computation graphs that JAX generates during its tracing process (e.g., when a function is decorated with `@jax.jit`).
***From JAX Graph to HLO**: The captured JAX computation graph is converted into XLA's HLO representation. At this high level, powerful, hardware-agnostic optimizations like operator fusion and efficient memory management are applied. The **StableHLO** dialect serves as a durable, versioned interface for this stage.
24
+
***From JAX Graph to HLO**: The captured JAX computation graph is converted into XLA's HLO representation. At this high level, powerful, hardware-agnostic optimizations like operator fusion and efficient memory management are applied. The **StableHLO** dialect serves as a durable, versioned interface for this stage.
25
25
***From HLO to LLO:** After high-level optimizations, hardware-specific backends take over, lowering the HLO representation into a machine-oriented LLO.
26
26
***From LLO to Hardware Code:** The LLO is finally compiled into highly-efficient machine code. For TPUs, this code is bundled as **Very Long Instruction Word (VLIW)** packets that are sent directly to the hardware.
27
27
@@ -47,7 +47,7 @@ To be able to train and deploy large models, hundreds to thousands of chips are
47
47
48
48
ML Pathways is the system we use for coordinating distributed computations across hosts and TPU chips. It is designed for scalability and efficiency across hundreds of thousands of accelerators. For large-scale training, it provides a single Python client for multi-slice/multi-pod jobs, [Megascale XLA](https://openxla.org/xprof/megascale_stats) integration, Compilation Service, and Remote Python. It also supports cross-slice parallelism and preemption tolerance, enabling automatic recovery from resource preemptions.
49
49
50
-
Pathways incorporates optimized cross host collectives which enable XLA computation graphs to further extend beyond a single TPU pod. It expands XLA's support for data, model, and pipeline parallelism to work across TPU slice boundaries using DCN by means of integrating a distributed runtime that manages DCN communication with XLA communication primitives.
50
+
Pathways incorporates optimized cross host collectives which enable XLA computation graphs to further extend beyond a single TPU pod. It expands XLA's support for data, model, and pipeline parallelism to work across TPU slice boundaries using DCN by means of integrating a distributed runtime that manages DCN communication with XLA communication primitives.
51
51
52
52
#### Key strengths
53
53
@@ -65,17 +65,17 @@ Pallas exposes a grid-based parallelism model where a user-defined kernel functi
65
65
66
66
### Tokamax: A Curated Library of State-of-the-Art Kernels
67
67
68
-
If Pallas is the *tool* for authoring kernels, [Tokamax](https://github.com/openxla/tokamax) is a *library* of state-of-the-art custom accelerator kernels supporting both TPUs and GPUs, built on top of JAX and Pallas enabling users to push their hardware to the maximum. It also provides tooling for users to build and autotune their own custom kernels.
68
+
If Pallas is the *tool* for authoring kernels, [Tokamax](https://github.com/openxla/tokamax) is a *library* of state-of-the-art custom accelerator kernels supporting both TPUs and GPUs, built on top of JAX and Pallas enabling users to push their hardware to the maximum. It also provides tooling for users to build and autotune their own custom kernels.
69
69
70
70
#### Motivation
71
71
72
72
JAX, with its roots in XLA, is a compiler-first framework, however a narrow set of cases exists where the user needs to take direct control of the hardware to achieve maximum performance[^7]. Custom kernels are critical to squeezing out every last ounce of performance from expensive ML accelerator resources such as TPUs and GPUs. While they are widely employed to enable performant execution of key operators such as Attention, implementing them requires a deep understanding of both the model and the target hardware (micro)architecture. Tokamax provides one authoritative source of curated, well-tested, high-performance kernels, in conjunction with robust shared infrastructure for their development, maintenance, and lifecycle management. Such a library can also act as a reference implementation for users to build on and customize as necessary. This allows users to focus on their modeling efforts without needing to worry about infrastructure.
73
73
74
74
#### Design
75
75
76
-
For any given kernel, Tokamax provides a common API that may be backed by multiple implementations. For example, TPU kernels may be implemented either by standard XLA lowering, or explicitly via Pallas/Mosaic-TPU. GPU kernels may be implemented by standard XLA lowering, via Mosaic-GPU, or Triton. By default, it picks the best-known implementation for a given configuration, determined by cached results from periodic autotuning and benchmarking runs, though users may choose specific implementations if desired. New implementations may be added over time to better exploit specific features in new hardware generations for even better performance.
76
+
For any given kernel, Tokamax provides a common API that may be backed by multiple implementations. For example, TPU kernels may be implemented either by standard XLA lowering, or explicitly via Pallas/Mosaic-TPU. GPU kernels may be implemented by standard XLA lowering, via Mosaic-GPU, or Triton. By default, it picks the best-known implementation for a given configuration, determined by cached results from periodic autotuning and benchmarking runs, though users may choose specific implementations if desired. New implementations may be added over time to better exploit specific features in new hardware generations for even better performance.
77
77
78
-
A key component of the library, beyond the kernels themselves, is the supporting infrastructure that will help power users choosing to write their own custom kernels. For example, the autotuning infrastructure lets the user define a set of configurable parameters (e.g., tile sizes) that Tokamax can perform an exhaustive sweep on, to determine and cache the best possible tuned settings. Nightly regressions protect users from unexpected performance and numerics issues caused by changes to underlying compiler infrastructure or other dependencies.
78
+
A key component of the library, beyond the kernels themselves, is the supporting infrastructure that will help power users choosing to write their own custom kernels. For example, the autotuning infrastructure lets the user define a set of configurable parameters (e.g., tile sizes) that Tokamax can perform an exhaustive sweep on, to determine and cache the best possible tuned settings. Nightly regressions protect users from unexpected performance and numerics issues caused by changes to underlying compiler infrastructure or other dependencies.
79
79
80
80
#### Key Strengths
81
81
@@ -193,8 +193,8 @@ vLLM-TPU builds on this foundation and develops core components for request hand
193
193
194
194
#### Key Strengths
195
195
196
-
***Zero Onboarding/Offboarding Cost for Users:** Users can adopt this solution without significant friction. From a user-experience perspective, processing inference requests is identical to on GPUs. The CLI to start the server, accept prompts, and return outputs are all shared.
197
-
***Fully Embrace the Ecosystem:** This approach utilizes and contributes to the vLLM interface and user experience, ensuring compatibility and ease of use.
196
+
***Zero Onboarding/Offboarding Cost for Users:** Users can adopt this solution without significant friction. From a user-experience perspective, processing inference requests is identical to on GPUs. The CLI to start the server, accept prompts, and return outputs are all shared.
197
+
***Fully Embrace the Ecosystem:** This approach utilizes and contributes to the vLLM interface and user experience, ensuring compatibility and ease of use.
198
198
***Fungibility between TPUs and GPUs:** The solution works efficiently on both TPUs and GPUs, allowing users flexibility.
199
199
***Cost Efficient (Best Perf/$):** Optimizes performance to provide the best performance-to-cost ratio for popular models.
200
200
@@ -244,11 +244,11 @@ XProf consists of two primary components: collection and analysis.
244
244
245
245
The true power of XProf comes from its deep integration with the full stack, providing a breadth and depth of analysis that is a tangible benefit of the co-designed JAX/TPU ecosystem.
246
246
247
-
***Co-designed with the TPU:** XProf exploits hardware features specifically designed for seamless profile collection, enabling a collection overhead of **less than 1%**. This allows profiling to be a lightweight, iterative part of development.
248
-
***Breadth and Depth of Analysis:** XProf yields deep analysis across multiple axes. Its tools include:
249
-
***Trace Viewer:** An op-by-op timeline view of execution on different hardware units (e.g., TensorCore).
250
-
***HLO Op Profile:** Breaks down the total time spent into different categories of operations.
251
-
***Memory Viewer:** Details memory allocations by different ops during the profiled window.
247
+
***Co-designed with the TPU:** XProf exploits hardware features specifically designed for seamless profile collection, enabling a collection overhead of **less than 1%**. This allows profiling to be a lightweight, iterative part of development.
248
+
***Breadth and Depth of Analysis:** XProf yields deep analysis across multiple axes. Its tools include:
249
+
***Trace Viewer:** An op-by-op timeline view of execution on different hardware units (e.g., TensorCore).
250
+
***HLO Op Profile:** Breaks down the total time spent into different categories of operations.
251
+
***Memory Viewer:** Details memory allocations by different ops during the profiled window.
252
252
***Roofline Analysis:** Helps identify whether specific ops are compute- or memory-bound and how far they are from the hardware’s peak capabilities.
253
253
***Graph Viewer:** Provides a view into the full HLO graph executed by the hardware.
0 commit comments