@@ -1008,72 +1008,52 @@ function init_optimization!(
1008
1008
Jfunc, gfunc = let estim= estim, model= model, nZ̃= nZ̃, nV̂= nV̂, nX̂= nX̂, ng= ng, nx̂= nx̂, nu= nu, nŷ= nŷ
1009
1009
Nc = nZ̃ + 3
1010
1010
last_Z̃tup_float, last_Z̃tup_dual = nothing , nothing
1011
+ Z̃_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nZ̃), Nc)
1011
1012
V̂_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nV̂), Nc)
1012
1013
g_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, ng), Nc)
1013
1014
X̂_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nX̂), Nc)
1014
1015
x̄_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nx̂), Nc)
1015
1016
û_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nu), Nc)
1016
1017
ŷ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŷ), Nc)
1017
- function Jfunc (Z̃tup:: JNT ... )
1018
+ function Jfunc (Z̃tup:: T ... ):: T where {T <: Real }
1018
1019
Z̃1 = Z̃tup[begin ]
1019
- V̂ = get_tmp (V̂_cache, Z̃1)
1020
- Z̃ = collect (Z̃tup)
1021
- if Z̃tup != = last_Z̃tup_float
1022
- g = get_tmp (g_cache, Z̃1)
1023
- X̂ = get_tmp (X̂_cache, Z̃1)
1024
- û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1025
- V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1026
- g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1020
+ if T == JNT
1027
1021
last_Z̃tup_float = Z̃tup
1028
- end
1029
- x̄ = get_tmp (x̄_cache, Z̃1)
1030
- return obj_nonlinprog! (x̄, estim, model, V̂, Z̃)
1031
- end
1032
- function Jfunc (Z̃tup:: ForwardDiff.Dual... )
1033
- Z̃1 = Z̃tup[begin ]
1034
- V̂ = get_tmp (V̂_cache, Z̃1)
1035
- Z̃ = collect (Z̃tup)
1036
- if Z̃tup != = last_Z̃tup_dual
1037
- g = get_tmp (g_cache, Z̃1)
1038
- X̂ = get_tmp (X̂_cache, Z̃1)
1039
- û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1040
- V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1041
- g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1022
+ else
1042
1023
last_Z̃tup_dual = Z̃tup
1043
1024
end
1025
+ Z̃, V̂ = get_tmp (Z̃_cache, Z̃1), get_tmp (V̂_cache, Z̃1)
1026
+ X̂ = get_tmp (X̂_cache, Z̃1)
1027
+ û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1028
+ Z̃ .= Z̃tup
1029
+ V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1030
+ g = get_tmp (g_cache, Z̃1)
1031
+ g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1044
1032
x̄ = get_tmp (x̄_cache, Z̃1)
1045
- return obj_nonlinprog! (x̄, estim, model, V̂, Z̃)
1033
+ return obj_nonlinprog! (x̄, estim, model, V̂, Z̃):: T
1046
1034
end
1047
- function gfunc_i (i, Z̃tup:: NTuple{N, JNT} ) where N
1035
+ function gfunc_i (i, Z̃tup:: NTuple{N, T} ) :: T where {N, T <: Real }
1048
1036
Z̃1 = Z̃tup[begin ]
1049
1037
g = get_tmp (g_cache, Z̃1)
1050
- if Z̃tup != = last_Z̃tup_float
1051
- Z̃ = collect (Z̃tup)
1052
- V̂ = get_tmp (V̂_cache, Z̃1)
1053
- X̂ = get_tmp (X̂_cache, Z̃1)
1054
- û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1055
- V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1056
- g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1057
- last_Z̃tup_float = Z̃tup
1038
+ if T == JNT
1039
+ isnewvalue = (Z̃tup != = last_Z̃tup_float)
1040
+ isnewvalue && (last_Z̃tup_float = Z̃tup)
1041
+ else
1042
+ isnewvalue = (Z̃tup != = last_Z̃tup_dual)
1043
+ isnewvalue && (last_Z̃tup_dual = Z̃tup)
1058
1044
end
1059
- return g[i]
1060
- end
1061
- function gfunc_i (i, Z̃tup:: NTuple{N, ForwardDiff.Dual} ) where N
1062
- Z̃1 = Z̃tup[begin ]
1063
- g = get_tmp (g_cache, Z̃1)
1064
- if Z̃tup != = last_Z̃tup_dual
1065
- Z̃ = collect (Z̃tup)
1066
- V̂ = get_tmp (V̂_cache, Z̃1)
1045
+ if isnewvalue
1046
+ Z̃, V̂ = get_tmp (Z̃_cache, Z̃1), get_tmp (V̂_cache, Z̃1)
1067
1047
X̂ = get_tmp (X̂_cache, Z̃1)
1068
1048
û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1049
+ Z̃ .= Z̃tup
1069
1050
V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1070
1051
g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1071
- last_Z̃tup_dual = Z̃tup
1072
1052
end
1073
1053
return g[i]
1074
1054
end
1075
1055
gfunc = [(Z̃... ) -> gfunc_i (i, Z̃) for i in 1 : ng]
1076
- Jfunc, gfunc
1056
+ ( Jfunc, gfunc)
1077
1057
end
1078
1058
register (optim, :Jfunc , nZ̃, Jfunc, autodiff= true )
1079
1059
@NLobjective (optim, Min, Jfunc (Z̃var... ))
0 commit comments