Skip to content

Commit

Permalink
fix and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rtqichen committed Feb 6, 2021
1 parent 0529edf commit 63bfb13
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.installed.cfg
*.egg
*__pycache__*
*.pyc
.vscode
build
dist
58 changes: 34 additions & 24 deletions tests/norm_tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import unittest

import torch
Expand All @@ -6,21 +7,35 @@
from problems import (DTYPES, DEVICES, ADAPTIVE_METHODS)


@contextlib.contextmanager
def random_seed_torch(seed):
cpu_rng_state = torch.get_rng_state()
torch.manual_seed(seed)

try:
yield
finally:
torch.set_rng_state(cpu_rng_state)


class _NeuralF(torch.nn.Module):
def __init__(self, width, oscillate):
super(_NeuralF, self).__init__()
self.linears = torch.nn.Sequential(torch.nn.Linear(2, width),
torch.nn.Tanh(),
torch.nn.Linear(width, 2),
torch.nn.Tanh())

# Use the same set of random weights for every instance.
with random_seed_torch(0):
self.linears = torch.nn.Sequential(torch.nn.Linear(2, width),
torch.nn.Tanh(),
torch.nn.Linear(width, 2),
torch.nn.Tanh())
self.nfe = 0
self.oscillate = oscillate

def forward(self, t, x):
self.nfe += 1
out = self.linears(x)
if self.oscillate:
out = out * t.mul(20).sin()
out = out * t.mul(2).sin()
return out


Expand Down Expand Up @@ -227,7 +242,6 @@ def adjoint_norm(tensor_tuple):
self.assertTrue(is_called)

def test_large_norm(self):
torch.manual_seed(3456789) # test can be flaky

def norm(tensor):
return tensor.abs().max()
Expand Down Expand Up @@ -256,38 +270,34 @@ def large_norm(tensor):
self.assertLessEqual(norm_f.nfe, large_norm_f.nfe)

def test_seminorm(self):
torch.manual_seed(3456786) # test can be flaky
for dtype in DTYPES:
for device in DEVICES:
for method in ADAPTIVE_METHODS:
if method == 'adaptive_heun':
# Adaptive heun is consistently an awful choice with seminorms, it seems. My guess is that it's
# consistently overconfident with its step sizes, and that having seminorms turned off means
# that it actually gets it right more often.
continue
if dtype == torch.float32 and method == 'dopri8':
continue

with self.subTest(dtype=dtype, device=device, method=method):

if dtype == torch.float32:
tol = 1e-6
else:
tol = 1e-8

x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype)
t = torch.tensor([0., 1.0], device=device, dtype=dtype)

norm_f = _NeuralF(width=256, oscillate=True).to(device, dtype)
out = torchdiffeq.odeint_adjoint(norm_f, x0, t, atol=3e-7, method=method)
norm_f.nfe = 0
ode_f = _NeuralF(width=1024, oscillate=True).to(device, dtype)

out = torchdiffeq.odeint_adjoint(ode_f, x0, t, atol=tol, rtol=tol, method=method)
ode_f.nfe = 0
out.sum().backward()
default_nfe = ode_f.nfe

seminorm_f = _NeuralF(width=256, oscillate=True).to(device, dtype)
with torch.no_grad():
for norm_param, seminorm_param in zip(norm_f.parameters(), seminorm_f.parameters()):
seminorm_param.copy_(norm_param)
out = torchdiffeq.odeint_adjoint(seminorm_f, x0, t, atol=1e-6, method=method,
out = torchdiffeq.odeint_adjoint(ode_f, x0, t, atol=tol, rtol=tol, method=method,
adjoint_options=dict(norm='seminorm'))
seminorm_f.nfe = 0
ode_f.nfe = 0
out.sum().backward()
seminorm_nfe = ode_f.nfe

self.assertLessEqual(seminorm_f.nfe, norm_f.nfe)
self.assertLessEqual(seminorm_nfe, default_nfe)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/odeint_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_odeint(self):
elif ode == 'linear':
eps = 2e-3
else:
eps = 1e-4
eps = 3e-4

with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method):
f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode,
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_odeint_jump_t(self):

simple_f = _JumpF()
odeint = partial(torchdiffeq.odeint_adjoint, adjoint_params=()) if adjoint else torchdiffeq.odeint
simple_xs = odeint(simple_f, x0, t, atol=1e-7, method=method)
simple_xs = odeint(simple_f, x0, t, atol=1e-6, method=method)

better_f = _JumpF()
options = dict(jump_t=torch.tensor([0.5], device=device))
Expand Down
2 changes: 1 addition & 1 deletion tests/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def y_exact(self, t):
DEVICES.append('cuda')
FIXED_METHODS = ('euler', 'midpoint', 'rk4', 'explicit_adams', 'implicit_adams')
ADAMS_METHODS = ('explicit_adams', 'implicit_adams')
ADAPTIVE_METHODS = ('dopri5', 'bosh3', 'adaptive_heun', 'dopri8')
ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'dopri5', 'dopri8')
SCIPY_METHODS = ('scipy_solver',)
METHODS = FIXED_METHODS + ADAPTIVE_METHODS + SCIPY_METHODS

Expand Down

0 comments on commit 63bfb13

Please sign in to comment.