Skip to content

Commit

Permalink
Merge pull request #261 from psv4/add-tsit5
Browse files Browse the repository at this point in the history
Adding tsit5 as a solver
  • Loading branch information
rtqichen authored Feb 13, 2025
2 parents a88aac5 + 4186274 commit f3135f3
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def y_exact(self, t):
DEVICES.append('cuda')
FIXED_METHODS = ('euler', 'midpoint', 'heun2', 'heun3', 'rk4', 'explicit_adams', 'implicit_adams')
ADAMS_METHODS = ('explicit_adams', 'implicit_adams')
ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'dopri5', 'dopri8')
ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'tsit5', 'dopri5', 'dopri8')
SCIPY_METHODS = ('scipy_solver',)
METHODS = FIXED_METHODS + ADAPTIVE_METHODS + SCIPY_METHODS

Expand Down
2 changes: 2 additions & 0 deletions torchdiffeq/_impl/odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4
from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton
from .dopri8 import Dopri8Solver
from .tsit5 import Tsit5Solver
from .scipy_wrapper import ScipyWrapperODESolver
from .misc import _check_inputs, _flat_to_shape
from .interp import _interp_evaluate

SOLVERS = {
'dopri8': Dopri8Solver,
'dopri5': Dopri5Solver,
'tsit5': Tsit5Solver,
'bosh3': Bosh3Solver,
'fehlberg2': Fehlberg2,
'adaptive_heun': AdaptiveHeunSolver,
Expand Down
82 changes: 82 additions & 0 deletions torchdiffeq/_impl/tsit5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
# https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/lib/OrdinaryDiffEqTsit5/src/tsit_tableaus.jl
# https://github.com/patrick-kidger/diffrax/blob/14baa1edddcacf27c0483962b3c9cf2e86e6e5b6/diffrax/_solver/tsit5.py#L158

_TSITOURAS_TABLEAU = _ButcherTableau(
alpha=torch.tensor([
161 / 1000,
327 / 1000,
9 / 10,
.9800255409045096857298102862870245954942137979563024768854764293221195950761080302604,
1,
1
], dtype=torch.float64),
beta=[
torch.tensor([161 / 1000], dtype=torch.float64),
torch.tensor([
-.8480655492356988544426874250230774675121177393430391537369234245294192976164141156943e-2,
.3354806554923569885444268742502307746751211773934303915373692342452941929761641411569
], dtype=torch.float64),
torch.tensor([
2.897153057105493432130432594192938764924887287701866490314866693455023795137503079289,
-6.359448489975074843148159912383825625952700647415626703305928850207288721235210244366,
4.362295432869581411017727318190886861027813359713760212991062156752264926097707165077,
], dtype=torch.float64),
torch.tensor([
5.325864828439256604428877920840511317836476253097040101202360397727981648835607691791,
-11.74888356406282787774717033978577296188744178259862899288666928009020615663593781589,
7.495539342889836208304604784564358155658679161518186721010132816213648793440552049753,
-.9249506636175524925650207933207191611349983406029535244034750452930469056411389539635e-1
], dtype=torch.float64),
torch.tensor([
5.861455442946420028659251486982647890394337666164814434818157239052507339770711679748,
-12.92096931784710929170611868178335939541780751955743459166312250439928519268343184452,
8.159367898576158643180400794539253485181918321135053305748355423955009222648673734986,
-.7158497328140099722453054252582973869127213147363544882721139659546372402303777878835e-1,
-.2826905039406838290900305721271224146717633626879770007617876201276764571291579142206e-1
], dtype=torch.float64),
torch.tensor([
.9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1,
1 / 100,
.4798896504144995747752495322905965199130404621990332488332634944254542060153074523509,
1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331,
-3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677,
2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841
], dtype=torch.float64),
],
c_sol=torch.tensor([
.9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1,
.9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2,
.4877705284247615707855642599631228241516691959761363774365216240304071651579571959813,
1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761,
-2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702,
1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089,
1 / 66
], dtype=torch.float64),
c_error=torch.tensor([
-1.780011052225771443378550607539534775944678804333659557637450799792588061629796e-03,
-8.164344596567469032236360633546862401862537590159047610940604670770447527463931e-04,
7.880878010261996010314727672526304238628733777103128603258129604952959142646516e-03,
-1.44711007173262907537165147972635116720922712343167677619514233896760819649515e-01,
5.823571654525552250199376106520421794260781239567387797673045438803694038950012e-01,
-4.580821059291869466616365188325542974428047279788398179474684434732070620889539e-01,
1 / 66
], dtype=torch.float64),
)

x = 1 / 2
TSIT_C_MID = torch.tensor([
-1.0530884977290216*x*(x-1.329989018975412)*(x*x-1.4364028541716351*x+0.7139816917074209),
0.1017*x*x*(x*x-2.1966568338249754*x+1.2949852507374631),
2.490627285651252793*x*x*(x*x-2.38535645472061657*x+1.57803468208092486),
-16.54810288924490272*(x-1.21712927295533244)*(x-0.61620406037800089)*x*x,
47.37952196281928122*(x-1.203071208372362603)*(x-0.658047292653547382)*x*x,
-34.87065786149660974*(x-1.2)*(x-2/3)*x*x,
2.5*(x-1)*(x-0.6)*x*x
], dtype=torch.float64)

class Tsit5Solver(RKAdaptiveStepsizeODESolver):
order = 5
tableau = _TSITOURAS_TABLEAU
mid = TSIT_C_MID

0 comments on commit f3135f3

Please sign in to comment.