Open
Description
I would like to contribute a new example that demonstrates how to implement Low-Rank Adaptation (LoRA) for fine-tuning language models using JAX and Flax.
Why this might be useful
LoRA is one of the most popular optimization techniques, and people utilizing JAX value optimization. I've noticed that implementing LoRA in JAX isn't quite straightforward. Someone new to JAX would have to search through documentation and GitHub issues to figure this out.
Implementation
The example will build on the JAX for LLM pretraining tutorial and will compare a model trained using that approach against one trained using LoRA.
I have a draft implementation ready to submit.
Metadata
Metadata
Assignees
Labels
No labels