Skip to content

Transformerx: JAX implementation of modern transformers

License

Notifications You must be signed in to change notification settings

cs-giung/transformerx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Transformerx

Important

This project is currently under development and is not yet stable. Keep in mind that any features of the project are subject to change.

Transformerx is inspired by Hugging Face Transformers and aims to implement state-of-the-art deep neural network architectures in JAX with minimal dependencies. It prioritizes simplicity and hackability by favoring code replication over complexity or increased abstraction. Currently, Transformerx is being developed in a Python 3.12 environment, and using alternative Python versions may lead to unexpected behaviors.

examples/notebooks/*.ipynb provide a walk-through on how to use the supported models.

Getting started

Basic dependencies for development in TPU environments:

pip install -U pip setuptools wheel google-cloud-tpu
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install tensorflow-cpu tensorflow-datasets
pip install datasets einops einshard jax-smi pylint qax tabulate transformers

About

Transformerx: JAX implementation of modern transformers

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages