Programming for accelerators such as GPUs is critical for modern AI systems. This often means programming directly in proprietary low-level languages such as CUDA. Triton is an alternative open-source language that allows you to code at a higher-level and compile to accelerators like GPU.
Coding for Triton is very similar to Numpy and PyTorch in both syntax and semantics. However, as a lower-level language there are a lot of details that you need to keep track of. In particular, one area that learners have trouble with is memory loading and storage which is critical for speed on low-level devices.
This set is puzzles is meant to teach you how to use Triton from first principles in an interactive fashion. You will start with trivial examples and build your way up to real algorithms like Flash Attention and Quantized neural networks. These puzzles do not need to run on GPU since they use a Triton interpreter.
To begin with, we will only use tl.load
and tl.store
in order to build simple programs.
Here's an example of load. It takes an arange
over the memory. By default the indexing of torch tensors with column, rows, depths or right-to-left. It also takes in a mask as the second argument. Mask is critically important because all shapes in Triton need to be powers of two.
You can also use this trick to read in a 2d array.
The tl.store
function is quite similar. It allows you to write to a tensor.
You can only load in relatively small blocks
at a time in Triton. to work with larger tensors you need to use a program id axis to run multiple blocks in parallel. Here is an example with one program axis with 3 blocks. You can use the visualizer to scroll over it.
See the Triton Docs for further information.
Add a constant to a vector. Uses one program id axis. Block size B0
is always the same as vector x
with length N0
.
Add a constant to a vector. Uses one program block axis (no for
loops yet). Block size B0
is now smaller than the shape vector x
which is N0
.
Add two vectors.
Uses one program block axis. Block size B0
is always the same as vector x
length N0
.
Block size B1
is always the same as vector y
length N1
.
Add a row vector to a column vector.
Uses two program block axes. Block size B0
is always less than the vector x
length N0
.
Block size B1
is always less than vector y
length N1
.
Multiply a row vector to a column vector and take a relu.
Uses two program block axes. Block size B0
is always less than the vector x
length N0
.
Block size B1
is always less than vector y
length N1
.
Backwards of a function that multiplies a matrix with a row vector and take a relu.
Uses two program blocks. Block size B0
is always less than the vector x
length N0
.
Block size B1
is always less than vector y
length N1
. Chain rule backward dz
is of shape N1
by N0
$$dx_{j, i} = f_x'(x, y){j, i} \times dz{j, i}$$
Sum of a batch of numbers.
Uses one program blocks. Block size B0
represents a range of batches of x
of length N0
.
Each element is of length T
. Process it B1 < T
elements at a time.
$$z_{i} = \sum^{T}j x{i,j} = \text{ for } i = 1\ldots N_0$$
Hint: You will need a for loop for this problem. These work and look the same as in Python.
Softmax of a batch of logits.
Uses one program block axis. Block size B0
represents the batch of x
of length N0
.
Block logit length T
. Process it B1 < T
elements at a time.
Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton they recommend not using exp
but instead using exp2
. You need the identity
Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful:
A scalar version of FlashAttention.
Uses zero programs. Block size B0
represent the batches of q
to process out of N0
. Sequence length is T
. Process it B1 < T
elements (k
, v
) at a time for some B1
.
$$z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)j v{j} \text{ for } i = 1\ldots N_0$$
This can be done in 1 loop using a similar trick from the last puzzle.
Hint: Use tl.where
to mask q dot k
to -inf to avoid overflow (NaN).
A batched 2D convolution.
Uses one program id axis. Block size B0
represent the batches to process out of N0
.
Image x
is size is H
by W
with only 1 channel, and kernel k
is size KH
by KW
.
A blocked matrix multiplication.
Uses three program id axes. Block size B2
represent the batches to process out of N2
.
Block size B0
represent the rows of x
to process out of N0
. Block size B1
represent the cols of y
to process out of N1
. The middle shape is MID
.
You are allowed to use tl.dot
which computes a smaller mat mul.
Hint: the main trick is that you can split a matmul into smaller parts.
When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term.
For this problem our weight
will be stored in 4 bits. We can store FPINT
of these in a 32 bit integer. In addition for every group
weights in order we will store 1 scale
float value and 1 shift
4 bit value. We store these for the column of weight. The activation
s are stored separately in standard floats.
Mathematically it looks like.
Where g
is the number of groups (GROUP
).
However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin.
Note:
- We don't consider batch size, i.e.
i
, in this puzzle. - Remember to unpack the
FPINT
values into separate 4-bit values. This contains some shape manipulation.