Skip to content

Commit 0e2b855

Browse files
committedMar 18, 2025·
added: mpc and estim object as DI.Constant
1 parent 6f72934 commit 0e2b855

File tree

2 files changed

+54
-51
lines changed

2 files changed

+54
-51
lines changed
 

‎src/controller/nonlinmpc.jl

+31-30
Original file line numberDiff line numberDiff line change
@@ -589,11 +589,9 @@ function get_optim_functions(
589589
grad_backend::AbstractADType,
590590
jac_backend ::AbstractADType
591591
) where JNT<:Real
592-
model, transcription = mpc.estim.model, mpc.transcription
593-
#TODO: fix type of all cache to ::Vector{JNT} (verify performance difference with and w/o)
594-
#TODO: mêmes choses pour le MHE
595-
# --------------------- update simulation function ------------------------------------
596-
function update_simulations!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
592+
# ------ update simulation function (all args after `mpc` are mutated) ----------------
593+
function update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
594+
model, transcription = mpc.estim.model, mpc.transcription
597595
U0 = getU0!(U0, mpc, Z̃)
598596
ΔŨ = getΔŨ!(ΔŨ, mpc, transcription, Z̃)
599597
Ŷ0, x̂0end = predict!(Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
@@ -605,36 +603,38 @@ function get_optim_functions(
605603
return nothing
606604
end
607605
# ----- common cache for Jfunc, gfuncs, geqfuncs called with floats -------------------
606+
model = mpc.estim.model
608607
nu, ny, nx̂, nϵ, Hp, Hc = model.nu, model.ny, mpc.estim.nx̂, mpc.nϵ, mpc.Hp, mpc.Hc
609608
ng, nc, neq = length(mpc.con.i_g), mpc.con.nc, mpc.con.neq
610609
nZ̃, nU, nŶ, nX̂ = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂
611610
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
612-
myNaN = convert(JNT, NaN)
613-
= fill(myNaN, nZ̃) # NaN to force update_simulations! at first call
614-
ΔŨ = zeros(JNT, nΔŨ)
615-
x̂0end = zeros(JNT, nx̂)
616-
Ue, Ŷe = zeros(JNT, nUe), zeros(JNT, nŶe)
617-
U0, Ŷ0 = zeros(JNT, nU), zeros(JNT, nŶ)
618-
Û0, X̂0 = zeros(JNT, nU), zeros(JNT, nX̂)
619-
gc, g = zeros(JNT, nc), zeros(JNT, ng)
620-
geq = zeros(JNT, neq)
621-
# ---------------------- objective function ------------------------------------------
611+
myNaN = convert(JNT, NaN) # NaN to force update_simulations! at first call:
612+
::Vector{JNT} = fill(myNaN, nZ̃)
613+
ΔŨ::Vector{JNT} = zeros(JNT, nΔŨ)
614+
x̂0end::Vector{JNT} = zeros(JNT, nx̂)
615+
Ue::Vector{JNT}, Ŷe::Vector{JNT} = zeros(JNT, nUe), zeros(JNT, nŶe)
616+
U0::Vector{JNT}, Ŷ0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nŶ)
617+
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
618+
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
619+
geq::Vector{JNT} = zeros(JNT, neq)
620+
# ---------------------- objective function -------------------------------------------
622621
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
623622
if isdifferent(Z̃arg, Z̃)
624623
Z̃ .= Z̃arg
625-
update_simulations!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
624+
update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
626625
end
627626
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)::T
628627
end
629-
function Jfunc!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
630-
update_simulations!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
628+
function Jfunc!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
629+
update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
631630
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
632631
end
633632
Z̃_∇J = fill(myNaN, nZ̃)
634633
∇J_context = (
634+
Constant(mpc),
635635
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
636636
Cache(Û0), Cache(X̂0),
637-
Cache(gc), Cache(g), Cache(geq)
637+
Cache(gc), Cache(g), Cache(geq),
638638
)
639639
∇J_prep = prepare_gradient(Jfunc!, grad_backend, Z̃_∇J, ∇J_context...)
640640
∇J = Vector{JNT}(undef, nZ̃)
@@ -657,26 +657,26 @@ function get_optim_functions(
657657
gfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
658658
if isdifferent(Z̃arg, Z̃)
659659
Z̃ .= Z̃arg
660-
update_simulations!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
660+
update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
661661
end
662662
return g[i]::T
663663
end
664664
gfuncs[i] = gfunc_i
665665
end
666-
function gfunc!(g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
667-
return update_simulations!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
666+
function gfunc!(g, Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
667+
return update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
668668
end
669669
Z̃_∇g = fill(myNaN, nZ̃)
670670
∇g_context = (
671+
Constant(mpc),
671672
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
672673
Cache(Û0), Cache(X̂0),
673-
Cache(gc), Cache(geq)
674+
Cache(gc), Cache(geq),
674675
)
675-
# temporarily enable all the inequality constraints for sparsity pattern detection:
676-
i_g_old = copy(mpc.con.i_g)
677-
mpc.con.i_g .= true
676+
# temporarily enable all the inequality constraints for sparsity detection:
677+
mpc.con.i_g[1:end-nc] .= true
678678
∇g_prep = prepare_jacobian(gfunc!, g, jac_backend, Z̃_∇g, ∇g_context...)
679-
mpc.con.i_g .= i_g_old
679+
mpc.con.i_g[1:end-nc] .= false
680680
∇g = init_diffmat(JNT, jac_backend, ∇g_prep, nZ̃, ng)
681681
∇gfuncs! = Vector{Function}(undef, ng)
682682
for i in eachindex(∇gfuncs!)
@@ -705,17 +705,18 @@ function get_optim_functions(
705705
geqfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
706706
if isdifferent(Z̃arg, Z̃)
707707
Z̃ .= Z̃arg
708-
update_simulations!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
708+
update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
709709
end
710710
return geq[i]::T
711711
end
712712
geqfuncs[i] = geqfunc_i
713713
end
714-
function geqfunc!(geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g)
715-
return update_simulations!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
714+
function geqfunc!(geq, Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g)
715+
return update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
716716
end
717717
Z̃_∇geq = fill(myNaN, nZ̃)
718718
∇geq_context = (
719+
Constant(mpc),
719720
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
720721
Cache(Û0), Cache(X̂0),
721722
Cache(gc), Cache(g)

‎src/estimator/mhe/construct.jl

+23-21
Original file line numberDiff line numberDiff line change
@@ -1332,37 +1332,39 @@ function get_optim_functions(
13321332
grad_backend::AbstractADType,
13331333
jac_backend::AbstractADType
13341334
) where {JNT <: Real}
1335-
model, con = estim.model, estim.con
1336-
# --------------------- update simulation function ------------------------------------
1337-
function update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1335+
# -------- update simulation function (all args after `estim` are mutated) ------------
1336+
function update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
1337+
model = estim.model
13381338
V̂, X̂0 = predict!(V̂, X̂0, û0, ŷ0, estim, model, Z̃)
13391339
ϵ = getϵ(estim, Z̃)
13401340
g = con_nonlinprog!(g, estim, model, X̂0, V̂, ϵ)
13411341
return nothing
13421342
end
13431343
# ---------- common cache for Jfunc, gfuncs called with floats ------------------------
1344+
model, con = estim.model, estim.con
13441345
nx̂, nym, nŷ, nu, nϵ, He = estim.nx̂, estim.nym, model.ny, model.nu, estim.nϵ, estim.He
13451346
nV̂, nX̂, ng, nZ̃ = He*nym, He*nx̂, length(con.i_g), length(estim.Z̃)
1346-
myNaN = convert(JNT, NaN)
1347-
= fill(myNaN, nZ̃) # NaN to force update_simulations! at first call
1348-
, X̂0 = zeros(JNT, nV̂), zeros(JNT, nX̂)
1349-
û0, ŷ0 = zeros(JNT, nu), zeros(JNT, nŷ)
1350-
g = zeros(JNT, ng)
1351-
= zeros(JNT, nx̂)
1347+
myNaN = convert(JNT, NaN) # NaN to force update_simulations! at first call
1348+
::Vector{JNT} = fill(myNaN, nZ̃)
1349+
::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nV̂), zeros(JNT, nX̂)
1350+
û0::Vector{JNT}, ŷ0::Vector{JNT} = zeros(JNT, nu), zeros(JNT, nŷ)
1351+
g::Vector{JNT} = zeros(JNT, ng)
1352+
::Vector{JNT} = zeros(JNT, nx̂)
13521353
# --------------------- objective functions -------------------------------------------
13531354
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
13541355
if isdifferent(Z̃arg, Z̃)
13551356
Z̃ .= Z̃arg
1356-
update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1357+
update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
13571358
end
13581359
return obj_nonlinprog!(x̄, estim, model, V̂, Z̃)::T
13591360
end
1360-
function Jfunc!(Z̃, V̂, X̂0, û0, ŷ0, g, x̄)
1361-
update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1361+
function Jfunc!(Z̃, estim, V̂, X̂0, û0, ŷ0, g, x̄)
1362+
update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
13621363
return obj_nonlinprog!(x̄, estim, model, V̂, Z̃)
13631364
end
13641365
Z̃_∇J = fill(myNaN, nZ̃)
13651366
∇J_context = (
1367+
Constant(estim),
13661368
Cache(V̂), Cache(X̂0),
13671369
Cache(û0), Cache(ŷ0),
13681370
Cache(g),
@@ -1389,27 +1391,27 @@ function get_optim_functions(
13891391
gfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
13901392
if isdifferent(Z̃arg, Z̃)
13911393
Z̃ .= Z̃arg
1392-
update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1394+
update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
13931395
end
13941396
return g[i]::T
13951397
end
13961398
gfuncs[i] = gfunc_i
13971399
end
1398-
function gfunc!(g, Z̃, V̂, X̂0, û0, ŷ0)
1399-
return update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1400+
function gfunc!(g, Z̃, estim, V̂, X̂0, û0, ŷ0)
1401+
return update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
14001402
end
14011403
Z̃_∇g = fill(myNaN, nZ̃)
14021404
∇g_context = (
1405+
Constant(estim),
14031406
Cache(V̂), Cache(X̂0),
14041407
Cache(û0), Cache(ŷ0),
14051408
)
1406-
# temporarily enable all the inequality constraints for sparsity pattern detection:
1407-
i_g_old = copy(estim.con.i_g)
1408-
estim.con.i_g .= true
1409-
estim.Nk .= estim.He
1409+
# temporarily enable all the inequality constraints for sparsity detection:
1410+
estim.con.i_g .= true
1411+
estim.Nk[] = He
14101412
∇g_prep = prepare_jacobian(gfunc!, g, jac_backend, Z̃_∇g, ∇g_context...)
1411-
estim.con.i_g .= i_g_old
1412-
estim.Nk .= 0
1413+
estim.con.i_g .= false
1414+
estim.Nk[] = 0
14131415
∇g = init_diffmat(JNT, jac_backend, ∇g_prep, nZ̃, ng)
14141416
∇gfuncs! = Vector{Function}(undef, ng)
14151417
for i in eachindex(∇gfuncs!)

0 commit comments

Comments
 (0)