From b48fe60ddd7d45c8d0f377cea6cd962a93b7fff6 Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Thu, 4 Aug 2022 15:25:09 -0400 Subject: [PATCH 1/5] Change type annotation and `similar` to `zero` --- src/PreallocationTools.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index b0c5334..5b2f68a 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -86,8 +86,9 @@ end function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} s = b.sizemap(size(u)) # required buffer size buf = get!(b.bufs, (T, s)) do - similar(u, s) # buffer to allocate if it was not found in b.bufs - end::T # declare type since b.bufs dictionary is untyped + # declare type since b.bufs dictionary is untyped + zero(u)::T # buffer to allocate if it was not found in b.bufs + end return buf end From 0363dd1545226682ee8b93635e22bf9b251f9e52 Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Tue, 9 Aug 2022 15:05:14 -0400 Subject: [PATCH 2/5] add specialization for TrackedArray --- src/PreallocationTools.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index 5b2f68a..20fb4ae 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -84,10 +84,18 @@ end # override the [] method function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} + s = b.sizemap(size(u)) # required buffer size + buf = get!(b.bufs, (T, s)) do + similar(u, s) # buffer to allocate if it was not found in b.bufs + end::T # declare type since b.bufs dictionary is untyped + return buf +end + +function Base.getindex(b::LazyBufferCache, u::ReverseDiff.TrackedArray) s = b.sizemap(size(u)) # required buffer size buf = get!(b.bufs, (T, s)) do # declare type since b.bufs dictionary is untyped - zero(u)::T # buffer to allocate if it was not found in b.bufs + similar(u, s)::T # buffer to allocate if it was not found in b.bufs end return buf end From 1c90fdcc7f3451fc34011a4525528b747361b9cd Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Tue, 9 Aug 2022 16:54:36 -0400 Subject: [PATCH 3/5] add ReverseDiff as dependency --- Project.toml | 1 + src/PreallocationTools.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index e07b51c..879fb2d 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.4.0" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [compat] Adapt = "3" diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index 20fb4ae..200ad3e 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -1,6 +1,7 @@ module PreallocationTools using ForwardDiff, ArrayInterfaceCore, Adapt +import ReverseDiff struct DiffCache{T <: AbstractArray, S <: AbstractArray} du::T From 9bdbec91faa04ab997e2e2978cedcfb239d398cc Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Tue, 9 Aug 2022 17:58:06 -0400 Subject: [PATCH 4/5] add upstream test? --- Project.toml | 5 ++- src/PreallocationTools.jl | 3 +- test/runtests.jl | 1 + test/upstream/sensitivity_analysis.jl | 44 +++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 test/upstream/sensitivity_analysis.jl diff --git a/Project.toml b/Project.toml index 879fb2d..ceea9da 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ ForwardDiff = "0.10.3" julia = "1.6" [extras] +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -24,7 +25,9 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL"] +test = ["FiniteDiff", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SciMLSensitivity", "Zygote"] diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index 200ad3e..8f32a14 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -94,9 +94,10 @@ end function Base.getindex(b::LazyBufferCache, u::ReverseDiff.TrackedArray) s = b.sizemap(size(u)) # required buffer size + T = ReverseDiff.TrackedArray buf = get!(b.bufs, (T, s)) do # declare type since b.bufs dictionary is untyped - similar(u, s)::T # buffer to allocate if it was not found in b.bufs + similar(u, s) end return buf end diff --git a/test/runtests.jl b/test/runtests.jl index 75acbce..899cd6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "ODE tests" begin include("core_odes.jl") end @safetestset "Resizing" begin include("core_resizing.jl") end @safetestset "Nested Duals" begin include("core_nesteddual.jl") end + @safetestset "ODE Sensitivity analysis" begin include("upstream/sensitivity_analysis.jl") end end if !is_APPVEYOR && GROUP == "GPU" diff --git a/test/upstream/sensitivity_analysis.jl b/test/upstream/sensitivity_analysis.jl new file mode 100644 index 0000000..c43d8de --- /dev/null +++ b/test/upstream/sensitivity_analysis.jl @@ -0,0 +1,44 @@ +using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools +using Random, FiniteDiff, ForwardDiff, ReverseDiff, SciMLSensitivity, Zygote + +# see https://github.com/SciML/PreallocationTools.jl/issues/29 +@testset "VJP computation with LazyBuffer" begin + u0 = rand(2, 2) + p = rand(2, 2) + struct foo{T} + lbc::T + end + + f = foo(LazyBufferCache()) + + function (f::foo)(du, u, p, t) + tmp = f.lbc[u] + mul!(tmp, p, u) # avoid tmp = p*u + @. du = u + tmp + nothing + end + + prob = ODEProblem(f, u0, (0.0, 1.0), p) + + function loss(u0, p; sensealg = nothing) + _prob = remake(prob, u0 = u0, p = p) + _sol = solve(_prob, Tsit5(), sensealg = sensealg, saveat = 0.1, abstol = 1e-14, + reltol = 1e-14) + sum(abs2, _sol) + end + + loss(u0, p) + + du0 = FiniteDiff.finite_difference_gradient(u0 -> loss(u0, p), u0) + dp = FiniteDiff.finite_difference_gradient(p -> loss(u0, p), p) + Fdu0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0) + Fdp = ForwardDiff.gradient(p -> loss(u0, p), p) + @test du0≈Fdu0 rtol=1e-8 + @test dp≈Fdp rtol=1e-8 + + Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p; + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())), + u0, p) + @test du0≈Zdu0 rtol=1e-8 + @test dp≈Zdp rtol=1e-8 +end From fab0ffd106518b95f9466937b5d842402052c9dc Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Wed, 10 Aug 2022 16:09:13 -0400 Subject: [PATCH 5/5] add Random --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ceea9da..ca5cb02 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" @@ -30,4 +31,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["FiniteDiff", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SciMLSensitivity", "Zygote"] +test = ["FiniteDiff", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "Random", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SciMLSensitivity", "Zygote"]