@@ -304,6 +304,7 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
304
304
Jfunc, gfunc = let mpc= mpc, model= model, ng= ng, nΔŨ= nΔŨ, nŶ= Hp* ny, nx̂= nx̂, nu= nu, nU= Hp* nu
305
305
Nc = nΔŨ + 3
306
306
last_ΔŨtup_float, last_ΔŨtup_dual = nothing , nothing
307
+ ΔŨ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nΔŨ), Nc)
307
308
Ŷ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Nc)
308
309
U_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Nc)
309
310
g_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, ng), Nc)
@@ -313,63 +314,46 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
313
314
û_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nu), Nc)
314
315
Ȳ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Nc)
315
316
Ū_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Nc)
316
- function Jfunc (ΔŨtup:: JNT ... )
317
+ function Jfunc (ΔŨtup:: T ... ):: T where T <: Real
317
318
ΔŨ1 = ΔŨtup[begin ]
318
- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
319
- ΔŨ = collect (ΔŨtup)
320
- if ΔŨtup != = last_ΔŨtup_float
321
- x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
322
- u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
323
- Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
324
- g = get_tmp (g_cache, ΔŨ1)
325
- g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
326
- last_ΔŨtup_float = ΔŨtup
319
+ ΔŨ, Ŷ = get_tmp (ΔŨ_cache, ΔŨ1), get_tmp (Ŷ_cache, ΔŨ1)
320
+ if T == JNT
321
+ isnewvalue = (ΔŨtup != = last_ΔŨtup_float)
322
+ isnewvalue && (last_ΔŨtup_float = ΔŨtup)
323
+ else
324
+ isnewvalue = (ΔŨtup != = last_ΔŨtup_dual)
325
+ isnewvalue && (last_ΔŨtup_dual = ΔŨtup)
327
326
end
328
- U, Ȳ, Ū = get_tmp (U_cache, ΔŨ1), get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
329
- return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
330
- end
331
- function Jfunc (ΔŨtup:: ForwardDiff.Dual... )
332
- ΔŨ1 = ΔŨtup[begin ]
333
- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
334
- ΔŨ = collect (ΔŨtup)
335
- if ΔŨtup != = last_ΔŨtup_dual
327
+ if isnewvalue
336
328
x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
337
329
u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
330
+ ΔŨ .= ΔŨtup
338
331
Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
339
332
g = get_tmp (g_cache, ΔŨ1)
340
333
g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
341
- last_ΔŨtup_dual = ΔŨtup
342
334
end
343
335
U, Ȳ, Ū = get_tmp (U_cache, ΔŨ1), get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
344
- return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
336
+ return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ):: T
345
337
end
346
- function gfunc_i (i, ΔŨtup:: NTuple{N, JNT} ) where N
338
+ function gfunc_i (i, ΔŨtup:: NTuple{N, T} ) :: T where {N, T <: Real }
347
339
ΔŨ1 = ΔŨtup[begin ]
348
340
g = get_tmp (g_cache, ΔŨ1)
349
- if ΔŨtup != = last_ΔŨtup_float
350
- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
351
- ΔŨ = collect (ΔŨtup)
352
- x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
353
- u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
354
- Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
355
- g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
356
- last_ΔŨtup_float = ΔŨtup
341
+ if T == JNT
342
+ isnewvalue = (ΔŨtup != = last_ΔŨtup_float)
343
+ isnewvalue && (last_ΔŨtup_float = ΔŨtup)
344
+ else
345
+ isnewvalue = (ΔŨtup != = last_ΔŨtup_dual)
346
+ isnewvalue && (last_ΔŨtup_dual = ΔŨtup)
357
347
end
358
- return g[i]
359
- end
360
- function gfunc_i (i, ΔŨtup:: NTuple{N, ForwardDiff.Dual} ) where N
361
- ΔŨ1 = ΔŨtup[begin ]
362
- g = get_tmp (g_cache, ΔŨ1)
363
- if ΔŨtup != = last_ΔŨtup_dual
364
- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
365
- ΔŨ = collect (ΔŨtup)
348
+ if isnewvalue
349
+ ΔŨ, Ŷ = get_tmp (ΔŨ_cache, ΔŨ1), get_tmp (Ŷ_cache, ΔŨ1)
366
350
x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
367
351
u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
352
+ ΔŨ .= ΔŨtup
368
353
Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
369
354
g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
370
- last_ΔŨtup_dual = ΔŨtup
371
355
end
372
- return g[i]
356
+ return g[i]:: T
373
357
end
374
358
gfunc = [(ΔŨ... ) -> gfunc_i (i, ΔŨ) for i in 1 : ng]
375
359
(Jfunc, gfunc)
0 commit comments