Skip to content

Commit 6f93cc1

Browse files
authored
[Backport] Update Pallas user guide (#6965)
1 parent b94c014 commit 6f93cc1

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

docs/pallas.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Custom Kernels via Pallas
2+
3+
With the rise of OpenAI [triton](https://openai.com/research/triton), custom kernels become more and more popular in the GPU community, for instance, the introduction of [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html). In order to provide the feature parity in the TPU world, Google has introduced [Pallas](http://go/jax-pallas) and [Mosaic](http://go/mosaic-tpu). For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas and Mosaic. The design doc is [TBA]().
4+
5+
Let's assume you have a Pallas kernel defined as follow:
6+
```python3
7+
import jax
8+
from jax.experimental import pallas as pl
9+
import jax.numpy as jnp
10+
11+
def add_vectors_kernel(x_ref, y_ref, o_ref):
12+
x, y = x_ref[...], y_ref[...]
13+
o_ref[...] = x + y
14+
15+
@jax.jit
16+
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
17+
return pl.pallas_call(add_vectors_kernel,
18+
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
19+
)(x, y)
20+
```
21+
22+
## Adopt the above kernel to be compatible with PyTorch/XLA
23+
24+
Example usage:
25+
```python3
26+
q = torch.randn(3, 2, 128, 4).to("xla")
27+
k = torch.randn(3, 2, 128, 4).to("xla")
28+
v = torch.randn(3, 2, 128, 4).to("xla")
29+
30+
# Adopts any Pallas kernel
31+
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
32+
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
33+
output = pt_kernel(q, k)
34+
```
35+
For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details.
36+
37+
## Use built-in kernels
38+
39+
Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already.
40+
41+
Example usage:
42+
```python3
43+
# Use built-in kernels
44+
from torch_xla.experimental.custom_kernel import flash_attention
45+
output = flash_attention(q, k, v)
46+
```
47+
48+
You can just use it like any other torch.ops.
49+
50+
## HuggingFace Llama 3 Example
51+
We have a fork of HF Llama 3 to demonstrate a potential integration [here](https://github.com/pytorch-tpu/transformers/tree/alanwaketan/flash_attention).
52+
53+
## Dependencies
54+
The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX:
55+
```bash
56+
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
57+
```

0 commit comments

Comments
 (0)