This repository contains the code to accompany the paper "A Unifying Framework for Parallelizing Sequential Models with Linear Dynamical Systems." The primary contributions of our paper are unifying different fixed-point algorithms (Newton, quasi-Newton, Picard, and Jacobi) in the framework of linear dynamical systems (LDS), and demonstrating the effectiveness of these algorithms in parallelizing stateful (Markov) models.
We focus on parallelizing state space models, i.e. models of the form
where
In the context of parallelizing such state space models, we find that a wide variety of fixed-point methods have iterations that can be expressed as a linear dynamical system (LDS), i.e. with update given by:
We summarize the fixed-point methods we consider in Table 1 of our paper.
| Fixed-point method | Order | Transition matrix |
|---|---|---|
| Newton | first-order | |
| Quasi-Newton | quasi first-order | |
| Picard | zeroth-order | |
| Jacobi | zeroth-order |
Each LDS can be parallelized over the sequence length with a parallel scan.
Info about how to install jax: https://docs.jax.dev/en/latest/installation.html
Use python 3.12.1
Use jax 0.5.2
pip install --upgrade "jax[cpu]==0.5.2"
pip install --upgrade "jax[cuda12]==0.5.2" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
After installing jax appropriately based on hardware, simply run
pip install -e .
Our experiments are run using experiments/harness.py on an H100 with 80GB of VRAM.
Please also star this repo if you find the code interesting or useful!
@article{UnifyingFramework2025,
title={A Unifying Framework for Parallelizing Sequential Models with Linear Dynamical Systems},
author={Xavier Gonzalez and E. Kelly Buchanan and Hyun Dong Lee and Jerry Weihong Liu and Ke Alexander Wang and David M. Zoltowski and Chris R\'e and Scott W. Linderman},
year={2025},
}

