This repository contains simple and educational implementations of deep learning models using JAX.
-
notes/
Containsnotes.py
, which includes helpful notes and code snippets related to JAX and deep learning concepts. -
basic_transformer/
A basic implementation of the Transformer architecture in JAX. -
main.py
A command-line interface (CLI) to run and experiment with the different model implementations.