From 9825fc701a4c5f8eef1aa29ab4f319cdfe88f473 Mon Sep 17 00:00:00 2001 From: ArnoStrouwen Date: Sat, 12 Nov 2022 07:19:48 +0100 Subject: [PATCH] non mutating gradient --- src/gradients.jl | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/gradients.jl b/src/gradients.jl index 3ed61fb..f4745d1 100644 --- a/src/gradients.jl +++ b/src/gradients.jl @@ -130,7 +130,12 @@ function finite_difference_gradient( end end cache = GradientCache(df, x, fdtype, returntype, inplace) - finite_difference_gradient!(df, f, x, cache, relstep=relstep, absstep=absstep, dir=dir) + if typeof(x) <: AbstractArray && fdtype == Val(:central) + df = finite_difference_gradient(f, x, cache, relstep=relstep, absstep=absstep, dir=dir) + else + df = finite_difference_gradient!(df,f, x, cache, relstep=relstep, absstep=absstep, dir=dir) + end + return df end function finite_difference_gradient!( @@ -169,6 +174,29 @@ end # vector of derivatives of a vector->scalar map by each component of a vector x # this ignores the value of "inplace", because it doesn't make much sense +function finite_difference_gradient( + f, + x::AbstractVector{<:Number}, + cache::GradientCache{T1,T2,T3,T4,fdtype,returntype,inplace}; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + dir=true) where {T1,T2,T3,T4,fdtype,returntype,inplace} + + df = deepcopy(x) # how to get correct output type here cache.returntype + + if fdtype == Val(:central) + @inbounds for i ∈ eachindex(x) + epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir) + c1 = f(setindex(x,x[i]+epsilon ,i)) + c2 = f(setindex(x,x[i]-epsilon ,i)) + dfi = (c1-c2)/(2epsilon) + df = setindex(df,dfi,i) + end + else + fdtype_error(returntype) + end + df +end function finite_difference_gradient!( df, f,