diff --git a/Project.toml b/Project.toml index e07b51c..ca5cb02 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.4.0" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [compat] Adapt = "3" @@ -15,15 +16,19 @@ ForwardDiff = "0.10.3" julia = "1.6" [extras] +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL"] +test = ["FiniteDiff", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "Random", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SciMLSensitivity", "Zygote"] diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index b0c5334..8f32a14 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -1,6 +1,7 @@ module PreallocationTools using ForwardDiff, ArrayInterfaceCore, Adapt +import ReverseDiff struct DiffCache{T <: AbstractArray, S <: AbstractArray} du::T @@ -87,7 +88,17 @@ function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} s = b.sizemap(size(u)) # required buffer size buf = get!(b.bufs, (T, s)) do similar(u, s) # buffer to allocate if it was not found in b.bufs - end::T # declare type since b.bufs dictionary is untyped + end::T # declare type since b.bufs dictionary is untyped + return buf +end + +function Base.getindex(b::LazyBufferCache, u::ReverseDiff.TrackedArray) + s = b.sizemap(size(u)) # required buffer size + T = ReverseDiff.TrackedArray + buf = get!(b.bufs, (T, s)) do + # declare type since b.bufs dictionary is untyped + similar(u, s) + end return buf end diff --git a/test/runtests.jl b/test/runtests.jl index 75acbce..899cd6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "ODE tests" begin include("core_odes.jl") end @safetestset "Resizing" begin include("core_resizing.jl") end @safetestset "Nested Duals" begin include("core_nesteddual.jl") end + @safetestset "ODE Sensitivity analysis" begin include("upstream/sensitivity_analysis.jl") end end if !is_APPVEYOR && GROUP == "GPU" diff --git a/test/upstream/sensitivity_analysis.jl b/test/upstream/sensitivity_analysis.jl new file mode 100644 index 0000000..c43d8de --- /dev/null +++ b/test/upstream/sensitivity_analysis.jl @@ -0,0 +1,44 @@ +using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools +using Random, FiniteDiff, ForwardDiff, ReverseDiff, SciMLSensitivity, Zygote + +# see https://github.com/SciML/PreallocationTools.jl/issues/29 +@testset "VJP computation with LazyBuffer" begin + u0 = rand(2, 2) + p = rand(2, 2) + struct foo{T} + lbc::T + end + + f = foo(LazyBufferCache()) + + function (f::foo)(du, u, p, t) + tmp = f.lbc[u] + mul!(tmp, p, u) # avoid tmp = p*u + @. du = u + tmp + nothing + end + + prob = ODEProblem(f, u0, (0.0, 1.0), p) + + function loss(u0, p; sensealg = nothing) + _prob = remake(prob, u0 = u0, p = p) + _sol = solve(_prob, Tsit5(), sensealg = sensealg, saveat = 0.1, abstol = 1e-14, + reltol = 1e-14) + sum(abs2, _sol) + end + + loss(u0, p) + + du0 = FiniteDiff.finite_difference_gradient(u0 -> loss(u0, p), u0) + dp = FiniteDiff.finite_difference_gradient(p -> loss(u0, p), p) + Fdu0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0) + Fdp = ForwardDiff.gradient(p -> loss(u0, p), p) + @test du0≈Fdu0 rtol=1e-8 + @test dp≈Fdp rtol=1e-8 + + Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p; + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())), + u0, p) + @test du0≈Zdu0 rtol=1e-8 + @test dp≈Zdp rtol=1e-8 +end