@@ -304,6 +304,7 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
304
304
Jfunc, gfunc = let mpc= mpc, model= model, ng= ng, nΔŨ= nΔŨ, nŶ= Hp* ny, nx̂= nx̂, nu= nu, nU= Hp* nu
305
305
Nc = nΔŨ + 3
306
306
last_ΔŨtup_float, last_ΔŨtup_dual = nothing , nothing
307
+ ΔŨ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nΔŨ), Nc)
307
308
Ŷ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Nc)
308
309
U_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Nc)
309
310
g_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, ng), Nc)
@@ -313,63 +314,42 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
313
314
û_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nu), Nc)
314
315
Ȳ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Nc)
315
316
Ū_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Nc)
316
- function Jfunc (ΔŨtup:: JNT ... )
317
+ function Jfunc (ΔŨtup:: T ... ):: T where {T <: Real }
317
318
ΔŨ1 = ΔŨtup[begin ]
318
- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
319
- ΔŨ = collect (ΔŨtup)
320
- if ΔŨtup != = last_ΔŨtup_float
321
- x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
322
- u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
323
- Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
324
- g = get_tmp (g_cache, ΔŨ1)
325
- g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
319
+ if T == JNT
326
320
last_ΔŨtup_float = ΔŨtup
321
+ else
322
+ last_ΔŨtup_dual = ΔŨtup
327
323
end
324
+ ΔŨ, Ŷ = get_tmp (ΔŨ_cache, ΔŨ1), get_tmp (Ŷ_cache, ΔŨ1)
325
+ x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
326
+ u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
327
+ ΔŨ .= ΔŨtup
328
+ Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
329
+ g = get_tmp (g_cache, ΔŨ1)
330
+ g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
328
331
U, Ȳ, Ū = get_tmp (U_cache, ΔŨ1), get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
329
- return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
330
- end
331
- function Jfunc (ΔŨtup:: ForwardDiff.Dual... )
332
- ΔŨ1 = ΔŨtup[begin ]
333
- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
334
- ΔŨ = collect (ΔŨtup)
335
- if ΔŨtup != = last_ΔŨtup_dual
336
- x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
337
- u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
338
- Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
339
- g = get_tmp (g_cache, ΔŨ1)
340
- g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
341
- last_ΔŨtup_dual = ΔŨtup
342
- end
343
- U, Ȳ, Ū = get_tmp (U_cache, ΔŨ1), get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
344
- return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
332
+ return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ):: T
345
333
end
346
- function gfunc_i (i, ΔŨtup:: NTuple{N, JNT} ) where N
334
+ function gfunc_i (i, ΔŨtup:: NTuple{N, T} ) :: T where {N, T <: Real }
347
335
ΔŨ1 = ΔŨtup[begin ]
348
336
g = get_tmp (g_cache, ΔŨ1)
349
- if ΔŨtup != = last_ΔŨtup_float
350
- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
351
- ΔŨ = collect (ΔŨtup)
352
- x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
353
- u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
354
- Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
355
- g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
356
- last_ΔŨtup_float = ΔŨtup
337
+ if T == JNT
338
+ isnewvalue = (ΔŨtup != = last_ΔŨtup_float)
339
+ isnewvalue && (last_ΔŨtup_float = ΔŨtup)
340
+ else
341
+ isnewvalue = (ΔŨtup != = last_ΔŨtup_dual)
342
+ isnewvalue && (last_ΔŨtup_dual = ΔŨtup)
357
343
end
358
- return g[i]
359
- end
360
- function gfunc_i (i, ΔŨtup:: NTuple{N, ForwardDiff.Dual} ) where N
361
- ΔŨ1 = ΔŨtup[begin ]
362
- g = get_tmp (g_cache, ΔŨ1)
363
- if ΔŨtup != = last_ΔŨtup_dual
364
- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
365
- ΔŨ = collect (ΔŨtup)
344
+ if isnewvalue
345
+ ΔŨ, Ŷ = get_tmp (ΔŨ_cache, ΔŨ1), get_tmp (Ŷ_cache, ΔŨ1)
366
346
x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
367
347
u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
348
+ ΔŨ .= ΔŨtup
368
349
Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
369
350
g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
370
- last_ΔŨtup_dual = ΔŨtup
371
351
end
372
- return g[i]
352
+ return g[i]:: T
373
353
end
374
354
gfunc = [(ΔŨ... ) -> gfunc_i (i, ΔŨ) for i in 1 : ng]
375
355
(Jfunc, gfunc)
0 commit comments