This is my notes of learning JAX, including some example codes, important features, mini project.
My goal is to learn JAX for Deep Learning, and use it to implement useful models & envs & algos.
Reference: Basic Tutorial of JAX:
- JAX: https://github.com/google/jax https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
- Intro to JAX: Accelerating Machine Learning research: https://www.youtube.com/watch?v=WdTeDXsOSj4&ab_channel=TensorFlow
- https://github.com/gordicaleksa/get-started-with-JAX
DL Resources for JAX: