This repository is an implementation of continuous-time recurrent neural networks (CT-RNNs) in the Python programming language and JAX ecosystem. Specifically, the architecture and training of CT-RNNs are implemented with the Flax Linen and Optax APIs.
CT-RNNs is a recurrent neural network architecture described by the following equations:
where
Computational neuroscience studies typically train CT-RNNs via the backpropigation-through-time (BPTT) learning algorithm to hypothesize low-dimensional dynamical systems underyling cognitive tasks. CT-RNNs, though, can be time-consuming to train due to the for
loop inherent in their architecture. JAX eleminates the for
loop by unrolling CT-RNNs in time via the scan
primative, creating a large feedforward network with shared parameters broadcasted to each layer. Using JAX's scan
primative to implement and train CT-RNNs results in large speedups over existing deep learning frameworks, like PyTorch and TensorFlow.
One challenge in using JAX is that pseudorandomly generated numbers need to be generated with a key. For the scan
. Fortunately, Flax gracefully manages layer-dependent keys.
pip install git+https://github.com/keith-murray/ctrnn-jax.git
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from ctrnn_jax.model import CTRNNCell
ctrnn = nn.RNN(
CTRNNCell(
hidden_features=2,
output_features=1,
alpha=jnp.float32(0.1),
noise_const=jnp.float32(0.1),
),
split_rngs={"params": False, "noise_stream": True},
)
params = ctrnn.init(
random.PRNGKey(0),
jnp.ones([1, 5, 1]), # (batch, time, input_features)
)
output, rates = ctrnn.apply(
params,
jnp.ones([1, 5, 1]),
rngs={"noise_stream": random.PRNGKey(0)},
)
Take note that Flax's nn.RNN
module is a wrapper for Jax's scan
primative. Also note how the rngs
argument in the ctrnn.apply
method is necessary to seed the
Refer to the scripts/
directory in this repository for examples of how to train and analyze CT-RNNs. Our examples consist of:
- training a CT-RNN on the sine wave generator task
- visualizing its learned solution with principal component analysis (PCA)
- identifying fixed-point attractors
- performing linear stability analysis
We also demonstrate how CT-RNNs can exhibit a variety of nonlinear dynamical phenomena.
Refer to our rnn-workbench repository for more examples of task-optimized CT-RNNs.