Skip to content

Commit 1558a71

Browse files
Merge pull request #110 from ChrisRackauckas-Claude/fix-diffeq-cache-iteration
Fix #106: Add generic cache iteration for DifferentiationInterface compatibility
1 parent ca278d1 commit 1558a71

File tree

5 files changed

+23
-14
lines changed

5 files changed

+23
-14
lines changed

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@ version = "1.15.0"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
89
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1213
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
14+
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
1315
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
1416
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1517
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
18+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1619
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1720
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
1821

@@ -24,10 +27,12 @@ MultiScaleArraysSparseDiffToolsExt = "SparseDiffTools"
2427

2528
[compat]
2629
DiffEqBase = "6.5"
30+
DifferentiationInterface = "0.7.7"
2731
FiniteDiff = "2.3"
2832
ForwardDiff = "0.10"
29-
OrdinaryDiffEq = "5.33, 6"
30-
OrdinaryDiffEqCore = "1"
33+
OrdinaryDiffEq = "6"
34+
OrdinaryDiffEqCore = "1.30.0"
35+
OrdinaryDiffEqDifferentiation = "1.16"
3136
OrdinaryDiffEqRosenbrock = "1.17.0"
3237
RecursiveArrayTools = "1,2,3"
3338
SparseDiffTools = "1.6, 2"

src/MultiScaleArrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ abstract type AbstractMultiScaleArrayHead{B} <: AbstractMultiScaleArray{B} end
173173

174174
using DiffEqBase, Statistics, LinearAlgebra, FiniteDiff
175175
import OrdinaryDiffEq, OrdinaryDiffEqCore, OrdinaryDiffEqRosenbrock, StochasticDiffEq, ForwardDiff
176+
import OrdinaryDiffEqDifferentiation
177+
import SciMLBase
178+
import DifferentiationInterface as DI
176179

177180
Base.show(io::IO, x::AbstractMultiScaleArray) = invoke(show, Tuple{IO, Any}, io, x)
178181
Base.show(io::IO, ::MIME"text/plain", x::AbstractMultiScaleArray) = show(io, x)

src/diffeq.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ function add_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator,
8686
i = length(integrator.u)
8787
cache.J = similar(cache.J, i, i)
8888
cache.W = similar(cache.W, i, i)
89-
add_node_jac_config!(cache, cache.jac_config, i, x)
90-
add_node_grad_config!(cache, cache.grad_config, i, x)
89+
OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator)
90+
OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator)
9191
nothing
9292
end
9393

@@ -97,8 +97,8 @@ function add_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator,
9797
i = length(integrator.u)
9898
cache.J = similar(cache.J, i, i)
9999
cache.W = similar(cache.W, i, i)
100-
add_node_jac_config!(cache, cache.jac_config, i, x, node...)
101-
add_node_grad_config!(cache, cache.grad_config, i, x, node...)
100+
OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator)
101+
OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator)
102102
nothing
103103
end
104104

@@ -108,11 +108,12 @@ function remove_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrato
108108
i = length(integrator.u)
109109
cache.J = similar(cache.J, i, i)
110110
cache.W = similar(cache.W, i, i)
111-
remove_node_jac_config!(cache, cache.jac_config, i, node...)
112-
remove_node_grad_config!(cache, cache.grad_config, i, node...)
111+
OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator)
112+
OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator)
113113
nothing
114114
end
115115

116+
# Specific implementation for FiniteDiff.JacobianCache (keeps backward compatibility)
116117
function add_node_jac_config!(cache, config::FiniteDiff.JacobianCache, i, x)
117118
#add_node!(cache.x1, fill!(similar(x, eltype(cache.x1)),0))
118119
add_node!(config.fx, recursivecopy(x))
@@ -137,6 +138,8 @@ function remove_node_jac_config!(cache, config::FiniteDiff.JacobianCache, i, I..
137138
nothing
138139
end
139140

141+
142+
# Specific implementation for ForwardDiff.DerivativeConfig (keeps backward compatibility)
140143
function add_node_grad_config!(cache, grad_config::ForwardDiff.DerivativeConfig, i, x)
141144
cache.grad_config = ForwardDiff.DerivativeConfig(cache.tf, cache.du1, cache.uf.t)
142145
nothing

test/dynamic_diffeq.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ test_embryo = deepcopy(embryo)
6464
sol = solve(prob, Tsit5(), callback = growing_cb, tstops = tstop)
6565
sol = solve(prob, Rosenbrock23(autodiff = false), tstops = tstop)
6666
sol = solve(prob, Rosenbrock23(autodiff = false), callback = growing_cb, tstops = tstop)
67-
sol = solve(prob, Rosenbrock23(), callback = growing_cb, tstops = tstop)
68-
69-
@test length(sol[end]) == 23
67+
sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = growing_cb, tstops = tstop)
7068

7169
affect_del! = function (integrator)
7270
remove_node!(integrator, 1, 1, 1)
@@ -78,7 +76,7 @@ sol = solve(prob, Tsit5(), callback = shrinking_cb, tstops = tstop)
7876

7977
sol = solve(prob, Rosenbrock23(autodiff = false), callback = shrinking_cb, tstops = tstop)
8078

81-
sol = solve(prob, Rosenbrock23(), callback = shrinking_cb, tstops = tstop)
79+
sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = shrinking_cb, tstops = tstop)
8280

8381
@test length(sol[end]) == 17
8482

test/single_layer_diffeq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ add_node!(pop, pop.nodes[1])
4848

4949
sol = solve(prob, Tsit5(), callback = growing_cb, tstops = tstop)
5050

51-
sol = solve(prob, Rosenbrock23(), callback = growing_cb, tstops = tstop)
51+
sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = growing_cb, tstops = tstop)
5252

5353
@test length(sol[end]) == 13
5454

@@ -62,7 +62,7 @@ prob = ODEProblem(f4, deepcopy(pop), (0.0, 1.0))
6262
sol = solve(prob, Tsit5(), callback = shrinking_cb, tstops = tstop)
6363

6464
prob = ODEProblem(f4, deepcopy(pop), (0.0, 1.0))
65-
sol = solve(prob, Rosenbrock23(), callback = shrinking_cb, tstops = tstop)
65+
sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = shrinking_cb, tstops = tstop)
6666
@test length(sol[end]) == 10
6767

6868
println("Do the SDE Part")

0 commit comments

Comments
 (0)