Skip to content

mrowan137/mnist-from-scratch-hip

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MNIST from scratch in HIP

Digit classification with the MNIST dataset is used widely as a 'Hello World' problem in machine learning. In the age of capable ML frameworks and abstractions, it is easier than ever to produce a working MNIST classifier. While powerful, these can obscure the underlying mechanisms.

This small project aims to restore some of that view by implementing a simple MNIST linear classification model from scratch in HIP, tuned for Radeon Pro VII (gfx906).

MNIST test digits
MNIST test dataset visualization and predictions (300 training iterations, learning rate 0.01).

Requirements

Build and run

make
./mnist ./dataset [iterations] [learning rate]

The optional arguments default to 100 iterations and learning rate of 0.01.

Model

Let $X^{jk}$ denote images (training or test), where $j$ indexes the image, and $k$ indexes pixels (plus bias).

Let $W^{ik}$ denote model weights, where $i$ indexes digit class (0–9).

Repeated indices are implicitly summed and unique indices denote free matrix component labels. We use all-up convention, so up/down indices carry no geometric significance.

Class scores are given by:

$$ \hat{Y}^{ij} = W^{ik} X^{jk}, $$

and loss by the average squared residual:

$$ L := \frac{1}{N}\lVert\hat{Y}^{ij} - Y_{\mathrm{true}}^{ij}\rVert_{\mathrm{F}}^{2}, $$

where $N$ is the number of images. Training proceeds by iterative update of the weights via gradient descent:

$$ \begin{aligned} W^{ik} &\leftarrow W^{ik} - \Delta \frac{\partial L }{\partial W^{ik}}\\ &= W^{ik} - \frac{2 \Delta}{N} ( \hat{Y}^{ij} - Y_{\mathrm{true}}^{ij}) X^{jk}, \end{aligned} $$

where $\Delta$ is the learning rate.

Training loss and accuracy

On my card and with ROCm 6.4.3, this implementation achieved 84.8% training / 85.5% test accuracy in 300 iterations using learning rate 0.01, 4.18 ms / iteration.

MNIST training loss
MNIST training loss over 300 iterations (learning rate 0.01).

MNIST training accuracy
MNIST training accuracy over 300 iterations (learning rate 0.01).

Optimizations

The majority of time in the training loop is spent in sgemm_sub_and_scale and sgemm, which roughly account for forward pass and back propagation, respectively. The primary mathematical operation in each of these steps is a matrix multiplication. Key performance considerations and optimizations incorporated in these functions are:

  • data layout: format matrices to be amenable to coalesced memory accesses (e.g. transpose matrices)
  • ping-pong buffering: overlap load of $k+1^{th}$ tile with compute of $k^{th}$ tile
  • shared memory tiling: compute partial dot products in a way that optimizes data reuse
  • hand-tuned tiling parameters
  • occupancy tuning via launch bounds, thread, and block parameters
  • loop unrolling (not always a win, but can help if it does not induce register spillage)
  • warp shuffle reduction: avoid atomics contention
  • kernel fusion: store $2(\hat{Y} - Y)$ during forward pass, in preparation for back propagation

References

License

This project is licensed under the MIT License.

About

From-scratch implementation of MNIST linear classifier in HIP, tuned for Radeon Pro VII (gfx906)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors