From 995838164f2eb2181f52efb67674706f2ec9bf59 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 7 Nov 2025 12:04:41 +0100 Subject: [PATCH 01/14] DOC: drop the md converted TR, add a toctree etc --- .../ecosystem_overview/architectural.md | 12 + docs/source/ecosystem_overview/comparative.md | 33 +++ docs/source/ecosystem_overview/conclusion.md | 26 ++ docs/source/ecosystem_overview/core.md | 156 +++++++++++ docs/source/ecosystem_overview/extended.md | 243 ++++++++++++++++++ docs/source/ecosystem_overview/modular.md | 44 ++++ .../the_ecosystem_overview_tr.md | 37 +++ docs/source/index.rst | 8 + 8 files changed, 559 insertions(+) create mode 100644 docs/source/ecosystem_overview/architectural.md create mode 100644 docs/source/ecosystem_overview/comparative.md create mode 100644 docs/source/ecosystem_overview/conclusion.md create mode 100644 docs/source/ecosystem_overview/core.md create mode 100644 docs/source/ecosystem_overview/extended.md create mode 100644 docs/source/ecosystem_overview/modular.md create mode 100644 docs/source/ecosystem_overview/the_ecosystem_overview_tr.md diff --git a/docs/source/ecosystem_overview/architectural.md b/docs/source/ecosystem_overview/architectural.md new file mode 100644 index 0000000..6a62a6f --- /dev/null +++ b/docs/source/ecosystem_overview/architectural.md @@ -0,0 +1,12 @@ +## The Architectural Imperative: Performance Beyond Frameworks + +As model architectures converge—for example, on multimodal Mixture-of-Experts (MoE) Transformers—the pursuit of peak performance is leading to the emergence of "Megakernels." A Megakernel is effectively the entire forward pass (or a large portion) of one specific model, hand-coded using a lower-level API like the CUDA SDK on NVIDIA GPUs. This approach achieves maximum hardware utilization by aggressively overlapping compute, memory, and communication. Recent work from the research community has demonstrated that this approach can yield significant throughput gains, over 22% in some cases, for inference on GPUs. This trend is not limited to inference; evidence suggests that some large-scale training efforts have involved low-level hardware control to achieve substantial efficiency gains. + +If this trend accelerates, all high-level frameworks as they exist today risk becoming less relevant, as low-level access to the hardware is what ultimately matters for performance on mature, stable architectures. This presents a challenge for all modern ML stacks: how to provide expert-level hardware control without sacrificing the productivity and flexibility of a high-level framework. + +For TPUs to provide a clear path to this level of performance, the ecosystem must expose an API layer that is closer to the hardware, enabling the development of these highly specialized kernels. As this report will detail, the JAX stack is designed to solve this by offering a continuum of abstraction (See Figure 2), from the automated, high-level optimizations of the XLA compiler to the fine-grained, manual control of the Pallas kernel-authoring library. + +![][image3] + +**Figure 2: The JAX continuum of abstraction** + diff --git a/docs/source/ecosystem_overview/comparative.md b/docs/source/ecosystem_overview/comparative.md new file mode 100644 index 0000000..9e11e6f --- /dev/null +++ b/docs/source/ecosystem_overview/comparative.md @@ -0,0 +1,33 @@ +## A Comparative Perspective: The JAX/TPU Stack as a Compelling Choice + +The modern Machine Learning landscape offers many excellent, mature toolchains. The JAX AI Stack, however, presents a unique and compelling set of advantages for developers focused on large-scale, high-performance ML, stemming directly from its modular design and deep hardware co-design. + +While many frameworks offer a wide array of features, the JAX AI Stack provides specific, powerful differentiators in key areas of the development lifecycle: + +* **A Simpler, More Powerful Developer Experience:** The "chainable gradient transformation paradigm" of [**Optax**](#optax:-composable-gradient-processing-and-optimization-strategies) allows for more powerful and flexible optimization strategies that are declared once, rather than imperatively managed in the training loop.1 At the system level, the "simpler single controller interface" of **Pathways** abstracts away the complexity of multi-pod, multi-slice training, a significant simplification for researchers. +* **Engineered for "Hero-Scale" Resilience:** The JAX stack is designed for extreme-scale training. **Orbax** provides "hero-scale training resilience" features like emergency and multi-tier checkpointing. This is complemented by **Grain**, which offers "full support for reproducibility with deterministic global shuffles and checkpointable data loaders". The ability to atomically checkpoint the data pipeline state (Grain) with the model state (Orbax) is a critical capability for guaranteeing reproducibility in long-running jobs. +* **A Complete, End-to-End Ecosystem:** The stack provides a cohesive, end-to-end solution. Developers can use [**MaxText**](https://maxtext.readthedocs.io/en/latest/) as a SOTA reference for training, [**Tunix**](https://tunix.readthedocs.io/en/latest/) for alignment, and follow a clear, dual-path to production with **vLLM** (for vLLM compatibility) and **NSE** (for native JAX performance). + +While many stacks are vastly similar from a high-level software standpoint, the deciding factor often comes down to **Performance/TCO**, which is where the co-design of JAX and TPUs provides a distinct advantage. This Performance/TCO benefit is a direct result of the "vertical integration across software and TPU hardware". The ability of the **XLA** compiler to fuse operations specifically for the TPU architecture, or for the **XProf** profiler to leverage hardware hooks for \<1% overhead profiling, are tangible benefits of this deep integration. + +For organizations adopting this stack, the "full featured nature" of the JAX AI Stack minimizes the cost of migration. For customers employing popular open model architectures, a shift from other frameworks to [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion) is often a matter of setting up config files. Furthermore, the stack's ability to ingest popular checkpoint formats like safetensors allows existing checkpoints to be migrated over without needing costly re-training. + +The table below provides a mapping of the components provided by the JAX AI stack and their equivalents in other frameworks or libraries. + +| Function | JAX | Alternatives/equivalents in other frameworks[^8] | +| :---- | :---- | :---- | +| Compiler / Runtime | XLA | Inductor, Eager | +| Multipod Training | Pathways | Torch Lightning Strategies, Ray Train, Monarch (new). | +| Core Framework | JAX | PyTorch | +| Model authoring | Flax, Max\* models | [torch.nn](http://torch.nn).\*, NVidia TransformerEngine, HuggingFace Transformers | +| Optimizers & Losses | Optax | torch.optim.\*, torch.nn.\*Loss | +| Data Loaders | Grain | Ray Data, HuggingFace dataloaders | +| Checkpointing | Orbax | PyTorch distributed checkpointing, NeMo checkpointing | +| Quantization | Qwix | TorchAO, bitsandbytes | +| Kernel authoring & well known implementations | Pallas / Tokamax | Triton/Helion, Liger-kernel, TransformerEngine | +| Post training / tuning | Tunix | VERL, NeMoRL | +| Profiling | XProf | PyTorch profiler, NSight systems, NSight Compute | +| Foundation model Training | MaxText, MaxDiffusion | NeMo-Megatron, DeepSpeed, TorchTitan | +| LLM inference | vLLM | vLLM, SGLang | +| Non-LLM Inference | NSE | Triton Inference Server, RayServe | + diff --git a/docs/source/ecosystem_overview/conclusion.md b/docs/source/ecosystem_overview/conclusion.md new file mode 100644 index 0000000..f997df9 --- /dev/null +++ b/docs/source/ecosystem_overview/conclusion.md @@ -0,0 +1,26 @@ +## Conclusion: A Durable, Production-Ready Platform for the Future of AI + +The data provided in the table above draws to a rather simple conclusion \- these stacks have their own strengths and weaknesses in a small number of areas but overall are vastly similar from the software standpoint. Both stacks provide out of the box turnkey solutions for pre-training, post-training adaptation and deployment of foundational models. + +The JAX AI stack offers a compelling and robust solution for training and deploying ML models at any scale. It leverages deep vertical integration across software and TPU hardware to deliver class-leading performance and total cost of ownership. + +By building on battle-tested internal systems, the stack has evolved to provide inherent reliability and scalability, enabling users to confidently develop and deploy even the largest models. Its modular and composable design, rooted in the JAX ecosystem philosophy, grants users unparalleled freedom and control, allowing them to tailor the stack to their specific needs without the constraints of a monolithic framework. + +With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax-/-tensorstore---large-scale-distributed-checkpointing), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion), vLLM, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production. + +[^1]: Included in the [jax-ai-stack Python package](https://docs.jaxstack.ai/en/latest/install.html) + +[^2]: Included in the [jax-ai-stack Python package](https://docs.jaxstack.ai/en/latest/install.html) + +[^3]: Image diffusion models are a typical example of this and can commonly be divided logically into a separately trained prompt encoder and a diffusion backbone. + +[^4]: We say effectively free since there could be other bottlenecks such as the DMA engines, HBM bandwidth contention etc. that still incur a performance penalty. + +[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup. + +[^6]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens. + +[^7]: This is a well established paradigm and has precedent in the CPU world, where compiled code forms the bulk of the program with developers dropping down to intrinsics or inline assembly to optimize performance critical sections. + +[^8]: Some of the equivalents here are not true 1:1 comparisons because other frameworks draw API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently. + diff --git a/docs/source/ecosystem_overview/core.md b/docs/source/ecosystem_overview/core.md new file mode 100644 index 0000000..d78566d --- /dev/null +++ b/docs/source/ecosystem_overview/core.md @@ -0,0 +1,156 @@ +## The Core JAX AI Stack + +The core JAX AI Stack consists of four key libraries that provide the foundation for model development: JAX, [Flax](https://flax.readthedocs.io/en/stable/), [Optax](https://optax.readthedocs.io/en/latest/), and [Orbax](https://orbax.readthedocs.io/en/latest/). + +### **JAX: A Foundation for Composable, High-Performance Program Transformation** {#jax:-a-foundation-for-composable,-high-performance-program-transformation} + +[JAX](https://docs.jax.dev/en/latest/) is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale Machine Learning. With its functional programming model and friendly, NumPy-like API, JAX provides a solid foundation for higher-level libraries. + +With its compiler-first design, JAX inherently promotes scalability by leveraging [XLA](https://openxla.org/xla) (see Section \<\*\*\*\*\*\*\>) for aggressive, whole-program analysis, optimization, and hardware targeting. The JAX emphasis on functional programming (i.e., pure functions) makes its core program transformations more tractable and, crucially, composable + +These core transformations can be mixed and matched to achieve high performance and scaling of workloads across model size, cluster size, and hardware types: + +* **jit**: Just-in-time compilation of Python functions into optimized, fused XLA executables. +* **grad**: Automatic differentiation, supporting forward- and reverse-mode, as well as higher-order derivatives. +* **vmap**: Automatic vectorization, enabling seamless batching and data parallelism without modifying function logic. +* **pmap / shard\_map**: Automatic parallelization across multiple devices (e.g., TPU cores), forming the basis for distributed training. + +The seamless integration with XLA's GSPMD (General-purpose SPMD) model allows JAX to automatically parallelize computations across large TPU pods with minimal code changes. In most cases, scaling simply requires high-level sharding annotations, a stark contrast to frameworks where scaling may require more manual management of device placement and communication collectives + +### **Flax: Flexible Neural Network Authoring and "Model Surgery"** {#flax:-flexible-neural-network-authoring-and-"model-surgery"} + +#### **Flax \- neural network layers** {#flax---neural-network-layers} + +[Flax](https://flax.readthedocs.io/en/latest/index.html) is a library designed to simplify the creation, debugging, and analysis of neural networks in JAX. While pure functional API provided by JAX can be used to fully specify and train a ML model, users coming from the PyTorch (or TensorFlow) ecosystem are more used to and comfortable with the object oriented approach of specifying models as a graph of `torch.nn.Modules`. The abstractions provided by [Flax](https://flax.readthedocs.io/en/stable/) allow users to think more in terms of layers rather than functions, making it more developer friendly to an audience who value ergonomics and experimentation ease. [Flax](https://flax.readthedocs.io/en/stable/) also enables config driven model construction systems, such as those present in [MaxText](https://maxtext.readthedocs.io/en/latest/) and AxLearn, which separate out model hyperparameters from layer definition code. + +With a simple Pythonic API, it allows developers to express models using regular Python objects, while retaining the power and performance of JAX. Flax's NNX API is an evolution of the Flax Linen interface, incorporating lessons learned to offer a more user-friendly interface that remains consistent with the core JAX APIs. Since Flax modules are fully backed by the core JAX APIs, there is no performance penalty associated with defining the model in [Flax](https://flax.readthedocs.io/en/stable/). + +##### **Motivation** {#motivation} + +JAX’s pure functional API, while powerful, can be complex for new users since it requires all the program state to be explicitly managed by the user. This paradigm can be unfamiliar to developers used to other frameworks. Modern model architectures are often complex with individual portions of the model trained separately and merged to form the final model[^3], in a process commonly referred to as model surgery. Even with decoder-only LLMs which tend to have a straightforward architecture, post training techniques such as LoRA and quantization require the model definition to be easily manipulated allowing parts of the architecture to be modified or even replaced. + +The Flax NNX library, with its simple yet powerful Pythonic API enables this functionality in a way that is intuitive to the user, reducing the amount of cognitive overhead involved in authoring and training a model. + +##### **Design** {#design} + +The [Flax](https://flax.readthedocs.io/en/stable/) NNX library introduces an object oriented model definition system that encapsulates the model and random number generator state internally, reducing the cognitive overhead of the user and provides a familiar experience for those accustomed to frameworks like PyTorch or TensorFlow. By making submodule definitions Pythonic and providing APIs to traverse the module hierarchy, it allows for the model definition to be easily editable programmatically for model introspection and surgery. + +The [Flax](https://flax.readthedocs.io/en/stable/) NNX APIs are designed to be consistent with the core JAX APIs to allow users to exploit the full expressibility and performance of JAX, with lifted transformations for common operations like sharding, jit and others. Models defined using the NNX APIs can also be adapted to work with functional training loops, allowing the user the flexibility they need while retaining an intuitive object oriented API. + +##### **Key Strengths** {#key-strengths} + +* **Intuitive object oriented flexible APIs:** Layers are represented as pure Python objects with internal state management, simplifying model construction and training loops, while also advanced model surgery use cases through support for submodule replacement, partial initialization and model hierarchy traversal. +* **Consistent with Core JAX APIs:** Lifted transformations consistent with core JAX and fully compatible with functional JAX provide the full performance of JAX without sacrificing developer friendliness. + + +### **Optax: Composable Gradient Processing and Optimization Strategies** {#optax:-composable-gradient-processing-and-optimization-strategies} + +[Optax](https://optax.readthedocs.io/en/latest/index.html) is a gradient processing and optimization library for JAX. It is designed to empower model builders by providing building blocks that can be recombined in custom ways in order to train deep learning models amongst other applications. It builds on the capabilities of the core JAX library to provide a well tested high performance library of losses and optimizer functions and associated techniques that can be used to train ML models. + +#### Motivation {#motivation-1} + +The calculation and minimization of losses is at the core of what enables the training of ML models. With its support for automatic differentiation the core JAX library provides the numeric capabilities to train models, but it does not provide standard implementations of popular optimizers (ex. `RMSProp`, `Adam`) or losses (`CrossEntropy`, `MSE` etc). While it is true that a user could implement these functions by themselves (and some advanced users will choose to do so), a bug in an optimizer implementation would introduce hard to diagnose model quality issues. Rather than having the user implement such critical pieces, [Optax](https://optax.readthedocs.io/en/latest/) provides implementations of these algorithms that are tested for correctness and performance. + +The field of optimization theory lies squarely in the realm of research, however its central role in training also makes it an indispensable part of training production ML models. A library that serves this role needs to be both flexible enough to accommodate rapid research iterations and also robust and performant enough to be dependable for production model training. It should also provide well tested implementations of state of the art algorithms which match the standard equations. The [Optax](https://optax.readthedocs.io/en/latest/) library, through its modular composable architecture and emphasis on correct readable code is designed to achieve this. + +#### Design {#design-1} + +[Optax](https://optax.readthedocs.io/en/latest/) is designed to both enhance research velocity and the transition from research to production by providing readable, well-tested, and efficient implementations of core algorithms. Optax has uses beyond the context of deep learning, however in this context it can be viewed as a collection of well known loss functions, optimization algorithms and gradient transformations implemented in a pure functional fashion in line with the JAX philosophy. The collection of well known [losses](https://optax.readthedocs.io/en/latest/api/losses.html) and [optimizers](https://optax.readthedocs.io/en/latest/api/optimizers.html) enable users to get started with ease and confidence. + +The modular approach taken by Optax easily allows users to [chain multiple optimizers](https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#chain) together followed by other common [transformations](https://optax.readthedocs.io/en/latest/api/transformations.html) like gradient clipping for example and [wrap](https://optax.readthedocs.io/en/latest/api/optimizer_wrappers.html) it using common techniques like MultiStep or Lookahead to achieve powerful optimization strategies all within a few lines of code. The flexible interface allows for easy research into new optimization algorithms and also enables powerful second order optimization techniques like shampoo or muon. + +```py +# Optax implementation of a RMSProp optimizer with a custom learning rate schedule, gradient clipping and gradient accumulation. +optimizer = optax.chain( + optax.clip_by_global_norm(GRADIENT_CLIP_VALUE), + optax.rmsprop(learning_rate=optax.cosine_decay_schedule(init_value=lr,decay_steps=decay)), + optax.apply_every(k=ACCUMULATION_STEPS) +) + +# The same thing, in PyTorch +optimizer = optim.RMSprop(model_params, lr=LEARNING_RATE) +scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_STEPS) +for i, (inputs, targets) in enumerate(data_loader): + # ... Training loop body ... + if (i + 1) % ACCUMULATION_STEPS == 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VALUE) + optimizer.step() + scheduler.step() + optimizer.zero_grad() +``` + +As it can be seen in the example above, setting up an optimizer with a custom learning rate, gradient clipping and gradient accumulation is a simple drop in replacement block of code, compared to PyTorch which forces the user to modify their training loop to directly manage the learning rate scheduler, gradient clipping and gradient accumulation. + +#### Key Strengths {#key-strengths-1} + +* **Robust Library:** Provides a comprehensive library of losses, optimizers, and algorithms with a focus on correctness and readability. +* **Modular Chainable Transformations:** As shown above, this flexible API allows users to craft powerful, complex optimization strategies declaratively, without modifying the training loop. +* **Functional and Scalable:** The pure functional implementations integrate seamlessly with JAX's parallelization mechanisms (e.g., pmap), enabling the same code to scale from a single host to large clusters. + +### **Orbax / TensorStore \- Large scale distributed checkpointing** {#orbax-/-tensorstore---large-scale-distributed-checkpointing} + +[**Orbax**](https://orbax.readthedocs.io/en/latest/) is an any-scale checkpointing library for JAX users backed primarily by [**TensorStore**](https://google.github.io/tensorstore/), a library for efficiently reading and writing multi-dimensional arrays. The two libraries operate at different levels of the stack \- Orbax at the level of ML models and states \- TensorStore at the level of individual arrays. + +#### Motivation {#motivation-2} + +[Orbax](https://orbax.readthedocs.io/en/latest/), which centers on JAX users and ML checkpointing, aims to reduce the fragmentation of checkpointing implementations across disparate research codebases, increase adoption of important performance features outside the most cutting-edge codebases, and provide a clean, flexible API for novice and advanced users alike. With advanced features like fully asynchronous distributed checkpointing, multi-tier checkpointing and emergency checkpointing, [Orbax](https://orbax.readthedocs.io/en/latest/) enables resilience in the largest of training jobs while also providing a flexible representation for publishing checkpoints. + +#### ML Checkpointing vs Generalized Checkpoint/Restore {#ml-checkpointing-vs-generalized-checkpoint/restore} + +It is worth considering the difference between ML checkpoint systems ([Orbax](https://orbax.readthedocs.io/en/latest/), NeMO-Megatron, Torch Distributed Checkpoint) with generalized checkpoint systems like CRIU. + +Systems like CRIU & CRIUgpu behave analogously to VM live migration; they halt the entire system and take a snapshot of every last bit of information so it can be faithfully reconstructed. This captures the entirety of the process’ host memory, device memory and operating system state. This is far more information that is actually needed to reconstruct a ML workload, since for a ML workload, a very large fraction of this information (activations, data examples, file handles) is trivially reconstructed. Capturing this much data also incurs a large amount of time when the job is halted. + +ML checkpoint systems are designed to minimize the amount of time the accelerator is halted by selectively persisting information that cannot be reconstructed. Specifically, this entails persisting model weights, optimizer state, dataloader state and random number generator state, which is a far smaller amount of data. + +#### Design {#design-2} + +The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) centers around handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html) (nested containers) of arrays as the standard representation of JAX models. Saving and loading can be synchronous or asynchronous, with saving consisting of blocking and non-blocking phases. A higher-level `Checkpointer` class is provided, which facilitates checkpointing in a training loop, with save intervals, garbage collection, dataset checkpointing, and metadata management. Finally, Orbax provides customization layers for dealing with user-defined checkpointable objects and PyTree leaves. + +The storage layer of [Orbax](https://orbax.readthedocs.io/en/latest/index.html) is the [TensorStore](https://google.github.io/tensorstore/) library, which is not technically part of the JAX ecosystem at all, and seeks to provide a flexible and highly versatile library for array storage. However, it is not designed around ML concepts and introduces too much complexity and manual management for most JAX users. [Orbax](https://orbax.readthedocs.io/en/latest/index.html) smooths out this experience to provide users an easy to use ML specific API surface. + +To maximize the utilization of the accelerator, the checkpointing library must minimize the time it halts the training to snapshot the state. This is achieved by overlapping the checkpointing operations with the compute operations as shown in the diagram below. It’s worth noting that asynchronous checkpointing is table-stakes for large workloads and isn’t unique to [Orbax](https://orbax.readthedocs.io/en/latest/index.html). It is also present in other frameworks such as NeMO-Megatron and Torch Distributed Checkpoints. + +When considering asynchronous checkpointing with non overlapped device-to-host transfers, the amount of time the accelerator is halted is thus a function of the number of model parameters, the size of the parameters and the PCI link speed. Enabling fully overlapped D2H can further reduce this time by overlapping the D2H transfer with the forward pass of the next step. As long as the D2H transfer can complete before the next forward step completes, the checkpoint will become effectively[^4] free. + +Restarting from an error is similarly bound by two factors, the XLA compilation time and the speed of reading the weights back from storage. XLA compilation caches can make the former insignificant. Reading from storage is hardware dependent \- emergency checkpoints save to ramdisks which are extremely fast, however there is a speed spectrum that ranges from ramdisk to SSD, HDD and GCS. + +Specific industry-leading performance features have their own design challenges, and merit separate attention: + +* [**Async checkpointing**](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): Checkpointing only needs to block accelerator computations while data is being transferred from host to/from accelerator memory. Expensive I/O operations can take place in a background thread meaning save time can be reduced by 95-99% relative to blocking saves. Asynchronous loading is also possible, and can save time on startup, but requires more extensive effort to integrate and has not yet seen widespread adoption. +* [**OCDBT format**](https://orbax.readthedocs.io/en/latest/guides/checkpoint/optimized_checkpointing.html): Most previous checkpointing implementations stored parameters as separate subdirectories, which caused significant overhead for small arrays. TensorStore’s OCDBT format uses an efficient [B+ tree](https://en.wikipedia.org/wiki/B%2B_tree) format, which allows fine-grained control over shard shapes and file sizes that can be tuned to different filesystems and models. The save/load strategy provides scalability to tens of thousands of nodes by ensuring each host independently reads and writes only the relevant pieces of each array. +* [**Restore \+ broadcast**](https://cloud.google.com/blog/products/compute/unlock-faster-workload-start-time-using-orbax-on-jax): Hero-scale training runs replicate the model weights among multiple data-parallel replicas. Orbax provides a load balancing feature that distributes the burden evenly among available replicas when saving. It also leverages fast chip interconnects to avoid redundant reads of the model on different groups of hosts, instead loading on a single primary replica and broadcasting the weights to all other replicas. +* **Emergency checkpointing**: Hero-scale training suffers from frequent interruptions and hardware failures. Checkpointing to persistent RAM disk improves goodput for hero-scale jobs by allowing for increased checkpoint frequency, faster restore times, and improved resiliency, since TPU states may be corrupted on some replicas, but not all. + +#### Key Strengths {#key-strengths-2} + +* **Widespread adoption:** As checkpoints are a medium for communication of ML artifacts between different codebases and stages of ML development, widespread adoption is an inherent advantage. Currently, Orbax has [\~4 million](https://pypistats.org/packages/orbax-checkpoint) monthly package downloads. +* **Easy to use:** Orbax abstracts away complex technical aspects of checkpointing like async saving, single- vs. multi-controller, checkpoint atomicity, distributed filesystem details, TPU vs. GPU, etc. It condenses use cases into simple, but generalizable APIs (direct-to-path, sequence-of-steps). +* **Flexible:** While Orbax focuses on exposing a simple API surface for the majority of users, additional layers for handling custom checkpointable objects and PyTree nodes allow for flexibility in specialized use cases. +* **Performant and scalable:** Orbax provides a variety of features designed to make checkpointing as fast and as unobtrusive as possible, freeing developers to focus on efficiency in the remainder of the training loop. Scalability to the cutting edge of ML research is a top concern of the library; training runs at a scale of O(10k) nodes currently rely on Orbax. + +#### **Grain: Deterministic and Scalable Input Data Pipelines** {#grain:-deterministic-and-scalable-input-data-pipelines} + +[Grain](https://google-grain.readthedocs.io/en/latest/) is a Python library for reading and processing data for training and evaluating JAX models. It is flexible, fast and deterministic and supports advanced features like checkpointing which are essential to successfully training large workloads. It supports popular data formats and storage backends and also provides a flexible API to extend support to user specific formats and backends that are not natively supported. While [Grain](https://google-grain.readthedocs.io/en/latest/) is primarily designed to work with JAX, it is framework independent, does not require JAX to run and can be used with other frameworks as well. + +##### **Motivation** {#motivation-7} + +Data pipelines form a critical part of the training infrastructure \- they need to be flexible so that common transformations can be expressed efficiently, and performant enough that they are able to keep the accelerators busy at all times. They also need to be able to accommodate multiple storage formats and backends. Due to their higher step times, training large models at scale pose unique additional requirements on the data pipeline beyond those that are required by regular training workloads, primarily focused around determinism and reproducibility[^5]. The [Grain](https://google-grain.readthedocs.io/en/latest/) library is designed with a flexible enough architecture to address all these needs. + +##### **Design** {#design-7} + +At the highest level, there are two ways to structure an input pipeline, as a separate cluster of data workers or by co-locating the data workers on the hosts that drive the accelerators. [Grain](https://google-grain.readthedocs.io/en/latest/) chooses the latter for a variety of reasons. + +Accelerators are combined with powerful hosts that typically sit idle during training steps, which makes it a natural choice to run the input data pipeline. There are however additional advantages to doing so \- it simplifies the user's view of data sharding by providing a consistent view of sharding across input and compute. It could be argued that putting the data worker on the accelerator host risks saturating the host CPU, however this does not preclude offloading compute intensive transformations to another cluster via RPCs[^6]. + +On the API front, with a pure python implementation that supports multiple processes and a flexible API, [Grain](https://google-grain.readthedocs.io/en/latest/) enables users to implement arbitrarily complex data transformations by composing together pipeline stages based on well understood [transformation](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) paradigms. + +Out of the box, [Grain](https://google-grain.readthedocs.io/en/latest/) supports efficient random access data formats like `ArrayRecord` and `Bagz` alongside other popular data formats such as Parquet and `TFDS`. [Grain](https://google-grain.readthedocs.io/en/latest/) includes support for reading from local file systems as well as reading from GCS by default. Along with supporting popular storage formats and backends, a clean abstraction to the storage layer allows users to easily add support for or wrap their existing data sources to be compatible with the [Grain](https://google-grain.readthedocs.io/en/latest/) library. + +##### **Key Strengths** {#key-strengths-7} + +* **Deterministic data feeding:** Colocating the data worker with the accelerator and coupling it with a stable global shuffle and [checkpointable iterators](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#checkpointing) allows the model state and data pipeline state to be checkpointed together in a consistent snapshot using [Orbax](https://docs.google.com/document/d/1rS4DGWSbHOX0rZgjv2rV2DcXuBnHvnCKOTAarZiC1Dg/edit?tab=t.0#heading=h.rtje6zr33hjw), enhancing the determinism of the training process. +* **Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline. +* **Extensible support for multiple formats and backends:** An extensible [data sources](https://google-grain.readthedocs.io/en/latest/tutorials/data_sources/index.html) API supports popular storage formats and backends and allows users to easily add support for new formats and backends. +* **Powerful debugging interface:** Data pipeline [visualization tools](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_debugging_tutorial.html) and a debug mode allow users to introspect, debug and optimize the performance of their data pipelines. + diff --git a/docs/source/ecosystem_overview/extended.md b/docs/source/ecosystem_overview/extended.md new file mode 100644 index 0000000..852f46c --- /dev/null +++ b/docs/source/ecosystem_overview/extended.md @@ -0,0 +1,243 @@ +## The Extended JAX Ecosystem + +Beyond the core stack, a rich ecosystem of specialized libraries provides the infrastructure, advanced tools, and application-layer solutions needed for end-to-end ML development. + +### **Foundational Infrastructure: Compilers and Runtimes** {#foundational-infrastructure:-compilers-and-runtimes} + +#### **XLA: The Hardware-Agnostic, Compiler-Centric Engine** {#xla:-the-hardware-agnostic,-compiler-centric-engine} + +##### **Motivation** {#motivation-3} + +XLA or Accelerated Linear Algebra is our domain specific compiler, which is well integrated into JAX and supports TPU, CPU and GPU hardware devices. From inception, XLA has been designed to be a hardware agnostic code generator targeting TPUs, GPUs, and CPUs. + +Our compiler-first design is a fundamental architectural choice that creates a durable advantage in a rapidly evolving research landscape. In contrast, the prevailing kernel-centric approach in other ecosystems relies on hand-optimized libraries for performance. While this is highly effective for stable, well-established model architectures, it creates a bottleneck for innovation. When new research introduces novel architectures, the ecosystem must wait for new kernels to be written and optimized. Our compiler-centric design, however, can often generalize to new patterns, providing a high-performance path for cutting-edge research from day one. + +##### **Design** {#design-3} + +XLA works by Just-In-Time (JIT) compiling the computation graphs that JAX generates during its tracing process (e.g., when a function is decorated with `@jax.jit`). + +This compilation follows a multi-stage pipeline: + +JAX Computation Graph → High-Level Optimizer (HLO) → Low-Level Optimizer (LLO) → Hardware Code + +* **From JAX Graph to HLO**: The captured JAX computation graph is converted into XLA's HLO representation. At this high level, powerful, hardware-agnostic optimizations like operator fusion and efficient memory management are applied. The **StableHLO** dialect serves as a durable, versioned interface for this stage. +* **From HLO to LLO:** After high-level optimizations, hardware-specific backends take over, lowering the HLO representation into a machine-oriented LLO. +* **From LLO to Hardware Code:** The LLO is finally compiled into highly-efficient machine code. For TPUs, this code is bundled as **Very Long Instruction Word (VLIW)** packets that are sent directly to the hardware. + +For scaling, XLA's design is built around parallelism. It employs algorithms to maximally utilize the matrix multiplication units (MXUs) on a chip. Between chips, XLA uses **SPMD (Single Program Multiple Data)**, a compiler-based parallelization that uses a single program across all devices. This powerful model is exposed through JAX APIs, allowing users to manage data, model, or pipeline parallelism with high-level sharding annotations. + +For more complex parallelism patterns, **Multiple Program Multiple Data (MPMD)** is also possible, and libraries like `PartIR:MPMD` allow JAX users to provide MPMD annotations as well. + +##### **Key strengths** {#key-strengths-3} + +* **Compilation**: just in time compilation of the computation graph allows for optimizations to memory layout, buffer allocation, and memory management. Alternatives such as kernel based methodologies put that burden on the user. In most cases, XLA can achieve excellent performance without compromising developer velocity. +* **Parallelism:** XLA implements several forms of parallelism with SPMD, and this is exposed at the JAX level. This allows for users to express sharding strategies easily, allowing experimentation and scalability of models across thousands of chips. + +#### **Pathways: A Unified Runtime for Massive-Scale Distributed Computation** {#pathways:-a-unified-runtime-for-massive-scale-distributed-computation} + +[Pathways](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) offers abstractions for distributed training and inference with built in fault tolerance and recovery, allowing ML researchers to code as if they are using a single, powerful machine. + +##### **Motivation** {#motivation-4} + +To be able to train and deploy large models, hundreds to thousands of chips are necessary. These chips are spread across numerous racks and host machines. A training job is a large-scale synchronous program that requires all of these chips, and their respective hosts to be working in tandem on XLA computations that have been parallelized (sharded). In the case of large language models, which may need more than tens of thousands of chips, this service must be capable of spanning multiple pods across a data center fabric in addition to using ICI and OCI fabrics within a pod. + +##### **Design** {#design-4} + +ML Pathways is the system we use for coordinating distributed computations across hosts and TPU chips. It is designed for scalability and efficiency across hundreds of thousands of accelerators. For large-scale training, it provides a single Python client for multi-slice/multi-pod jobs, [Megascale XLA](https://openxla.org/xprof/megascale_stats) integration, Compilation Service, and Remote Python. It also supports cross-slice parallelism and preemption tolerance, enabling automatic recovery from resource preemptions. + +Pathways incorporates optimized cross host collectives which enable XLA computation graphs to further extend beyond a single TPU pod. It expands XLA's support for data, model, and pipeline parallelism to work across TPU slice boundaries using DCN by means of integrating a distributed runtime that manages DCN communication with XLA communication primitives. + +##### **Key strengths** {#key-strengths-4} + +* The single-controller architecture, integrated with JAX, is a key abstraction. It allows researchers to explore various sharding and parallelism strategies for training and deployment while scaling to tens of thousands of chips with ease. +* Scaling to tens of thousands of chips with ease, allowing exploration of various sharding and parallelism strategies during model research, training and deployment. + +### **Advanced Development: Performance, Data, and Efficiency** {#advanced-development:-performance,-data,-and-efficiency} + +#### **Pallas: Writing High-Performance Custom Kernels in JAX** {#pallas:-writing-high-performance-custom-kernels-in-jax} + +While JAX is compiler first, there are situations where the user would like to exercise fine grained control over the hardware to achieve maximum performance. Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. It aims to provide precise control over the generated code, combined with the high-level ergonomics of JAX tracing and the `jax.numpy API`. + +Pallas exposes a grid-based parallelism model where a user-defined kernel function is launched across a multi-dimensional grid of parallel work-groups. It enables explicit management of the memory hierarchy by allowing the user to define how tensors are tiled and transferred between slower, larger memory (e.g., HBM) and faster, smaller on-chip memory (e.g., VMEM on TPU, Shared Memory on GPU), using index maps to associate grid locations with specific data blocks. Pallas can lower the same kernel definition to execute efficiently on both Google's TPUs and various GPUs by compiling kernels into an intermediate representation suitable for the target architecture – Mosaic for TPUs, or utilizing technologies like Triton for the GPU path. With Pallas, users can write high performance kernels that specialize blocks like attention to achieve the best model performance on the target hardware without needing to rely on vendor specific toolkits. + +#### **Tokamax: A Curated Library of State-of-the-Art Kernels** {#tokamax:-a-curated-library-of-state-of-the-art-kernels} + +If Pallas is the *tool* for authoring kernels, [Tokamax](https://github.com/openxla/tokamax) is a *library* of state-of-the-art custom accelerator kernels supporting both TPUs and GPUs, built on top of JAX and Pallas enabling users to push their hardware to the maximum. It also provides tooling for users to build and autotune their own custom kernels. + +##### **Motivation** {#motivation-5} + +JAX, with its roots in XLA, is a compiler-first framework, however a narrow set of cases exists where the user needs to take direct control of the hardware to achieve maximum performance[^7]. Custom kernels are critical to squeezing out every last ounce of performance from expensive ML accelerator resources such as TPUs and GPUs. While they are widely employed to enable performant execution of key operators such as Attention, implementing them requires a deep understanding of both the model and the target hardware (micro)architecture. Tokamax provides one authoritative source of curated, well-tested, high-performance kernels, in conjunction with robust shared infrastructure for their development, maintenance, and lifecycle management. Such a library can also act as a reference implementation for users to build on and customize as necessary. This allows users to focus on their modeling efforts without needing to worry about infrastructure. + +##### **Design** {#design-5} + +For any given kernel, Tokamax provides a common API that may be backed by multiple implementations. For example, TPU kernels may be implemented either by standard XLA lowering, or explicitly via Pallas/Mosaic-TPU. GPU kernels may be implemented by standard XLA lowering, via Mosaic-GPU, or Triton. By default, it picks the best-known implementation for a given configuration, determined by cached results from periodic autotuning and benchmarking runs, though users may choose specific implementations if desired. New implementations may be added over time to better exploit specific features in new hardware generations for even better performance. + +A key component of the library, beyond the kernels themselves, is the supporting infrastructure that will help power users choosing to write their own custom kernels. For example, the autotuning infrastructure lets the user define a set of configurable parameters (e.g., tile sizes) that Tokamax can perform an exhaustive sweep on, to determine and cache the best possible tuned settings. Nightly regressions protect users from unexpected performance and numerics issues caused by changes to underlying compiler infrastructure or other dependencies. + +##### **Key Strengths** {#key-strengths-5} + +* **Seamless developer experience:** A unified, curated, library will provide known-good high-performance implementations of key kernels, with clear expressions of supported hardware generations and expected performance, both programmatically and in documentation. This minimizes fragmentation and churn. +* **Flexibility and lifecycle management:** Users may choose different implementations as desired, even changing them over time if appropriate. For example, if the XLA compiler enhances support for certain operations obviating the need for custom kernels, there is a simple path to deprecation and migration. +* **Extensibility:** Users may implement their own kernels, while leveraging well-supported shared infrastructure, allowing them to focus on their value added capabilities and optimizations. Clearly authored standard implementations serve as a starting point for users to learn from and extend. + +#### **Qwix: Non-Intrusive, Comprehensive Quantization** {#qwix:-non-intrusive,-comprehensive-quantization} + +Qwix is a comprehensive quantization library for the JAX ecosystem, supporting both LLMs and other model types across all stages, including training (QAT, QT, QLoRA) and inference (PTQ), targeting both XLA and on-device runtimes. + +##### **Motivation** {#motivation-6} + +Existing quantization libraries, particularly in the PyTorch ecosystem, often serve limited purposes (e.g., only PTQ or only QLoRA). This fragmented landscape forces users to switch tools, impeding consistent code usage and precise numerical matching between training and inference. Furthermore, many solutions require substantial model modifications, tightly coupling the model logic to the quantization logic. + +##### **Design** {#design-6} + +Qwix's design philosophy emphasizes a comprehensive solution and, critically, **non-intrusive model integration**. It is architected with a hierarchical, extensible design built on reusable functional APIs. + +This non-intrusive integration is achieved through a meticulously designed **interception mechanism** that redirects JAX functions to their quantized counterparts. This allows users to integrate their models without any modifications, completely decoupling quantization code from model definitions. + +The following example demonstrates applying `w4a4` (4-bit weight, 4-bit activation) quantization to an LLM's MLP layers and `w8` (8-bit weight) quantization to the embedder. To change the quantization recipe, only the rules list needs to be updated. + +```py +fp_model = ModelWithoutQuantization(...) +rules = [ + qwix.QuantizationRule( + module_path=r'embedder', + weight_qtype='int8', + ), + qwix.QuantizationRule( + module_path=r'layers_\d+/mlp', + weight_qtype='int4', + act_qtype='int4', + tile_size=128, + weight_calibration_method='rms,7', + ), +] +quantized_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules)) +``` + +##### Key Strengths {#key-strengths-6} + +* **Comprehensive Solution:** Qwix is broadly applicable across numerous quantization scenarios, ensuring consistent code usage between training and inference. +* **Non-Intrusive Model Integration:** As the example shows, users can integrate models with a single line of code, without modification. This allows developers to easily sweep hyperparameters over many quantization schemes to find the best quality/performance tradeoff. +* **Federated with Other Libraries:** Qwix seamlessly integrates with the JAX AI stack. For example, Tokamax automatically adapts to use quantized versions of kernels, without additional user code, when the model is quantized with Qwix. +* **Research Friendly:** Qwix's foundational APIs and extensible architecture empower researchers to explore new algorithms and facilitate straightforward comparisons with integrated benchmark and evaluation tools. + +### **The Application Layer: Training and Alignment** {#the-application-layer:-training-and-alignment} + +#### **Foundation Model Training: MaxText and MaxDiffusion** {#foundation-model-training:-maxtext-and-maxdiffusion} + +[MaxText](https://maxtext.readthedocs.io/en/latest/) and [MaxDiffusion](https://github.com/AI-Hypercomputer/maxdiffusion) are Google’s flagship LLM and Diffusion model training frameworks, respectively. With a large selection of highly optimized implementations of popular open-weights models, these repositories serve a dual purpose: they function as both a ready-to-go model training codebase and as a reference that foundation model builders can use to build upon. + +##### **Motivation** {#motivation-8} + +There is rapid growth of interest across the industry in training GenAI models. The popularity of open models has accelerated this trend, providing users with proven architectures. To train and adapt these models, users require high performance, efficiency, scalability to extreme numbers of chips, and clear, understandable code. They need a framework that can adapt to new techniques and target both TPUs and GPUs. [MaxText](https://maxtext.readthedocs.io/en/latest/) and MaxDiffusion are comprehensive solutions designed to fulfill these needs. + +##### **Design** {#design-8} + +[MaxText](https://maxtext.readthedocs.io/en/latest/) and MaxDiffusion are foundation model codebases designed with readability and performance in mind. They are structured with well-tested, reusable components: model definitions that leverage custom kernels (like Tokamax) for maximum performance, a training harness for orchestration and monitoring, and a powerful config system that allows users to control details like sharding and quantization (via Qwix) through an intuitive interface. Advanced reliability features like multi-tier checkpointing are incorporated to ensure sustained goodput. + +They leverage the best-in-class JAX libraries—Qwix, [Tunix](https://tunix.readthedocs.io/en/latest/), [Orbax](https://orbax.readthedocs.io/en/latest/index.html), and [Optax](https://optax.readthedocs.io/en/latest/)—to deliver core capabilities. This allows them to provide robust, scalable infrastructure, reducing development overhead and allowing users to focus on the modeling task. For inference, the model code is shared to enable efficient and scalable serving. + +##### **Key Strengths** {#key-strengths-8} + +* **Performant by Design:** With training infrastructure set up for high "goodput" (useful throughput) and model implementations optimized for high MFU (Model Flops Utilization), [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion) and MaxDiffusion deliver high performance at scale out of the box +* **Built for Scale:** Leveraging the power of the JAX AI stack (especially Pathways), these frameworks allow users to scale seamlessly from tens of chips to tens of thousands of chips +* **Solid Base for Foundation Model Builders:** The high-quality, readable implementations serve as a solid starting point for builders to either use as an end-to-end solution or as a reference implementation for their own customizations + +#### **Post-Training and Alignment: The Tunix Framework** {#post-training-and-alignment:-the-tunix-framework} + +[Tunix](https://tunix.readthedocs.io/en/latest/) offers state-of-the-art open-source reinforcement learning (RL) algorithms, along with a robust framework and infrastructure, providing a streamlined path for users to experiment with LLM post-training techniques (including Supervised Fine-Tuning (SFT) and alignment) using JAX and TPUs. + +##### **Motivation** {#motivation-9} + +Post-training is the critical step in unlocking the true power of LLMs. The Reinforcement Learning stage is particularly crucial for developing alignment and reasoning capabilities. While fast-moving open-source development in this area has been prolific, it has been almost exclusively based on PyTorch and GPUs, leaving a fundamental gap for JAX and TPU solutions. [Tunix](https://tunix.readthedocs.io/en/latest/) (Tune-in-JAX) is a high-performance, JAX-native library designed to fill this gap. + +##### **Design** {#design-9} + +![][image4] + +From a framework perspective, [Tunix](https://tunix.readthedocs.io/en/latest/) enables a state-of-the-art setup that **clearly separates RL algorithms from the infrastructure**. It offers a lightweight, client-like API that hides the complexity of the RL infrastructure, allowing users to develop new algorithms easily. [Tunix](https://tunix.readthedocs.io/en/latest/) provides out-of-the-box solutions for popular algorithms, including PPO, DPO, and others. + +On the infrastructure side, [Tunix](https://tunix.readthedocs.io/en/latest/) has native integration with Pathways, enabling a single-controller architecture that makes multi-node RL training easily accessible. On the trainer side, [Tunix](https://tunix.readthedocs.io/en/latest/) natively supports parameter-efficient training (e.g., LoRA) and leverages JAX sharding and XLA (GSPMD) to generate a performant compute graph. It supports popular open-source models like Gemma and Llama out of the box. + +##### **Key Strengths** {#key-strengths-9} + +* **Simplicity:** It provides a high-level, client-like API that abstracts away the complexities of the underlying distributed infrastructure. +* **Developer Efficiency:** Tunix accelerates the R\&D lifecycle with out-of-the-box algorithms and pre-built "recipes," enabling users to get a working model and iterate quickly. +* **Performance and Scalability:** Tunix enables a highly efficient and horizontally scalable training infrastructure by leveraging Pathways as a single controller on the backend. + + +### **The Application Layer: Production and Inference** {#the-application-layer:-production-and-inference} + +A historical challenge for JAX adoption has been the path from research to production. The JAX AI stack now provides a mature, two-pronged production story that offers both ecosystem compatibility and native JAX performance. + +#### **High-Performance LLM Inference: The vLLM Solution** {#high-performance-llm-inference:-the-vllm-solution} + +vLLM TPU is Google's high-performance inference stack designed to run PyTorch and JAX native Large Language Models (LLMs) efficiently on Cloud TPUs. It achieves this by natively integrating the popular open-source vLLM framework with Google's JAX and TPU ecosystem. + +##### **Motivation** {#motivation-10} + +The industry is rapidly evolving, with growing demand for seamless, high-performing, and easy-to-use inference solutions. Users often face significant challenges from complex and inconsistent tooling, subpar performance, and limited model compatibility. The vLLM stack addresses these issues by providing a unified, performant, and intuitive platform. + +##### **Design** {#design-10} + +This solution pragmatically extends the vLLM framework, rather than reinventing it. vLLM is a highly optimized open-source LLM serving engine known for its high throughput, achieved via key features like **`PagedAttention`** (which manages KV caches like virtual memory to minimize fragmentation) and **`Continuous Batching`** (which dynamically adds requests to the batch to improve utilization). + +vLLM TPU builds on this foundation and develops core components for request handling, scheduling, and memory management. It introduces a **JAX-based backend** that acts as a bridge, translating vLLM's computational graph and memory operations into TPU-executable code. This backend handles device interactions, JAX model execution, and the specifics of managing the KV cache on TPU hardware. It incorporates TPU-specific optimizations, such as efficient attention mechanisms (e.g., leveraging JAX Pallas kernels for Ragged Paged Attention) and quantization, all tailored for the TPU architecture. + +##### **Key Strengths** {#key-strengths-10} + +* **Zero Onboarding/Offboarding Cost for Users:** Users can adopt this solution without significant friction. From a user-experience perspective, processing inference requests is identical to on GPUs. The CLI to start the server, accept prompts, and return outputs are all shared. +* **Fully Embrace the Ecosystem:** This approach utilizes and contributes to the vLLM interface and user experience, ensuring compatibility and ease of use. +* **Fungibility between TPUs and GPUs:** The solution works efficiently on both TPUs and GPUs, allowing users flexibility. +* **Cost Efficient (Best Perf/$):** Optimizes performance to provide the best performance-to-cost ratio for popular models. + +#### **JAX-Native Serving: Orbax Serialization and Neptune Serving Engine** {#jax-native-serving:-orbax-serialization-and-neptune-serving-engine} + +For models other than LLMs, or for users desiring a fully JAX-native pipeline, the Orbax serialization library and Neptune serving engine (NSE) system provide an end-to-end, high-performance serving solution. + +##### **Motivation** {#motivation-11} + +Historically, JAX models often relied on a circuitous path to production, such as being wrapped in TensorFlow graphs and deployed using TensorFlow serving. This approach introduced significant limitations and inefficiencies, forcing developers to engage with a separate ecosystem and slowing down iteration. A dedicated JAX-native serving system is crucial for sustainability, reduced complexity, and optimized performance. + +##### **Design** {#design-11} + +This solution consists of two core components, as illustrated in the diagram below. + +1. **Orbax Serialization Library:** Provides user-friendly APIs for serializing JAX models into a new, robust Orbax serialization format. This format is optimized for production deployment. Its core includes: (a) directly representing JAX model computations using **StableHLO**, allowing the computation graph to be represented natively, and (b) leveraging **TensorStore** for storing weights, enabling fast checkpoint loading for serving. +2. **Neptune Serving Engine (NSE):** This is the accompanying high-performance, flexible serving engine (typically deployed as a C++ binary) designed to natively run JAX models in the Orbax format. NSE offers production-essential capabilities, such as fast model loading, high-throughput concurrent serving with built-in batching, support for multiple model versions, and both single- and multi-host serving (leveraging PJRT and Pathways). + +##### **Key Strengths** {#key-strengths-11} + +* **JAX Native Serving:** The solution is built natively for JAX, eliminating inter-framework overhead in model serialization and serving. This ensures lightning-fast model loading and optimized execution across CPUs, GPUs, and TPUs. +* **Effortless Production Deployment:** Serialized models provide a **hermetic deployment path** that is unaffected by drift in Python dependencies and enables runtime model integrity checks. This provides a seamless, intuitive path for JAX model productionization. +* **Enhanced Developer Experience:** By eliminating the need for cumbersome framework wrapping, this solution significantly reduces dependencies and system complexity, speeding up iteration for JAX developers. + +### **System-Wide Analysis and Profiling** {#system-wide-analysis-and-profiling} + +#### **XProf: Deep, Hardware-Integrated Performance Profiling** {#xprof:-deep,-hardware-integrated-performance-profiling} + +[XProf](https://openxla.org/xprof) is a profiling and performance analysis tool that provides in-depth visibility into various aspects of ML workload execution, enabling users to debug and optimize performance. It is deeply integrated into both the JAX and TPU ecosystems. + +##### **Motivation** {#motivation-12} + +On one hand, ML workloads are growing ever more complicated. On the other, there is an explosion of specialized hardware capabilities targeting these workloads. Matching the two effectively to ensure peak performance and efficiency is critical, given the enormous costs of ML infrastructure. This requires deep visibility into both the workload and the hardware, presented in a way that is easily consumable. XProf excels at this. + +##### **Design** {#design-12} + +XProf consists of two primary components: collection and analysis. + +1. **Collection:** XProf captures information from various sources: annotations in the user’s JAX code, cost models for operations within the XLA compiler, and **purpose-built hardware profiling features within the TPU**. This collection can be triggered programmatically or on-demand, generating a comprehensive event artifact. +2. **Analysis:** XProf post-processes the collected data and creates a suite of powerful visualizations, accessed via a browser. + +##### **Key Strengths** {#key-strengths-12} + +The true power of XProf comes from its deep integration with the full stack, providing a breadth and depth of analysis that is a tangible benefit of the co-designed JAX/TPU ecosystem. + +* **Co-designed with the TPU:** XProf exploits hardware features specifically designed for seamless profile collection, enabling a collection overhead of **less than 1%**. This allows profiling to be a lightweight, iterative part of development. +* **Breadth and Depth of Analysis:** XProf yields deep analysis across multiple axes. Its tools include: + * **Trace Viewer:** An op-by-op timeline view of execution on different hardware units (e.g., TensorCore). + * **HLO Op Profile:** Breaks down the total time spent into different categories of operations. + * **Memory Viewer:** Details memory allocations by different ops during the profiled window. + * **Roofline Analysis:** Helps identify whether specific ops are compute- or memory-bound and how far they are from the hardware’s peak capabilities. + * **Graph Viewer:** Provides a view into the full HLO graph executed by the hardware. + diff --git a/docs/source/ecosystem_overview/modular.md b/docs/source/ecosystem_overview/modular.md new file mode 100644 index 0000000..1a06ec8 --- /dev/null +++ b/docs/source/ecosystem_overview/modular.md @@ -0,0 +1,44 @@ +## A Modular, Compiler-First Architecture for Modern AI + +The [JAX AI stack](https://jaxstack.ai/) extends the JAX numerical core with a collection of Google-backed composable libraries, evolving it into a robust, end-to-end, open-source platform for Machine Learning at extreme scales. As such, the JAX AI stack consists of a comprehensive and robust ecosystem that addresses the entire ML lifecycle: + +* **Industrial-Scale Foundation:** The stack is architected for massive scale, leveraging ML Pathways for orchestrating training across tens of thousands of chips and [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for resilient, high-throughput asynchronous checkpointing, enabling production-grade training of state-of-the-art models. +* **Complete, Production-Ready Toolkit:** It provides a comprehensive set of libraries for the entire development process: Flax for flexible model authoring and "surgery," [Optax](#optax:-composable-gradient-processing-and-optimization-strategies) for composable optimization strategies, and [Grain](https://google-grain.readthedocs.io/en/latest/) for the deterministic data pipelines essential for reproducible large-scale runs. +* **Peak, Specialized Performance:** To achieve maximum hardware utilization, the stack offers specialized libraries including Tokamax for state-of-the-art custom kernels, Qwix for non-intrusive quantization that boosts training and inference speed, and XProf for deep, hardware-integrated performance profiling. +* **Full Path to Production:** The stack provides a seamless transition from research to deployment. This includes [MaxText](https://maxtext.readthedocs.io/en/latest/) as a scalable reference for foundation model training, [Tunix](https://tunix.readthedocs.io/en/latest/) for state-of-the-art reinforcement learning (RL) and alignment, and a unified inference solution via vLLM integration and the native JAX Serving runtime. + +The JAX ecosystem philosophy is one of loosely coupled components, each of which does one thing well. Rather than being a monolithic ML framework, JAX itself is narrowly-scoped and focuses on efficient array operations and program transformations. The ecosystem is built upon this core framework to provide a wide array of functionalities, related to both the training of ML models and other types of workloads such as scientific computing. + +This system of loosely coupled components hands freedom of choice back to users, enabling them to select and combine libraries in the best way to suit their requirements. From a software engineering perspective, this architecture also allows parts that would traditionally be considered core framework components (for example data pipelines, checkpointing, etc.) to be iterated upon rapidly without the risk of destabilizing the core framework or being caught up in release cycles. Given that most functionality is brought in using libraries rather than via changes to a monolithic framework, it makes the core numerics library more durable and adaptable to future shifts in the technology landscape. + +The following sections provide a technical overview of the JAX ecosystem, its key features, the design decisions behind them, and how they combine to build a durable platform for modern ML workloads. + +**Table 1: The JAX AI Stack and other Ecosystem Components** + +| Component | Function / Description | +| ----- | ----- | +| **JAX AI stack core and components**[^2] | | +| [**JAX**](https://docs.jax.dev/en/latest/) | Accelerator-oriented array computation and program transformation (JIT, grad, vmap, pmap). | +| [**Flax**](https://flax.readthedocs.io/en/stable/) | Flexible neural network authoring library for intuitive model creation and "surgery." | +| [**Optax**](https://optax.readthedocs.io/en/latest/) | A library of composable gradient processing and optimization transformations. | +| [**Orbax**](https://orbax.readthedocs.io/en/latest/) | "Any-scale" distributed checkpointing library for hero-scale training resilience. | +| **JAX Ecosystem \- Infrastructure** | | +| **[XLA](https://openxla.org/)** | Distributed runtime for orchestrating computation across tens of thousands of chips. | +| **[Pathways](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro)** | A JAX extension for writing low-level, high-performance custom kernels in Python. | +| **JAX Ecosystem \- Adv. Development** | | +| **[Pallas](https://docs.jax.dev/en/latest/pallas/index.html)** | A JAX extension for writing low-level, high-performance custom kernels in Python. | +| **[Tokamax](https://github.com/openxla/tokamax)** | A curated library of state-of-the-art, high-performance custom kernels (e.g., Attention). | +| **[Qwix](https://github.com/google/qwix)** | A comprehensive, non-intrusive library for quantization (PTQ, QAT, QLoRA). | +| **[Grain](https://google-grain.readthedocs.io/en/latest/)** | A scalable, deterministic, and checkpointable input data pipeline library. | +| **JAX Ecosystem \- Application** | | +| **[MaxText / MaxDiffusion](https://maxtext.readthedocs.io/en/latest/)** | Flagship, scalable reference frameworks for training foundation models (LLM, Diffusion). | +| **[Tunix](https://tunix.readthedocs.io/en/latest/index.html)** | A framework for state-of-the-art post-training and alignment (RLHF, DPO). | +| **[vLLM](https://docs.vllm.ai/projects/tpu/en/latest/)** | A high-performance LLM inference solution via native integration of the vLLM framework. | +| **Neptune Serving Engine** (coming soon) | JAX Serving Runtime: a high-performance, JAX-native C++ server for non-LLM models. | +| **[XProf](https://openxla.org/xprof)** | A deep, hardware-integrated profiler for system-wide performance analysis. | + +![][image1] +**Figure 1: The JAX AI Stack and Ecosystem Components** + +**![][image2]** + diff --git a/docs/source/ecosystem_overview/the_ecosystem_overview_tr.md b/docs/source/ecosystem_overview/the_ecosystem_overview_tr.md new file mode 100644 index 0000000..d21b814 --- /dev/null +++ b/docs/source/ecosystem_overview/the_ecosystem_overview_tr.md @@ -0,0 +1,37 @@ + + + +# The JAX Ecosystem: A Modular, Scalable, and High-Performance ML Ecosystem + + +```{toctree} +:maxdepth: 1 + +modular +architectural +core +extended +comparative +conclusion +``` + + + +[^1]: Image diffusion models are a typical example of this and can commonly be divided logically into a separately trained prompt encoder and a diffusion backbone. + +[^2]: We say effectively free since there could be other bottlenecks such as the DMA engines, HBM bandwidth contention etc. that still incur a performance penalty. + +[^3]: This is a well established paradigm and has precedent in the CPU world, where compiled code forms the bulk of the program with developers dropping down to intrinsics or inline assembly to optimize performance critical sections. + +[^4]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup. + +[^5]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens. + +[^6]: Some of the equivalents here are not true 1:1 comparisons because PyTorch draws API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently. + +[image1]: + +[image2]: + +[image3]: + diff --git a/docs/source/index.rst b/docs/source/index.rst index 7aac42b..47c61bf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,14 @@ JAX AI Stack data_loaders pytorch_users +.. toctree:: + :hidden: + :caption: FIXME OVERVIEW + :maxdepth: 1 + + ecosystem_overview/the_ecosystem_overview_tr + + .. toctree:: :hidden: :caption: Example applications From 34599a3152f89117d2d3cf91fa0450ad1db8b2c2 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 11 Nov 2025 16:34:34 +0100 Subject: [PATCH 02/14] MAINT: first pass cleanup the conversion artefacts --- docs/source/_static/images/JAX_ecosystem.svg | 1 + docs/source/_static/images/Tunix_diagram.svg | 1 + .../_static/images/async_checkpointing.svg | 1 + .../_static/images/programming_TPUS.svg | 1 + .../_static/images/serving_orbax_nse.svg | 1 + .../ecosystem_overview/architectural.md | 2 +- docs/source/ecosystem_overview/comparative.md | 12 +- docs/source/ecosystem_overview/conclusion.md | 18 +- docs/source/ecosystem_overview/core.md | 84 ++++++---- docs/source/ecosystem_overview/extended.md | 158 ++++++++++-------- docs/source/ecosystem_overview/modular.md | 39 +++-- .../the_ecosystem_overview_tr.md | 13 -- docs/source/index.rst | 15 +- 13 files changed, 175 insertions(+), 171 deletions(-) create mode 100644 docs/source/_static/images/JAX_ecosystem.svg create mode 100644 docs/source/_static/images/Tunix_diagram.svg create mode 100644 docs/source/_static/images/async_checkpointing.svg create mode 100644 docs/source/_static/images/programming_TPUS.svg create mode 100644 docs/source/_static/images/serving_orbax_nse.svg diff --git a/docs/source/_static/images/JAX_ecosystem.svg b/docs/source/_static/images/JAX_ecosystem.svg new file mode 100644 index 0000000..d264629 --- /dev/null +++ b/docs/source/_static/images/JAX_ecosystem.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/source/_static/images/Tunix_diagram.svg b/docs/source/_static/images/Tunix_diagram.svg new file mode 100644 index 0000000..24e03ba --- /dev/null +++ b/docs/source/_static/images/Tunix_diagram.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/source/_static/images/async_checkpointing.svg b/docs/source/_static/images/async_checkpointing.svg new file mode 100644 index 0000000..b9e2a76 --- /dev/null +++ b/docs/source/_static/images/async_checkpointing.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/source/_static/images/programming_TPUS.svg b/docs/source/_static/images/programming_TPUS.svg new file mode 100644 index 0000000..7bc4e99 --- /dev/null +++ b/docs/source/_static/images/programming_TPUS.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/source/_static/images/serving_orbax_nse.svg b/docs/source/_static/images/serving_orbax_nse.svg new file mode 100644 index 0000000..e64e40b --- /dev/null +++ b/docs/source/_static/images/serving_orbax_nse.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/source/ecosystem_overview/architectural.md b/docs/source/ecosystem_overview/architectural.md index 6a62a6f..e12fc13 100644 --- a/docs/source/ecosystem_overview/architectural.md +++ b/docs/source/ecosystem_overview/architectural.md @@ -6,7 +6,7 @@ If this trend accelerates, all high-level frameworks as they exist today risk be For TPUs to provide a clear path to this level of performance, the ecosystem must expose an API layer that is closer to the hardware, enabling the development of these highly specialized kernels. As this report will detail, the JAX stack is designed to solve this by offering a continuum of abstraction (See Figure 2), from the automated, high-level optimizations of the XLA compiler to the fine-grained, manual control of the Pallas kernel-authoring library. -![][image3] +![](../_static/images/programming_TPUS.svg) **Figure 2: The JAX continuum of abstraction** diff --git a/docs/source/ecosystem_overview/comparative.md b/docs/source/ecosystem_overview/comparative.md index 9e11e6f..456dc5d 100644 --- a/docs/source/ecosystem_overview/comparative.md +++ b/docs/source/ecosystem_overview/comparative.md @@ -4,13 +4,13 @@ The modern Machine Learning landscape offers many excellent, mature toolchains. While many frameworks offer a wide array of features, the JAX AI Stack provides specific, powerful differentiators in key areas of the development lifecycle: -* **A Simpler, More Powerful Developer Experience:** The "chainable gradient transformation paradigm" of [**Optax**](#optax:-composable-gradient-processing-and-optimization-strategies) allows for more powerful and flexible optimization strategies that are declared once, rather than imperatively managed in the training loop.1 At the system level, the "simpler single controller interface" of **Pathways** abstracts away the complexity of multi-pod, multi-slice training, a significant simplification for researchers. -* **Engineered for "Hero-Scale" Resilience:** The JAX stack is designed for extreme-scale training. **Orbax** provides "hero-scale training resilience" features like emergency and multi-tier checkpointing. This is complemented by **Grain**, which offers "full support for reproducibility with deterministic global shuffles and checkpointable data loaders". The ability to atomically checkpoint the data pipeline state (Grain) with the model state (Orbax) is a critical capability for guaranteeing reproducibility in long-running jobs. -* **A Complete, End-to-End Ecosystem:** The stack provides a cohesive, end-to-end solution. Developers can use [**MaxText**](https://maxtext.readthedocs.io/en/latest/) as a SOTA reference for training, [**Tunix**](https://tunix.readthedocs.io/en/latest/) for alignment, and follow a clear, dual-path to production with **vLLM** (for vLLM compatibility) and **NSE** (for native JAX performance). +* **A Simpler, More Powerful Developer Experience:** The "chainable gradient transformation paradigm" of [**Optax**](#optax:composable) allows for more powerful and flexible optimization strategies that are declared once, rather than imperatively managed in the training loop.1 At the system level, the "simpler single controller interface" of **Pathways** abstracts away the complexity of multi-pod, multi-slice training, a significant simplification for researchers. +* **Engineered for "Hero-Scale" Resilience:** The JAX stack is designed for extreme-scale training. **Orbax** provides "hero-scale training resilience" features like emergency and multi-tier checkpointing. This is complemented by **Grain**, which offers "full support for reproducibility with deterministic global shuffles and checkpointable data loaders". The ability to atomically checkpoint the data pipeline state (Grain) with the model state (Orbax) is a critical capability for guaranteeing reproducibility in long-running jobs. +* **A Complete, End-to-End Ecosystem:** The stack provides a cohesive, end-to-end solution. Developers can use [**MaxText**](https://maxtext.readthedocs.io/en/latest/) as a SOTA reference for training, [**Tunix**](https://tunix.readthedocs.io/en/latest/) for alignment, and follow a clear, dual-path to production with **vLLM-TPU** (for vLLM compatibility) and **NSE** (for native JAX performance). While many stacks are vastly similar from a high-level software standpoint, the deciding factor often comes down to **Performance/TCO**, which is where the co-design of JAX and TPUs provides a distinct advantage. This Performance/TCO benefit is a direct result of the "vertical integration across software and TPU hardware". The ability of the **XLA** compiler to fuse operations specifically for the TPU architecture, or for the **XProf** profiler to leverage hardware hooks for \<1% overhead profiling, are tangible benefits of this deep integration. -For organizations adopting this stack, the "full featured nature" of the JAX AI Stack minimizes the cost of migration. For customers employing popular open model architectures, a shift from other frameworks to [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion) is often a matter of setting up config files. Furthermore, the stack's ability to ingest popular checkpoint formats like safetensors allows existing checkpoints to be migrated over without needing costly re-training. +For organizations adopting this stack, the "full featured nature" of the JAX AI Stack minimizes the cost of migration. For customers employing popular open model architectures, a shift from other frameworks to [MaxText](#foundational-model-maxtext-and) is often a matter of setting up config files. Furthermore, the stack's ability to ingest popular checkpoint formats like safetensors allows existing checkpoints to be migrated over without needing costly re-training. The table below provides a mapping of the components provided by the JAX AI stack and their equivalents in other frameworks or libraries. @@ -28,6 +28,8 @@ The table below provides a mapping of the components provided by the JAX AI stac | Post training / tuning | Tunix | VERL, NeMoRL | | Profiling | XProf | PyTorch profiler, NSight systems, NSight Compute | | Foundation model Training | MaxText, MaxDiffusion | NeMo-Megatron, DeepSpeed, TorchTitan | -| LLM inference | vLLM | vLLM, SGLang | +| LLM inference | vLLM-TPU | vLLM, SGLang | | Non-LLM Inference | NSE | Triton Inference Server, RayServe | + +[^8]: Some of the equivalents here are not true 1:1 comparisons because other frameworks draw API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently. diff --git a/docs/source/ecosystem_overview/conclusion.md b/docs/source/ecosystem_overview/conclusion.md index f997df9..2506f32 100644 --- a/docs/source/ecosystem_overview/conclusion.md +++ b/docs/source/ecosystem_overview/conclusion.md @@ -6,21 +6,5 @@ The JAX AI stack offers a compelling and robust solution for training and deploy By building on battle-tested internal systems, the stack has evolved to provide inherent reliability and scalability, enabling users to confidently develop and deploy even the largest models. Its modular and composable design, rooted in the JAX ecosystem philosophy, grants users unparalleled freedom and control, allowing them to tailor the stack to their specific needs without the constraints of a monolithic framework. -With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax-/-tensorstore---large-scale-distributed-checkpointing), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion), vLLM, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production. - -[^1]: Included in the [jax-ai-stack Python package](https://docs.jaxstack.ai/en/latest/install.html) - -[^2]: Included in the [jax-ai-stack Python package](https://docs.jaxstack.ai/en/latest/install.html) - -[^3]: Image diffusion models are a typical example of this and can commonly be divided logically into a separately trained prompt encoder and a diffusion backbone. - -[^4]: We say effectively free since there could be other bottlenecks such as the DMA engines, HBM bandwidth contention etc. that still incur a performance penalty. - -[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup. - -[^6]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens. - -[^7]: This is a well established paradigm and has precedent in the CPU world, where compiled code forms the bulk of the program with developers dropping down to intrinsics or inline assembly to optimize performance critical sections. - -[^8]: Some of the equivalents here are not true 1:1 comparisons because other frameworks draw API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently. +With XLA and Pathways providing a scalable and fault-tolerant base, JAX providing a performant and expressive numerics library, powerful core development libraries like [Flax](https://flax.readthedocs.io/en/stable/), Optax, [Grain](https://google-grain.readthedocs.io/en/latest/), and [Orbax](#orbax:tensorstore), advanced performance tools like Pallas, Tokamax, and Qwix, and a robust application and production layer in [MaxText](#foundational-model-maxtext-and), vLLM-TPU, and NSE, the JAX AI stack provides a durable foundation for users to build on and rapidly bring state-of-the-art research to production. diff --git a/docs/source/ecosystem_overview/core.md b/docs/source/ecosystem_overview/core.md index d78566d..28d28f0 100644 --- a/docs/source/ecosystem_overview/core.md +++ b/docs/source/ecosystem_overview/core.md @@ -1,8 +1,8 @@ ## The Core JAX AI Stack -The core JAX AI Stack consists of four key libraries that provide the foundation for model development: JAX, [Flax](https://flax.readthedocs.io/en/stable/), [Optax](https://optax.readthedocs.io/en/latest/), and [Orbax](https://orbax.readthedocs.io/en/latest/). +The core JAX AI Stack consists of five key libraries that provide the foundation for model development: JAX, [Flax](https://flax.readthedocs.io/en/stable/), [Optax](https://optax.readthedocs.io/en/latest/), [Orbax](https://orbax.readthedocs.io/en/latest/) and [Grain](https://google-grain.readthedocs.io/en/latest/). -### **JAX: A Foundation for Composable, High-Performance Program Transformation** {#jax:-a-foundation-for-composable,-high-performance-program-transformation} +### JAX: A Foundation for Composable, High-Performance Program Transformation [JAX](https://docs.jax.dev/en/latest/) is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale Machine Learning. With its functional programming model and friendly, NumPy-like API, JAX provides a solid foundation for higher-level libraries. @@ -10,50 +10,49 @@ With its compiler-first design, JAX inherently promotes scalability by leveragin These core transformations can be mixed and matched to achieve high performance and scaling of workloads across model size, cluster size, and hardware types: -* **jit**: Just-in-time compilation of Python functions into optimized, fused XLA executables. -* **grad**: Automatic differentiation, supporting forward- and reverse-mode, as well as higher-order derivatives. -* **vmap**: Automatic vectorization, enabling seamless batching and data parallelism without modifying function logic. +* **jit**: Just-in-time compilation of Python functions into optimized, fused XLA executables. +* **grad**: Automatic differentiation, supporting forward- and reverse-mode, as well as higher-order derivatives. +* **vmap**: Automatic vectorization, enabling seamless batching and data parallelism without modifying function logic. * **pmap / shard\_map**: Automatic parallelization across multiple devices (e.g., TPU cores), forming the basis for distributed training. The seamless integration with XLA's GSPMD (General-purpose SPMD) model allows JAX to automatically parallelize computations across large TPU pods with minimal code changes. In most cases, scaling simply requires high-level sharding annotations, a stark contrast to frameworks where scaling may require more manual management of device placement and communication collectives -### **Flax: Flexible Neural Network Authoring and "Model Surgery"** {#flax:-flexible-neural-network-authoring-and-"model-surgery"} - -#### **Flax \- neural network layers** {#flax---neural-network-layers} +### Flax: Flexible Neural Network Authoring and "Model Surgery" [Flax](https://flax.readthedocs.io/en/latest/index.html) is a library designed to simplify the creation, debugging, and analysis of neural networks in JAX. While pure functional API provided by JAX can be used to fully specify and train a ML model, users coming from the PyTorch (or TensorFlow) ecosystem are more used to and comfortable with the object oriented approach of specifying models as a graph of `torch.nn.Modules`. The abstractions provided by [Flax](https://flax.readthedocs.io/en/stable/) allow users to think more in terms of layers rather than functions, making it more developer friendly to an audience who value ergonomics and experimentation ease. [Flax](https://flax.readthedocs.io/en/stable/) also enables config driven model construction systems, such as those present in [MaxText](https://maxtext.readthedocs.io/en/latest/) and AxLearn, which separate out model hyperparameters from layer definition code. With a simple Pythonic API, it allows developers to express models using regular Python objects, while retaining the power and performance of JAX. Flax's NNX API is an evolution of the Flax Linen interface, incorporating lessons learned to offer a more user-friendly interface that remains consistent with the core JAX APIs. Since Flax modules are fully backed by the core JAX APIs, there is no performance penalty associated with defining the model in [Flax](https://flax.readthedocs.io/en/stable/). -##### **Motivation** {#motivation} +#### Motivation JAX’s pure functional API, while powerful, can be complex for new users since it requires all the program state to be explicitly managed by the user. This paradigm can be unfamiliar to developers used to other frameworks. Modern model architectures are often complex with individual portions of the model trained separately and merged to form the final model[^3], in a process commonly referred to as model surgery. Even with decoder-only LLMs which tend to have a straightforward architecture, post training techniques such as LoRA and quantization require the model definition to be easily manipulated allowing parts of the architecture to be modified or even replaced. The Flax NNX library, with its simple yet powerful Pythonic API enables this functionality in a way that is intuitive to the user, reducing the amount of cognitive overhead involved in authoring and training a model. -##### **Design** {#design} +#### Design The [Flax](https://flax.readthedocs.io/en/stable/) NNX library introduces an object oriented model definition system that encapsulates the model and random number generator state internally, reducing the cognitive overhead of the user and provides a familiar experience for those accustomed to frameworks like PyTorch or TensorFlow. By making submodule definitions Pythonic and providing APIs to traverse the module hierarchy, it allows for the model definition to be easily editable programmatically for model introspection and surgery. The [Flax](https://flax.readthedocs.io/en/stable/) NNX APIs are designed to be consistent with the core JAX APIs to allow users to exploit the full expressibility and performance of JAX, with lifted transformations for common operations like sharding, jit and others. Models defined using the NNX APIs can also be adapted to work with functional training loops, allowing the user the flexibility they need while retaining an intuitive object oriented API. -##### **Key Strengths** {#key-strengths} +#### Key Strengths -* **Intuitive object oriented flexible APIs:** Layers are represented as pure Python objects with internal state management, simplifying model construction and training loops, while also advanced model surgery use cases through support for submodule replacement, partial initialization and model hierarchy traversal. +* **Intuitive object oriented flexible APIs:** Layers are represented as pure Python objects with internal state management, simplifying model construction and training loops, while also advanced model surgery use cases through support for submodule replacement, partial initialization and model hierarchy traversal. * **Consistent with Core JAX APIs:** Lifted transformations consistent with core JAX and fully compatible with functional JAX provide the full performance of JAX without sacrificing developer friendliness. -### **Optax: Composable Gradient Processing and Optimization Strategies** {#optax:-composable-gradient-processing-and-optimization-strategies} +(optax:composable)= +### Optax: Composable Gradient Processing and Optimization Strategies [Optax](https://optax.readthedocs.io/en/latest/index.html) is a gradient processing and optimization library for JAX. It is designed to empower model builders by providing building blocks that can be recombined in custom ways in order to train deep learning models amongst other applications. It builds on the capabilities of the core JAX library to provide a well tested high performance library of losses and optimizer functions and associated techniques that can be used to train ML models. -#### Motivation {#motivation-1} +#### Motivation The calculation and minimization of losses is at the core of what enables the training of ML models. With its support for automatic differentiation the core JAX library provides the numeric capabilities to train models, but it does not provide standard implementations of popular optimizers (ex. `RMSProp`, `Adam`) or losses (`CrossEntropy`, `MSE` etc). While it is true that a user could implement these functions by themselves (and some advanced users will choose to do so), a bug in an optimizer implementation would introduce hard to diagnose model quality issues. Rather than having the user implement such critical pieces, [Optax](https://optax.readthedocs.io/en/latest/) provides implementations of these algorithms that are tested for correctness and performance. The field of optimization theory lies squarely in the realm of research, however its central role in training also makes it an indispensable part of training production ML models. A library that serves this role needs to be both flexible enough to accommodate rapid research iterations and also robust and performant enough to be dependable for production model training. It should also provide well tested implementations of state of the art algorithms which match the standard equations. The [Optax](https://optax.readthedocs.io/en/latest/) library, through its modular composable architecture and emphasis on correct readable code is designed to achieve this. -#### Design {#design-1} +#### Design [Optax](https://optax.readthedocs.io/en/latest/) is designed to both enhance research velocity and the transition from research to production by providing readable, well-tested, and efficient implementations of core algorithms. Optax has uses beyond the context of deep learning, however in this context it can be viewed as a collection of well known loss functions, optimization algorithms and gradient transformations implemented in a pure functional fashion in line with the JAX philosophy. The collection of well known [losses](https://optax.readthedocs.io/en/latest/api/losses.html) and [optimizers](https://optax.readthedocs.io/en/latest/api/optimizers.html) enable users to get started with ease and confidence. @@ -81,21 +80,23 @@ for i, (inputs, targets) in enumerate(data_loader): As it can be seen in the example above, setting up an optimizer with a custom learning rate, gradient clipping and gradient accumulation is a simple drop in replacement block of code, compared to PyTorch which forces the user to modify their training loop to directly manage the learning rate scheduler, gradient clipping and gradient accumulation. -#### Key Strengths {#key-strengths-1} +#### Key Strengths -* **Robust Library:** Provides a comprehensive library of losses, optimizers, and algorithms with a focus on correctness and readability. -* **Modular Chainable Transformations:** As shown above, this flexible API allows users to craft powerful, complex optimization strategies declaratively, without modifying the training loop. +* **Robust Library:** Provides a comprehensive library of losses, optimizers, and algorithms with a focus on correctness and readability. +* **Modular Chainable Transformations:** As shown above, this flexible API allows users to craft powerful, complex optimization strategies declaratively, without modifying the training loop. * **Functional and Scalable:** The pure functional implementations integrate seamlessly with JAX's parallelization mechanisms (e.g., pmap), enabling the same code to scale from a single host to large clusters. -### **Orbax / TensorStore \- Large scale distributed checkpointing** {#orbax-/-tensorstore---large-scale-distributed-checkpointing} + +(orbax:tensorstore)= +### Orbax / TensorStore \- Large scale distributed checkpointing [**Orbax**](https://orbax.readthedocs.io/en/latest/) is an any-scale checkpointing library for JAX users backed primarily by [**TensorStore**](https://google.github.io/tensorstore/), a library for efficiently reading and writing multi-dimensional arrays. The two libraries operate at different levels of the stack \- Orbax at the level of ML models and states \- TensorStore at the level of individual arrays. -#### Motivation {#motivation-2} +#### Motivation [Orbax](https://orbax.readthedocs.io/en/latest/), which centers on JAX users and ML checkpointing, aims to reduce the fragmentation of checkpointing implementations across disparate research codebases, increase adoption of important performance features outside the most cutting-edge codebases, and provide a clean, flexible API for novice and advanced users alike. With advanced features like fully asynchronous distributed checkpointing, multi-tier checkpointing and emergency checkpointing, [Orbax](https://orbax.readthedocs.io/en/latest/) enables resilience in the largest of training jobs while also providing a flexible representation for publishing checkpoints. -#### ML Checkpointing vs Generalized Checkpoint/Restore {#ml-checkpointing-vs-generalized-checkpoint/restore} +#### ML Checkpointing vs Generalized Checkpoint/Restore It is worth considering the difference between ML checkpoint systems ([Orbax](https://orbax.readthedocs.io/en/latest/), NeMO-Megatron, Torch Distributed Checkpoint) with generalized checkpoint systems like CRIU. @@ -103,7 +104,7 @@ Systems like CRIU & CRIUgpu behave analogously to VM live migration; they halt t ML checkpoint systems are designed to minimize the amount of time the accelerator is halted by selectively persisting information that cannot be reconstructed. Specifically, this entails persisting model weights, optimizer state, dataloader state and random number generator state, which is a far smaller amount of data. -#### Design {#design-2} +#### Design The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) centers around handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html) (nested containers) of arrays as the standard representation of JAX models. Saving and loading can be synchronous or asynchronous, with saving consisting of blocking and non-blocking phases. A higher-level `Checkpointer` class is provided, which facilitates checkpointing in a training loop, with save intervals, garbage collection, dataset checkpointing, and metadata management. Finally, Orbax provides customization layers for dealing with user-defined checkpointable objects and PyTree leaves. @@ -111,33 +112,36 @@ The storage layer of [Orbax](https://orbax.readthedocs.io/en/latest/index.html) To maximize the utilization of the accelerator, the checkpointing library must minimize the time it halts the training to snapshot the state. This is achieved by overlapping the checkpointing operations with the compute operations as shown in the diagram below. It’s worth noting that asynchronous checkpointing is table-stakes for large workloads and isn’t unique to [Orbax](https://orbax.readthedocs.io/en/latest/index.html). It is also present in other frameworks such as NeMO-Megatron and Torch Distributed Checkpoints. +![](../_static/images/async_checkpointing.svg) + When considering asynchronous checkpointing with non overlapped device-to-host transfers, the amount of time the accelerator is halted is thus a function of the number of model parameters, the size of the parameters and the PCI link speed. Enabling fully overlapped D2H can further reduce this time by overlapping the D2H transfer with the forward pass of the next step. As long as the D2H transfer can complete before the next forward step completes, the checkpoint will become effectively[^4] free. Restarting from an error is similarly bound by two factors, the XLA compilation time and the speed of reading the weights back from storage. XLA compilation caches can make the former insignificant. Reading from storage is hardware dependent \- emergency checkpoints save to ramdisks which are extremely fast, however there is a speed spectrum that ranges from ramdisk to SSD, HDD and GCS. Specific industry-leading performance features have their own design challenges, and merit separate attention: -* [**Async checkpointing**](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): Checkpointing only needs to block accelerator computations while data is being transferred from host to/from accelerator memory. Expensive I/O operations can take place in a background thread meaning save time can be reduced by 95-99% relative to blocking saves. Asynchronous loading is also possible, and can save time on startup, but requires more extensive effort to integrate and has not yet seen widespread adoption. -* [**OCDBT format**](https://orbax.readthedocs.io/en/latest/guides/checkpoint/optimized_checkpointing.html): Most previous checkpointing implementations stored parameters as separate subdirectories, which caused significant overhead for small arrays. TensorStore’s OCDBT format uses an efficient [B+ tree](https://en.wikipedia.org/wiki/B%2B_tree) format, which allows fine-grained control over shard shapes and file sizes that can be tuned to different filesystems and models. The save/load strategy provides scalability to tens of thousands of nodes by ensuring each host independently reads and writes only the relevant pieces of each array. -* [**Restore \+ broadcast**](https://cloud.google.com/blog/products/compute/unlock-faster-workload-start-time-using-orbax-on-jax): Hero-scale training runs replicate the model weights among multiple data-parallel replicas. Orbax provides a load balancing feature that distributes the burden evenly among available replicas when saving. It also leverages fast chip interconnects to avoid redundant reads of the model on different groups of hosts, instead loading on a single primary replica and broadcasting the weights to all other replicas. +* [**Async checkpointing**](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): Checkpointing only needs to block accelerator computations while data is being transferred from host to/from accelerator memory. Expensive I/O operations can take place in a background thread meaning save time can be reduced by 95-99% relative to blocking saves. Asynchronous loading is also possible, and can save time on startup, but requires more extensive effort to integrate and has not yet seen widespread adoption. +* [**OCDBT format**](https://orbax.readthedocs.io/en/latest/guides/checkpoint/optimized_checkpointing.html): Most previous checkpointing implementations stored parameters as separate subdirectories, which caused significant overhead for small arrays. TensorStore’s OCDBT format uses an efficient [B+ tree](https://en.wikipedia.org/wiki/B%2B_tree) format, which allows fine-grained control over shard shapes and file sizes that can be tuned to different filesystems and models. The save/load strategy provides scalability to tens of thousands of nodes by ensuring each host independently reads and writes only the relevant pieces of each array. +* [**Restore \+ broadcast**](https://cloud.google.com/blog/products/compute/unlock-faster-workload-start-time-using-orbax-on-jax): Hero-scale training runs replicate the model weights among multiple data-parallel replicas. Orbax provides a load balancing feature that distributes the burden evenly among available replicas when saving. It also leverages fast chip interconnects to avoid redundant reads of the model on different groups of hosts, instead loading on a single primary replica and broadcasting the weights to all other replicas. * **Emergency checkpointing**: Hero-scale training suffers from frequent interruptions and hardware failures. Checkpointing to persistent RAM disk improves goodput for hero-scale jobs by allowing for increased checkpoint frequency, faster restore times, and improved resiliency, since TPU states may be corrupted on some replicas, but not all. -#### Key Strengths {#key-strengths-2} +#### Key Strengths -* **Widespread adoption:** As checkpoints are a medium for communication of ML artifacts between different codebases and stages of ML development, widespread adoption is an inherent advantage. Currently, Orbax has [\~4 million](https://pypistats.org/packages/orbax-checkpoint) monthly package downloads. -* **Easy to use:** Orbax abstracts away complex technical aspects of checkpointing like async saving, single- vs. multi-controller, checkpoint atomicity, distributed filesystem details, TPU vs. GPU, etc. It condenses use cases into simple, but generalizable APIs (direct-to-path, sequence-of-steps). -* **Flexible:** While Orbax focuses on exposing a simple API surface for the majority of users, additional layers for handling custom checkpointable objects and PyTree nodes allow for flexibility in specialized use cases. +* **Widespread adoption:** As checkpoints are a medium for communication of ML artifacts between different codebases and stages of ML development, widespread adoption is an inherent advantage. Currently, Orbax has [\~4 million](https://pypistats.org/packages/orbax-checkpoint) monthly package downloads. +* **Easy to use:** Orbax abstracts away complex technical aspects of checkpointing like async saving, single- vs. multi-controller, checkpoint atomicity, distributed filesystem details, TPU vs. GPU, etc. It condenses use cases into simple, but generalizable APIs (direct-to-path, sequence-of-steps). +* **Flexible:** While Orbax focuses on exposing a simple API surface for the majority of users, additional layers for handling custom checkpointable objects and PyTree nodes allow for flexibility in specialized use cases. * **Performant and scalable:** Orbax provides a variety of features designed to make checkpointing as fast and as unobtrusive as possible, freeing developers to focus on efficiency in the remainder of the training loop. Scalability to the cutting edge of ML research is a top concern of the library; training runs at a scale of O(10k) nodes currently rely on Orbax. -#### **Grain: Deterministic and Scalable Input Data Pipelines** {#grain:-deterministic-and-scalable-input-data-pipelines} + +### Grain: Deterministic and Scalable Input Data Pipelines [Grain](https://google-grain.readthedocs.io/en/latest/) is a Python library for reading and processing data for training and evaluating JAX models. It is flexible, fast and deterministic and supports advanced features like checkpointing which are essential to successfully training large workloads. It supports popular data formats and storage backends and also provides a flexible API to extend support to user specific formats and backends that are not natively supported. While [Grain](https://google-grain.readthedocs.io/en/latest/) is primarily designed to work with JAX, it is framework independent, does not require JAX to run and can be used with other frameworks as well. -##### **Motivation** {#motivation-7} +#### Motivation Data pipelines form a critical part of the training infrastructure \- they need to be flexible so that common transformations can be expressed efficiently, and performant enough that they are able to keep the accelerators busy at all times. They also need to be able to accommodate multiple storage formats and backends. Due to their higher step times, training large models at scale pose unique additional requirements on the data pipeline beyond those that are required by regular training workloads, primarily focused around determinism and reproducibility[^5]. The [Grain](https://google-grain.readthedocs.io/en/latest/) library is designed with a flexible enough architecture to address all these needs. -##### **Design** {#design-7} +#### Design At the highest level, there are two ways to structure an input pipeline, as a separate cluster of data workers or by co-locating the data workers on the hosts that drive the accelerators. [Grain](https://google-grain.readthedocs.io/en/latest/) chooses the latter for a variety of reasons. @@ -147,10 +151,18 @@ On the API front, with a pure python implementation that supports multiple proce Out of the box, [Grain](https://google-grain.readthedocs.io/en/latest/) supports efficient random access data formats like `ArrayRecord` and `Bagz` alongside other popular data formats such as Parquet and `TFDS`. [Grain](https://google-grain.readthedocs.io/en/latest/) includes support for reading from local file systems as well as reading from GCS by default. Along with supporting popular storage formats and backends, a clean abstraction to the storage layer allows users to easily add support for or wrap their existing data sources to be compatible with the [Grain](https://google-grain.readthedocs.io/en/latest/) library. -##### **Key Strengths** {#key-strengths-7} +#### Key Strengths -* **Deterministic data feeding:** Colocating the data worker with the accelerator and coupling it with a stable global shuffle and [checkpointable iterators](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#checkpointing) allows the model state and data pipeline state to be checkpointed together in a consistent snapshot using [Orbax](https://docs.google.com/document/d/1rS4DGWSbHOX0rZgjv2rV2DcXuBnHvnCKOTAarZiC1Dg/edit?tab=t.0#heading=h.rtje6zr33hjw), enhancing the determinism of the training process. -* **Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline. -* **Extensible support for multiple formats and backends:** An extensible [data sources](https://google-grain.readthedocs.io/en/latest/tutorials/data_sources/index.html) API supports popular storage formats and backends and allows users to easily add support for new formats and backends. +* **Deterministic data feeding:** Colocating the data worker with the accelerator and coupling it with a stable global shuffle and [checkpointable iterators](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#checkpointing) allows the model state and data pipeline state to be checkpointed together in a consistent snapshot using [Orbax](https://orbax.readthedocs.io/en/latest/), enhancing the determinism of the training process. +* **Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline. +* **Extensible support for multiple formats and backends:** An extensible [data sources](https://google-grain.readthedocs.io/en/latest/tutorials/data_sources/index.html) API supports popular storage formats and backends and allows users to easily add support for new formats and backends. * **Powerful debugging interface:** Data pipeline [visualization tools](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_debugging_tutorial.html) and a debug mode allow users to introspect, debug and optimize the performance of their data pipelines. + +[^3]: Image diffusion models are a typical example of this and can commonly be divided logically into a separately trained prompt encoder and a diffusion backbone. + +[^4]: We say effectively free since there could be other bottlenecks such as the DMA engines, HBM bandwidth contention etc. that still incur a performance penalty. + +[^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup. + +[^6]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens. diff --git a/docs/source/ecosystem_overview/extended.md b/docs/source/ecosystem_overview/extended.md index 852f46c..b618bb6 100644 --- a/docs/source/ecosystem_overview/extended.md +++ b/docs/source/ecosystem_overview/extended.md @@ -2,17 +2,17 @@ Beyond the core stack, a rich ecosystem of specialized libraries provides the infrastructure, advanced tools, and application-layer solutions needed for end-to-end ML development. -### **Foundational Infrastructure: Compilers and Runtimes** {#foundational-infrastructure:-compilers-and-runtimes} +### Foundational Infrastructure: Compilers and Runtimes -#### **XLA: The Hardware-Agnostic, Compiler-Centric Engine** {#xla:-the-hardware-agnostic,-compiler-centric-engine} +#### XLA: The Hardware-Agnostic, Compiler-Centric Engine -##### **Motivation** {#motivation-3} +##### Motivation XLA or Accelerated Linear Algebra is our domain specific compiler, which is well integrated into JAX and supports TPU, CPU and GPU hardware devices. From inception, XLA has been designed to be a hardware agnostic code generator targeting TPUs, GPUs, and CPUs. Our compiler-first design is a fundamental architectural choice that creates a durable advantage in a rapidly evolving research landscape. In contrast, the prevailing kernel-centric approach in other ecosystems relies on hand-optimized libraries for performance. While this is highly effective for stable, well-established model architectures, it creates a bottleneck for innovation. When new research introduces novel architectures, the ecosystem must wait for new kernels to be written and optimized. Our compiler-centric design, however, can often generalize to new patterns, providing a high-performance path for cutting-edge research from day one. -##### **Design** {#design-3} +##### Design XLA works by Just-In-Time (JIT) compiling the computation graphs that JAX generates during its tracing process (e.g., when a function is decorated with `@jax.jit`). @@ -20,75 +20,78 @@ This compilation follows a multi-stage pipeline: JAX Computation Graph → High-Level Optimizer (HLO) → Low-Level Optimizer (LLO) → Hardware Code -* **From JAX Graph to HLO**: The captured JAX computation graph is converted into XLA's HLO representation. At this high level, powerful, hardware-agnostic optimizations like operator fusion and efficient memory management are applied. The **StableHLO** dialect serves as a durable, versioned interface for this stage. -* **From HLO to LLO:** After high-level optimizations, hardware-specific backends take over, lowering the HLO representation into a machine-oriented LLO. +* **From JAX Graph to HLO**: The captured JAX computation graph is converted into XLA's HLO representation. At this high level, powerful, hardware-agnostic optimizations like operator fusion and efficient memory management are applied. The **StableHLO** dialect serves as a durable, versioned interface for this stage. +* **From HLO to LLO:** After high-level optimizations, hardware-specific backends take over, lowering the HLO representation into a machine-oriented LLO. * **From LLO to Hardware Code:** The LLO is finally compiled into highly-efficient machine code. For TPUs, this code is bundled as **Very Long Instruction Word (VLIW)** packets that are sent directly to the hardware. For scaling, XLA's design is built around parallelism. It employs algorithms to maximally utilize the matrix multiplication units (MXUs) on a chip. Between chips, XLA uses **SPMD (Single Program Multiple Data)**, a compiler-based parallelization that uses a single program across all devices. This powerful model is exposed through JAX APIs, allowing users to manage data, model, or pipeline parallelism with high-level sharding annotations. For more complex parallelism patterns, **Multiple Program Multiple Data (MPMD)** is also possible, and libraries like `PartIR:MPMD` allow JAX users to provide MPMD annotations as well. -##### **Key strengths** {#key-strengths-3} +##### Key strengths -* **Compilation**: just in time compilation of the computation graph allows for optimizations to memory layout, buffer allocation, and memory management. Alternatives such as kernel based methodologies put that burden on the user. In most cases, XLA can achieve excellent performance without compromising developer velocity. +* **Compilation**: just in time compilation of the computation graph allows for optimizations to memory layout, buffer allocation, and memory management. Alternatives such as kernel based methodologies put that burden on the user. In most cases, XLA can achieve excellent performance without compromising developer velocity. * **Parallelism:** XLA implements several forms of parallelism with SPMD, and this is exposed at the JAX level. This allows for users to express sharding strategies easily, allowing experimentation and scalability of models across thousands of chips. -#### **Pathways: A Unified Runtime for Massive-Scale Distributed Computation** {#pathways:-a-unified-runtime-for-massive-scale-distributed-computation} + +#### Pathways: A Unified Runtime for Massive-Scale Distributed Computation [Pathways](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) offers abstractions for distributed training and inference with built in fault tolerance and recovery, allowing ML researchers to code as if they are using a single, powerful machine. -##### **Motivation** {#motivation-4} +##### Motivation To be able to train and deploy large models, hundreds to thousands of chips are necessary. These chips are spread across numerous racks and host machines. A training job is a large-scale synchronous program that requires all of these chips, and their respective hosts to be working in tandem on XLA computations that have been parallelized (sharded). In the case of large language models, which may need more than tens of thousands of chips, this service must be capable of spanning multiple pods across a data center fabric in addition to using ICI and OCI fabrics within a pod. -##### **Design** {#design-4} +##### Design ML Pathways is the system we use for coordinating distributed computations across hosts and TPU chips. It is designed for scalability and efficiency across hundreds of thousands of accelerators. For large-scale training, it provides a single Python client for multi-slice/multi-pod jobs, [Megascale XLA](https://openxla.org/xprof/megascale_stats) integration, Compilation Service, and Remote Python. It also supports cross-slice parallelism and preemption tolerance, enabling automatic recovery from resource preemptions. Pathways incorporates optimized cross host collectives which enable XLA computation graphs to further extend beyond a single TPU pod. It expands XLA's support for data, model, and pipeline parallelism to work across TPU slice boundaries using DCN by means of integrating a distributed runtime that manages DCN communication with XLA communication primitives. -##### **Key strengths** {#key-strengths-4} +##### Key strengths -* The single-controller architecture, integrated with JAX, is a key abstraction. It allows researchers to explore various sharding and parallelism strategies for training and deployment while scaling to tens of thousands of chips with ease. +* The single-controller architecture, integrated with JAX, is a key abstraction. It allows researchers to explore various sharding and parallelism strategies for training and deployment while scaling to tens of thousands of chips with ease. * Scaling to tens of thousands of chips with ease, allowing exploration of various sharding and parallelism strategies during model research, training and deployment. -### **Advanced Development: Performance, Data, and Efficiency** {#advanced-development:-performance,-data,-and-efficiency} -#### **Pallas: Writing High-Performance Custom Kernels in JAX** {#pallas:-writing-high-performance-custom-kernels-in-jax} +### Advanced Development: Performance, Data, and Efficiency + +#### Pallas: Writing High-Performance Custom Kernels in JAX While JAX is compiler first, there are situations where the user would like to exercise fine grained control over the hardware to achieve maximum performance. Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. It aims to provide precise control over the generated code, combined with the high-level ergonomics of JAX tracing and the `jax.numpy API`. Pallas exposes a grid-based parallelism model where a user-defined kernel function is launched across a multi-dimensional grid of parallel work-groups. It enables explicit management of the memory hierarchy by allowing the user to define how tensors are tiled and transferred between slower, larger memory (e.g., HBM) and faster, smaller on-chip memory (e.g., VMEM on TPU, Shared Memory on GPU), using index maps to associate grid locations with specific data blocks. Pallas can lower the same kernel definition to execute efficiently on both Google's TPUs and various GPUs by compiling kernels into an intermediate representation suitable for the target architecture – Mosaic for TPUs, or utilizing technologies like Triton for the GPU path. With Pallas, users can write high performance kernels that specialize blocks like attention to achieve the best model performance on the target hardware without needing to rely on vendor specific toolkits. -#### **Tokamax: A Curated Library of State-of-the-Art Kernels** {#tokamax:-a-curated-library-of-state-of-the-art-kernels} +#### Tokamax: A Curated Library of State-of-the-Art Kernels If Pallas is the *tool* for authoring kernels, [Tokamax](https://github.com/openxla/tokamax) is a *library* of state-of-the-art custom accelerator kernels supporting both TPUs and GPUs, built on top of JAX and Pallas enabling users to push their hardware to the maximum. It also provides tooling for users to build and autotune their own custom kernels. -##### **Motivation** {#motivation-5} +##### Motivation JAX, with its roots in XLA, is a compiler-first framework, however a narrow set of cases exists where the user needs to take direct control of the hardware to achieve maximum performance[^7]. Custom kernels are critical to squeezing out every last ounce of performance from expensive ML accelerator resources such as TPUs and GPUs. While they are widely employed to enable performant execution of key operators such as Attention, implementing them requires a deep understanding of both the model and the target hardware (micro)architecture. Tokamax provides one authoritative source of curated, well-tested, high-performance kernels, in conjunction with robust shared infrastructure for their development, maintenance, and lifecycle management. Such a library can also act as a reference implementation for users to build on and customize as necessary. This allows users to focus on their modeling efforts without needing to worry about infrastructure. -##### **Design** {#design-5} +##### Design For any given kernel, Tokamax provides a common API that may be backed by multiple implementations. For example, TPU kernels may be implemented either by standard XLA lowering, or explicitly via Pallas/Mosaic-TPU. GPU kernels may be implemented by standard XLA lowering, via Mosaic-GPU, or Triton. By default, it picks the best-known implementation for a given configuration, determined by cached results from periodic autotuning and benchmarking runs, though users may choose specific implementations if desired. New implementations may be added over time to better exploit specific features in new hardware generations for even better performance. A key component of the library, beyond the kernels themselves, is the supporting infrastructure that will help power users choosing to write their own custom kernels. For example, the autotuning infrastructure lets the user define a set of configurable parameters (e.g., tile sizes) that Tokamax can perform an exhaustive sweep on, to determine and cache the best possible tuned settings. Nightly regressions protect users from unexpected performance and numerics issues caused by changes to underlying compiler infrastructure or other dependencies. -##### **Key Strengths** {#key-strengths-5} +##### Key Strengths -* **Seamless developer experience:** A unified, curated, library will provide known-good high-performance implementations of key kernels, with clear expressions of supported hardware generations and expected performance, both programmatically and in documentation. This minimizes fragmentation and churn. -* **Flexibility and lifecycle management:** Users may choose different implementations as desired, even changing them over time if appropriate. For example, if the XLA compiler enhances support for certain operations obviating the need for custom kernels, there is a simple path to deprecation and migration. +* **Seamless developer experience:** A unified, curated, library will provide known-good high-performance implementations of key kernels, with clear expressions of supported hardware generations and expected performance, both programmatically and in documentation. This minimizes fragmentation and churn. +* **Flexibility and lifecycle management:** Users may choose different implementations as desired, even changing them over time if appropriate. For example, if the XLA compiler enhances support for certain operations obviating the need for custom kernels, there is a simple path to deprecation and migration. * **Extensibility:** Users may implement their own kernels, while leveraging well-supported shared infrastructure, allowing them to focus on their value added capabilities and optimizations. Clearly authored standard implementations serve as a starting point for users to learn from and extend. -#### **Qwix: Non-Intrusive, Comprehensive Quantization** {#qwix:-non-intrusive,-comprehensive-quantization} + +#### Qwix: Non-Intrusive, Comprehensive Quantization Qwix is a comprehensive quantization library for the JAX ecosystem, supporting both LLMs and other model types across all stages, including training (QAT, QT, QLoRA) and inference (PTQ), targeting both XLA and on-device runtimes. -##### **Motivation** {#motivation-6} +##### Motivation Existing quantization libraries, particularly in the PyTorch ecosystem, often serve limited purposes (e.g., only PTQ or only QLoRA). This fragmented landscape forces users to switch tools, impeding consistent code usage and precise numerical matching between training and inference. Furthermore, many solutions require substantial model modifications, tightly coupling the model logic to the quantization logic. -##### **Design** {#design-6} +##### Design Qwix's design philosophy emphasizes a comprehensive solution and, critically, **non-intrusive model integration**. It is architected with a hierarchical, extensible design built on reusable functional APIs. @@ -114,130 +117,139 @@ rules = [ quantized_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules)) ``` -##### Key Strengths {#key-strengths-6} +##### Key Strengths -* **Comprehensive Solution:** Qwix is broadly applicable across numerous quantization scenarios, ensuring consistent code usage between training and inference. -* **Non-Intrusive Model Integration:** As the example shows, users can integrate models with a single line of code, without modification. This allows developers to easily sweep hyperparameters over many quantization schemes to find the best quality/performance tradeoff. -* **Federated with Other Libraries:** Qwix seamlessly integrates with the JAX AI stack. For example, Tokamax automatically adapts to use quantized versions of kernels, without additional user code, when the model is quantized with Qwix. +* **Comprehensive Solution:** Qwix is broadly applicable across numerous quantization scenarios, ensuring consistent code usage between training and inference. +* **Non-Intrusive Model Integration:** As the example shows, users can integrate models with a single line of code, without modification. This allows developers to easily sweep hyperparameters over many quantization schemes to find the best quality/performance tradeoff. +* **Federated with Other Libraries:** Qwix seamlessly integrates with the JAX AI stack. For example, Tokamax automatically adapts to use quantized versions of kernels, without additional user code, when the model is quantized with Qwix. * **Research Friendly:** Qwix's foundational APIs and extensible architecture empower researchers to explore new algorithms and facilitate straightforward comparisons with integrated benchmark and evaluation tools. -### **The Application Layer: Training and Alignment** {#the-application-layer:-training-and-alignment} -#### **Foundation Model Training: MaxText and MaxDiffusion** {#foundation-model-training:-maxtext-and-maxdiffusion} +### The Application Layer: Training and Alignment + +(foundational-model-maxtext-and)= +#### Foundation Model Training: MaxText and MaxDiffusion [MaxText](https://maxtext.readthedocs.io/en/latest/) and [MaxDiffusion](https://github.com/AI-Hypercomputer/maxdiffusion) are Google’s flagship LLM and Diffusion model training frameworks, respectively. With a large selection of highly optimized implementations of popular open-weights models, these repositories serve a dual purpose: they function as both a ready-to-go model training codebase and as a reference that foundation model builders can use to build upon. -##### **Motivation** {#motivation-8} +##### Motivation There is rapid growth of interest across the industry in training GenAI models. The popularity of open models has accelerated this trend, providing users with proven architectures. To train and adapt these models, users require high performance, efficiency, scalability to extreme numbers of chips, and clear, understandable code. They need a framework that can adapt to new techniques and target both TPUs and GPUs. [MaxText](https://maxtext.readthedocs.io/en/latest/) and MaxDiffusion are comprehensive solutions designed to fulfill these needs. -##### **Design** {#design-8} +##### Design [MaxText](https://maxtext.readthedocs.io/en/latest/) and MaxDiffusion are foundation model codebases designed with readability and performance in mind. They are structured with well-tested, reusable components: model definitions that leverage custom kernels (like Tokamax) for maximum performance, a training harness for orchestration and monitoring, and a powerful config system that allows users to control details like sharding and quantization (via Qwix) through an intuitive interface. Advanced reliability features like multi-tier checkpointing are incorporated to ensure sustained goodput. They leverage the best-in-class JAX libraries—Qwix, [Tunix](https://tunix.readthedocs.io/en/latest/), [Orbax](https://orbax.readthedocs.io/en/latest/index.html), and [Optax](https://optax.readthedocs.io/en/latest/)—to deliver core capabilities. This allows them to provide robust, scalable infrastructure, reducing development overhead and allowing users to focus on the modeling task. For inference, the model code is shared to enable efficient and scalable serving. -##### **Key Strengths** {#key-strengths-8} +##### Key Strengths -* **Performant by Design:** With training infrastructure set up for high "goodput" (useful throughput) and model implementations optimized for high MFU (Model Flops Utilization), [MaxText](#foundation-model-training:-maxtext-and-maxdiffusion) and MaxDiffusion deliver high performance at scale out of the box -* **Built for Scale:** Leveraging the power of the JAX AI stack (especially Pathways), these frameworks allow users to scale seamlessly from tens of chips to tens of thousands of chips +* **Performant by Design:** With training infrastructure set up for high "goodput" (useful throughput) and model implementations optimized for high MFU (Model Flops Utilization), [MaxText](#foundational-model-maxtext-and) and MaxDiffusion deliver high performance at scale out of the box +* **Built for Scale:** Leveraging the power of the JAX AI stack (especially Pathways), these frameworks allow users to scale seamlessly from tens of chips to tens of thousands of chips * **Solid Base for Foundation Model Builders:** The high-quality, readable implementations serve as a solid starting point for builders to either use as an end-to-end solution or as a reference implementation for their own customizations -#### **Post-Training and Alignment: The Tunix Framework** {#post-training-and-alignment:-the-tunix-framework} + +#### Post-Training and Alignment: The Tunix Framework [Tunix](https://tunix.readthedocs.io/en/latest/) offers state-of-the-art open-source reinforcement learning (RL) algorithms, along with a robust framework and infrastructure, providing a streamlined path for users to experiment with LLM post-training techniques (including Supervised Fine-Tuning (SFT) and alignment) using JAX and TPUs. -##### **Motivation** {#motivation-9} +##### Motivation Post-training is the critical step in unlocking the true power of LLMs. The Reinforcement Learning stage is particularly crucial for developing alignment and reasoning capabilities. While fast-moving open-source development in this area has been prolific, it has been almost exclusively based on PyTorch and GPUs, leaving a fundamental gap for JAX and TPU solutions. [Tunix](https://tunix.readthedocs.io/en/latest/) (Tune-in-JAX) is a high-performance, JAX-native library designed to fill this gap. -##### **Design** {#design-9} +##### Design -![][image4] +![](../_static/images/Tunix_diagram.svg) From a framework perspective, [Tunix](https://tunix.readthedocs.io/en/latest/) enables a state-of-the-art setup that **clearly separates RL algorithms from the infrastructure**. It offers a lightweight, client-like API that hides the complexity of the RL infrastructure, allowing users to develop new algorithms easily. [Tunix](https://tunix.readthedocs.io/en/latest/) provides out-of-the-box solutions for popular algorithms, including PPO, DPO, and others. On the infrastructure side, [Tunix](https://tunix.readthedocs.io/en/latest/) has native integration with Pathways, enabling a single-controller architecture that makes multi-node RL training easily accessible. On the trainer side, [Tunix](https://tunix.readthedocs.io/en/latest/) natively supports parameter-efficient training (e.g., LoRA) and leverages JAX sharding and XLA (GSPMD) to generate a performant compute graph. It supports popular open-source models like Gemma and Llama out of the box. -##### **Key Strengths** {#key-strengths-9} +##### Key Strengths -* **Simplicity:** It provides a high-level, client-like API that abstracts away the complexities of the underlying distributed infrastructure. -* **Developer Efficiency:** Tunix accelerates the R\&D lifecycle with out-of-the-box algorithms and pre-built "recipes," enabling users to get a working model and iterate quickly. +* **Simplicity:** It provides a high-level, client-like API that abstracts away the complexities of the underlying distributed infrastructure. +* **Developer Efficiency:** Tunix accelerates the R\&D lifecycle with out-of-the-box algorithms and pre-built "recipes," enabling users to get a working model and iterate quickly. * **Performance and Scalability:** Tunix enables a highly efficient and horizontally scalable training infrastructure by leveraging Pathways as a single controller on the backend. -### **The Application Layer: Production and Inference** {#the-application-layer:-production-and-inference} +### The Application Layer: Production and Inference A historical challenge for JAX adoption has been the path from research to production. The JAX AI stack now provides a mature, two-pronged production story that offers both ecosystem compatibility and native JAX performance. -#### **High-Performance LLM Inference: The vLLM Solution** {#high-performance-llm-inference:-the-vllm-solution} +#### High-Performance LLM Inference: The vLLM-TPU Solutions -vLLM TPU is Google's high-performance inference stack designed to run PyTorch and JAX native Large Language Models (LLMs) efficiently on Cloud TPUs. It achieves this by natively integrating the popular open-source vLLM framework with Google's JAX and TPU ecosystem. +vLLM-TPU is Google's high-performance inference stack designed to run PyTorch and JAX native Large Language Models (LLMs) efficiently on Cloud TPUs. It achieves this by natively integrating the popular open-source vLLM framework with Google's JAX and TPU ecosystem. -##### **Motivation** {#motivation-10} +##### Motivation -The industry is rapidly evolving, with growing demand for seamless, high-performing, and easy-to-use inference solutions. Users often face significant challenges from complex and inconsistent tooling, subpar performance, and limited model compatibility. The vLLM stack addresses these issues by providing a unified, performant, and intuitive platform. +The industry is rapidly evolving, with growing demand for seamless, high-performing, and easy-to-use inference solutions. Users often face significant challenges from complex and inconsistent tooling, subpar performance, and limited model compatibility. The vLLM-TPU stack addresses these issues by providing a unified, performant, and intuitive platform. -##### **Design** {#design-10} +##### Design -This solution pragmatically extends the vLLM framework, rather than reinventing it. vLLM is a highly optimized open-source LLM serving engine known for its high throughput, achieved via key features like **`PagedAttention`** (which manages KV caches like virtual memory to minimize fragmentation) and **`Continuous Batching`** (which dynamically adds requests to the batch to improve utilization). +This solution pragmatically extends the vLLM framework, rather than reinventing it. vLLM-TPU is a highly optimized open-source LLM serving engine known for its high throughput, achieved via key features like **`PagedAttention`** (which manages KV caches like virtual memory to minimize fragmentation) and **`Continuous Batching`** (which dynamically adds requests to the batch to improve utilization). -vLLM TPU builds on this foundation and develops core components for request handling, scheduling, and memory management. It introduces a **JAX-based backend** that acts as a bridge, translating vLLM's computational graph and memory operations into TPU-executable code. This backend handles device interactions, JAX model execution, and the specifics of managing the KV cache on TPU hardware. It incorporates TPU-specific optimizations, such as efficient attention mechanisms (e.g., leveraging JAX Pallas kernels for Ragged Paged Attention) and quantization, all tailored for the TPU architecture. +vLLM-TPU builds on this foundation and develops core components for request handling, scheduling, and memory management. It introduces a **JAX-based backend** that acts as a bridge, translating vLLM's computational graph and memory operations into TPU-executable code. This backend handles device interactions, JAX model execution, and the specifics of managing the KV cache on TPU hardware. It incorporates TPU-specific optimizations, such as efficient attention mechanisms (e.g., leveraging JAX Pallas kernels for Ragged Paged Attention) and quantization, all tailored for the TPU architecture. -##### **Key Strengths** {#key-strengths-10} +##### Key Strengths -* **Zero Onboarding/Offboarding Cost for Users:** Users can adopt this solution without significant friction. From a user-experience perspective, processing inference requests is identical to on GPUs. The CLI to start the server, accept prompts, and return outputs are all shared. -* **Fully Embrace the Ecosystem:** This approach utilizes and contributes to the vLLM interface and user experience, ensuring compatibility and ease of use. -* **Fungibility between TPUs and GPUs:** The solution works efficiently on both TPUs and GPUs, allowing users flexibility. +* **Zero Onboarding/Offboarding Cost for Users:** Users can adopt this solution without significant friction. From a user-experience perspective, processing inference requests is identical to on GPUs. The CLI to start the server, accept prompts, and return outputs are all shared. +* **Fully Embrace the Ecosystem:** This approach utilizes and contributes to the vLLM interface and user experience, ensuring compatibility and ease of use. +* **Fungibility between TPUs and GPUs:** The solution works efficiently on both TPUs and GPUs, allowing users flexibility. * **Cost Efficient (Best Perf/$):** Optimizes performance to provide the best performance-to-cost ratio for popular models. -#### **JAX-Native Serving: Orbax Serialization and Neptune Serving Engine** {#jax-native-serving:-orbax-serialization-and-neptune-serving-engine} + +#### JAX-Native Serving: Orbax Serialization and Neptune Serving Engine For models other than LLMs, or for users desiring a fully JAX-native pipeline, the Orbax serialization library and Neptune serving engine (NSE) system provide an end-to-end, high-performance serving solution. -##### **Motivation** {#motivation-11} +##### Motivation Historically, JAX models often relied on a circuitous path to production, such as being wrapped in TensorFlow graphs and deployed using TensorFlow serving. This approach introduced significant limitations and inefficiencies, forcing developers to engage with a separate ecosystem and slowing down iteration. A dedicated JAX-native serving system is crucial for sustainability, reduced complexity, and optimized performance. -##### **Design** {#design-11} +##### Design This solution consists of two core components, as illustrated in the diagram below. -1. **Orbax Serialization Library:** Provides user-friendly APIs for serializing JAX models into a new, robust Orbax serialization format. This format is optimized for production deployment. Its core includes: (a) directly representing JAX model computations using **StableHLO**, allowing the computation graph to be represented natively, and (b) leveraging **TensorStore** for storing weights, enabling fast checkpoint loading for serving. +![](../_static/images/serving_orbax_nse.svg) + + +1. **Orbax Serialization Library:** Provides user-friendly APIs for serializing JAX models into a new, robust Orbax serialization format. This format is optimized for production deployment. Its core includes: (a) directly representing JAX model computations using **StableHLO**, allowing the computation graph to be represented natively, and (b) leveraging **TensorStore** for storing weights, enabling fast checkpoint loading for serving. 2. **Neptune Serving Engine (NSE):** This is the accompanying high-performance, flexible serving engine (typically deployed as a C++ binary) designed to natively run JAX models in the Orbax format. NSE offers production-essential capabilities, such as fast model loading, high-throughput concurrent serving with built-in batching, support for multiple model versions, and both single- and multi-host serving (leveraging PJRT and Pathways). -##### **Key Strengths** {#key-strengths-11} +##### Key Strengths -* **JAX Native Serving:** The solution is built natively for JAX, eliminating inter-framework overhead in model serialization and serving. This ensures lightning-fast model loading and optimized execution across CPUs, GPUs, and TPUs. -* **Effortless Production Deployment:** Serialized models provide a **hermetic deployment path** that is unaffected by drift in Python dependencies and enables runtime model integrity checks. This provides a seamless, intuitive path for JAX model productionization. +* **JAX Native Serving:** The solution is built natively for JAX, eliminating inter-framework overhead in model serialization and serving. This ensures lightning-fast model loading and optimized execution across CPUs, GPUs, and TPUs. +* **Effortless Production Deployment:** Serialized models provide a **hermetic deployment path** that is unaffected by drift in Python dependencies and enables runtime model integrity checks. This provides a seamless, intuitive path for JAX model productionization. * **Enhanced Developer Experience:** By eliminating the need for cumbersome framework wrapping, this solution significantly reduces dependencies and system complexity, speeding up iteration for JAX developers. -### **System-Wide Analysis and Profiling** {#system-wide-analysis-and-profiling} +### System-Wide Analysis and Profiling -#### **XProf: Deep, Hardware-Integrated Performance Profiling** {#xprof:-deep,-hardware-integrated-performance-profiling} +#### XProf: Deep, Hardware-Integrated Performance Profiling [XProf](https://openxla.org/xprof) is a profiling and performance analysis tool that provides in-depth visibility into various aspects of ML workload execution, enabling users to debug and optimize performance. It is deeply integrated into both the JAX and TPU ecosystems. -##### **Motivation** {#motivation-12} +##### Motivation On one hand, ML workloads are growing ever more complicated. On the other, there is an explosion of specialized hardware capabilities targeting these workloads. Matching the two effectively to ensure peak performance and efficiency is critical, given the enormous costs of ML infrastructure. This requires deep visibility into both the workload and the hardware, presented in a way that is easily consumable. XProf excels at this. -##### **Design** {#design-12} +##### Design XProf consists of two primary components: collection and analysis. -1. **Collection:** XProf captures information from various sources: annotations in the user’s JAX code, cost models for operations within the XLA compiler, and **purpose-built hardware profiling features within the TPU**. This collection can be triggered programmatically or on-demand, generating a comprehensive event artifact. +1. **Collection:** XProf captures information from various sources: annotations in the user’s JAX code, cost models for operations within the XLA compiler, and **purpose-built hardware profiling features within the TPU**. This collection can be triggered programmatically or on-demand, generating a comprehensive event artifact. 2. **Analysis:** XProf post-processes the collected data and creates a suite of powerful visualizations, accessed via a browser. -##### **Key Strengths** {#key-strengths-12} +##### Key Strengths The true power of XProf comes from its deep integration with the full stack, providing a breadth and depth of analysis that is a tangible benefit of the co-designed JAX/TPU ecosystem. -* **Co-designed with the TPU:** XProf exploits hardware features specifically designed for seamless profile collection, enabling a collection overhead of **less than 1%**. This allows profiling to be a lightweight, iterative part of development. -* **Breadth and Depth of Analysis:** XProf yields deep analysis across multiple axes. Its tools include: - * **Trace Viewer:** An op-by-op timeline view of execution on different hardware units (e.g., TensorCore). - * **HLO Op Profile:** Breaks down the total time spent into different categories of operations. - * **Memory Viewer:** Details memory allocations by different ops during the profiled window. - * **Roofline Analysis:** Helps identify whether specific ops are compute- or memory-bound and how far they are from the hardware’s peak capabilities. +* **Co-designed with the TPU:** XProf exploits hardware features specifically designed for seamless profile collection, enabling a collection overhead of **less than 1%**. This allows profiling to be a lightweight, iterative part of development. +* **Breadth and Depth of Analysis:** XProf yields deep analysis across multiple axes. Its tools include: + * **Trace Viewer:** An op-by-op timeline view of execution on different hardware units (e.g., TensorCore). + * **HLO Op Profile:** Breaks down the total time spent into different categories of operations. + * **Memory Viewer:** Details memory allocations by different ops during the profiled window. + * **Roofline Analysis:** Helps identify whether specific ops are compute- or memory-bound and how far they are from the hardware’s peak capabilities. * **Graph Viewer:** Provides a view into the full HLO graph executed by the hardware. + +[^7]: This is a well established paradigm and has precedent in the CPU world, where compiled code forms the bulk of the program with developers dropping down to intrinsics or inline assembly to optimize performance critical sections. diff --git a/docs/source/ecosystem_overview/modular.md b/docs/source/ecosystem_overview/modular.md index 1a06ec8..50e7597 100644 --- a/docs/source/ecosystem_overview/modular.md +++ b/docs/source/ecosystem_overview/modular.md @@ -2,10 +2,10 @@ The [JAX AI stack](https://jaxstack.ai/) extends the JAX numerical core with a collection of Google-backed composable libraries, evolving it into a robust, end-to-end, open-source platform for Machine Learning at extreme scales. As such, the JAX AI stack consists of a comprehensive and robust ecosystem that addresses the entire ML lifecycle: -* **Industrial-Scale Foundation:** The stack is architected for massive scale, leveraging ML Pathways for orchestrating training across tens of thousands of chips and [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for resilient, high-throughput asynchronous checkpointing, enabling production-grade training of state-of-the-art models. -* **Complete, Production-Ready Toolkit:** It provides a comprehensive set of libraries for the entire development process: Flax for flexible model authoring and "surgery," [Optax](#optax:-composable-gradient-processing-and-optimization-strategies) for composable optimization strategies, and [Grain](https://google-grain.readthedocs.io/en/latest/) for the deterministic data pipelines essential for reproducible large-scale runs. -* **Peak, Specialized Performance:** To achieve maximum hardware utilization, the stack offers specialized libraries including Tokamax for state-of-the-art custom kernels, Qwix for non-intrusive quantization that boosts training and inference speed, and XProf for deep, hardware-integrated performance profiling. -* **Full Path to Production:** The stack provides a seamless transition from research to deployment. This includes [MaxText](https://maxtext.readthedocs.io/en/latest/) as a scalable reference for foundation model training, [Tunix](https://tunix.readthedocs.io/en/latest/) for state-of-the-art reinforcement learning (RL) and alignment, and a unified inference solution via vLLM integration and the native JAX Serving runtime. +* **Industrial-Scale Foundation:** The stack is architected for massive scale, leveraging ML Pathways for orchestrating training across tens of thousands of chips and [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for resilient, high-throughput asynchronous checkpointing, enabling production-grade training of state-of-the-art models. +* **Complete, Production-Ready Toolkit:** It provides a comprehensive set of libraries for the entire development process: Flax for flexible model authoring and "surgery," [Optax](https://optax.readthedocs.io/en/latest/) for composable optimization strategies, and [Grain](https://google-grain.readthedocs.io/en/latest/) for the deterministic data pipelines essential for reproducible large-scale runs. +* **Peak, Specialized Performance:** To achieve maximum hardware utilization, the stack offers specialized libraries including Tokamax for state-of-the-art custom kernels, Qwix for non-intrusive quantization that boosts training and inference speed, and XProf for deep, hardware-integrated performance profiling. +* **Full Path to Production:** The stack provides a seamless transition from research to deployment. This includes [MaxText](https://maxtext.readthedocs.io/en/latest/) as a scalable reference for foundation model training, [Tunix](https://tunix.readthedocs.io/en/latest/) for state-of-the-art reinforcement learning (RL) and alignment, and a unified inference solution via vLLM-TPU integration and the native JAX Serving runtime. The JAX ecosystem philosophy is one of loosely coupled components, each of which does one thing well. Rather than being a monolithic ML framework, JAX itself is narrowly-scoped and focuses on efficient array operations and program transformations. The ecosystem is built upon this core framework to provide a wide array of functionalities, related to both the training of ML models and other types of workloads such as scientific computing. @@ -13,32 +13,35 @@ This system of loosely coupled components hands freedom of choice back to users, The following sections provide a technical overview of the JAX ecosystem, its key features, the design decisions behind them, and how they combine to build a durable platform for modern ML workloads. -**Table 1: The JAX AI Stack and other Ecosystem Components** | Component | Function / Description | | ----- | ----- | -| **JAX AI stack core and components**[^2] | | +| **JAX AI stack core and components**[^1] | | | [**JAX**](https://docs.jax.dev/en/latest/) | Accelerator-oriented array computation and program transformation (JIT, grad, vmap, pmap). | | [**Flax**](https://flax.readthedocs.io/en/stable/) | Flexible neural network authoring library for intuitive model creation and "surgery." | | [**Optax**](https://optax.readthedocs.io/en/latest/) | A library of composable gradient processing and optimization transformations. | | [**Orbax**](https://orbax.readthedocs.io/en/latest/) | "Any-scale" distributed checkpointing library for hero-scale training resilience. | | **JAX Ecosystem \- Infrastructure** | | -| **[XLA](https://openxla.org/)** | Distributed runtime for orchestrating computation across tens of thousands of chips. | -| **[Pathways](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro)** | A JAX extension for writing low-level, high-performance custom kernels in Python. | +| [**XLA**](https://openxla.org/) | Distributed runtime for orchestrating computation across tens of thousands of chips. | +| [**Pathways**](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) | A JAX extension for writing low-level, high-performance custom kernels in Python. | | **JAX Ecosystem \- Adv. Development** | | -| **[Pallas](https://docs.jax.dev/en/latest/pallas/index.html)** | A JAX extension for writing low-level, high-performance custom kernels in Python. | -| **[Tokamax](https://github.com/openxla/tokamax)** | A curated library of state-of-the-art, high-performance custom kernels (e.g., Attention). | -| **[Qwix](https://github.com/google/qwix)** | A comprehensive, non-intrusive library for quantization (PTQ, QAT, QLoRA). | -| **[Grain](https://google-grain.readthedocs.io/en/latest/)** | A scalable, deterministic, and checkpointable input data pipeline library. | +| [**Pallas**](https://docs.jax.dev/en/latest/pallas/index.html) | A JAX extension for writing low-level, high-performance custom kernels in Python. | +| [**Tokamax**](https://github.com/openxla/tokamax) | A curated library of state-of-the-art, high-performance custom kernels (e.g., Attention). | +| [**Qwix**](https://github.com/google/qwix) | A comprehensive, non-intrusive library for quantization (PTQ, QAT, QLoRA). | +| [**Grain**](https://google-grain.readthedocs.io/en/latest/) | A scalable, deterministic, and checkpointable input data pipeline library. | | **JAX Ecosystem \- Application** | | -| **[MaxText / MaxDiffusion](https://maxtext.readthedocs.io/en/latest/)** | Flagship, scalable reference frameworks for training foundation models (LLM, Diffusion). | -| **[Tunix](https://tunix.readthedocs.io/en/latest/index.html)** | A framework for state-of-the-art post-training and alignment (RLHF, DPO). | -| **[vLLM](https://docs.vllm.ai/projects/tpu/en/latest/)** | A high-performance LLM inference solution via native integration of the vLLM framework. | +| [**MaxText** / **MaxDiffusion**](https://maxtext.readthedocs.io/en/latest/) | Flagship, scalable reference frameworks for training foundation models (LLM, Diffusion). | +| [**Tunix**](https://tunix.readthedocs.io/en/latest/index.html) | A framework for state-of-the-art post-training and alignment (RLHF, DPO). | +| [**vLLM-TPU**](https://docs.vllm.ai/projects/tpu/en/latest/) | A high-performance LLM inference solution via native integration of the vLLM framework. | | **Neptune Serving Engine** (coming soon) | JAX Serving Runtime: a high-performance, JAX-native C++ server for non-LLM models. | -| **[XProf](https://openxla.org/xprof)** | A deep, hardware-integrated profiler for system-wide performance analysis. | +| [**XProf**](https://openxla.org/xprof) | A deep, hardware-integrated profiler for system-wide performance analysis. | + +**Table 1: The JAX AI Stack and other Ecosystem Components** + +[^1]: The core components are included in the [`jax-ai-stack` Python package](https://docs.jaxstack.ai/en/latest/install.html). + -![][image1] +![](../_static/images/JAX_ecosystem.svg) **Figure 1: The JAX AI Stack and Ecosystem Components** -**![][image2]** diff --git a/docs/source/ecosystem_overview/the_ecosystem_overview_tr.md b/docs/source/ecosystem_overview/the_ecosystem_overview_tr.md index d21b814..a1725b9 100644 --- a/docs/source/ecosystem_overview/the_ecosystem_overview_tr.md +++ b/docs/source/ecosystem_overview/the_ecosystem_overview_tr.md @@ -16,19 +16,6 @@ conclusion ``` - -[^1]: Image diffusion models are a typical example of this and can commonly be divided logically into a separately trained prompt encoder and a diffusion backbone. - -[^2]: We say effectively free since there could be other bottlenecks such as the DMA engines, HBM bandwidth contention etc. that still incur a performance penalty. - -[^3]: This is a well established paradigm and has precedent in the CPU world, where compiled code forms the bulk of the program with developers dropping down to intrinsics or inline assembly to optimize performance critical sections. - -[^4]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup. - -[^5]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens. - -[^6]: Some of the equivalents here are not true 1:1 comparisons because PyTorch draws API boundaries differently compared to JAX. The list of equivalents is not exhaustive and there are new libraries appearing frequently. - [image1]: [image2]: diff --git a/docs/source/index.rst b/docs/source/index.rst index 47c61bf..04ca430 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,14 +22,6 @@ JAX AI Stack data_loaders pytorch_users -.. toctree:: - :hidden: - :caption: FIXME OVERVIEW - :maxdepth: 1 - - ecosystem_overview/the_ecosystem_overview_tr - - .. toctree:: :hidden: :caption: Example applications @@ -44,6 +36,13 @@ JAX AI Stack JAX_Vision_transformer JAX_time_series_classification +.. toctree:: + :hidden: + :caption: The JAX Ecosystem: A Modular, Scalable, and High-Performance ML Ecosystem + :maxdepth: 1 + + ecosystem_overview/the_ecosystem_overview_tr + .. toctree:: :hidden: :caption: Developer resources From 8619ae538ef1876da1da9809a42ebcc46db8b4d2 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 12 Nov 2025 19:53:57 +0100 Subject: [PATCH 03/14] fix headings, H2->H1 --- .../ecosystem_overview/architectural.md | 2 +- docs/source/ecosystem_overview/comparative.md | 2 +- docs/source/ecosystem_overview/conclusion.md | 2 +- docs/source/ecosystem_overview/core.md | 38 ++++---- docs/source/ecosystem_overview/extended.md | 86 +++++++++---------- docs/source/ecosystem_overview/modular.md | 2 +- 6 files changed, 66 insertions(+), 66 deletions(-) diff --git a/docs/source/ecosystem_overview/architectural.md b/docs/source/ecosystem_overview/architectural.md index e12fc13..9e9ebba 100644 --- a/docs/source/ecosystem_overview/architectural.md +++ b/docs/source/ecosystem_overview/architectural.md @@ -1,4 +1,4 @@ -## The Architectural Imperative: Performance Beyond Frameworks +# The Architectural Imperative: Performance Beyond Frameworks As model architectures converge—for example, on multimodal Mixture-of-Experts (MoE) Transformers—the pursuit of peak performance is leading to the emergence of "Megakernels." A Megakernel is effectively the entire forward pass (or a large portion) of one specific model, hand-coded using a lower-level API like the CUDA SDK on NVIDIA GPUs. This approach achieves maximum hardware utilization by aggressively overlapping compute, memory, and communication. Recent work from the research community has demonstrated that this approach can yield significant throughput gains, over 22% in some cases, for inference on GPUs. This trend is not limited to inference; evidence suggests that some large-scale training efforts have involved low-level hardware control to achieve substantial efficiency gains. diff --git a/docs/source/ecosystem_overview/comparative.md b/docs/source/ecosystem_overview/comparative.md index 456dc5d..47584da 100644 --- a/docs/source/ecosystem_overview/comparative.md +++ b/docs/source/ecosystem_overview/comparative.md @@ -1,4 +1,4 @@ -## A Comparative Perspective: The JAX/TPU Stack as a Compelling Choice +# A Comparative Perspective: The JAX/TPU Stack as a Compelling Choice The modern Machine Learning landscape offers many excellent, mature toolchains. The JAX AI Stack, however, presents a unique and compelling set of advantages for developers focused on large-scale, high-performance ML, stemming directly from its modular design and deep hardware co-design. diff --git a/docs/source/ecosystem_overview/conclusion.md b/docs/source/ecosystem_overview/conclusion.md index 2506f32..dec9c7e 100644 --- a/docs/source/ecosystem_overview/conclusion.md +++ b/docs/source/ecosystem_overview/conclusion.md @@ -1,4 +1,4 @@ -## Conclusion: A Durable, Production-Ready Platform for the Future of AI +# Conclusion: A Durable, Production-Ready Platform for the Future of AI The data provided in the table above draws to a rather simple conclusion \- these stacks have their own strengths and weaknesses in a small number of areas but overall are vastly similar from the software standpoint. Both stacks provide out of the box turnkey solutions for pre-training, post-training adaptation and deployment of foundational models. diff --git a/docs/source/ecosystem_overview/core.md b/docs/source/ecosystem_overview/core.md index 28d28f0..6a43938 100644 --- a/docs/source/ecosystem_overview/core.md +++ b/docs/source/ecosystem_overview/core.md @@ -1,8 +1,8 @@ -## The Core JAX AI Stack +# The Core JAX AI Stack The core JAX AI Stack consists of five key libraries that provide the foundation for model development: JAX, [Flax](https://flax.readthedocs.io/en/stable/), [Optax](https://optax.readthedocs.io/en/latest/), [Orbax](https://orbax.readthedocs.io/en/latest/) and [Grain](https://google-grain.readthedocs.io/en/latest/). -### JAX: A Foundation for Composable, High-Performance Program Transformation +## JAX: A Foundation for Composable, High-Performance Program Transformation [JAX](https://docs.jax.dev/en/latest/) is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale Machine Learning. With its functional programming model and friendly, NumPy-like API, JAX provides a solid foundation for higher-level libraries. @@ -17,42 +17,42 @@ These core transformations can be mixed and matched to achieve high performance The seamless integration with XLA's GSPMD (General-purpose SPMD) model allows JAX to automatically parallelize computations across large TPU pods with minimal code changes. In most cases, scaling simply requires high-level sharding annotations, a stark contrast to frameworks where scaling may require more manual management of device placement and communication collectives -### Flax: Flexible Neural Network Authoring and "Model Surgery" +## Flax: Flexible Neural Network Authoring and "Model Surgery" [Flax](https://flax.readthedocs.io/en/latest/index.html) is a library designed to simplify the creation, debugging, and analysis of neural networks in JAX. While pure functional API provided by JAX can be used to fully specify and train a ML model, users coming from the PyTorch (or TensorFlow) ecosystem are more used to and comfortable with the object oriented approach of specifying models as a graph of `torch.nn.Modules`. The abstractions provided by [Flax](https://flax.readthedocs.io/en/stable/) allow users to think more in terms of layers rather than functions, making it more developer friendly to an audience who value ergonomics and experimentation ease. [Flax](https://flax.readthedocs.io/en/stable/) also enables config driven model construction systems, such as those present in [MaxText](https://maxtext.readthedocs.io/en/latest/) and AxLearn, which separate out model hyperparameters from layer definition code. With a simple Pythonic API, it allows developers to express models using regular Python objects, while retaining the power and performance of JAX. Flax's NNX API is an evolution of the Flax Linen interface, incorporating lessons learned to offer a more user-friendly interface that remains consistent with the core JAX APIs. Since Flax modules are fully backed by the core JAX APIs, there is no performance penalty associated with defining the model in [Flax](https://flax.readthedocs.io/en/stable/). -#### Motivation +### Motivation JAX’s pure functional API, while powerful, can be complex for new users since it requires all the program state to be explicitly managed by the user. This paradigm can be unfamiliar to developers used to other frameworks. Modern model architectures are often complex with individual portions of the model trained separately and merged to form the final model[^3], in a process commonly referred to as model surgery. Even with decoder-only LLMs which tend to have a straightforward architecture, post training techniques such as LoRA and quantization require the model definition to be easily manipulated allowing parts of the architecture to be modified or even replaced. The Flax NNX library, with its simple yet powerful Pythonic API enables this functionality in a way that is intuitive to the user, reducing the amount of cognitive overhead involved in authoring and training a model. -#### Design +### Design The [Flax](https://flax.readthedocs.io/en/stable/) NNX library introduces an object oriented model definition system that encapsulates the model and random number generator state internally, reducing the cognitive overhead of the user and provides a familiar experience for those accustomed to frameworks like PyTorch or TensorFlow. By making submodule definitions Pythonic and providing APIs to traverse the module hierarchy, it allows for the model definition to be easily editable programmatically for model introspection and surgery. The [Flax](https://flax.readthedocs.io/en/stable/) NNX APIs are designed to be consistent with the core JAX APIs to allow users to exploit the full expressibility and performance of JAX, with lifted transformations for common operations like sharding, jit and others. Models defined using the NNX APIs can also be adapted to work with functional training loops, allowing the user the flexibility they need while retaining an intuitive object oriented API. -#### Key Strengths +### Key Strengths * **Intuitive object oriented flexible APIs:** Layers are represented as pure Python objects with internal state management, simplifying model construction and training loops, while also advanced model surgery use cases through support for submodule replacement, partial initialization and model hierarchy traversal. * **Consistent with Core JAX APIs:** Lifted transformations consistent with core JAX and fully compatible with functional JAX provide the full performance of JAX without sacrificing developer friendliness. (optax:composable)= -### Optax: Composable Gradient Processing and Optimization Strategies +## Optax: Composable Gradient Processing and Optimization Strategies [Optax](https://optax.readthedocs.io/en/latest/index.html) is a gradient processing and optimization library for JAX. It is designed to empower model builders by providing building blocks that can be recombined in custom ways in order to train deep learning models amongst other applications. It builds on the capabilities of the core JAX library to provide a well tested high performance library of losses and optimizer functions and associated techniques that can be used to train ML models. -#### Motivation +### Motivation The calculation and minimization of losses is at the core of what enables the training of ML models. With its support for automatic differentiation the core JAX library provides the numeric capabilities to train models, but it does not provide standard implementations of popular optimizers (ex. `RMSProp`, `Adam`) or losses (`CrossEntropy`, `MSE` etc). While it is true that a user could implement these functions by themselves (and some advanced users will choose to do so), a bug in an optimizer implementation would introduce hard to diagnose model quality issues. Rather than having the user implement such critical pieces, [Optax](https://optax.readthedocs.io/en/latest/) provides implementations of these algorithms that are tested for correctness and performance. The field of optimization theory lies squarely in the realm of research, however its central role in training also makes it an indispensable part of training production ML models. A library that serves this role needs to be both flexible enough to accommodate rapid research iterations and also robust and performant enough to be dependable for production model training. It should also provide well tested implementations of state of the art algorithms which match the standard equations. The [Optax](https://optax.readthedocs.io/en/latest/) library, through its modular composable architecture and emphasis on correct readable code is designed to achieve this. -#### Design +### Design [Optax](https://optax.readthedocs.io/en/latest/) is designed to both enhance research velocity and the transition from research to production by providing readable, well-tested, and efficient implementations of core algorithms. Optax has uses beyond the context of deep learning, however in this context it can be viewed as a collection of well known loss functions, optimization algorithms and gradient transformations implemented in a pure functional fashion in line with the JAX philosophy. The collection of well known [losses](https://optax.readthedocs.io/en/latest/api/losses.html) and [optimizers](https://optax.readthedocs.io/en/latest/api/optimizers.html) enable users to get started with ease and confidence. @@ -80,7 +80,7 @@ for i, (inputs, targets) in enumerate(data_loader): As it can be seen in the example above, setting up an optimizer with a custom learning rate, gradient clipping and gradient accumulation is a simple drop in replacement block of code, compared to PyTorch which forces the user to modify their training loop to directly manage the learning rate scheduler, gradient clipping and gradient accumulation. -#### Key Strengths +### Key Strengths * **Robust Library:** Provides a comprehensive library of losses, optimizers, and algorithms with a focus on correctness and readability. * **Modular Chainable Transformations:** As shown above, this flexible API allows users to craft powerful, complex optimization strategies declaratively, without modifying the training loop. @@ -88,15 +88,15 @@ As it can be seen in the example above, setting up an optimizer with a custom le (orbax:tensorstore)= -### Orbax / TensorStore \- Large scale distributed checkpointing +## Orbax / TensorStore \- Large scale distributed checkpointing [**Orbax**](https://orbax.readthedocs.io/en/latest/) is an any-scale checkpointing library for JAX users backed primarily by [**TensorStore**](https://google.github.io/tensorstore/), a library for efficiently reading and writing multi-dimensional arrays. The two libraries operate at different levels of the stack \- Orbax at the level of ML models and states \- TensorStore at the level of individual arrays. -#### Motivation +### Motivation [Orbax](https://orbax.readthedocs.io/en/latest/), which centers on JAX users and ML checkpointing, aims to reduce the fragmentation of checkpointing implementations across disparate research codebases, increase adoption of important performance features outside the most cutting-edge codebases, and provide a clean, flexible API for novice and advanced users alike. With advanced features like fully asynchronous distributed checkpointing, multi-tier checkpointing and emergency checkpointing, [Orbax](https://orbax.readthedocs.io/en/latest/) enables resilience in the largest of training jobs while also providing a flexible representation for publishing checkpoints. -#### ML Checkpointing vs Generalized Checkpoint/Restore +### ML Checkpointing vs Generalized Checkpoint/Restore It is worth considering the difference between ML checkpoint systems ([Orbax](https://orbax.readthedocs.io/en/latest/), NeMO-Megatron, Torch Distributed Checkpoint) with generalized checkpoint systems like CRIU. @@ -104,7 +104,7 @@ Systems like CRIU & CRIUgpu behave analogously to VM live migration; they halt t ML checkpoint systems are designed to minimize the amount of time the accelerator is halted by selectively persisting information that cannot be reconstructed. Specifically, this entails persisting model weights, optimizer state, dataloader state and random number generator state, which is a far smaller amount of data. -#### Design +### Design The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) centers around handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html) (nested containers) of arrays as the standard representation of JAX models. Saving and loading can be synchronous or asynchronous, with saving consisting of blocking and non-blocking phases. A higher-level `Checkpointer` class is provided, which facilitates checkpointing in a training loop, with save intervals, garbage collection, dataset checkpointing, and metadata management. Finally, Orbax provides customization layers for dealing with user-defined checkpointable objects and PyTree leaves. @@ -125,7 +125,7 @@ Specific industry-leading performance features have their own design challenges, * [**Restore \+ broadcast**](https://cloud.google.com/blog/products/compute/unlock-faster-workload-start-time-using-orbax-on-jax): Hero-scale training runs replicate the model weights among multiple data-parallel replicas. Orbax provides a load balancing feature that distributes the burden evenly among available replicas when saving. It also leverages fast chip interconnects to avoid redundant reads of the model on different groups of hosts, instead loading on a single primary replica and broadcasting the weights to all other replicas. * **Emergency checkpointing**: Hero-scale training suffers from frequent interruptions and hardware failures. Checkpointing to persistent RAM disk improves goodput for hero-scale jobs by allowing for increased checkpoint frequency, faster restore times, and improved resiliency, since TPU states may be corrupted on some replicas, but not all. -#### Key Strengths +### Key Strengths * **Widespread adoption:** As checkpoints are a medium for communication of ML artifacts between different codebases and stages of ML development, widespread adoption is an inherent advantage. Currently, Orbax has [\~4 million](https://pypistats.org/packages/orbax-checkpoint) monthly package downloads. * **Easy to use:** Orbax abstracts away complex technical aspects of checkpointing like async saving, single- vs. multi-controller, checkpoint atomicity, distributed filesystem details, TPU vs. GPU, etc. It condenses use cases into simple, but generalizable APIs (direct-to-path, sequence-of-steps). @@ -133,15 +133,15 @@ Specific industry-leading performance features have their own design challenges, * **Performant and scalable:** Orbax provides a variety of features designed to make checkpointing as fast and as unobtrusive as possible, freeing developers to focus on efficiency in the remainder of the training loop. Scalability to the cutting edge of ML research is a top concern of the library; training runs at a scale of O(10k) nodes currently rely on Orbax. -### Grain: Deterministic and Scalable Input Data Pipelines +## Grain: Deterministic and Scalable Input Data Pipelines [Grain](https://google-grain.readthedocs.io/en/latest/) is a Python library for reading and processing data for training and evaluating JAX models. It is flexible, fast and deterministic and supports advanced features like checkpointing which are essential to successfully training large workloads. It supports popular data formats and storage backends and also provides a flexible API to extend support to user specific formats and backends that are not natively supported. While [Grain](https://google-grain.readthedocs.io/en/latest/) is primarily designed to work with JAX, it is framework independent, does not require JAX to run and can be used with other frameworks as well. -#### Motivation +### Motivation Data pipelines form a critical part of the training infrastructure \- they need to be flexible so that common transformations can be expressed efficiently, and performant enough that they are able to keep the accelerators busy at all times. They also need to be able to accommodate multiple storage formats and backends. Due to their higher step times, training large models at scale pose unique additional requirements on the data pipeline beyond those that are required by regular training workloads, primarily focused around determinism and reproducibility[^5]. The [Grain](https://google-grain.readthedocs.io/en/latest/) library is designed with a flexible enough architecture to address all these needs. -#### Design +### Design At the highest level, there are two ways to structure an input pipeline, as a separate cluster of data workers or by co-locating the data workers on the hosts that drive the accelerators. [Grain](https://google-grain.readthedocs.io/en/latest/) chooses the latter for a variety of reasons. @@ -151,7 +151,7 @@ On the API front, with a pure python implementation that supports multiple proce Out of the box, [Grain](https://google-grain.readthedocs.io/en/latest/) supports efficient random access data formats like `ArrayRecord` and `Bagz` alongside other popular data formats such as Parquet and `TFDS`. [Grain](https://google-grain.readthedocs.io/en/latest/) includes support for reading from local file systems as well as reading from GCS by default. Along with supporting popular storage formats and backends, a clean abstraction to the storage layer allows users to easily add support for or wrap their existing data sources to be compatible with the [Grain](https://google-grain.readthedocs.io/en/latest/) library. -#### Key Strengths +### Key Strengths * **Deterministic data feeding:** Colocating the data worker with the accelerator and coupling it with a stable global shuffle and [checkpointable iterators](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#checkpointing) allows the model state and data pipeline state to be checkpointed together in a consistent snapshot using [Orbax](https://orbax.readthedocs.io/en/latest/), enhancing the determinism of the training process. * **Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline. diff --git a/docs/source/ecosystem_overview/extended.md b/docs/source/ecosystem_overview/extended.md index b618bb6..3216e5d 100644 --- a/docs/source/ecosystem_overview/extended.md +++ b/docs/source/ecosystem_overview/extended.md @@ -1,18 +1,18 @@ -## The Extended JAX Ecosystem +# The Extended JAX Ecosystem Beyond the core stack, a rich ecosystem of specialized libraries provides the infrastructure, advanced tools, and application-layer solutions needed for end-to-end ML development. -### Foundational Infrastructure: Compilers and Runtimes +## Foundational Infrastructure: Compilers and Runtimes -#### XLA: The Hardware-Agnostic, Compiler-Centric Engine +### XLA: The Hardware-Agnostic, Compiler-Centric Engine -##### Motivation +#### Motivation XLA or Accelerated Linear Algebra is our domain specific compiler, which is well integrated into JAX and supports TPU, CPU and GPU hardware devices. From inception, XLA has been designed to be a hardware agnostic code generator targeting TPUs, GPUs, and CPUs. Our compiler-first design is a fundamental architectural choice that creates a durable advantage in a rapidly evolving research landscape. In contrast, the prevailing kernel-centric approach in other ecosystems relies on hand-optimized libraries for performance. While this is highly effective for stable, well-established model architectures, it creates a bottleneck for innovation. When new research introduces novel architectures, the ecosystem must wait for new kernels to be written and optimized. Our compiler-centric design, however, can often generalize to new patterns, providing a high-performance path for cutting-edge research from day one. -##### Design +#### Design XLA works by Just-In-Time (JIT) compiling the computation graphs that JAX generates during its tracing process (e.g., when a function is decorated with `@jax.jit`). @@ -28,70 +28,70 @@ For scaling, XLA's design is built around parallelism. It employs algorithms to For more complex parallelism patterns, **Multiple Program Multiple Data (MPMD)** is also possible, and libraries like `PartIR:MPMD` allow JAX users to provide MPMD annotations as well. -##### Key strengths +#### Key strengths * **Compilation**: just in time compilation of the computation graph allows for optimizations to memory layout, buffer allocation, and memory management. Alternatives such as kernel based methodologies put that burden on the user. In most cases, XLA can achieve excellent performance without compromising developer velocity. * **Parallelism:** XLA implements several forms of parallelism with SPMD, and this is exposed at the JAX level. This allows for users to express sharding strategies easily, allowing experimentation and scalability of models across thousands of chips. -#### Pathways: A Unified Runtime for Massive-Scale Distributed Computation +### Pathways: A Unified Runtime for Massive-Scale Distributed Computation [Pathways](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) offers abstractions for distributed training and inference with built in fault tolerance and recovery, allowing ML researchers to code as if they are using a single, powerful machine. -##### Motivation +#### Motivation To be able to train and deploy large models, hundreds to thousands of chips are necessary. These chips are spread across numerous racks and host machines. A training job is a large-scale synchronous program that requires all of these chips, and their respective hosts to be working in tandem on XLA computations that have been parallelized (sharded). In the case of large language models, which may need more than tens of thousands of chips, this service must be capable of spanning multiple pods across a data center fabric in addition to using ICI and OCI fabrics within a pod. -##### Design +#### Design ML Pathways is the system we use for coordinating distributed computations across hosts and TPU chips. It is designed for scalability and efficiency across hundreds of thousands of accelerators. For large-scale training, it provides a single Python client for multi-slice/multi-pod jobs, [Megascale XLA](https://openxla.org/xprof/megascale_stats) integration, Compilation Service, and Remote Python. It also supports cross-slice parallelism and preemption tolerance, enabling automatic recovery from resource preemptions. Pathways incorporates optimized cross host collectives which enable XLA computation graphs to further extend beyond a single TPU pod. It expands XLA's support for data, model, and pipeline parallelism to work across TPU slice boundaries using DCN by means of integrating a distributed runtime that manages DCN communication with XLA communication primitives. -##### Key strengths +#### Key strengths * The single-controller architecture, integrated with JAX, is a key abstraction. It allows researchers to explore various sharding and parallelism strategies for training and deployment while scaling to tens of thousands of chips with ease. * Scaling to tens of thousands of chips with ease, allowing exploration of various sharding and parallelism strategies during model research, training and deployment. -### Advanced Development: Performance, Data, and Efficiency +## Advanced Development: Performance, Data, and Efficiency -#### Pallas: Writing High-Performance Custom Kernels in JAX +### Pallas: Writing High-Performance Custom Kernels in JAX While JAX is compiler first, there are situations where the user would like to exercise fine grained control over the hardware to achieve maximum performance. Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. It aims to provide precise control over the generated code, combined with the high-level ergonomics of JAX tracing and the `jax.numpy API`. Pallas exposes a grid-based parallelism model where a user-defined kernel function is launched across a multi-dimensional grid of parallel work-groups. It enables explicit management of the memory hierarchy by allowing the user to define how tensors are tiled and transferred between slower, larger memory (e.g., HBM) and faster, smaller on-chip memory (e.g., VMEM on TPU, Shared Memory on GPU), using index maps to associate grid locations with specific data blocks. Pallas can lower the same kernel definition to execute efficiently on both Google's TPUs and various GPUs by compiling kernels into an intermediate representation suitable for the target architecture – Mosaic for TPUs, or utilizing technologies like Triton for the GPU path. With Pallas, users can write high performance kernels that specialize blocks like attention to achieve the best model performance on the target hardware without needing to rely on vendor specific toolkits. -#### Tokamax: A Curated Library of State-of-the-Art Kernels +### Tokamax: A Curated Library of State-of-the-Art Kernels If Pallas is the *tool* for authoring kernels, [Tokamax](https://github.com/openxla/tokamax) is a *library* of state-of-the-art custom accelerator kernels supporting both TPUs and GPUs, built on top of JAX and Pallas enabling users to push their hardware to the maximum. It also provides tooling for users to build and autotune their own custom kernels. -##### Motivation +#### Motivation JAX, with its roots in XLA, is a compiler-first framework, however a narrow set of cases exists where the user needs to take direct control of the hardware to achieve maximum performance[^7]. Custom kernels are critical to squeezing out every last ounce of performance from expensive ML accelerator resources such as TPUs and GPUs. While they are widely employed to enable performant execution of key operators such as Attention, implementing them requires a deep understanding of both the model and the target hardware (micro)architecture. Tokamax provides one authoritative source of curated, well-tested, high-performance kernels, in conjunction with robust shared infrastructure for their development, maintenance, and lifecycle management. Such a library can also act as a reference implementation for users to build on and customize as necessary. This allows users to focus on their modeling efforts without needing to worry about infrastructure. -##### Design +#### Design For any given kernel, Tokamax provides a common API that may be backed by multiple implementations. For example, TPU kernels may be implemented either by standard XLA lowering, or explicitly via Pallas/Mosaic-TPU. GPU kernels may be implemented by standard XLA lowering, via Mosaic-GPU, or Triton. By default, it picks the best-known implementation for a given configuration, determined by cached results from periodic autotuning and benchmarking runs, though users may choose specific implementations if desired. New implementations may be added over time to better exploit specific features in new hardware generations for even better performance. A key component of the library, beyond the kernels themselves, is the supporting infrastructure that will help power users choosing to write their own custom kernels. For example, the autotuning infrastructure lets the user define a set of configurable parameters (e.g., tile sizes) that Tokamax can perform an exhaustive sweep on, to determine and cache the best possible tuned settings. Nightly regressions protect users from unexpected performance and numerics issues caused by changes to underlying compiler infrastructure or other dependencies. -##### Key Strengths +#### Key Strengths * **Seamless developer experience:** A unified, curated, library will provide known-good high-performance implementations of key kernels, with clear expressions of supported hardware generations and expected performance, both programmatically and in documentation. This minimizes fragmentation and churn. * **Flexibility and lifecycle management:** Users may choose different implementations as desired, even changing them over time if appropriate. For example, if the XLA compiler enhances support for certain operations obviating the need for custom kernels, there is a simple path to deprecation and migration. * **Extensibility:** Users may implement their own kernels, while leveraging well-supported shared infrastructure, allowing them to focus on their value added capabilities and optimizations. Clearly authored standard implementations serve as a starting point for users to learn from and extend. -#### Qwix: Non-Intrusive, Comprehensive Quantization +### Qwix: Non-Intrusive, Comprehensive Quantization Qwix is a comprehensive quantization library for the JAX ecosystem, supporting both LLMs and other model types across all stages, including training (QAT, QT, QLoRA) and inference (PTQ), targeting both XLA and on-device runtimes. -##### Motivation +#### Motivation Existing quantization libraries, particularly in the PyTorch ecosystem, often serve limited purposes (e.g., only PTQ or only QLoRA). This fragmented landscape forces users to switch tools, impeding consistent code usage and precise numerical matching between training and inference. Furthermore, many solutions require substantial model modifications, tightly coupling the model logic to the quantization logic. -##### Design +#### Design Qwix's design philosophy emphasizes a comprehensive solution and, critically, **non-intrusive model integration**. It is architected with a hierarchical, extensible design built on reusable functional APIs. @@ -117,7 +117,7 @@ rules = [ quantized_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules)) ``` -##### Key Strengths +#### Key Strengths * **Comprehensive Solution:** Qwix is broadly applicable across numerous quantization scenarios, ensuring consistent code usage between training and inference. * **Non-Intrusive Model Integration:** As the example shows, users can integrate models with a single line of code, without modification. This allows developers to easily sweep hyperparameters over many quantization schemes to find the best quality/performance tradeoff. @@ -125,39 +125,39 @@ quantized_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules)) * **Research Friendly:** Qwix's foundational APIs and extensible architecture empower researchers to explore new algorithms and facilitate straightforward comparisons with integrated benchmark and evaluation tools. -### The Application Layer: Training and Alignment +## The Application Layer: Training and Alignment (foundational-model-maxtext-and)= -#### Foundation Model Training: MaxText and MaxDiffusion +### Foundation Model Training: MaxText and MaxDiffusion [MaxText](https://maxtext.readthedocs.io/en/latest/) and [MaxDiffusion](https://github.com/AI-Hypercomputer/maxdiffusion) are Google’s flagship LLM and Diffusion model training frameworks, respectively. With a large selection of highly optimized implementations of popular open-weights models, these repositories serve a dual purpose: they function as both a ready-to-go model training codebase and as a reference that foundation model builders can use to build upon. -##### Motivation +#### Motivation There is rapid growth of interest across the industry in training GenAI models. The popularity of open models has accelerated this trend, providing users with proven architectures. To train and adapt these models, users require high performance, efficiency, scalability to extreme numbers of chips, and clear, understandable code. They need a framework that can adapt to new techniques and target both TPUs and GPUs. [MaxText](https://maxtext.readthedocs.io/en/latest/) and MaxDiffusion are comprehensive solutions designed to fulfill these needs. -##### Design +#### Design [MaxText](https://maxtext.readthedocs.io/en/latest/) and MaxDiffusion are foundation model codebases designed with readability and performance in mind. They are structured with well-tested, reusable components: model definitions that leverage custom kernels (like Tokamax) for maximum performance, a training harness for orchestration and monitoring, and a powerful config system that allows users to control details like sharding and quantization (via Qwix) through an intuitive interface. Advanced reliability features like multi-tier checkpointing are incorporated to ensure sustained goodput. They leverage the best-in-class JAX libraries—Qwix, [Tunix](https://tunix.readthedocs.io/en/latest/), [Orbax](https://orbax.readthedocs.io/en/latest/index.html), and [Optax](https://optax.readthedocs.io/en/latest/)—to deliver core capabilities. This allows them to provide robust, scalable infrastructure, reducing development overhead and allowing users to focus on the modeling task. For inference, the model code is shared to enable efficient and scalable serving. -##### Key Strengths +#### Key Strengths * **Performant by Design:** With training infrastructure set up for high "goodput" (useful throughput) and model implementations optimized for high MFU (Model Flops Utilization), [MaxText](#foundational-model-maxtext-and) and MaxDiffusion deliver high performance at scale out of the box * **Built for Scale:** Leveraging the power of the JAX AI stack (especially Pathways), these frameworks allow users to scale seamlessly from tens of chips to tens of thousands of chips * **Solid Base for Foundation Model Builders:** The high-quality, readable implementations serve as a solid starting point for builders to either use as an end-to-end solution or as a reference implementation for their own customizations -#### Post-Training and Alignment: The Tunix Framework +### Post-Training and Alignment: The Tunix Framework [Tunix](https://tunix.readthedocs.io/en/latest/) offers state-of-the-art open-source reinforcement learning (RL) algorithms, along with a robust framework and infrastructure, providing a streamlined path for users to experiment with LLM post-training techniques (including Supervised Fine-Tuning (SFT) and alignment) using JAX and TPUs. -##### Motivation +#### Motivation Post-training is the critical step in unlocking the true power of LLMs. The Reinforcement Learning stage is particularly crucial for developing alignment and reasoning capabilities. While fast-moving open-source development in this area has been prolific, it has been almost exclusively based on PyTorch and GPUs, leaving a fundamental gap for JAX and TPU solutions. [Tunix](https://tunix.readthedocs.io/en/latest/) (Tune-in-JAX) is a high-performance, JAX-native library designed to fill this gap. -##### Design +#### Design ![](../_static/images/Tunix_diagram.svg) @@ -165,32 +165,32 @@ From a framework perspective, [Tunix](https://tunix.readthedocs.io/en/latest/) e On the infrastructure side, [Tunix](https://tunix.readthedocs.io/en/latest/) has native integration with Pathways, enabling a single-controller architecture that makes multi-node RL training easily accessible. On the trainer side, [Tunix](https://tunix.readthedocs.io/en/latest/) natively supports parameter-efficient training (e.g., LoRA) and leverages JAX sharding and XLA (GSPMD) to generate a performant compute graph. It supports popular open-source models like Gemma and Llama out of the box. -##### Key Strengths +#### Key Strengths * **Simplicity:** It provides a high-level, client-like API that abstracts away the complexities of the underlying distributed infrastructure. * **Developer Efficiency:** Tunix accelerates the R\&D lifecycle with out-of-the-box algorithms and pre-built "recipes," enabling users to get a working model and iterate quickly. * **Performance and Scalability:** Tunix enables a highly efficient and horizontally scalable training infrastructure by leveraging Pathways as a single controller on the backend. -### The Application Layer: Production and Inference +## The Application Layer: Production and Inference A historical challenge for JAX adoption has been the path from research to production. The JAX AI stack now provides a mature, two-pronged production story that offers both ecosystem compatibility and native JAX performance. -#### High-Performance LLM Inference: The vLLM-TPU Solutions +### High-Performance LLM Inference: The vLLM-TPU Solutions vLLM-TPU is Google's high-performance inference stack designed to run PyTorch and JAX native Large Language Models (LLMs) efficiently on Cloud TPUs. It achieves this by natively integrating the popular open-source vLLM framework with Google's JAX and TPU ecosystem. -##### Motivation +#### Motivation The industry is rapidly evolving, with growing demand for seamless, high-performing, and easy-to-use inference solutions. Users often face significant challenges from complex and inconsistent tooling, subpar performance, and limited model compatibility. The vLLM-TPU stack addresses these issues by providing a unified, performant, and intuitive platform. -##### Design +#### Design This solution pragmatically extends the vLLM framework, rather than reinventing it. vLLM-TPU is a highly optimized open-source LLM serving engine known for its high throughput, achieved via key features like **`PagedAttention`** (which manages KV caches like virtual memory to minimize fragmentation) and **`Continuous Batching`** (which dynamically adds requests to the batch to improve utilization). vLLM-TPU builds on this foundation and develops core components for request handling, scheduling, and memory management. It introduces a **JAX-based backend** that acts as a bridge, translating vLLM's computational graph and memory operations into TPU-executable code. This backend handles device interactions, JAX model execution, and the specifics of managing the KV cache on TPU hardware. It incorporates TPU-specific optimizations, such as efficient attention mechanisms (e.g., leveraging JAX Pallas kernels for Ragged Paged Attention) and quantization, all tailored for the TPU architecture. -##### Key Strengths +#### Key Strengths * **Zero Onboarding/Offboarding Cost for Users:** Users can adopt this solution without significant friction. From a user-experience perspective, processing inference requests is identical to on GPUs. The CLI to start the server, accept prompts, and return outputs are all shared. * **Fully Embrace the Ecosystem:** This approach utilizes and contributes to the vLLM interface and user experience, ensuring compatibility and ease of use. @@ -198,15 +198,15 @@ vLLM-TPU builds on this foundation and develops core components for request hand * **Cost Efficient (Best Perf/$):** Optimizes performance to provide the best performance-to-cost ratio for popular models. -#### JAX-Native Serving: Orbax Serialization and Neptune Serving Engine +### JAX-Native Serving: Orbax Serialization and Neptune Serving Engine For models other than LLMs, or for users desiring a fully JAX-native pipeline, the Orbax serialization library and Neptune serving engine (NSE) system provide an end-to-end, high-performance serving solution. -##### Motivation +#### Motivation Historically, JAX models often relied on a circuitous path to production, such as being wrapped in TensorFlow graphs and deployed using TensorFlow serving. This approach introduced significant limitations and inefficiencies, forcing developers to engage with a separate ecosystem and slowing down iteration. A dedicated JAX-native serving system is crucial for sustainability, reduced complexity, and optimized performance. -##### Design +#### Design This solution consists of two core components, as illustrated in the diagram below. @@ -216,30 +216,30 @@ This solution consists of two core components, as illustrated in the diagram bel 1. **Orbax Serialization Library:** Provides user-friendly APIs for serializing JAX models into a new, robust Orbax serialization format. This format is optimized for production deployment. Its core includes: (a) directly representing JAX model computations using **StableHLO**, allowing the computation graph to be represented natively, and (b) leveraging **TensorStore** for storing weights, enabling fast checkpoint loading for serving. 2. **Neptune Serving Engine (NSE):** This is the accompanying high-performance, flexible serving engine (typically deployed as a C++ binary) designed to natively run JAX models in the Orbax format. NSE offers production-essential capabilities, such as fast model loading, high-throughput concurrent serving with built-in batching, support for multiple model versions, and both single- and multi-host serving (leveraging PJRT and Pathways). -##### Key Strengths +#### Key Strengths * **JAX Native Serving:** The solution is built natively for JAX, eliminating inter-framework overhead in model serialization and serving. This ensures lightning-fast model loading and optimized execution across CPUs, GPUs, and TPUs. * **Effortless Production Deployment:** Serialized models provide a **hermetic deployment path** that is unaffected by drift in Python dependencies and enables runtime model integrity checks. This provides a seamless, intuitive path for JAX model productionization. * **Enhanced Developer Experience:** By eliminating the need for cumbersome framework wrapping, this solution significantly reduces dependencies and system complexity, speeding up iteration for JAX developers. -### System-Wide Analysis and Profiling +## System-Wide Analysis and Profiling -#### XProf: Deep, Hardware-Integrated Performance Profiling +### XProf: Deep, Hardware-Integrated Performance Profiling [XProf](https://openxla.org/xprof) is a profiling and performance analysis tool that provides in-depth visibility into various aspects of ML workload execution, enabling users to debug and optimize performance. It is deeply integrated into both the JAX and TPU ecosystems. -##### Motivation +#### Motivation On one hand, ML workloads are growing ever more complicated. On the other, there is an explosion of specialized hardware capabilities targeting these workloads. Matching the two effectively to ensure peak performance and efficiency is critical, given the enormous costs of ML infrastructure. This requires deep visibility into both the workload and the hardware, presented in a way that is easily consumable. XProf excels at this. -##### Design +#### Design XProf consists of two primary components: collection and analysis. 1. **Collection:** XProf captures information from various sources: annotations in the user’s JAX code, cost models for operations within the XLA compiler, and **purpose-built hardware profiling features within the TPU**. This collection can be triggered programmatically or on-demand, generating a comprehensive event artifact. 2. **Analysis:** XProf post-processes the collected data and creates a suite of powerful visualizations, accessed via a browser. -##### Key Strengths +#### Key Strengths The true power of XProf comes from its deep integration with the full stack, providing a breadth and depth of analysis that is a tangible benefit of the co-designed JAX/TPU ecosystem. diff --git a/docs/source/ecosystem_overview/modular.md b/docs/source/ecosystem_overview/modular.md index 50e7597..586bbbc 100644 --- a/docs/source/ecosystem_overview/modular.md +++ b/docs/source/ecosystem_overview/modular.md @@ -1,4 +1,4 @@ -## A Modular, Compiler-First Architecture for Modern AI +# A Modular, Compiler-First Architecture for Modern AI The [JAX AI stack](https://jaxstack.ai/) extends the JAX numerical core with a collection of Google-backed composable libraries, evolving it into a robust, end-to-end, open-source platform for Machine Learning at extreme scales. As such, the JAX AI stack consists of a comprehensive and robust ecosystem that addresses the entire ML lifecycle: From c169b0fa579f0abc938cd30335d2f16f92e3e3fd Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 13 Nov 2025 16:04:07 +0100 Subject: [PATCH 04/14] cross-link the XLA section --- docs/source/ecosystem_overview/core.md | 2 +- docs/source/ecosystem_overview/extended.md | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/ecosystem_overview/core.md b/docs/source/ecosystem_overview/core.md index 6a43938..da5320d 100644 --- a/docs/source/ecosystem_overview/core.md +++ b/docs/source/ecosystem_overview/core.md @@ -6,7 +6,7 @@ The core JAX AI Stack consists of five key libraries that provide the foundation [JAX](https://docs.jax.dev/en/latest/) is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale Machine Learning. With its functional programming model and friendly, NumPy-like API, JAX provides a solid foundation for higher-level libraries. -With its compiler-first design, JAX inherently promotes scalability by leveraging [XLA](https://openxla.org/xla) (see Section \<\*\*\*\*\*\*\>) for aggressive, whole-program analysis, optimization, and hardware targeting. The JAX emphasis on functional programming (i.e., pure functions) makes its core program transformations more tractable and, crucially, composable +With its compiler-first design, JAX inherently promotes scalability by leveraging [XLA](https://openxla.org/xla) (see the [XLA Section](#xla-section)) for aggressive, whole-program analysis, optimization, and hardware targeting. The JAX emphasis on functional programming (i.e., pure functions) makes its core program transformations more tractable and, crucially, composable These core transformations can be mixed and matched to achieve high performance and scaling of workloads across model size, cluster size, and hardware types: diff --git a/docs/source/ecosystem_overview/extended.md b/docs/source/ecosystem_overview/extended.md index 3216e5d..43a0877 100644 --- a/docs/source/ecosystem_overview/extended.md +++ b/docs/source/ecosystem_overview/extended.md @@ -4,11 +4,12 @@ Beyond the core stack, a rich ecosystem of specialized libraries provides the in ## Foundational Infrastructure: Compilers and Runtimes +(xla-section)= ### XLA: The Hardware-Agnostic, Compiler-Centric Engine #### Motivation -XLA or Accelerated Linear Algebra is our domain specific compiler, which is well integrated into JAX and supports TPU, CPU and GPU hardware devices. From inception, XLA has been designed to be a hardware agnostic code generator targeting TPUs, GPUs, and CPUs. +[XLA](https://openxla.org/xla) or Accelerated Linear Algebra is our domain specific compiler, which is well integrated into JAX and supports TPU, CPU and GPU hardware devices. From inception, XLA has been designed to be a hardware agnostic code generator targeting TPUs, GPUs, and CPUs. Our compiler-first design is a fundamental architectural choice that creates a durable advantage in a rapidly evolving research landscape. In contrast, the prevailing kernel-centric approach in other ecosystems relies on hand-optimized libraries for performance. While this is highly effective for stable, well-established model architectures, it creates a bottleneck for innovation. When new research introduces novel architectures, the ecosystem must wait for new kernels to be written and optimized. Our compiler-centric design, however, can often generalize to new patterns, providing a high-performance path for cutting-edge research from day one. From 55e2c0013656f670ba08e8215a87b0b4f2fd22b7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 13 Nov 2025 20:20:06 +0100 Subject: [PATCH 05/14] sync the GDoc edits --- docs/source/ecosystem_overview/core.md | 69 +++++--------------------- 1 file changed, 13 insertions(+), 56 deletions(-) diff --git a/docs/source/ecosystem_overview/core.md b/docs/source/ecosystem_overview/core.md index da5320d..b49c37a 100644 --- a/docs/source/ecosystem_overview/core.md +++ b/docs/source/ecosystem_overview/core.md @@ -19,26 +19,16 @@ The seamless integration with XLA's GSPMD (General-purpose SPMD) model allows JA ## Flax: Flexible Neural Network Authoring and "Model Surgery" -[Flax](https://flax.readthedocs.io/en/latest/index.html) is a library designed to simplify the creation, debugging, and analysis of neural networks in JAX. While pure functional API provided by JAX can be used to fully specify and train a ML model, users coming from the PyTorch (or TensorFlow) ecosystem are more used to and comfortable with the object oriented approach of specifying models as a graph of `torch.nn.Modules`. The abstractions provided by [Flax](https://flax.readthedocs.io/en/stable/) allow users to think more in terms of layers rather than functions, making it more developer friendly to an audience who value ergonomics and experimentation ease. [Flax](https://flax.readthedocs.io/en/stable/) also enables config driven model construction systems, such as those present in [MaxText](https://maxtext.readthedocs.io/en/latest/) and AxLearn, which separate out model hyperparameters from layer definition code. +[Flax](https://flax.readthedocs.io/en/latest/index.html) simplifies the creation, debugging, and analysis of neural networks in JAX by providing an intuitive, object-oriented approach to model building. While JAX's functional API is powerful, Flax offers a more familiar layer-based abstraction for developers accustomed to frameworks like PyTorch, without any performance penalty. -With a simple Pythonic API, it allows developers to express models using regular Python objects, while retaining the power and performance of JAX. Flax's NNX API is an evolution of the Flax Linen interface, incorporating lessons learned to offer a more user-friendly interface that remains consistent with the core JAX APIs. Since Flax modules are fully backed by the core JAX APIs, there is no performance penalty associated with defining the model in [Flax](https://flax.readthedocs.io/en/stable/). +This design simplifies modern ML practices like "model surgery"—the process of modifying or combining trained model components. Techniques such as LoRA and quantization require easily manipulable model definitions, which Flax's NNX API provides through a simple, Pythonic interface. NNX encapsulates model state, reducing user cognitive load, and allows for programmatic traversal and modification of the model hierarchy. -### Motivation - -JAX’s pure functional API, while powerful, can be complex for new users since it requires all the program state to be explicitly managed by the user. This paradigm can be unfamiliar to developers used to other frameworks. Modern model architectures are often complex with individual portions of the model trained separately and merged to form the final model[^3], in a process commonly referred to as model surgery. Even with decoder-only LLMs which tend to have a straightforward architecture, post training techniques such as LoRA and quantization require the model definition to be easily manipulated allowing parts of the architecture to be modified or even replaced. +### Key Strengths: -The Flax NNX library, with its simple yet powerful Pythonic API enables this functionality in a way that is intuitive to the user, reducing the amount of cognitive overhead involved in authoring and training a model. - -### Design +* Intuitive Object-Oriented API: Simplifies model construction and enables advanced use cases like submodule replacement and partial initialization. -The [Flax](https://flax.readthedocs.io/en/stable/) NNX library introduces an object oriented model definition system that encapsulates the model and random number generator state internally, reducing the cognitive overhead of the user and provides a familiar experience for those accustomed to frameworks like PyTorch or TensorFlow. By making submodule definitions Pythonic and providing APIs to traverse the module hierarchy, it allows for the model definition to be easily editable programmatically for model introspection and surgery. - -The [Flax](https://flax.readthedocs.io/en/stable/) NNX APIs are designed to be consistent with the core JAX APIs to allow users to exploit the full expressibility and performance of JAX, with lifted transformations for common operations like sharding, jit and others. Models defined using the NNX APIs can also be adapted to work with functional training loops, allowing the user the flexibility they need while retaining an intuitive object oriented API. - -### Key Strengths +* Consistent with Core JAX: Flax provides lifted transformations that are fully compatible with JAX's functional paradigm, offering the full performance of JAX with enhanced developer friendliness. -* **Intuitive object oriented flexible APIs:** Layers are represented as pure Python objects with internal state management, simplifying model construction and training loops, while also advanced model surgery use cases through support for submodule replacement, partial initialization and model hierarchy traversal. -* **Consistent with Core JAX APIs:** Lifted transformations consistent with core JAX and fully compatible with functional JAX provide the full performance of JAX without sacrificing developer friendliness. (optax:composable)= @@ -90,47 +80,18 @@ As it can be seen in the example above, setting up an optimizer with a custom le (orbax:tensorstore)= ## Orbax / TensorStore \- Large scale distributed checkpointing -[**Orbax**](https://orbax.readthedocs.io/en/latest/) is an any-scale checkpointing library for JAX users backed primarily by [**TensorStore**](https://google.github.io/tensorstore/), a library for efficiently reading and writing multi-dimensional arrays. The two libraries operate at different levels of the stack \- Orbax at the level of ML models and states \- TensorStore at the level of individual arrays. +[Orbax](https://orbax.readthedocs.io/en/latest/) is a checkpointing library for JAX designed for any scale, from single-device to large-scale distributed training. It aims to unify fragmented checkpointing implementations and deliver critical performance features, such as asynchronous and multi-tier checkpointing, to a wider audience. Orbax enables the resilience required for massive training jobs and provides a flexible format for publishing checkpoints. -### Motivation - -[Orbax](https://orbax.readthedocs.io/en/latest/), which centers on JAX users and ML checkpointing, aims to reduce the fragmentation of checkpointing implementations across disparate research codebases, increase adoption of important performance features outside the most cutting-edge codebases, and provide a clean, flexible API for novice and advanced users alike. With advanced features like fully asynchronous distributed checkpointing, multi-tier checkpointing and emergency checkpointing, [Orbax](https://orbax.readthedocs.io/en/latest/) enables resilience in the largest of training jobs while also providing a flexible representation for publishing checkpoints. - -### ML Checkpointing vs Generalized Checkpoint/Restore +Unlike generalized checkpoint/restore systems that snapshot the entire system state, ML checkpointing with Orbax selectively persists only the information essential for resuming training—model weights, optimizer state, and data loader state. This targeted approach minimizes accelerator downtime. Orbax achieves this by overlapping I/O operations with computation, a critical feature for large workloads. The time accelerators are halted is thus reduced to the duration of the device-to-host data transfer, which can be further overlapped with the next training step, making checkpointing nearly free from a performance perspective. +At its core, Orbax uses [TensorStore](https://google.github.io/tensorstore/) for efficient, parallel reading and writing of array data. The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) abstracts this complexity, offering a user-friendly interface for handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html), which are the standard representation of models in JAX. -It is worth considering the difference between ML checkpoint systems ([Orbax](https://orbax.readthedocs.io/en/latest/), NeMO-Megatron, Torch Distributed Checkpoint) with generalized checkpoint systems like CRIU. +### Key Strengths: -Systems like CRIU & CRIUgpu behave analogously to VM live migration; they halt the entire system and take a snapshot of every last bit of information so it can be faithfully reconstructed. This captures the entirety of the process’ host memory, device memory and operating system state. This is far more information that is actually needed to reconstruct a ML workload, since for a ML workload, a very large fraction of this information (activations, data examples, file handles) is trivially reconstructed. Capturing this much data also incurs a large amount of time when the job is halted. - -ML checkpoint systems are designed to minimize the amount of time the accelerator is halted by selectively persisting information that cannot be reconstructed. Specifically, this entails persisting model weights, optimizer state, dataloader state and random number generator state, which is a far smaller amount of data. - -### Design +* [Widespread Adoption](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): With millions of monthly downloads, Orbax serves as a common medium for sharing ML artifacts. +* Easy to Use: Orbax abstracts away the complexities of distributed checkpointing, including asynchronous saving, atomicity, and filesystem details. +* Flexible: While offering simple APIs for common use cases, Orbax allows for customization to handle specialized requirements. +* Performant and Scalable: Features like asynchronous checkpointing, an efficient storage format ([OCDBT](https://orbax.readthedocs.io/en/latest/guides/checkpoint/optimized_checkpointing.html)), and intelligent data loading strategies ensure that Orbax scales to training runs involving tens of thousands of nodes. -The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) centers around handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html) (nested containers) of arrays as the standard representation of JAX models. Saving and loading can be synchronous or asynchronous, with saving consisting of blocking and non-blocking phases. A higher-level `Checkpointer` class is provided, which facilitates checkpointing in a training loop, with save intervals, garbage collection, dataset checkpointing, and metadata management. Finally, Orbax provides customization layers for dealing with user-defined checkpointable objects and PyTree leaves. - -The storage layer of [Orbax](https://orbax.readthedocs.io/en/latest/index.html) is the [TensorStore](https://google.github.io/tensorstore/) library, which is not technically part of the JAX ecosystem at all, and seeks to provide a flexible and highly versatile library for array storage. However, it is not designed around ML concepts and introduces too much complexity and manual management for most JAX users. [Orbax](https://orbax.readthedocs.io/en/latest/index.html) smooths out this experience to provide users an easy to use ML specific API surface. - -To maximize the utilization of the accelerator, the checkpointing library must minimize the time it halts the training to snapshot the state. This is achieved by overlapping the checkpointing operations with the compute operations as shown in the diagram below. It’s worth noting that asynchronous checkpointing is table-stakes for large workloads and isn’t unique to [Orbax](https://orbax.readthedocs.io/en/latest/index.html). It is also present in other frameworks such as NeMO-Megatron and Torch Distributed Checkpoints. - -![](../_static/images/async_checkpointing.svg) - -When considering asynchronous checkpointing with non overlapped device-to-host transfers, the amount of time the accelerator is halted is thus a function of the number of model parameters, the size of the parameters and the PCI link speed. Enabling fully overlapped D2H can further reduce this time by overlapping the D2H transfer with the forward pass of the next step. As long as the D2H transfer can complete before the next forward step completes, the checkpoint will become effectively[^4] free. - -Restarting from an error is similarly bound by two factors, the XLA compilation time and the speed of reading the weights back from storage. XLA compilation caches can make the former insignificant. Reading from storage is hardware dependent \- emergency checkpoints save to ramdisks which are extremely fast, however there is a speed spectrum that ranges from ramdisk to SSD, HDD and GCS. - -Specific industry-leading performance features have their own design challenges, and merit separate attention: - -* [**Async checkpointing**](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html): Checkpointing only needs to block accelerator computations while data is being transferred from host to/from accelerator memory. Expensive I/O operations can take place in a background thread meaning save time can be reduced by 95-99% relative to blocking saves. Asynchronous loading is also possible, and can save time on startup, but requires more extensive effort to integrate and has not yet seen widespread adoption. -* [**OCDBT format**](https://orbax.readthedocs.io/en/latest/guides/checkpoint/optimized_checkpointing.html): Most previous checkpointing implementations stored parameters as separate subdirectories, which caused significant overhead for small arrays. TensorStore’s OCDBT format uses an efficient [B+ tree](https://en.wikipedia.org/wiki/B%2B_tree) format, which allows fine-grained control over shard shapes and file sizes that can be tuned to different filesystems and models. The save/load strategy provides scalability to tens of thousands of nodes by ensuring each host independently reads and writes only the relevant pieces of each array. -* [**Restore \+ broadcast**](https://cloud.google.com/blog/products/compute/unlock-faster-workload-start-time-using-orbax-on-jax): Hero-scale training runs replicate the model weights among multiple data-parallel replicas. Orbax provides a load balancing feature that distributes the burden evenly among available replicas when saving. It also leverages fast chip interconnects to avoid redundant reads of the model on different groups of hosts, instead loading on a single primary replica and broadcasting the weights to all other replicas. -* **Emergency checkpointing**: Hero-scale training suffers from frequent interruptions and hardware failures. Checkpointing to persistent RAM disk improves goodput for hero-scale jobs by allowing for increased checkpoint frequency, faster restore times, and improved resiliency, since TPU states may be corrupted on some replicas, but not all. - -### Key Strengths - -* **Widespread adoption:** As checkpoints are a medium for communication of ML artifacts between different codebases and stages of ML development, widespread adoption is an inherent advantage. Currently, Orbax has [\~4 million](https://pypistats.org/packages/orbax-checkpoint) monthly package downloads. -* **Easy to use:** Orbax abstracts away complex technical aspects of checkpointing like async saving, single- vs. multi-controller, checkpoint atomicity, distributed filesystem details, TPU vs. GPU, etc. It condenses use cases into simple, but generalizable APIs (direct-to-path, sequence-of-steps). -* **Flexible:** While Orbax focuses on exposing a simple API surface for the majority of users, additional layers for handling custom checkpointable objects and PyTree nodes allow for flexibility in specialized use cases. -* **Performant and scalable:** Orbax provides a variety of features designed to make checkpointing as fast and as unobtrusive as possible, freeing developers to focus on efficiency in the remainder of the training loop. Scalability to the cutting edge of ML research is a top concern of the library; training runs at a scale of O(10k) nodes currently rely on Orbax. ## Grain: Deterministic and Scalable Input Data Pipelines @@ -159,10 +120,6 @@ Out of the box, [Grain](https://google-grain.readthedocs.io/en/latest/) supports * **Powerful debugging interface:** Data pipeline [visualization tools](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_debugging_tutorial.html) and a debug mode allow users to introspect, debug and optimize the performance of their data pipelines. -[^3]: Image diffusion models are a typical example of this and can commonly be divided logically into a separately trained prompt encoder and a diffusion backbone. - -[^4]: We say effectively free since there could be other bottlenecks such as the DMA engines, HBM bandwidth contention etc. that still incur a performance penalty. - [^5]: In the Section 5.1 of the [Palm paper](https://dl.acm.org/doi/10.5555/3648699.3648939), the authors note that they observed very large loss spikes despite having gradient clipping enabled and the solution was to remove the offending data batches and restart training from a checkpoint before the loss spike. This is only possible with a fully deterministic and reproducible training setup. [^6]: This is indeed how multimodal data pipelines would need to operate \- image and audio tokenizers for example are models themselves which run in their own clusters on their own accelerators and the input pipelines would make RPCs out to convert data examples into streams of tokens. From efef3dbaad1b775fc31c26b0b39e4da27aa1090d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Melissa=20Weber=20Mendon=C3=A7a?= Date: Thu, 13 Nov 2025 16:38:00 -0300 Subject: [PATCH 06/14] Add link to TR in Learn navbar dropdown --- docs/source/_templates/navbar-top.html | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/_templates/navbar-top.html b/docs/source/_templates/navbar-top.html index 475d322..17eb6a8 100644 --- a/docs/source/_templates/navbar-top.html +++ b/docs/source/_templates/navbar-top.html @@ -7,6 +7,7 @@