@@ -589,11 +589,9 @@ function get_optim_functions(
589
589
grad_backend:: AbstractADType ,
590
590
jac_backend :: AbstractADType
591
591
) where JNT<: Real
592
- model, transcription = mpc. estim. model, mpc. transcription
593
- # TODO : fix type of all cache to ::Vector{JNT} (verify performance difference with and w/o)
594
- # TODO : mêmes choses pour le MHE
595
- # --------------------- update simulation function ------------------------------------
596
- function update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
592
+ # ------ update simulation function (all args after `mpc` are mutated) ----------------
593
+ function update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
594
+ model, transcription = mpc. estim. model, mpc. transcription
597
595
U0 = getU0! (U0, mpc, Z̃)
598
596
ΔŨ = getΔŨ! (ΔŨ, mpc, transcription, Z̃)
599
597
Ŷ0, x̂0end = predict! (Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
@@ -605,36 +603,38 @@ function get_optim_functions(
605
603
return nothing
606
604
end
607
605
# ----- common cache for Jfunc, gfuncs, geqfuncs called with floats -------------------
606
+ model = mpc. estim. model
608
607
nu, ny, nx̂, nϵ, Hp, Hc = model. nu, model. ny, mpc. estim. nx̂, mpc. nϵ, mpc. Hp, mpc. Hc
609
608
ng, nc, neq = length (mpc. con. i_g), mpc. con. nc, mpc. con. neq
610
609
nZ̃, nU, nŶ, nX̂ = length (mpc. Z̃), Hp* nu, Hp* ny, Hp* nx̂
611
610
nΔŨ, nUe, nŶe = nu* Hc + nϵ, nU + nu, nŶ + ny
612
- myNaN = convert (JNT, NaN )
613
- Z̃ = fill (myNaN, nZ̃) # NaN to force update_simulations! at first call
614
- ΔŨ = zeros (JNT, nΔŨ)
615
- x̂0end = zeros (JNT, nx̂)
616
- Ue, Ŷe = zeros (JNT, nUe), zeros (JNT, nŶe)
617
- U0, Ŷ0 = zeros (JNT, nU), zeros (JNT, nŶ)
618
- Û0, X̂0 = zeros (JNT, nU), zeros (JNT, nX̂)
619
- gc, g = zeros (JNT, nc), zeros (JNT, ng)
620
- geq = zeros (JNT, neq)
621
- # ---------------------- objective function ------------------------------------------
611
+ myNaN = convert (JNT, NaN ) # NaN to force update_simulations! at first call:
612
+ Z̃ :: Vector{JNT} = fill (myNaN, nZ̃)
613
+ ΔŨ:: Vector{JNT} = zeros (JNT, nΔŨ)
614
+ x̂0end:: Vector{JNT} = zeros (JNT, nx̂)
615
+ Ue:: Vector{JNT} , Ŷe:: Vector{JNT} = zeros (JNT, nUe), zeros (JNT, nŶe)
616
+ U0:: Vector{JNT} , Ŷ0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nŶ)
617
+ Û0:: Vector{JNT} , X̂0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nX̂)
618
+ gc:: Vector{JNT} , g:: Vector{JNT} = zeros (JNT, nc), zeros (JNT, ng)
619
+ geq:: Vector{JNT} = zeros (JNT, neq)
620
+ # ---------------------- objective function -------------------------------------------
622
621
function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
623
622
if isdifferent (Z̃arg, Z̃)
624
623
Z̃ .= Z̃arg
625
- update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
624
+ update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
626
625
end
627
626
return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ):: T
628
627
end
629
- function Jfunc! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
630
- update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
628
+ function Jfunc! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
629
+ update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
631
630
return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
632
631
end
633
632
Z̃_∇J = fill (myNaN, nZ̃)
634
633
∇J_context = (
634
+ Constant (mpc),
635
635
Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
636
636
Cache (Û0), Cache (X̂0),
637
- Cache (gc), Cache (g), Cache (geq)
637
+ Cache (gc), Cache (g), Cache (geq),
638
638
)
639
639
∇J_prep = prepare_gradient (Jfunc!, grad_backend, Z̃_∇J, ∇J_context... )
640
640
∇J = Vector {JNT} (undef, nZ̃)
@@ -657,26 +657,26 @@ function get_optim_functions(
657
657
gfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
658
658
if isdifferent (Z̃arg, Z̃)
659
659
Z̃ .= Z̃arg
660
- update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
660
+ update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
661
661
end
662
662
return g[i]:: T
663
663
end
664
664
gfuncs[i] = gfunc_i
665
665
end
666
- function gfunc! (g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
667
- return update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
666
+ function gfunc! (g, Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
667
+ return update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
668
668
end
669
669
Z̃_∇g = fill (myNaN, nZ̃)
670
670
∇g_context = (
671
+ Constant (mpc),
671
672
Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
672
673
Cache (Û0), Cache (X̂0),
673
- Cache (gc), Cache (geq)
674
+ Cache (gc), Cache (geq),
674
675
)
675
- # temporarily enable all the inequality constraints for sparsity pattern detection:
676
- i_g_old = copy (mpc. con. i_g)
677
- mpc. con. i_g .= true
676
+ # temporarily enable all the inequality constraints for sparsity detection:
677
+ mpc. con. i_g[1 : end - nc] .= true
678
678
∇g_prep = prepare_jacobian (gfunc!, g, jac_backend, Z̃_∇g, ∇g_context... )
679
- mpc. con. i_g .= i_g_old
679
+ mpc. con. i_g[ 1 : end - nc] .= false
680
680
∇g = init_diffmat (JNT, jac_backend, ∇g_prep, nZ̃, ng)
681
681
∇gfuncs! = Vector {Function} (undef, ng)
682
682
for i in eachindex (∇gfuncs!)
@@ -705,17 +705,18 @@ function get_optim_functions(
705
705
geqfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
706
706
if isdifferent (Z̃arg, Z̃)
707
707
Z̃ .= Z̃arg
708
- update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
708
+ update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
709
709
end
710
710
return geq[i]:: T
711
711
end
712
712
geqfuncs[i] = geqfunc_i
713
713
end
714
- function geqfunc! (geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g)
715
- return update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
714
+ function geqfunc! (geq, Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g)
715
+ return update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
716
716
end
717
717
Z̃_∇geq = fill (myNaN, nZ̃)
718
718
∇geq_context = (
719
+ Constant (mpc),
719
720
Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
720
721
Cache (Û0), Cache (X̂0),
721
722
Cache (gc), Cache (g)
0 commit comments