Skip to content

Commit f055519

Browse files
authored
Merge pull request #139 from SciML/ap/defaults
Default to using SimpleGMRES for the backward pass
2 parents 8b1d48d + 64f84a1 commit f055519

4 files changed

+60
-12
lines changed

Diff for: Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DeepEquilibriumNetworks"
22
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
authors = ["Avik Pal <[email protected]>"]
4-
version = "2.0.0"
4+
version = "2.0.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -18,11 +18,12 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
1818
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
1919

2020
[weakdeps]
21+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2122
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2223
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2324

2425
[extensions]
25-
DeepEquilibriumNetworksSciMLSensitivityExt = "SciMLSensitivity"
26+
DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
2627
DeepEquilibriumNetworksZygoteExt = "Zygote"
2728

2829
[compat]
@@ -32,6 +33,7 @@ ConcreteStructs = "0.2"
3233
ConstructionBase = "1"
3334
DiffEqBase = "6.119"
3435
LinearAlgebra = "1"
36+
LinearSolve = "2.21.2"
3537
Lux = "0.5.11"
3638
Random = "1"
3739
SciMLBase = "2"

Diff for: docs/src/api.md

+38
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,44 @@ To construct a continuous DEQ, any ODE solver compatible with `DifferentialEquat
99
can be passed as the solver. To construct a discrete DEQ, any root finding algorithm
1010
compatible with `NonlinearSolve.jl` API can be passed as the solver.
1111

12+
## Choosing a Solver
13+
14+
### Root Finding Algorithms
15+
16+
Using Root Finding Algorithms give fast convergence when possible, but these methods also
17+
tend to be unstable. If you must use a root finding algorithm, we recommend using:
18+
19+
1. `NewtonRaphson` or `TrustRegion` for small models
20+
2. `LimitedMemoryBroyden` for large Deep Learning applications (with well-conditioned
21+
Jacobians)
22+
3. `NewtonRaphson(; linsolve = KrylovJL_GMRES())` for cases when Broyden methods fail
23+
24+
Note that Krylov Methods rely on efficient VJPs which are not available for all Lux models.
25+
If you think this is causing a performance regression, please open an issue in
26+
[Lux.jl](https://github.com/LuxDL/Lux.jl).
27+
28+
### ODE Solvers
29+
30+
Using ODE Solvers give slower convergence, but are more stable. We generally recommend these
31+
methods over root finding algorithms. If you use implicit ODE solvers, remember to use
32+
Krylov linear solvers, see OrdinaryDiffEq.jl documentation for these. For most cases, we
33+
recommend:
34+
35+
1. `VCAB3()` for high tolerance problems
36+
2. `Tsit5()` for high tolerance problems where `VCAB3()` fails
37+
3. In all other cases, follow the recommendation given in [OrdinaryDiffEq.jl](https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/#ode_solve) documentation
38+
39+
### Sensitivity Analysis
40+
41+
1. For `MultiScaleNeuralODE`, we default to `GaussAdjoint(; autojacvec = ZygoteVJP())`. A
42+
faster alternative would be `BacksolveAdjoint(; autojacvec = ZygoteVJP())` but there are
43+
stability concerns for using that. Follow the recommendation given in [SciMLSensitivity.jl](https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#Choosing-a-Sensitivity-Algorithm) documentation.
44+
2. For Steady State Problems, we default to
45+
`SteadyStateAdjoint(; linsolve = SimpleGMRES(; blocksize, linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)))`.
46+
This default will perform poorly on small models. It is recommended to pass
47+
`sensealg = SteadyStateAdjoint()` or
48+
`sensealg = SteadyStateAdjoint(; linsolve = LUFactorization())` for small models.
49+
1250
## Standard Models
1351

1452
```@docs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt
2+
3+
# Linear Solve is a dependency of SciMLSensitivity, so we only need to load SciMLSensitivity
4+
# to load this extension
5+
using LinearSolve, SciMLBase, SciMLSensitivity
6+
import DeepEquilibriumNetworks: __default_sensealg
7+
8+
@inline function __default_sensealg(prob::SteadyStateProblem)
9+
# We want to avoid the cost for cache construction for linsolve = nothing
10+
# For small problems we should use concrete jacobian but we assume users want to solve
11+
# large problems with this package so we default to GMRES and avoid runtime dispatches
12+
linsolve = SimpleGMRES{true}(; blocksize=prod(size(prob.u0)[1:(end - 1)]))
13+
linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)
14+
return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP())
15+
end
16+
@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())
17+
18+
end

Diff for: ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl

-10
This file was deleted.

0 commit comments

Comments
 (0)