|
| 1 | +# Custom Kernels via Pallas |
| 2 | + |
| 3 | +With the rise of OpenAI [triton](https://openai.com/research/triton), custom kernels become more and more popular in the GPU community, for instance, the introduction of [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html). In order to provide the feature parity in the TPU world, Google has introduced [Pallas](http://go/jax-pallas) and [Mosaic](http://go/mosaic-tpu). For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas and Mosaic. The design doc is [TBA](). |
| 4 | + |
| 5 | +Let's assume you have a Pallas kernel defined as follow: |
| 6 | +```python3 |
| 7 | +import jax |
| 8 | +from jax.experimental import pallas as pl |
| 9 | +import jax.numpy as jnp |
| 10 | + |
| 11 | +def add_vectors_kernel(x_ref, y_ref, o_ref): |
| 12 | + x, y = x_ref[...], y_ref[...] |
| 13 | + o_ref[...] = x + y |
| 14 | + |
| 15 | +@jax.jit |
| 16 | +def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: |
| 17 | + return pl.pallas_call(add_vectors_kernel, |
| 18 | + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) |
| 19 | + )(x, y) |
| 20 | +``` |
| 21 | + |
| 22 | +## Adopt the above kernel to be compatible with PyTorch/XLA |
| 23 | + |
| 24 | +Example usage: |
| 25 | +```python3 |
| 26 | +q = torch.randn(3, 2, 128, 4).to("xla") |
| 27 | +k = torch.randn(3, 2, 128, 4).to("xla") |
| 28 | +v = torch.randn(3, 2, 128, 4).to("xla") |
| 29 | + |
| 30 | +# Adopts any Pallas kernel |
| 31 | +from torch_xla.experimental.custom_kernel import make_kernel_from_pallas |
| 32 | +pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)]) |
| 33 | +output = pt_kernel(q, k) |
| 34 | +``` |
| 35 | +For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details. |
| 36 | + |
| 37 | +## Use built-in kernels |
| 38 | + |
| 39 | +Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already. |
| 40 | + |
| 41 | +Example usage: |
| 42 | +```python3 |
| 43 | +# Use built-in kernels |
| 44 | +from torch_xla.experimental.custom_kernel import flash_attention |
| 45 | +output = flash_attention(q, k, v) |
| 46 | +``` |
| 47 | + |
| 48 | +You can just use it like any other torch.ops. |
| 49 | + |
| 50 | +## HuggingFace Llama 3 Example |
| 51 | +We have a fork of HF Llama 3 to demonstrate a potential integration [here](https://github.com/pytorch-tpu/transformers/tree/alanwaketan/flash_attention). |
| 52 | + |
| 53 | +## Dependencies |
| 54 | +The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX: |
| 55 | +```bash |
| 56 | +pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html |
| 57 | +``` |
0 commit comments