Skip to content

Add LoRA Training Example to JAX Examples #186

Open
@nikolasavic3

Description

@nikolasavic3

I would like to contribute a new example that demonstrates how to implement Low-Rank Adaptation (LoRA) for fine-tuning language models using JAX and Flax.

Why this might be useful

LoRA is one of the most popular optimization techniques, and people utilizing JAX value optimization. I've noticed that implementing LoRA in JAX isn't quite straightforward. Someone new to JAX would have to search through documentation and GitHub issues to figure this out.

Implementation

The example will build on the JAX for LLM pretraining tutorial and will compare a model trained using that approach against one trained using LoRA.

I have a draft implementation ready to submit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions