-
Notifications
You must be signed in to change notification settings - Fork 948
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 810fda4
Showing
24 changed files
with
2,472 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
*__pycache__* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2018 Ricky Tian Qi Chen | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# PyTorch Implementation of Differentiable ODE Solvers | ||
|
||
This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through all solvers is supported using the adjoint method. For usage of ODE solvers in deep learning applications, see [1]. | ||
|
||
As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU. | ||
|
||
--- | ||
|
||
<p align="center"> | ||
<img align="middle" src="./assets/resnet_0_viz.png" alt="Discrete-depth network" width="240" height="330" /> | ||
<img align="middle" src="./assets/odenet_0_viz.png" alt="Continuous-depth network" width="240" height="330" /> | ||
</p> | ||
|
||
## Installation | ||
``` | ||
git clone https://github.com/rtqichen/torchdiffeq.git | ||
cd torchdiffeq | ||
pip install -e . | ||
``` | ||
|
||
## Examples | ||
Examples are placed in the [`examples`](./examples) directory. | ||
|
||
We encourage those who are interested in using this library to take a look at `examples/ode_demo.py` for understanding how to use `torchdiffeq` to fit a simple spiral ODE. | ||
|
||
<p align="center"> | ||
<img align="middle" src="./assets/ode_demo.gif" alt="ODE Demo" width="500" height="250" /> | ||
</p> | ||
|
||
## Basic usage | ||
This library provides one main interface `odeint` which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem provides an ODE and an initial value | ||
``` | ||
dy/dt = f(t, y) y(t_0) = y_0 | ||
``` | ||
The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition. | ||
|
||
To solve an IVP using the default solver: | ||
``` | ||
from torchdiffeq import odeint | ||
odeint(func, y0, t) | ||
``` | ||
where `func` is any callable implementing the ordinary differential equation `f(t, x)`, `y0` is an _any_-D Tensor or a tuple of _any_-D Tensors representing the initial values, and `t` is a 1-D Tensor containing the evaluation points. The initial time is taken to be `t[0]`. | ||
|
||
Backpropagation through `odeint` goes through the internals of the solver, but this is not supported for all solvers. Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage. | ||
|
||
To use the adjoint method: | ||
``` | ||
from torchdiffeq import odeint_adjoint as odeint | ||
odeint(func, y0, t) | ||
``` | ||
`odeint_adjoint` simply wraps around `odeint`, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call. | ||
|
||
The biggest **gotcha** is that `func` must be a `nn.Module` when using the adjoint method. This is used to collect parameters of the differential equation. | ||
|
||
### Keyword Arguments | ||
- `rtol` Relative tolerance. | ||
- `atol` Absolute tolerance. | ||
- `method` One of the solvers listed below. | ||
|
||
#### List of ODE Solvers: | ||
|
||
Adaptive-step: | ||
- `dopri5` Runge-Kutta 4(5) [default]. | ||
- `adams` Adaptive-order implicit Adams. | ||
|
||
Fixed-step: | ||
- `euler` Euler method. | ||
- `midpoint` Midpoint method. | ||
- `rk4` Fourth-order Runge-Kutta with 3/8 rule. | ||
- `explicit_adams` Explicit Adams. | ||
- `fixed_adams` Implicit Adams. | ||
|
||
### References | ||
[1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Processing Information Systems.* 2018. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Overview of Examples | ||
|
||
This `examples` directory contains cleaned up code regarding the usage of adaptive ODE solvers in machine learning. The scripts in this directory assume that `torchdiffeq` is installed following instructions from the main directory. | ||
|
||
## Demo | ||
The `ode_demo.py` file contains a short implementation of learning a dynamics model to mimic a spiral ODE. | ||
|
||
To visualize the training progress, run | ||
``` | ||
python ode_demo.py --viz | ||
``` | ||
The training should look similar to this: | ||
|
||
<p align="center"> | ||
<img align="middle" src="../assets/ode_demo.gif" alt="ODE Demo" width="500" height="250" /> | ||
</p> | ||
|
||
## ODEnet for MNIST | ||
The `odenet_mnist.py` file contains a reproduction of the MNIST experiments in our Neural ODE paper. Notably the ODE solver library and method are different from our original experiments, but the results are similar to those we report in the paper. | ||
|
||
We can use an adaptive ODE solver to approximate our continuous-depth network while still backpropagating through the network. | ||
``` | ||
python odenet_mnist.py --network odenet | ||
``` | ||
However, the memory requirements for this will blow up very fast, especially for more complex problems where the number of function evaluations can reach nearly a thousand. | ||
|
||
For applications that require solving complex trajectories, we recommend using the adjoint method. | ||
``` | ||
python odenet_mnist.py --network odenet --adjoint True | ||
``` | ||
The adjoint method can be slower when using an adaptive ODE solver as it involves another solve in the backward pass with a much larger system, so experimenting on small systems with direct backpropagation first is recommended. | ||
|
||
Thankfully, it is extremely easy to write code for both adjoint and non-adjoint backpropagation, as they use the same interface. | ||
``` | ||
if adjoint: | ||
from torchdiffeq import odeint_adjoint as odeint | ||
else: | ||
from torchdiffeq import odeint | ||
``` | ||
The main gotcha is that `odeint_adjoint` requires implementing the dynamics network as a `nn.Module` while `odeint` can work with any callable in Python. | ||
|
||
## Continuous Normalizing Flows | ||
Code for continuous normalizing flows (CNF) have their own public repository. Tools for training, evaluating, and visualizing CNF for reversible generative modeling are provided along with FFJORD, a linear cost stochastic approximation of CNF. | ||
|
||
Find the code in https://github.com/rtqichen/ffjord. This code contains some advanced tricks for `torchdiffeq`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import os | ||
import argparse | ||
import time | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
|
||
parser = argparse.ArgumentParser('ODE demo') | ||
parser.add_argument('--method', type=str, default='dopri5') | ||
parser.add_argument('--data_size', type=int, default=1000) | ||
parser.add_argument('--batch_time', type=int, default=10) | ||
parser.add_argument('--batch_size', type=int, default=20) | ||
parser.add_argument('--niters', type=int, default=2000) | ||
parser.add_argument('--test_freq', type=int, default=20) | ||
parser.add_argument('--viz', action='store_true') | ||
parser.add_argument('--gpu', type=int, default=0) | ||
parser.add_argument('--adjoint', action='store_true') | ||
args = parser.parse_args() | ||
|
||
if args.adjoint: | ||
from torchdiffeq import odeint_adjoint as odeint | ||
else: | ||
from torchdiffeq import odeint | ||
|
||
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') | ||
|
||
true_y0 = torch.tensor([[2., 0.]]) | ||
t = torch.linspace(0., 25., args.data_size) | ||
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]) | ||
|
||
|
||
class Lambda(nn.Module): | ||
|
||
def forward(self, t, y): | ||
return torch.mm(y**3, true_A) | ||
|
||
|
||
with torch.no_grad(): | ||
true_y = odeint(Lambda(), true_y0, t) | ||
|
||
|
||
def get_batch(): | ||
s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time), args.batch_size, replace=False)) | ||
batch_y0 = true_y[s] # (M, D) | ||
batch_t = t[:args.batch_time] # (T) | ||
batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D) | ||
return batch_y0, batch_t, batch_y | ||
|
||
|
||
def makedirs(dirname): | ||
if not os.path.exists(dirname): | ||
os.makedirs(dirname) | ||
|
||
|
||
if args.viz: | ||
makedirs('png') | ||
import matplotlib.pyplot as plt | ||
fig = plt.figure(figsize=(12, 4), facecolor='white') | ||
ax_traj = fig.add_subplot(131, frameon=False) | ||
ax_phase = fig.add_subplot(132, frameon=False) | ||
ax_vecfield = fig.add_subplot(133, frameon=False) | ||
plt.show(block=False) | ||
|
||
|
||
def visualize(true_y, pred_y, odefunc, itr): | ||
|
||
if args.viz: | ||
|
||
ax_traj.cla() | ||
ax_traj.set_title('Trajectories') | ||
ax_traj.set_xlabel('t') | ||
ax_traj.set_ylabel('x,y') | ||
ax_traj.plot(t.numpy(), true_y.numpy()[:, 0, 0], t.numpy(), true_y.numpy()[:, 0, 1], 'g-') | ||
ax_traj.plot(t.numpy(), pred_y.numpy()[:, 0, 0], '--', t.numpy(), pred_y.numpy()[:, 0, 1], 'b--') | ||
ax_traj.set_xlim(t.min(), t.max()) | ||
ax_traj.set_ylim(-2, 2) | ||
ax_traj.legend() | ||
|
||
ax_phase.cla() | ||
ax_phase.set_title('Phase Portrait') | ||
ax_phase.set_xlabel('x') | ||
ax_phase.set_ylabel('y') | ||
ax_phase.plot(true_y.numpy()[:, 0, 0], true_y.numpy()[:, 0, 1], 'g-') | ||
ax_phase.plot(pred_y.numpy()[:, 0, 0], pred_y.numpy()[:, 0, 1], 'b--') | ||
ax_phase.set_xlim(-2, 2) | ||
ax_phase.set_ylim(-2, 2) | ||
|
||
ax_vecfield.cla() | ||
ax_vecfield.set_title('Learned Vector Field') | ||
ax_vecfield.set_xlabel('x') | ||
ax_vecfield.set_ylabel('y') | ||
|
||
y, x = np.mgrid[-2:2:21j, -2:2:21j] | ||
dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2))).cpu().detach().numpy() | ||
mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1) | ||
dydt = (dydt / mag) | ||
dydt = dydt.reshape(21, 21, 2) | ||
|
||
ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black") | ||
ax_vecfield.set_xlim(-2, 2) | ||
ax_vecfield.set_ylim(-2, 2) | ||
|
||
fig.tight_layout() | ||
plt.savefig('png/{:03d}'.format(itr)) | ||
plt.draw() | ||
plt.pause(0.001) | ||
|
||
|
||
class ODEFunc(nn.Module): | ||
|
||
def __init__(self): | ||
super(ODEFunc, self).__init__() | ||
|
||
self.net = nn.Sequential( | ||
nn.Linear(2, 50), | ||
nn.Tanh(), | ||
nn.Linear(50, 2), | ||
) | ||
|
||
for m in self.net.modules(): | ||
if isinstance(m, nn.Linear): | ||
nn.init.normal_(m.weight, mean=0, std=0.1) | ||
nn.init.constant_(m.bias, val=0) | ||
|
||
def forward(self, t, y): | ||
return self.net(y**3) | ||
This comment has been minimized.
Sorry, something went wrong. |
||
|
||
|
||
class RunningAverageMeter(object): | ||
"""Computes and stores the average and current value""" | ||
|
||
def __init__(self, momentum=0.99): | ||
self.momentum = momentum | ||
self.reset() | ||
|
||
def reset(self): | ||
self.val = None | ||
self.avg = 0 | ||
|
||
def update(self, val): | ||
if self.val is None: | ||
self.avg = val | ||
else: | ||
self.avg = self.avg * self.momentum + val * (1 - self.momentum) | ||
self.val = val | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
ii = 0 | ||
|
||
func = ODEFunc() | ||
optimizer = optim.Adam(func.parameters(), lr=1e-3) | ||
end = time.time() | ||
|
||
time_meter = RunningAverageMeter(0.97) | ||
loss_meter = RunningAverageMeter(0.97) | ||
|
||
for itr in range(1, args.niters + 1): | ||
optimizer.zero_grad() | ||
batch_y0, batch_t, batch_y = get_batch() | ||
pred_y = odeint(func, batch_y0, batch_t) | ||
loss = torch.mean(torch.abs(pred_y - batch_y)) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
time_meter.update(time.time() - end) | ||
loss_meter.update(loss.item()) | ||
|
||
if itr % args.test_freq == 0: | ||
with torch.no_grad(): | ||
pred_y = odeint(func, true_y0, t) | ||
loss = torch.mean(torch.abs(pred_y - true_y)) | ||
print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item())) | ||
visualize(true_y, pred_y, func, ii) | ||
ii += 1 | ||
|
||
end = time.time() |
Oops, something went wrong.
I don't get why use **3 here. In fact, removing this makes ODEFunc unable to approximate the type of dynamics Lambda implements.
