Skip to content

Commit 0493147

Browse files
authored
Merge pull request #44 from JuliaControl/reduce_alloc_fhat
Reduce allocation for estimators based on augmented `NonLinModel`
2 parents 3ca6aea + 2739ce8 commit 0493147

File tree

8 files changed

+87
-86
lines changed

8 files changed

+87
-86
lines changed

src/controller/execute.jl

+16-23
Original file line numberDiff line numberDiff line change
@@ -103,24 +103,25 @@ julia> round.(getinfo(mpc)[:Ŷ], digits=3)
103103
```
104104
"""
105105
function getinfo(mpc::PredictiveController{NT}) where NT<:Real
106+
model = mpc.estim.model
106107
info = Dict{Symbol, Union{JuMP._SolutionSummary, Vector{NT}, NT}}()
107-
Ŷ, u = similar(mpc.Ŷop), similar(mpc.estim.lastu0)
108+
Ŷ, u, û = similar(mpc.Ŷop), similar(model.uop), similar(model.uop)
108109
x̂, x̂next = similar(mpc.estim.x̂), similar(mpc.estim.x̂)
109-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, mpc, mpc.estim.model, mpc.ΔŨ)
110+
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, , mpc, model, mpc.ΔŨ)
110111
U = mpc.*mpc.ΔŨ + mpc.T_lastu
111112
Ȳ, Ū = similar(Ŷ), similar(U)
112-
J = obj_nonlinprog!(U, Ȳ, Ū, mpc, mpc.estim.model, Ŷ, mpc.ΔŨ)
113-
info[:ΔU] = mpc.ΔŨ[1:mpc.Hc*mpc.estim.model.nu]
113+
J = obj_nonlinprog!(U, Ȳ, Ū, mpc, model, Ŷ, mpc.ΔŨ)
114+
info[:ΔU] = mpc.ΔŨ[1:mpc.Hc*model.nu]
114115
info[] = isinf(mpc.C) ? NaN : mpc.ΔŨ[end]
115116
info[:J] = J
116117
info[:U] = U
117-
info[:u] = info[:U][1:mpc.estim.model.nu]
118-
info[:d] = mpc.d0 + mpc.estim.model.dop
118+
info[:u] = info[:U][1:model.nu]
119+
info[:d] = mpc.d0 + model.dop
119120
info[:D̂] = mpc.D̂0 + mpc.Dop
120121
info[:ŷ] = mpc.
121122
info[:Ŷ] =
122123
info[:x̂end] =end
123-
info[:Ŷs] = mpc.Ŷop - repeat(mpc.estim.model.yop, mpc.Hp) # Ŷop = Ŷs + Yop
124+
info[:Ŷs] = mpc.Ŷop - repeat(model.yop, mpc.Hp) # Ŷop = Ŷs + Yop
124125
info[:R̂y] = mpc.R̂y
125126
info[:R̂u] = mpc.R̂u
126127
info = addinfo!(info, mpc)
@@ -296,16 +297,14 @@ function linconstraint!(mpc::PredictiveController, ::SimModel)
296297
end
297298

298299
@doc raw"""
299-
predict!(Ŷ, x̂, _ , _ , mpc::PredictiveController, model::LinModel, ΔŨ) -> Ŷ, x̂end
300+
predict!(Ŷ, x̂, _ , _ , _ , mpc::PredictiveController, model::LinModel, ΔŨ) -> Ŷ, x̂end
300301
301302
Compute the predictions `Ŷ` and terminal states `x̂end` if model is a [`LinModel`](@ref).
302303
303304
The method mutates `Ŷ` and `x̂` vector arguments. The `x̂end` vector is used for
304305
the terminal constraints applied on ``\mathbf{x̂}_{k-1}(k+H_p)``.
305306
"""
306-
function predict!(
307-
Ŷ, x̂, _ , _ , mpc::PredictiveController, ::LinModel, ΔŨ::Vector{NT}
308-
) where {NT<:Real}
307+
function predict!(Ŷ, x̂, _ , _ , _ , mpc::PredictiveController, ::LinModel, ΔŨ)
309308
# in-place operations to reduce allocations :
310309
Ŷ .= mul!(Ŷ, mpc.Ẽ, ΔŨ) .+ mpc.F
311310
x̂ .= mul!(x̂, mpc.con.ẽx̂, ΔŨ) .+ mpc.con.fx̂
@@ -314,15 +313,13 @@ function predict!(
314313
end
315314

316315
@doc raw"""
317-
predict!(Ŷ, x̂, x̂next, u, mpc::PredictiveController, model::SimModel, ΔŨ) -> Ŷ, x̂end
316+
predict!(Ŷ, x̂, x̂next, u, û, mpc::PredictiveController, model::SimModel, ΔŨ) -> Ŷ, x̂end
318317
319318
Compute both vectors if `model` is not a [`LinModel`](@ref).
320319
321-
The method mutates `Ŷ`, `x̂`, `x̂next` and `u` arguments.
320+
The method mutates `Ŷ`, `x̂`, `x̂next`, `u` and `` arguments.
322321
"""
323-
function predict!(
324-
Ŷ, x̂, x̂next, u, mpc::PredictiveController, model::SimModel, ΔŨ::Vector{NT}
325-
) where {NT<:Real}
322+
function predict!(Ŷ, x̂, x̂next, u, û, mpc::PredictiveController, model::SimModel, ΔŨ)
326323
nu, ny, nd, Hp, Hc = model.nu, model.ny, model.nd, mpc.Hp, mpc.Hc
327324
u0 = u
328325
x̂ .= mpc.estim.
@@ -332,7 +329,7 @@ function predict!(
332329
if j Hc
333330
u0 .+= @views ΔŨ[(1 + nu*(j-1)):(nu*j)]
334331
end
335-
f̂!(x̂next, mpc.estim, model, x̂, u0, d0)
332+
f̂!(x̂next, û, mpc.estim, model, x̂, u0, d0)
336333
x̂ .= x̂next
337334
d0 = @views mpc.D̂0[(1 + nd*(j-1)):(nd*j)]
338335
= @views Ŷ[(1 + ny*(j-1)):(ny*j)]
@@ -352,9 +349,7 @@ The function is called by the nonlinear optimizer of [`NonLinMPC`](@ref) control
352349
also be called on any [`PredictiveController`](@ref)s to evaluate the objective function `J`
353350
at specific input increments `ΔŨ` and predictions `Ŷ` values. It mutates the `U` argument.
354351
"""
355-
function obj_nonlinprog!(
356-
U , _ , _ , mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ::Vector{NT}
357-
) where {NT<:Real}
352+
function obj_nonlinprog!(U, _ , _ , mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ)
358353
J = obj_quadprog(ΔŨ, mpc.H̃, mpc.q̃) + mpc.p[]
359354
if !iszero(mpc.E)
360355
U .= mul!(U, mpc.S̃, ΔŨ) .+ mpc.T_lastu
@@ -373,9 +368,7 @@ function `dot(x, A, x)` is a performant way of calculating `x'*A*x`. This method
373368
`U`, `Ȳ` and `Ū` arguments (input over `Hp`, and output and input setpoint tracking error,
374369
respectively).
375370
"""
376-
function obj_nonlinprog!(
377-
U, Ȳ, Ū, mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ::Vector{NT}
378-
) where {NT<:Real}
371+
function obj_nonlinprog!(U, Ȳ, Ū, mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ)
379372
# --- output setpoint tracking term ---
380373
Ȳ .= mpc.R̂y .-
381374
JR̂y = dot(Ȳ, mpc.M_Hp, Ȳ)

src/controller/explicitmpc.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,7 @@ The solution is ``\mathbf{ΔŨ = - H̃^{-1} q̃}``, see [`init_quadprog`](@ref)
197197
optim_objective!(mpc::ExplicitMPC) = lmul!(-1, ldiv!(mpc.ΔŨ, mpc.H̃_chol, mpc.q̃))
198198

199199
"Compute the predictions but not the terminal states if `mpc` is an [`ExplicitMPC`](@ref)."
200-
function predict!(
201-
Ŷ, x̂, _ , _ , mpc::ExplicitMPC, ::LinModel, ΔŨ::Vector{NT}
202-
) where {NT<:Real}
200+
function predict!(Ŷ, x̂, _ , _ , _ , mpc::ExplicitMPC, ::LinModel, ΔŨ)
203201
# in-place operations to reduce allocations :
204202
Ŷ .= mul!(Ŷ, mpc.Ẽ, ΔŨ) .+ mpc.F
205203
x̂ .= NaN

src/controller/nonlinmpc.jl

+9-8
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
310310
x̂_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Nc)
311311
x̂next_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Nc)
312312
u_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
313+
û_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
313314
Ȳ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), Nc)
314315
Ū_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Nc)
315316
function Jfunc(ΔŨtup::JNT...)
@@ -318,8 +319,8 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
318319
ΔŨ = collect(ΔŨtup)
319320
if ΔŨtup !== last_ΔŨtup_float
320321
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
321-
u = get_tmp(u_cache, ΔŨ1)
322-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, mpc, model, ΔŨ)
322+
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
323+
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
323324
g = get_tmp(g_cache, ΔŨ1)
324325
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
325326
last_ΔŨtup_float = ΔŨtup
@@ -333,8 +334,8 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
333334
ΔŨ = collect(ΔŨtup)
334335
if ΔŨtup !== last_ΔŨtup_dual
335336
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
336-
u = get_tmp(u_cache, ΔŨ1)
337-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, mpc, model, ΔŨ)
337+
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
338+
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
338339
g = get_tmp(g_cache, ΔŨ1)
339340
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
340341
last_ΔŨtup_dual = ΔŨtup
@@ -349,8 +350,8 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
349350
= get_tmp(Ŷ_cache, ΔŨ1)
350351
ΔŨ = collect(ΔŨtup)
351352
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
352-
u = get_tmp(u_cache, ΔŨ1)
353-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, mpc, model, ΔŨ)
353+
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
354+
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
354355
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
355356
last_ΔŨtup_float = ΔŨtup
356357
end
@@ -363,8 +364,8 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
363364
= get_tmp(Ŷ_cache, ΔŨ1)
364365
ΔŨ = collect(ΔŨtup)
365366
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
366-
u = get_tmp(u_cache, ΔŨ1)
367-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, mpc, model, ΔŨ)
367+
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
368+
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
368369
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
369370
last_ΔŨtup_dual = ΔŨtup
370371
end

src/estimator/execute.jl

+17-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function remove_op!(estim::StateEstimator, u, ym, d)
1515
end
1616

1717
@doc raw"""
18-
f̂!(x̂next, estim::StateEstimator, model::SimModel, x̂, u, d) -> nothing
18+
f̂!(x̂next, û, estim::StateEstimator, model::SimModel, x̂, u, d) -> nothing
1919
2020
Mutating state function ``\mathbf{f̂}`` of the augmented model.
2121
@@ -27,22 +27,26 @@ function returns the next state of the augmented model, defined as:
2727
\mathbf{ŷ}(k) &= \mathbf{ĥ}\Big(\mathbf{x̂}(k), \mathbf{d}(k)\Big)
2828
\end{aligned}
2929
```
30-
where ``\mathbf{x̂}(k+1)`` is stored in `x̂next` argument.
30+
where ``\mathbf{x̂}(k+1)`` is stored in `x̂next` argument. The method mutates `x̂next` and `û`
31+
in place, the latter stores the input vector of the augmented model ``\mathbf{u + ŷ_{s_u}}``.
3132
"""
32-
function f̂!(x̂next, estim::StateEstimator, model::SimModel, x̂, u, d)
33+
function f̂!(x̂next, û, estim::StateEstimator, model::SimModel, x̂, u, d)
3334
# `@views` macro avoid copies with matrix slice operator e.g. [a:b]
3435
@views x̂d, x̂s = x̂[1:model.nx], x̂[model.nx+1:end]
3536
@views x̂d_next, x̂s_next = x̂next[1:model.nx], x̂next[model.nx+1:end]
36-
T = promote_type(eltype(x̂), eltype(u))
37-
= Vector{T}(undef, model.nu) # TODO: avoid this allocation if possible
38-
û .= u .+ mul!(û, estim.Cs_u, x̂s)
37+
mul!(û, estim.Cs_u, x̂s)
38+
.+= u
3939
f!(x̂d_next, model, x̂d, û, d)
4040
mul!(x̂s_next, estim.As, x̂s)
4141
return nothing
4242
end
4343

44-
"Use the augmented model matrices if `model` is a [`LinModel`](@ref)."
45-
function f̂!(x̂next, estim::StateEstimator, ::LinModel, x̂, u, d)
44+
"""
45+
f̂!(x̂next, _ , estim::StateEstimator, model::LinModel, x̂, u, d) -> nothing
46+
47+
Use the augmented model matrices if `model` is a [`LinModel`](@ref).
48+
"""
49+
function f̂!(x̂next, _ , estim::StateEstimator, ::LinModel, x̂, u, d)
4650
mul!(x̂next, estim.Â, x̂)
4751
mul!(x̂next, estim.B̂u, u, 1, 1)
4852
mul!(x̂next, estim.B̂d, d, 1, 1)
@@ -61,7 +65,11 @@ function ĥ!(ŷ, estim::StateEstimator, model::SimModel, x̂, d)
6165
mul!(ŷ, estim.Cs_y, x̂s, 1, 1)
6266
return nothing
6367
end
64-
"Use the augmented model matrices if `model` is a [`LinModel`](@ref)."
68+
"""
69+
ĥ!(ŷ, estim::StateEstimator, model::LinModel, x̂, d) -> nothing
70+
71+
Use the augmented model matrices if `model` is a [`LinModel`](@ref).
72+
"""
6573
function ĥ!(ŷ, estim::StateEstimator, ::LinModel, x̂, d)
6674
mul!(ŷ, estim.Ĉ, x̂)
6775
mul!(ŷ, estim.D̂d, d, 1, 1)

src/estimator/internal_model.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,16 @@ function matrices_internalmodel(model::SimModel{NT}) where NT<:Real
144144
end
145145

146146
@doc raw"""
147-
f̂!(x̂next, ::InternalModel, model::NonLinModel, x̂, u, d)
147+
f̂!(x̂next, _ , estim::InternalModel, model::NonLinModel, x̂, u, d)
148148
149149
State function ``\mathbf{f̂}`` of [`InternalModel`](@ref) for [`NonLinModel`](@ref).
150150
151151
It calls `model.f!(x̂next, x̂, u ,d)` since this estimator does not augment the states.
152152
"""
153-
f̂!(x̂next, ::InternalModel, model::NonLinModel, x̂, u, d) = model.f!(x̂next, x̂, u, d)
153+
f̂!(x̂next, _ , ::InternalModel, model::NonLinModel, x̂, u, d) = model.f!(x̂next, x̂, u, d)
154154

155155
@doc raw"""
156-
ĥ!(ŷ, ::InternalModel, model::NonLinModel, x̂, d)
156+
ĥ!(ŷ, estim::InternalModel, model::NonLinModel, x̂, d)
157157
158158
Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls `model.h!`.
159159
"""

src/estimator/kalman.jl

+14-7
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ function update_estimate!(estim::UnscentedKalmanFilter{NT}, u, ym, d) where NT<:
570570
γ, m̂, Ŝ = estim.γ, estim.m̂, estim.
571571
# --- initialize matrices ---
572572
X̂, X̂_next = Matrix{NT}(undef, nx̂, nσ), Matrix{NT}(undef, nx̂, nσ)
573+
= Vector{NT}(undef, estim.model.nu)
573574
ŷm = Vector{NT}(undef, nym)
574575
= Vector{NT}(undef, estim.model.ny)
575576
Ŷm = Matrix{NT}(undef, nym, nσ)
@@ -604,7 +605,7 @@ function update_estimate!(estim::UnscentedKalmanFilter{NT}, u, ym, d) where NT<:
604605
X̂_cor[:, nx̂+2:end] .-= γ_sqrt_P̂_cor
605606
X̂_next = similar(X̂_cor)
606607
for j in axes(X̂_next, 2)
607-
@views f̂!(X̂_next[:, j], estim, estim.model, X̂_cor[:, j], u, d)
608+
@views f̂!(X̂_next[:, j], û, estim, estim.model, X̂_cor[:, j], u, d)
608609
end
609610
x̂_next = mul!(x̂, X̂_next, m̂)
610611
X̄_next = X̂_next
@@ -765,10 +766,15 @@ function update_estimate!(
765766
estim::ExtendedKalmanFilter{NT}, u, ym, d=empty(estim.x̂)
766767
) where NT<:Real
767768
model = estim.model
768-
x̂next, ŷ = Vector{NT}(undef, estim.nx̂), Vector{NT}(undef, model.ny)
769-
= ForwardDiff.jacobian((x̂next, x̂) -> f̂!(x̂next, estim, model, x̂, u, d), x̂next, estim.x̂)
770-
= ForwardDiff.jacobian((ŷ, x̂) -> ĥ!(ŷ, estim, model, x̂, d), ŷ, estim.x̂)
771-
return update_estimate_kf!(estim, u, ym, d, F̂, Ĥ[estim.i_ym, :], estim.P̂, estim.x̂)
769+
nx̂, nu, ny = estim.nx̂, model.nu, model.ny
770+
x̂, P̂ = estim.x̂, estim.
771+
# concatenate x̂next and û vectors to allows û vector with dual numbers for auto diff:
772+
x̂nextû, ŷ = Vector{NT}(undef, nx̂ + nu), Vector{NT}(undef, ny)
773+
f̂AD! = (x̂nextû, x̂) -> @views f̂!(x̂nextû[1:nx̂], x̂nextû[nx̂+1:end], estim, model, x̂, u, d)
774+
ĥAD! = (ŷ, x̂) -> ĥ!(ŷ, estim, model, x̂, d)
775+
= ForwardDiff.jacobian(f̂AD!, x̂nextû, x̂)[1:nx̂, :]
776+
Ĥm = ForwardDiff.jacobian(ĥAD!, ŷ, x̂)[estim.i_ym, :]
777+
return update_estimate_kf!(estim, u, ym, d, F̂, Ĥm, P̂, x̂)
772778
end
773779

774780
"Set `estim.P̂` to `estim.P̂0` for the time-varying Kalman Filters."
@@ -810,15 +816,16 @@ allocations. See e.g. [`KalmanFilter`](@ref) docstring for the equations.
810816
"""
811817
function update_estimate_kf!(estim::StateEstimator{NT}, u, ym, d, Â, Ĉm, P̂, x̂) where NT<:Real
812818
Q̂, R̂, M̂, K̂ = estim.Q̂, estim.R̂, estim.M̂, estim.
819+
nx̂, nu, ny = estim.nx̂, estim.model.nu, estim.model.ny
820+
x̂next, û, ŷ = Vector{NT}(undef, nx̂), Vector{NT}(undef, nu), Vector{NT}(undef, ny)
813821
mul!(M̂, P̂, Ĉm')
814822
rdiv!(M̂, cholesky!(Hermitian(Ĉm ** Ĉm' .+ R̂)))
815823
mul!(K̂, Â, M̂)
816-
x̂next, ŷ = Vector{NT}(undef, estim.nx̂), Vector{NT}(undef, estim.model.ny)
817824
ĥ!(ŷ, estim, estim.model, x̂, d)
818825
ŷm = @views ŷ[estim.i_ym]
819826
= ŷm
820827
v̂ .= ym .- ŷm
821-
f̂!(x̂next, estim, estim.model, x̂, u, d)
828+
f̂!(x̂next, û, estim, estim.model, x̂, u, d)
822829
mul!(x̂next, K̂, v̂, 1, 1)
823830
estim.x̂ .= x̂next
824831
.data .=* (P̂ .-* Ĉm * P̂) *' .+# .data is necessary for Hermitians

src/estimator/mhe/construct.jl

+11-10
Original file line numberDiff line numberDiff line change
@@ -1003,15 +1003,16 @@ function init_optimization!(
10031003
end
10041004
He = estim.He
10051005
nV̂, nX̂, ng = He*estim.nym, He*estim.nx̂, length(con.i_g)
1006-
nx̂, nŷ = estim.nx̂, model.ny
1006+
nx̂, nŷ, nu = estim.nx̂, model.ny, model.nu
10071007
# see init_optimization!(mpc::NonLinMPC, optim) for details on the inspiration
1008-
Jfunc, gfunc = let estim=estim, model=model, nZ̃=nZ̃ , nV̂=nV̂, nX̂=nX̂, ng=ng, nx̂=nx̂, nŷ=nŷ
1008+
Jfunc, gfunc = let estim=estim, model=model, nZ̃=nZ̃, nV̂=nV̂, nX̂=nX̂, ng=ng, nx̂=nx̂, nu=nu, nŷ=nŷ
10091009
Nc = nZ̃ + 3
10101010
last_Z̃tup_float, last_Z̃tup_dual = nothing, nothing
10111011
V̂_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nV̂), Nc)
10121012
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), Nc)
10131013
X̂_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nX̂), Nc)
10141014
x̄_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Nc)
1015+
û_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
10151016
ŷ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŷ), Nc)
10161017
function Jfunc(Z̃tup::JNT...)
10171018
Z̃1 = Z̃tup[begin]
@@ -1020,8 +1021,8 @@ function init_optimization!(
10201021
if Z̃tup !== last_Z̃tup_float
10211022
g = get_tmp(g_cache, Z̃1)
10221023
= get_tmp(X̂_cache, Z̃1)
1023-
= get_tmp(ŷ_cache, Z̃1)
1024-
V̂, X̂ = predict!(V̂, X̂, ŷ, estim, model, Z̃)
1024+
û, = get_tmp(û_cache, Z̃1), get_tmp(ŷ_cache, Z̃1)
1025+
V̂, X̂ = predict!(V̂, X̂, û, ŷ, estim, model, Z̃)
10251026
g = con_nonlinprog!(g, estim, model, X̂, V̂, Z̃)
10261027
last_Z̃tup_float = Z̃tup
10271028
end
@@ -1035,8 +1036,8 @@ function init_optimization!(
10351036
if Z̃tup !== last_Z̃tup_dual
10361037
g = get_tmp(g_cache, Z̃1)
10371038
= get_tmp(X̂_cache, Z̃1)
1038-
= get_tmp(ŷ_cache, Z̃1)
1039-
V̂, X̂ = predict!(V̂, X̂, ŷ, estim, model, Z̃)
1039+
û, = get_tmp(û_cache, Z̃1), get_tmp(ŷ_cache, Z̃1)
1040+
V̂, X̂ = predict!(V̂, X̂, û, ŷ, estim, model, Z̃)
10401041
g = con_nonlinprog!(g, estim, model, X̂, V̂, Z̃)
10411042
last_Z̃tup_dual = Z̃tup
10421043
end
@@ -1050,8 +1051,8 @@ function init_optimization!(
10501051
= collect(Z̃tup)
10511052
= get_tmp(V̂_cache, Z̃1)
10521053
= get_tmp(X̂_cache, Z̃1)
1053-
= get_tmp(ŷ_cache, Z̃1)
1054-
V̂, X̂ = predict!(V̂, X̂, ŷ, estim, model, Z̃)
1054+
û, = get_tmp(û_cache, Z̃1), get_tmp(ŷ_cache, Z̃1)
1055+
V̂, X̂ = predict!(V̂, X̂, û, ŷ, estim, model, Z̃)
10551056
g = con_nonlinprog!(g, estim, model, X̂, V̂, Z̃)
10561057
last_Z̃tup_float = Z̃tup
10571058
end
@@ -1064,8 +1065,8 @@ function init_optimization!(
10641065
= collect(Z̃tup)
10651066
= get_tmp(V̂_cache, Z̃1)
10661067
= get_tmp(X̂_cache, Z̃1)
1067-
= get_tmp(ŷ_cache, Z̃1)
1068-
V̂, X̂ = predict!(V̂, X̂, ŷ, estim, model, Z̃)
1068+
û, = get_tmp(û_cache, Z̃1), get_tmp(ŷ_cache, Z̃1)
1069+
V̂, X̂ = predict!(V̂, X̂, û, ŷ, estim, model, Z̃)
10691070
g = con_nonlinprog!(g, estim, model, X̂, V̂, Z̃)
10701071
last_Z̃tup_dual = Z̃tup
10711072
end

0 commit comments

Comments
 (0)