Skip to content

Implementation of CT-RNNs in Python's JAX ecosystem

License

Notifications You must be signed in to change notification settings

keith-murray/ctrnn-jax

Repository files navigation

CT-RNN Implementation in Python's JAX ecosystem

This repository is an implementation of continuous-time recurrent neural networks (CT-RNNs) in the Python programming language and JAX ecosystem. Specifically, the architecture and training of CT-RNNs are implemented with the Flax Linen and Optax APIs.

logo

What is a CT-RNN?

CT-RNNs is a recurrent neural network architecture described by the following equations:

$$x_{t+1}=(1-\alpha)x_t+\alpha(W_{\text{rec}}f(x_t)+W_{\text{in}}u_t + b_{\text{rec}} + \eta_t)$$ $$y_t=W_{\text{out}}f(x_t) + b_{\text{out}}$$

where $x_t\in\mathbb{R}^{n}$ is the voltage vector of recurrent neurons, $u_t\in\mathbb{R}^{m}$ is the input vector to the CT-RNN, $y_t\in\mathbb{R}^{p}$ is the firing rate vector of output neurons, $f$ is an activation function mapping voltage to firing rate, $W$ is a weight matrix, $b$ is a bias vector, and $\eta_t\in\mathbb{R}^{n}$ is a vector of randomly sampled values from $\mathcal{N}(0,\sigma^2)$ at each time step, $t$.

Computational neuroscience studies typically train CT-RNNs via the backpropigation-through-time (BPTT) learning algorithm to hypothesize low-dimensional dynamical systems underyling cognitive tasks. CT-RNNs, though, can be time-consuming to train due to the for loop inherent in their architecture. JAX eleminates the for loop by unrolling CT-RNNs in time via the scan primative, creating a large feedforward network with shared parameters broadcasted to each layer. Using JAX's scan primative to implement and train CT-RNNs results in large speedups over existing deep learning frameworks, like PyTorch and TensorFlow.

One challenge in using JAX is that pseudorandomly generated numbers need to be generated with a key. For the $\eta_t$ vector, this requires broadcasting unique keys to all time steps of the CT-RNN through scan. Fortunately, Flax gracefully manages layer-dependent keys.

Installation

pip install git+https://github.com/keith-murray/ctrnn-jax.git

Usage

import jax.numpy as jnp
from jax import random
from flax import linen as nn
from ctrnn_jax.model import CTRNNCell

ctrnn = nn.RNN(
    CTRNNCell(
        hidden_features=2,
        output_features=1,
        alpha=jnp.float32(0.1),
        noise_const=jnp.float32(0.1),
    ),
    split_rngs={"params": False, "noise_stream": True},
)
params = ctrnn.init(
    random.PRNGKey(0),
    jnp.ones([1, 5, 1]),  # (batch, time, input_features)
)

output, rates = ctrnn.apply(
    params,
    jnp.ones([1, 5, 1]),
    rngs={"noise_stream": random.PRNGKey(0)},
)

Take note that Flax's nn.RNN module is a wrapper for Jax's scan primative. Also note how the rngs argument in the ctrnn.apply method is necessary to seed the $\eta_t$ vector.

Examples

Refer to the scripts/ directory in this repository for examples of how to train and analyze CT-RNNs. Our examples consist of:

  • training a CT-RNN on the sine wave generator task
  • visualizing its learned solution with principal component analysis (PCA)
  • identifying fixed-point attractors
  • performing linear stability analysis

We also demonstrate how CT-RNNs can exhibit a variety of nonlinear dynamical phenomena.

Refer to our rnn-workbench repository for more examples of task-optimized CT-RNNs.

About

Implementation of CT-RNNs in Python's JAX ecosystem

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages