Skip to content

Commit

Permalink
release
Browse files Browse the repository at this point in the history
  • Loading branch information
rtqichen committed Nov 14, 2018
0 parents commit 810fda4
Show file tree
Hide file tree
Showing 24 changed files with 2,472 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.egg-info/
.installed.cfg
*.egg
*__pycache__*
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2018 Ricky Tian Qi Chen

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# PyTorch Implementation of Differentiable ODE Solvers

This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through all solvers is supported using the adjoint method. For usage of ODE solvers in deep learning applications, see [1].

As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.

---

<p align="center">
<img align="middle" src="./assets/resnet_0_viz.png" alt="Discrete-depth network" width="240" height="330" />
<img align="middle" src="./assets/odenet_0_viz.png" alt="Continuous-depth network" width="240" height="330" />
</p>

## Installation
```
git clone https://github.com/rtqichen/torchdiffeq.git
cd torchdiffeq
pip install -e .
```

## Examples
Examples are placed in the [`examples`](./examples) directory.

We encourage those who are interested in using this library to take a look at `examples/ode_demo.py` for understanding how to use `torchdiffeq` to fit a simple spiral ODE.

<p align="center">
<img align="middle" src="./assets/ode_demo.gif" alt="ODE Demo" width="500" height="250" />
</p>

## Basic usage
This library provides one main interface `odeint` which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem provides an ODE and an initial value
```
dy/dt = f(t, y) y(t_0) = y_0
```
The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition.

To solve an IVP using the default solver:
```
from torchdiffeq import odeint
odeint(func, y0, t)
```
where `func` is any callable implementing the ordinary differential equation `f(t, x)`, `y0` is an _any_-D Tensor or a tuple of _any_-D Tensors representing the initial values, and `t` is a 1-D Tensor containing the evaluation points. The initial time is taken to be `t[0]`.

Backpropagation through `odeint` goes through the internals of the solver, but this is not supported for all solvers. Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage.

To use the adjoint method:
```
from torchdiffeq import odeint_adjoint as odeint
odeint(func, y0, t)
```
`odeint_adjoint` simply wraps around `odeint`, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call.

The biggest **gotcha** is that `func` must be a `nn.Module` when using the adjoint method. This is used to collect parameters of the differential equation.

### Keyword Arguments
- `rtol` Relative tolerance.
- `atol` Absolute tolerance.
- `method` One of the solvers listed below.

#### List of ODE Solvers:

Adaptive-step:
- `dopri5` Runge-Kutta 4(5) [default].
- `adams` Adaptive-order implicit Adams.

Fixed-step:
- `euler` Euler method.
- `midpoint` Midpoint method.
- `rk4` Fourth-order Runge-Kutta with 3/8 rule.
- `explicit_adams` Explicit Adams.
- `fixed_adams` Implicit Adams.

### References
[1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Processing Information Systems.* 2018.
Binary file added assets/ode_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/odenet_0_viz.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/resnet_0_viz.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Overview of Examples

This `examples` directory contains cleaned up code regarding the usage of adaptive ODE solvers in machine learning. The scripts in this directory assume that `torchdiffeq` is installed following instructions from the main directory.

## Demo
The `ode_demo.py` file contains a short implementation of learning a dynamics model to mimic a spiral ODE.

To visualize the training progress, run
```
python ode_demo.py --viz
```
The training should look similar to this:

<p align="center">
<img align="middle" src="../assets/ode_demo.gif" alt="ODE Demo" width="500" height="250" />
</p>

## ODEnet for MNIST
The `odenet_mnist.py` file contains a reproduction of the MNIST experiments in our Neural ODE paper. Notably the ODE solver library and method are different from our original experiments, but the results are similar to those we report in the paper.

We can use an adaptive ODE solver to approximate our continuous-depth network while still backpropagating through the network.
```
python odenet_mnist.py --network odenet
```
However, the memory requirements for this will blow up very fast, especially for more complex problems where the number of function evaluations can reach nearly a thousand.

For applications that require solving complex trajectories, we recommend using the adjoint method.
```
python odenet_mnist.py --network odenet --adjoint True
```
The adjoint method can be slower when using an adaptive ODE solver as it involves another solve in the backward pass with a much larger system, so experimenting on small systems with direct backpropagation first is recommended.

Thankfully, it is extremely easy to write code for both adjoint and non-adjoint backpropagation, as they use the same interface.
```
if adjoint:
from torchdiffeq import odeint_adjoint as odeint
else:
from torchdiffeq import odeint
```
The main gotcha is that `odeint_adjoint` requires implementing the dynamics network as a `nn.Module` while `odeint` can work with any callable in Python.

## Continuous Normalizing Flows
Code for continuous normalizing flows (CNF) have their own public repository. Tools for training, evaluating, and visualizing CNF for reversible generative modeling are provided along with FFJORD, a linear cost stochastic approximation of CNF.

Find the code in https://github.com/rtqichen/ffjord. This code contains some advanced tricks for `torchdiffeq`.
180 changes: 180 additions & 0 deletions examples/ode_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os
import argparse
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

parser = argparse.ArgumentParser('ODE demo')
parser.add_argument('--method', type=str, default='dopri5')
parser.add_argument('--data_size', type=int, default=1000)
parser.add_argument('--batch_time', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--niters', type=int, default=2000)
parser.add_argument('--test_freq', type=int, default=20)
parser.add_argument('--viz', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--adjoint', action='store_true')
args = parser.parse_args()

if args.adjoint:
from torchdiffeq import odeint_adjoint as odeint
else:
from torchdiffeq import odeint

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])


class Lambda(nn.Module):

def forward(self, t, y):
return torch.mm(y**3, true_A)


with torch.no_grad():
true_y = odeint(Lambda(), true_y0, t)


def get_batch():
s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time), args.batch_size, replace=False))
batch_y0 = true_y[s] # (M, D)
batch_t = t[:args.batch_time] # (T)
batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D)
return batch_y0, batch_t, batch_y


def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)


if args.viz:
makedirs('png')
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(12, 4), facecolor='white')
ax_traj = fig.add_subplot(131, frameon=False)
ax_phase = fig.add_subplot(132, frameon=False)
ax_vecfield = fig.add_subplot(133, frameon=False)
plt.show(block=False)


def visualize(true_y, pred_y, odefunc, itr):

if args.viz:

ax_traj.cla()
ax_traj.set_title('Trajectories')
ax_traj.set_xlabel('t')
ax_traj.set_ylabel('x,y')
ax_traj.plot(t.numpy(), true_y.numpy()[:, 0, 0], t.numpy(), true_y.numpy()[:, 0, 1], 'g-')
ax_traj.plot(t.numpy(), pred_y.numpy()[:, 0, 0], '--', t.numpy(), pred_y.numpy()[:, 0, 1], 'b--')
ax_traj.set_xlim(t.min(), t.max())
ax_traj.set_ylim(-2, 2)
ax_traj.legend()

ax_phase.cla()
ax_phase.set_title('Phase Portrait')
ax_phase.set_xlabel('x')
ax_phase.set_ylabel('y')
ax_phase.plot(true_y.numpy()[:, 0, 0], true_y.numpy()[:, 0, 1], 'g-')
ax_phase.plot(pred_y.numpy()[:, 0, 0], pred_y.numpy()[:, 0, 1], 'b--')
ax_phase.set_xlim(-2, 2)
ax_phase.set_ylim(-2, 2)

ax_vecfield.cla()
ax_vecfield.set_title('Learned Vector Field')
ax_vecfield.set_xlabel('x')
ax_vecfield.set_ylabel('y')

y, x = np.mgrid[-2:2:21j, -2:2:21j]
dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2))).cpu().detach().numpy()
mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
dydt = (dydt / mag)
dydt = dydt.reshape(21, 21, 2)

ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
ax_vecfield.set_xlim(-2, 2)
ax_vecfield.set_ylim(-2, 2)

fig.tight_layout()
plt.savefig('png/{:03d}'.format(itr))
plt.draw()
plt.pause(0.001)


class ODEFunc(nn.Module):

def __init__(self):
super(ODEFunc, self).__init__()

self.net = nn.Sequential(
nn.Linear(2, 50),
nn.Tanh(),
nn.Linear(50, 2),
)

for m in self.net.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=0.1)
nn.init.constant_(m.bias, val=0)

def forward(self, t, y):
return self.net(y**3)

This comment has been minimized.

Copy link
@pgeorgant

pgeorgant Aug 22, 2023

I don't get why use **3 here. In fact, removing this makes ODEFunc unable to approximate the type of dynamics Lambda implements.
image



class RunningAverageMeter(object):
"""Computes and stores the average and current value"""

def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()

def reset(self):
self.val = None
self.avg = 0

def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val


if __name__ == '__main__':

ii = 0

func = ODEFunc()
optimizer = optim.Adam(func.parameters(), lr=1e-3)
end = time.time()

time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)

for itr in range(1, args.niters + 1):
optimizer.zero_grad()
batch_y0, batch_t, batch_y = get_batch()
pred_y = odeint(func, batch_y0, batch_t)
loss = torch.mean(torch.abs(pred_y - batch_y))
loss.backward()
optimizer.step()

time_meter.update(time.time() - end)
loss_meter.update(loss.item())

if itr % args.test_freq == 0:
with torch.no_grad():
pred_y = odeint(func, true_y0, t)
loss = torch.mean(torch.abs(pred_y - true_y))
print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
visualize(true_y, pred_y, func, ii)
ii += 1

end = time.time()
Loading

0 comments on commit 810fda4

Please sign in to comment.