Skip to content

Added LSTM Model in JAX from scratch #4690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

Tushaam
Copy link

@Tushaam Tushaam commented Apr 5, 2025

This PR introduces a new LSTM (Long Short-Term Memory) model implementation from scratch using JAX within the examples/ directory.

Motivation:

There was previously no basic LSTM implementation in Flax using pure JAX APIs. This example serves as an educational reference for those interested in understanding how to build an LSTM layer manually in JAX and integrate it with Flax-style modeling.

In this:

LSTM_JAX.py: Manual implementation of an LSTM model using JAX primitives.

LSTM_TestCase.py: A simple test case that verifies input-output dimensions and prints sample outputs.

init.py: To ensure module accessibility.

This can be especially useful for beginners or researchers looking to build custom RNN models using Flax + JAX.

Fixes : (N/A – This is a new contribution not tied to an open issue.)

Tushaam and others added 8 commits April 6, 2025 00:21
This commit adds a basic LSTM cell and sequence module implemented using Flax and JAX. 

Includes a dummy input test script for demonstration.

The implementation is modular and can be reused/modified for other sequence modeling tasks.
This commit adds a basic LSTM cell and sequence module implemented using Flax and JAX. 

It also Includes a dummy input test script for demonstration.

The implementation is modular and can be reused/modified for other sequence modeling tasks.
Copy link

google-cla bot commented Apr 5, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant