Skip to content

Commit bb07f7f

Browse files
committed
changed: NonLinMPC without collect on decision vars.
1 parent a4795ad commit bb07f7f

File tree

1 file changed

+23
-39
lines changed

1 file changed

+23
-39
lines changed

src/controller/nonlinmpc.jl

+23-39
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
304304
Jfunc, gfunc = let mpc=mpc, model=model, ng=ng, nΔŨ=nΔŨ, nŶ=Hp*ny, nx̂=nx̂, nu=nu, nU=Hp*nu
305305
Nc = nΔŨ + 3
306306
last_ΔŨtup_float, last_ΔŨtup_dual = nothing, nothing
307+
ΔŨ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nΔŨ), Nc)
307308
Ŷ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), Nc)
308309
U_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Nc)
309310
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), Nc)
@@ -313,63 +314,46 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
313314
û_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
314315
Ȳ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), Nc)
315316
Ū_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Nc)
316-
function Jfunc(ΔŨtup::JNT...)
317+
function Jfunc(ΔŨtup::T...)::T where T <: Real
317318
ΔŨ1 = ΔŨtup[begin]
318-
= get_tmp(Ŷ_cache, ΔŨ1)
319-
ΔŨ = collect(ΔŨtup)
320-
if ΔŨtup !== last_ΔŨtup_float
321-
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
322-
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
323-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
324-
g = get_tmp(g_cache, ΔŨ1)
325-
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
326-
last_ΔŨtup_float = ΔŨtup
319+
ΔŨ, Ŷ = get_tmp(ΔŨ_cache, ΔŨ1), get_tmp(Ŷ_cache, ΔŨ1)
320+
if T == JNT
321+
isnewvalue = (ΔŨtup !== last_ΔŨtup_float)
322+
isnewvalue && (last_ΔŨtup_float = ΔŨtup)
323+
else
324+
isnewvalue = (ΔŨtup !== last_ΔŨtup_dual)
325+
isnewvalue && (last_ΔŨtup_dual = ΔŨtup)
327326
end
328-
U, Ȳ, Ū = get_tmp(U_cache, ΔŨ1), get_tmp(Ȳ_cache, ΔŨ1), get_tmp(Ū_cache, ΔŨ1)
329-
return obj_nonlinprog!(U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
330-
end
331-
function Jfunc(ΔŨtup::ForwardDiff.Dual...)
332-
ΔŨ1 = ΔŨtup[begin]
333-
= get_tmp(Ŷ_cache, ΔŨ1)
334-
ΔŨ = collect(ΔŨtup)
335-
if ΔŨtup !== last_ΔŨtup_dual
327+
if isnewvalue
336328
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
337329
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
330+
ΔŨ .= ΔŨtup
338331
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
339332
g = get_tmp(g_cache, ΔŨ1)
340333
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
341-
last_ΔŨtup_dual = ΔŨtup
342334
end
343335
U, Ȳ, Ū = get_tmp(U_cache, ΔŨ1), get_tmp(Ȳ_cache, ΔŨ1), get_tmp(Ū_cache, ΔŨ1)
344-
return obj_nonlinprog!(U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
336+
return obj_nonlinprog!(U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)::T
345337
end
346-
function gfunc_i(i, ΔŨtup::NTuple{N, JNT}) where N
338+
function gfunc_i(i, ΔŨtup::NTuple{N, T})::T where {N, T <:Real}
347339
ΔŨ1 = ΔŨtup[begin]
348340
g = get_tmp(g_cache, ΔŨ1)
349-
if ΔŨtup !== last_ΔŨtup_float
350-
= get_tmp(Ŷ_cache, ΔŨ1)
351-
ΔŨ = collect(ΔŨtup)
352-
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
353-
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
354-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
355-
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
356-
last_ΔŨtup_float = ΔŨtup
341+
if T == JNT
342+
isnewvalue = (ΔŨtup !== last_ΔŨtup_float)
343+
isnewvalue && (last_ΔŨtup_float = ΔŨtup)
344+
else
345+
isnewvalue = (ΔŨtup !== last_ΔŨtup_dual)
346+
isnewvalue && (last_ΔŨtup_dual = ΔŨtup)
357347
end
358-
return g[i]
359-
end
360-
function gfunc_i(i, ΔŨtup::NTuple{N, ForwardDiff.Dual}) where N
361-
ΔŨ1 = ΔŨtup[begin]
362-
g = get_tmp(g_cache, ΔŨ1)
363-
if ΔŨtup !== last_ΔŨtup_dual
364-
= get_tmp(Ŷ_cache, ΔŨ1)
365-
ΔŨ = collect(ΔŨtup)
348+
if isnewvalue
349+
ΔŨ, Ŷ = get_tmp(ΔŨ_cache, ΔŨ1), get_tmp(Ŷ_cache, ΔŨ1)
366350
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
367351
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
352+
ΔŨ .= ΔŨtup
368353
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
369354
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
370-
last_ΔŨtup_dual = ΔŨtup
371355
end
372-
return g[i]
356+
return g[i]::T
373357
end
374358
gfunc = [(ΔŨ...) -> gfunc_i(i, ΔŨ) for i in 1:ng]
375359
(Jfunc, gfunc)

0 commit comments

Comments
 (0)