From d23fa129b22e5561d64825245ceaa2dfd7b85d53 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 4 Mar 2021 06:44:25 -0500 Subject: [PATCH] fix bounds translation (#499) * fix bounds translation fixes https://github.com/SciML/DiffEqFlux.jl/issues/498 * patch release --- Project.toml | 2 +- src/train.jl | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index cb0f0ef534..c36dbc9170 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqFlux" uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" authors = ["Chris Rackauckas "] -version = "1.34.0" +version = "1.34.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/train.jl b/src/train.jl index cdad55021a..aa286c05ed 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,6 +1,7 @@ -function sciml_train(loss, θ, opt, adtype::DiffEqBase.AbstractADType = GalacticOptim.AutoZygote(), args...; kwargs...) +function sciml_train(loss, θ, opt, adtype::DiffEqBase.AbstractADType = GalacticOptim.AutoZygote(), args...; + lower_bounds = nothing, upper_bounds = nothing, kwargs...) optf = GalacticOptim.OptimizationFunction((x, p) -> loss(x), adtype) optfunc = GalacticOptim.instantiate_function(optf, θ, adtype, nothing) - optprob = GalacticOptim.OptimizationProblem(optfunc, θ; kwargs...) + optprob = GalacticOptim.OptimizationProblem(optfunc, θ; lb = lower_bounds, ub = upper_bounds, kwargs...) GalacticOptim.solve(optprob, opt, args...; kwargs...) end