Skip to content

Latest commit

 

History

History
200 lines (100 loc) · 1.53 MB

puzzles.md

File metadata and controls

200 lines (100 loc) · 1.53 MB

Triton Puzzles Lite

Programming for accelerators such as GPUs is critical for modern AI systems. This often means programming directly in proprietary low-level languages such as CUDA. Triton is an alternative open-source language that allows you to code at a higher-level and compile to accelerators like GPU.

image.png

Coding for Triton is very similar to Numpy and PyTorch in both syntax and semantics. However, as a lower-level language there are a lot of details that you need to keep track of. In particular, one area that learners have trouble with is memory loading and storage which is critical for speed on low-level devices.

This set is puzzles is meant to teach you how to use Triton from first principles in an interactive fashion. You will start with trivial examples and build your way up to real algorithms like Flash Attention and Quantized neural networks. These puzzles do not need to run on GPU since they use a Triton interpreter.

Introduction

To begin with, we will only use tl.load and tl.store in order to build simple programs.

Here's an example of load. It takes an arange over the memory. By default the indexing of torch tensors with column, rows, depths or right-to-left. It also takes in a mask as the second argument. Mask is critically important because all shapes in Triton need to be powers of two.

You can also use this trick to read in a 2d array.

The tl.store function is quite similar. It allows you to write to a tensor.

You can only load in relatively small blocks at a time in Triton. to work with larger tensors you need to use a program id axis to run multiple blocks in parallel. Here is an example with one program axis with 3 blocks. You can use the visualizer to scroll over it.

See the Triton Docs for further information.

Puzzle 1: Constant Add

Add a constant to a vector. Uses one program id axis. Block size B0 is always the same as vector x with length N0.

$$z_i = 10 + x_i \text{ for } i = 1\ldots N_0$$

image.png

Puzzle 2: Constant Add Block

Add a constant to a vector. Uses one program block axis (no for loops yet). Block size B0 is now smaller than the shape vector x which is N0.

$$z_i = 10 + x_i \text{ for } i = 1\ldots N_0$$

image.png

Puzzle 3: Outer Vector Add

Add two vectors.

Uses one program block axis. Block size B0 is always the same as vector x length N0. Block size B1 is always the same as vector y length N1.

$$z_{j, i} = x_i + y_j\text{ for } i = 1\ldots B_0,\ j = 1\ldots B_1$$

image.png

Puzzle 4: Outer Vector Add Block

Add a row vector to a column vector.

Uses two program block axes. Block size B0 is always less than the vector x length N0. Block size B1 is always less than vector y length N1.

$$z_{j, i} = x_i + y_j\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$

image.png

Puzzle 5: Fused Outer Multiplication

Multiply a row vector to a column vector and take a relu.

Uses two program block axes. Block size B0 is always less than the vector x length N0. Block size B1 is always less than vector y length N1.

$$z_{j, i} = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$

image.png

Puzzle 6: Fused Outer Multiplication - Backwards

Backwards of a function that multiplies a matrix with a row vector and take a relu.

Uses two program blocks. Block size B0 is always less than the vector x length N0. Block size B1 is always less than vector y length N1. Chain rule backward dz is of shape N1 by N0

$$f(x, y) = \text{relu}(x_{j, i} \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$

$$dx_{j, i} = f_x'(x, y){j, i} \times dz{j, i}$$

image.png

Puzzle 7: Long Sum

Sum of a batch of numbers.

Uses one program blocks. Block size B0 represents a range of batches of x of length N0. Each element is of length T. Process it B1 < T elements at a time.

$$z_{i} = \sum^{T}j x{i,j} = \text{ for } i = 1\ldots N_0$$

Hint: You will need a for loop for this problem. These work and look the same as in Python.

image.png

Puzzle 8: Long Softmax

Softmax of a batch of logits.

Uses one program block axis. Block size B0 represents the batch of x of length N0. Block logit length T. Process it B1 < T elements at a time.

$$z_{i, j} = \text{softmax}(x_{i,1} \ldots x_{i, T}) \text{ for } i = 1\ldots N_0$$

Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton they recommend not using exp but instead using exp2. You need the identity

$$\exp(x) = 2^{\log_2(e) x}$$

Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful:

$$\exp(x_i - m) = \exp(x_i - m/2 - m/2) = \exp(x_i - m/ 2) / \exp(m/2) $$

image.png

Puzzle 9: Simple FlashAttention

A scalar version of FlashAttention.

Uses zero programs. Block size B0 represent the batches of q to process out of N0. Sequence length is T. Process it B1 < T elements (k, v) at a time for some B1.

$$z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)j v{j} \text{ for } i = 1\ldots N_0$$

This can be done in 1 loop using a similar trick from the last puzzle.

Hint: Use tl.where to mask q dot k to -inf to avoid overflow (NaN).

image.png

Puzzle 10: Two Dimensional Convolution

A batched 2D convolution.

Uses one program id axis. Block size B0 represent the batches to process out of N0. Image x is size is H by W with only 1 channel, and kernel k is size KH by KW.

$$ z_{i, j, l} = \sum_{oj, ol}^{j+oj\le H, l+ol\le W} k_{oj,ol} \times x_{i,j + oj, l + ol} \text{ for } i = 1\ldots N_0 \text{ for } j = 1\ldots H \text{ for } l = 1\ldots W $$

image.png

Puzzle 11: Matrix Multiplication

A blocked matrix multiplication.

Uses three program id axes. Block size B2 represent the batches to process out of N2. Block size B0 represent the rows of x to process out of N0. Block size B1 represent the cols of y to process out of N1. The middle shape is MID.

$$z_{i, j, k} = \sum_{l} x_{i,j, l} \times y_{i, l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1$$

You are allowed to use tl.dot which computes a smaller mat mul.

Hint: the main trick is that you can split a matmul into smaller parts.

$$z_{i, j, k} = \sum_{l=1}^{L/2} x_{i,j, l} \times y_{i, l, k} + \sum_{l=L/2}^{L} x_{i,j, l} \times y_{i, l, k} $$

image.png

Puzzle 12: Quantized Matrix Mult

When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term.

For this problem our weight will be stored in 4 bits. We can store FPINT of these in a 32 bit integer. In addition for every group weights in order we will store 1 scale float value and 1 shift 4 bit value. We store these for the column of weight. The activations are stored separately in standard floats.

Mathematically it looks like.

$$z_{j, k} = \sum_{l} sc_{j, \frac{l}{g}} (w_{j, l} - sh_{j, \frac{l}{g}}) \times y_{l, k} \text{ for } j = 1\ldots N_0, k = 1\ldots N_1$$

Where g is the number of groups (GROUP).

However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin.

Note:

  • We don't consider batch size, i.e. i, in this puzzle.
  • Remember to unpack the FPINT values into separate 4-bit values. This contains some shape manipulation.

image.png