Skip to content

Commit ad33c5b

Browse files
committed
run linter
1 parent 0b29a3c commit ad33c5b

File tree

11 files changed

+26
-31
lines changed

11 files changed

+26
-31
lines changed

docs/source/_static/images/JAX_ecosystem.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/source/_static/images/Tunix_diagram.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/source/_static/images/async_checkpointing.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/source/_static/images/programming_TPUS.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/source/_static/images/serving_orbax_nse.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/source/ecosystem_overview/architectural.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ For TPUs to provide a clear path to this level of performance, the ecosystem mus
99
![](../_static/images/programming_TPUS.svg)
1010

1111
**Figure 2: The JAX continuum of abstraction**
12-
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Conclusion: A Durable, Production-Ready Platform for the Future of AI
1+
# Conclusion: A Durable, Production-Ready Platform for the Future of AI
22

33
The data provided in the table above draws to a rather simple conclusion \- these stacks have their own strengths and weaknesses in a small number of areas but overall are vastly similar from the software standpoint. Both stacks provide out of the box turnkey solutions for pre-training, post-training adaptation and deployment of foundational models.
44

@@ -7,4 +7,3 @@ The JAX AI stack offers a compelling and robust solution for training and deploy
77
By building on battle-tested internal systems, the stack has evolved to provide inherent reliability and scalability, enabling users to confidently develop and deploy even the largest models. Its modular and composable design, rooted in the JAX ecosystem philosophy, grants users unparalleled freedom and control, allowing them to tailor the stack to their specific needs without the constraints of a monolithic framework.
88

99
With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax:tensorstore), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundational-model-maxtext-and), vLLM-TPU, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production.
10-

docs/source/ecosystem_overview/core.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ As it can be seen in the example above, setting up an optimizer with a custom le
8282

8383
[Orbax](https://orbax.readthedocs.io/en/latest/) is a checkpointing library for JAX designed for any scale, from single-device to large-scale distributed training. It aims to unify fragmented checkpointing implementations and deliver critical performance features, such as asynchronous and multi-tier checkpointing, to a wider audience. Orbax enables the resilience required for massive training jobs and provides a flexible format for publishing checkpoints.
8484

85-
Unlike generalized checkpoint/restore systems that snapshot the entire system state, ML checkpointing with Orbax selectively persists only the information essential for resuming training—model weights, optimizer state, and data loader state. This targeted approach minimizes accelerator downtime. Orbax achieves this by overlapping I/O operations with computation, a critical feature for large workloads. The time accelerators are halted is thus reduced to the duration of the device-to-host data transfer, which can be further overlapped with the next training step, making checkpointing nearly free from a performance perspective.
85+
Unlike generalized checkpoint/restore systems that snapshot the entire system state, ML checkpointing with Orbax selectively persists only the information essential for resuming training—model weights, optimizer state, and data loader state. This targeted approach minimizes accelerator downtime. Orbax achieves this by overlapping I/O operations with computation, a critical feature for large workloads. The time accelerators are halted is thus reduced to the duration of the device-to-host data transfer, which can be further overlapped with the next training step, making checkpointing nearly free from a performance perspective.
8686
At its core, Orbax uses [TensorStore](https://google.github.io/tensorstore/) for efficient, parallel reading and writing of array data. The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) abstracts this complexity, offering a user-friendly interface for handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html), which are the standard representation of models in JAX.
8787

8888
### Key Strengths:
8989

90-
* [Widespread Adoption](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): With millions of monthly downloads, Orbax serves as a common medium for sharing ML artifacts.
91-
* Easy to Use: Orbax abstracts away the complexities of distributed checkpointing, including asynchronous saving, atomicity, and filesystem details.
92-
* Flexible: While offering simple APIs for common use cases, Orbax allows for customization to handle specialized requirements.
90+
* [Widespread Adoption](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): With millions of monthly downloads, Orbax serves as a common medium for sharing ML artifacts.
91+
* Easy to Use: Orbax abstracts away the complexities of distributed checkpointing, including asynchronous saving, atomicity, and filesystem details.
92+
* Flexible: While offering simple APIs for common use cases, Orbax allows for customization to handle specialized requirements.
9393
* Performant and Scalable: Features like asynchronous checkpointing, an efficient storage format ([OCDBT](https://orbax.readthedocs.io/en/latest/guides/checkpoint/optimized_checkpointing.html)), and intelligent data loading strategies ensure that Orbax scales to training runs involving tens of thousands of nodes.
9494

9595

@@ -115,11 +115,11 @@ Out of the box, [Grain](https://google-grain.readthedocs.io/en/latest/) supports
115115
### Key Strengths
116116

117117
* **Deterministic data feeding:** Colocating the data worker with the accelerator and coupling it with a stable global shuffle and [checkpointable iterators](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#checkpointing) allows the model state and data pipeline state to be checkpointed together in a consistent snapshot using [Orbax](https://orbax.readthedocs.io/en/latest/), enhancing the determinism of the training process.
118-
* **Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline.
119-
* **Extensible support for multiple formats and backends:** An extensible [data sources](https://google-grain.readthedocs.io/en/latest/tutorials/data_sources/index.html) API supports popular storage formats and backends and allows users to easily add support for new formats and backends.
118+
* **Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline.
119+
* **Extensible support for multiple formats and backends:** An extensible [data sources](https://google-grain.readthedocs.io/en/latest/tutorials/data_sources/index.html) API supports popular storage formats and backends and allows users to easily add support for new formats and backends.
120120
* **Powerful debugging interface:** Data pipeline [visualization tools](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_debugging_tutorial.html) and a debug mode allow users to introspect, debug and optimize the performance of their data pipelines.
121121

122122

123-
[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup.
123+
[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup.
124124

125125
[^6]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens.

docs/source/ecosystem_overview/extended.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ Our compiler-first design is a fundamental architectural choice that creates a d
1515

1616
#### Design
1717

18-
XLA works by Just-In-Time (JIT) compiling the computation graphs that JAX generates during its tracing process (e.g., when a function is decorated with `@jax.jit`).
18+
XLA works by Just-In-Time (JIT) compiling the computation graphs that JAX generates during its tracing process (e.g., when a function is decorated with `@jax.jit`).
1919

2020
This compilation follows a multi-stage pipeline:
2121

2222
JAX Computation Graph → High-Level Optimizer (HLO) → Low-Level Optimizer (LLO) → Hardware Code
2323

24-
* **From JAX Graph to HLO**: The captured JAX computation graph is converted into XLA's HLO representation. At this high level, powerful, hardware-agnostic optimizations like operator fusion and efficient memory management are applied. The **StableHLO** dialect serves as a durable, versioned interface for this stage.
24+
* **From JAX Graph to HLO**: The captured JAX computation graph is converted into XLA's HLO representation. At this high level, powerful, hardware-agnostic optimizations like operator fusion and efficient memory management are applied. The **StableHLO** dialect serves as a durable, versioned interface for this stage.
2525
* **From HLO to LLO:** After high-level optimizations, hardware-specific backends take over, lowering the HLO representation into a machine-oriented LLO.
2626
* **From LLO to Hardware Code:** The LLO is finally compiled into highly-efficient machine code. For TPUs, this code is bundled as **Very Long Instruction Word (VLIW)** packets that are sent directly to the hardware.
2727

@@ -47,7 +47,7 @@ To be able to train and deploy large models, hundreds to thousands of chips are
4747

4848
ML Pathways is the system we use for coordinating distributed computations across hosts and TPU chips. It is designed for scalability and efficiency across hundreds of thousands of accelerators. For large-scale training, it provides a single Python client for multi-slice/multi-pod jobs, [Megascale XLA](https://openxla.org/xprof/megascale_stats) integration, Compilation Service, and Remote Python. It also supports cross-slice parallelism and preemption tolerance, enabling automatic recovery from resource preemptions.
4949

50-
Pathways incorporates optimized cross host collectives which enable XLA computation graphs to further extend beyond a single TPU pod. It expands XLA's support for data, model, and pipeline parallelism to work across TPU slice boundaries using DCN by means of integrating a distributed runtime that manages DCN communication with XLA communication primitives.
50+
Pathways incorporates optimized cross host collectives which enable XLA computation graphs to further extend beyond a single TPU pod. It expands XLA's support for data, model, and pipeline parallelism to work across TPU slice boundaries using DCN by means of integrating a distributed runtime that manages DCN communication with XLA communication primitives.
5151

5252
#### Key strengths
5353

@@ -65,17 +65,17 @@ Pallas exposes a grid-based parallelism model where a user-defined kernel functi
6565

6666
### Tokamax: A Curated Library of State-of-the-Art Kernels
6767

68-
If Pallas is the *tool* for authoring kernels, [Tokamax](https://github.com/openxla/tokamax) is a *library* of state-of-the-art custom accelerator kernels supporting both TPUs and GPUs, built on top of JAX and Pallas enabling users to push their hardware to the maximum. It also provides tooling for users to build and autotune their own custom kernels.
68+
If Pallas is the *tool* for authoring kernels, [Tokamax](https://github.com/openxla/tokamax) is a *library* of state-of-the-art custom accelerator kernels supporting both TPUs and GPUs, built on top of JAX and Pallas enabling users to push their hardware to the maximum. It also provides tooling for users to build and autotune their own custom kernels.
6969

7070
#### Motivation
7171

7272
JAX, with its roots in XLA, is a compiler-first framework, however a narrow set of cases exists where the user needs to take direct control of the hardware to achieve maximum performance[^7]. Custom kernels are critical to squeezing out every last ounce of performance from expensive ML accelerator resources such as TPUs and GPUs. While they are widely employed to enable performant execution of key operators such as Attention, implementing them requires a deep understanding of both the model and the target hardware (micro)architecture. Tokamax provides one authoritative source of curated, well-tested, high-performance kernels, in conjunction with robust shared infrastructure for their development, maintenance, and lifecycle management. Such a library can also act as a reference implementation for users to build on and customize as necessary. This allows users to focus on their modeling efforts without needing to worry about infrastructure.
7373

7474
#### Design
7575

76-
For any given kernel, Tokamax provides a common API that may be backed by multiple implementations. For example, TPU kernels may be implemented either by standard XLA lowering, or explicitly via Pallas/Mosaic-TPU. GPU kernels may be implemented by standard XLA lowering, via Mosaic-GPU, or Triton. By default, it picks the best-known implementation for a given configuration, determined by cached results from periodic autotuning and benchmarking runs, though users may choose specific implementations if desired. New implementations may be added over time to better exploit specific features in new hardware generations for even better performance.
76+
For any given kernel, Tokamax provides a common API that may be backed by multiple implementations. For example, TPU kernels may be implemented either by standard XLA lowering, or explicitly via Pallas/Mosaic-TPU. GPU kernels may be implemented by standard XLA lowering, via Mosaic-GPU, or Triton. By default, it picks the best-known implementation for a given configuration, determined by cached results from periodic autotuning and benchmarking runs, though users may choose specific implementations if desired. New implementations may be added over time to better exploit specific features in new hardware generations for even better performance.
7777

78-
A key component of the library, beyond the kernels themselves, is the supporting infrastructure that will help power users choosing to write their own custom kernels. For example, the autotuning infrastructure lets the user define a set of configurable parameters (e.g., tile sizes) that Tokamax can perform an exhaustive sweep on, to determine and cache the best possible tuned settings. Nightly regressions protect users from unexpected performance and numerics issues caused by changes to underlying compiler infrastructure or other dependencies.
78+
A key component of the library, beyond the kernels themselves, is the supporting infrastructure that will help power users choosing to write their own custom kernels. For example, the autotuning infrastructure lets the user define a set of configurable parameters (e.g., tile sizes) that Tokamax can perform an exhaustive sweep on, to determine and cache the best possible tuned settings. Nightly regressions protect users from unexpected performance and numerics issues caused by changes to underlying compiler infrastructure or other dependencies.
7979

8080
#### Key Strengths
8181

@@ -193,8 +193,8 @@ vLLM-TPU builds on this foundation and develops core components for request hand
193193

194194
#### Key Strengths
195195

196-
* **Zero Onboarding/Offboarding Cost for Users:** Users can adopt this solution without significant friction. From a user-experience perspective, processing inference requests is identical to on GPUs. The CLI to start the server, accept prompts, and return outputs are all shared.
197-
* **Fully Embrace the Ecosystem:** This approach utilizes and contributes to the vLLM interface and user experience, ensuring compatibility and ease of use.
196+
* **Zero Onboarding/Offboarding Cost for Users:** Users can adopt this solution without significant friction. From a user-experience perspective, processing inference requests is identical to on GPUs. The CLI to start the server, accept prompts, and return outputs are all shared.
197+
* **Fully Embrace the Ecosystem:** This approach utilizes and contributes to the vLLM interface and user experience, ensuring compatibility and ease of use.
198198
* **Fungibility between TPUs and GPUs:** The solution works efficiently on both TPUs and GPUs, allowing users flexibility.
199199
* **Cost Efficient (Best Perf/$):** Optimizes performance to provide the best performance-to-cost ratio for popular models.
200200

@@ -244,11 +244,11 @@ XProf consists of two primary components: collection and analysis.
244244

245245
The true power of XProf comes from its deep integration with the full stack, providing a breadth and depth of analysis that is a tangible benefit of the co-designed JAX/TPU ecosystem.
246246

247-
* **Co-designed with the TPU:** XProf exploits hardware features specifically designed for seamless profile collection, enabling a collection overhead of **less than 1%**. This allows profiling to be a lightweight, iterative part of development.
248-
* **Breadth and Depth of Analysis:** XProf yields deep analysis across multiple axes. Its tools include:
249-
* **Trace Viewer:** An op-by-op timeline view of execution on different hardware units (e.g., TensorCore).
250-
* **HLO Op Profile:** Breaks down the total time spent into different categories of operations.
251-
* **Memory Viewer:** Details memory allocations by different ops during the profiled window.
247+
* **Co-designed with the TPU:** XProf exploits hardware features specifically designed for seamless profile collection, enabling a collection overhead of **less than 1%**. This allows profiling to be a lightweight, iterative part of development.
248+
* **Breadth and Depth of Analysis:** XProf yields deep analysis across multiple axes. Its tools include:
249+
* **Trace Viewer:** An op-by-op timeline view of execution on different hardware units (e.g., TensorCore).
250+
* **HLO Op Profile:** Breaks down the total time spent into different categories of operations.
251+
* **Memory Viewer:** Details memory allocations by different ops during the profiled window.
252252
* **Roofline Analysis:** Helps identify whether specific ops are compute- or memory-bound and how far they are from the hardware’s peak capabilities.
253253
* **Graph Viewer:** Provides a view into the full HLO graph executed by the hardware.
254254

docs/source/ecosystem_overview/modular.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,3 @@ The following sections provide a technical overview of the JAX ecosystem, its ke
4343

4444
![](../_static/images/JAX_ecosystem.svg)
4545
**Figure 1: The JAX AI Stack and Ecosystem Components**
46-
47-

0 commit comments

Comments
 (0)