Skip to content

Commit 77f0b14

Browse files
committed
added: store AbstractADType backends as struct parameters
This is "safer" like that. And it allows the user to verify what are the DI backends in a given MPC/MHE object. Also useful for the tests.
1 parent 1a1a474 commit 77f0b14

File tree

2 files changed

+60
-80
lines changed

2 files changed

+60
-80
lines changed

src/controller/nonlinmpc.jl

+26-39
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ struct NonLinMPC{
1212
NT<:Real,
1313
SE<:StateEstimator,
1414
TM<:TranscriptionMethod,
15-
JM<:JuMP.GenericModel,
15+
JM<:JuMP.GenericModel,
16+
GB<:AbstractADType,
17+
JB<:AbstractADType,
1618
PT<:Any,
1719
JEfunc<:Function,
1820
GCfunc<:Function
@@ -23,6 +25,8 @@ struct NonLinMPC{
2325
# different since solvers that support non-Float64 are scarce.
2426
optim::JM
2527
con::ControllerConstraint{NT, GCfunc}
28+
gradient::GB
29+
jacobian::JB
2630
::Vector{NT}
2731
::Vector{NT}
2832
Hp::Int
@@ -59,12 +63,14 @@ struct NonLinMPC{
5963
function NonLinMPC{NT}(
6064
estim::SE,
6165
Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, Ewt, JE::JEfunc, gc!::GCfunc, nc, p::PT,
62-
transcription::TM, optim::JM, gradient, jacobian
66+
transcription::TM, optim::JM, gradient::GB, jacobian::JB
6367
) where {
6468
NT<:Real,
6569
SE<:StateEstimator,
6670
TM<:TranscriptionMethod,
6771
JM<:JuMP.GenericModel,
72+
GB<:AbstractADType,
73+
JB<:AbstractADType,
6874
PT<:Any,
6975
JEfunc<:Function,
7076
GCfunc<:Function,
@@ -101,8 +107,9 @@ struct NonLinMPC{
101107
nZ̃ = get_nZ(estim, transcription, Hp, Hc) +
102108
= zeros(NT, nZ̃)
103109
buffer = PredictiveControllerBuffer(estim, transcription, Hp, Hc, nϵ)
104-
mpc = new{NT, SE, TM, JM, PT, JEfunc, GCfunc}(
110+
mpc = new{NT, SE, TM, JM, GB, JB, PT, JEfunc, GCfunc}(
105111
estim, transcription, optim, con,
112+
gradient, jacobian,
106113
Z̃, ŷ,
107114
Hp, Hc, nϵ,
108115
weights,
@@ -116,7 +123,7 @@ struct NonLinMPC{
116123
Uop, Yop, Dop,
117124
buffer
118125
)
119-
init_optimization!(mpc, model, optim, gradient, jacobian)
126+
init_optimization!(mpc, model, optim)
120127
return mpc
121128
end
122129
end
@@ -505,23 +512,11 @@ function addinfo!(info, mpc::NonLinMPC{NT}) where NT<:Real
505512
end
506513

507514
"""
508-
init_optimization!(
509-
mpc::NonLinMPC,
510-
model::SimModel,
511-
optim::JuMP.GenericModel,
512-
gradient::AbstractADType,
513-
Jacobian::AbstractADType
514-
) -> nothing
515+
init_optimization!(mpc::NonLinMPC, model::SimModel, optim::JuMP.GenericModel) -> nothing
515516
516517
Init the nonlinear optimization for [`NonLinMPC`](@ref) controllers.
517518
"""
518-
function init_optimization!(
519-
mpc::NonLinMPC,
520-
model::SimModel,
521-
optim::JuMP.GenericModel,
522-
gradient::AbstractADType,
523-
jacobian::AbstractADType
524-
)
519+
function init_optimization!(mpc::NonLinMPC, model::SimModel, optim::JuMP.GenericModel)
525520
# --- variables and linear constraints ---
526521
con, transcription = mpc.con, mpc.transcription
527522
nZ̃ = length(mpc.Z̃)
@@ -546,7 +541,7 @@ function init_optimization!(
546541
end
547542
end
548543
Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs! = get_optim_functions(
549-
mpc, optim, gradient, jacobian
544+
mpc, optim
550545
)
551546
@operator(optim, J, nZ̃, Jfunc, ∇Jfunc!)
552547
@objective(optim, Min, J(Z̃var...))
@@ -557,10 +552,7 @@ end
557552

558553
"""
559554
get_optim_functions(
560-
mpc::NonLinMPC,
561-
optim::JuMP.GenericModel,
562-
grad_backend::AbstractADType,
563-
jac_backend ::AbstractADType
555+
mpc::NonLinMPC, optim::JuMP.GenericModel
564556
) -> Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
565557
566558
Return the functions for the nonlinear optimization of `mpc` [`NonLinMPC`](@ref) controller.
@@ -583,12 +575,7 @@ This method is really intricate and I'm not proud of it. That's because of 3 ele
583575
584576
Inspired from: [User-defined operators with vector outputs](@extref JuMP User-defined-operators-with-vector-outputs)
585577
"""
586-
function get_optim_functions(
587-
mpc::NonLinMPC,
588-
optim::JuMP.GenericModel{JNT},
589-
grad_backend::AbstractADType,
590-
jac_backend ::AbstractADType
591-
) where JNT<:Real
578+
function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT<:Real
592579
# ----- common cache for Jfunc, gfuncs, geqfuncs called with floats -------------------
593580
model = mpc.estim.model
594581
nu, ny, nx̂, nϵ, Hp, Hc = model.nu, model.ny, mpc.estim.nx̂, mpc.nϵ, mpc.Hp, mpc.Hc
@@ -624,18 +611,18 @@ function get_optim_functions(
624611
Cache(Û0), Cache(X̂0),
625612
Cache(gc), Cache(g), Cache(geq),
626613
)
627-
∇J_prep = prepare_gradient(Jfunc!, grad_backend, Z̃_∇J, ∇J_context...; strict)
614+
∇J_prep = prepare_gradient(Jfunc!, mpc.gradient, Z̃_∇J, ∇J_context...; strict)
628615
∇J = Vector{JNT}(undef, nZ̃)
629616
∇Jfunc! = if nZ̃ == 1
630617
function (Z̃arg)
631618
Z̃_∇J .= Z̃arg
632-
gradient!(Jfunc!, ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context...)
619+
gradient!(Jfunc!, ∇J, ∇J_prep, mpc.gradient, Z̃_∇J, ∇J_context...)
633620
return ∇J[begin] # univariate syntax, see JuMP.@operator doc
634621
end
635622
else
636623
function (∇J::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
637624
Z̃_∇J .= Z̃arg
638-
gradient!(Jfunc!, ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context...)
625+
gradient!(Jfunc!, ∇J, ∇J_prep, mpc.gradient, Z̃_∇J, ∇J_context...)
639626
return ∇J # multivariate syntax, see JuMP.@operator doc
640627
end
641628
end
@@ -663,24 +650,24 @@ function get_optim_functions(
663650
)
664651
# temporarily enable all the inequality constraints for sparsity detection:
665652
mpc.con.i_g[1:end-nc] .= true
666-
∇g_prep = prepare_jacobian(gfunc!, g, jac_backend, Z̃_∇g, ∇g_context...; strict)
653+
∇g_prep = prepare_jacobian(gfunc!, g, mpc.jacobian, Z̃_∇g, ∇g_context...; strict)
667654
mpc.con.i_g[1:end-nc] .= false
668-
∇g = init_diffmat(JNT, jac_backend, ∇g_prep, nZ̃, ng)
655+
∇g = init_diffmat(JNT, mpc.jacobian, ∇g_prep, nZ̃, ng)
669656
∇gfuncs! = Vector{Function}(undef, ng)
670657
for i in eachindex(∇gfuncs!)
671658
∇gfuncs_i! = if nZ̃ == 1
672659
function (Z̃arg::T) where T<:Real
673660
if isdifferent(Z̃arg, Z̃_∇g)
674661
Z̃_∇g .= Z̃arg
675-
jacobian!(gfunc!, g, ∇g, ∇g_prep, jac_backend, Z̃_∇g, ∇g_context...)
662+
jacobian!(gfunc!, g, ∇g, ∇g_prep, mpc.jacobian, Z̃_∇g, ∇g_context...)
676663
end
677664
return ∇g[i, begin] # univariate syntax, see JuMP.@operator doc
678665
end
679666
else
680667
function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
681668
if isdifferent(Z̃arg, Z̃_∇g)
682669
Z̃_∇g .= Z̃arg
683-
jacobian!(gfunc!, g, ∇g, ∇g_prep, jac_backend, Z̃_∇g, ∇g_context...)
670+
jacobian!(gfunc!, g, ∇g, ∇g_prep, mpc.jacobian, Z̃_∇g, ∇g_context...)
684671
end
685672
return ∇g_i .= @views ∇g[i, :] # multivariate syntax, see JuMP.@operator doc
686673
end
@@ -709,8 +696,8 @@ function get_optim_functions(
709696
Cache(Û0), Cache(X̂0),
710697
Cache(gc), Cache(g)
711698
)
712-
∇geq_prep = prepare_jacobian(geqfunc!, geq, jac_backend, Z̃_∇geq, ∇geq_context...; strict)
713-
∇geq = init_diffmat(JNT, jac_backend, ∇geq_prep, nZ̃, neq)
699+
∇geq_prep = prepare_jacobian(geqfunc!, geq, mpc.jacobian, Z̃_∇geq, ∇geq_context...; strict)
700+
∇geq = init_diffmat(JNT, mpc.jacobian, ∇geq_prep, nZ̃, neq)
714701
∇geqfuncs! = Vector{Function}(undef, neq)
715702
for i in eachindex(∇geqfuncs!)
716703
# only multivariate syntax, univariate is impossible since nonlinear equality
@@ -720,7 +707,7 @@ function get_optim_functions(
720707
if isdifferent(Z̃arg, Z̃_∇geq)
721708
Z̃_∇geq .= Z̃arg
722709
jacobian!(
723-
geqfunc!, geq, ∇geq, ∇geq_prep, jac_backend, Z̃_∇geq, ∇geq_context...
710+
geqfunc!, geq, ∇geq, ∇geq_prep, mpc.jacobian, Z̃_∇geq, ∇geq_context...
724711
)
725712
end
726713
return ∇geq_i .= @views ∇geq[i, :]

src/estimator/mhe/construct.jl

+34-41
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,17 @@ struct MovingHorizonEstimator{
4848
NT<:Real,
4949
SM<:SimModel,
5050
JM<:JuMP.GenericModel,
51+
GB<:AbstractADType,
52+
JB<:AbstractADType,
5153
CE<:StateEstimator,
5254
} <: StateEstimator{NT}
5355
model::SM
5456
# note: `NT` and the number type `JNT` in `JuMP.GenericModel{JNT}` can be
5557
# different since solvers that support non-Float64 are scarce.
5658
optim::JM
5759
con::EstimatorConstraint{NT}
60+
gradient::GB
61+
jacobian::JB
5862
covestim::CE
5963
::Vector{NT}
6064
lastu0::Vector{NT}
@@ -112,9 +116,16 @@ struct MovingHorizonEstimator{
112116
function MovingHorizonEstimator{NT}(
113117
model::SM,
114118
He, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂, Cwt,
115-
optim::JM, gradient, jacobian, covestim::CE;
119+
optim::JM, gradient::GB, jacobian::JB, covestim::CE;
116120
direct=true
117-
) where {NT<:Real, SM<:SimModel{NT}, JM<:JuMP.GenericModel, CE<:StateEstimator{NT}}
121+
) where {
122+
NT<:Real,
123+
SM<:SimModel{NT},
124+
JM<:JuMP.GenericModel,
125+
GB<:AbstractADType,
126+
JB<:AbstractADType,
127+
CE<:StateEstimator{NT}
128+
}
118129
nu, ny, nd = model.nu, model.ny, model.nd
119130
He < 1 && throw(ArgumentError("Estimation horizon He should be ≥ 1"))
120131
Cwt < 0 && throw(ArgumentError("Cwt weight should be ≥ 0"))
@@ -158,8 +169,10 @@ struct MovingHorizonEstimator{
158169
P̂arr_old = copy(P̂_0)
159170
Nk = [0]
160171
corrected = [false]
161-
estim = new{NT, SM, JM, CE}(
162-
model, optim, con, covestim,
172+
estim = new{NT, SM, JM, GB, JB, CE}(
173+
model, optim, con,
174+
gradient, jacobian,
175+
covestim,
163176
Z̃, lastu0, x̂op, f̂op, x̂0,
164177
He, nϵ,
165178
i_ym, nx̂, nym, nyu, nxs,
@@ -173,7 +186,7 @@ struct MovingHorizonEstimator{
173186
direct, corrected,
174187
buffer
175188
)
176-
init_optimization!(estim, model, optim, gradient, jacobian)
189+
init_optimization!(estim, model, optim)
177190
return estim
178191
end
179192
end
@@ -261,9 +274,9 @@ transcription for now.
261274
nonlinear optimizer for solving (default to [`Ipopt`](https://github.com/jump-dev/Ipopt.jl),
262275
or [`OSQP`](https://osqp.org/docs/parsers/jump.html) if `model` is a [`LinModel`](@ref)).
263276
- `gradient=AutoForwardDiff()` : an `AbstractADType` backend for the gradient of the objective
264-
function if `model` is not a [`LinModel`](@ref), see [`DifferentiationInterface` doc](@extref DifferentiationInterface List).
277+
function when `model` is not a [`LinModel`](@ref), see [`DifferentiationInterface` doc](@extref DifferentiationInterface List).
265278
- `jacobian=AutoForwardDiff()` : an `AbstractADType` backend for the Jacobian of the
266-
constraints if `model` is not a [`LinModel`](@ref), see `gradient` above for the options.
279+
constraints when `model` is not a [`LinModel`](@ref), see `gradient` above for the options.
267280
- `direct=true`: construct with a direct transmission from ``\mathbf{y^m}`` (a.k.a. current
268281
estimator, in opposition to the delayed/predictor form).
269282
@@ -1197,22 +1210,18 @@ end
11971210

11981211
"""
11991212
init_optimization!(
1200-
estim::MovingHorizonEstimator, model::SimModel, optim::JuMP.GenericModel, _ , _
1213+
estim::MovingHorizonEstimator, model::LinModel, optim::JuMP.GenericModel
12011214
)
12021215
12031216
Init the quadratic optimization of [`MovingHorizonEstimator`](@ref).
12041217
"""
12051218
function init_optimization!(
1206-
estim::MovingHorizonEstimator,
1207-
::LinModel,
1208-
optim::JuMP.GenericModel,
1209-
::AbstractADType,
1210-
::AbstractADType
1219+
estim::MovingHorizonEstimator, model::LinModel, optim::JuMP.GenericModel,
12111220
)
12121221
nZ̃ = length(estim.Z̃)
12131222
JuMP.num_variables(optim) == 0 || JuMP.empty!(optim)
12141223
JuMP.set_silent(optim)
1215-
limit_solve_time(estim.optim, estim.model.Ts)
1224+
limit_solve_time(optim, model.Ts)
12161225
@variable(optim, Z̃var[1:nZ̃])
12171226
A = estim.con.A[estim.con.i_b, :]
12181227
b = estim.con.b[estim.con.i_b]
@@ -1223,28 +1232,20 @@ end
12231232

12241233
"""
12251234
init_optimization!(
1226-
estim::MovingHorizonEstimator,
1227-
model::SimModel,
1228-
optim::JuMP.GenericModel,
1229-
gradient::AbstractADType,
1230-
jacobian::AbstractADType
1235+
estim::MovingHorizonEstimator, model::SimModel, optim::JuMP.GenericModel,
12311236
) -> nothing
12321237
12331238
Init the nonlinear optimization of [`MovingHorizonEstimator`](@ref).
12341239
"""
12351240
function init_optimization!(
1236-
estim::MovingHorizonEstimator,
1237-
model::SimModel,
1238-
optim::JuMP.GenericModel{JNT},
1239-
gradient::AbstractADType,
1240-
jacobian::AbstractADType
1241+
estim::MovingHorizonEstimator, model::SimModel, optim::JuMP.GenericModel{JNT}
12411242
) where JNT<:Real
12421243
C, con = estim.C, estim.con
12431244
nZ̃ = length(estim.Z̃)
12441245
# --- variables and linear constraints ---
12451246
JuMP.num_variables(optim) == 0 || JuMP.empty!(optim)
12461247
JuMP.set_silent(optim)
1247-
limit_solve_time(estim.optim, estim.model.Ts)
1248+
limit_solve_time(optim, model.Ts)
12481249
@variable(optim, Z̃var[1:nZ̃])
12491250
A = estim.con.A[con.i_b, :]
12501251
b = estim.con.b[con.i_b]
@@ -1258,9 +1259,7 @@ function init_optimization!(
12581259
JuMP.set_attribute(optim, "nlp_scaling_max_gradient", 10.0/C)
12591260
end
12601261
end
1261-
Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs! = get_optim_functions(
1262-
estim, optim, gradient, jacobian
1263-
)
1262+
Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs! = get_optim_functions(estim, optim)
12641263
@operator(optim, J, nZ̃, Jfunc, ∇Jfunc!)
12651264
@objective(optim, Min, J(Z̃var...))
12661265
nV̂, nX̂ = estim.He*estim.nym, estim.He*estim.nx̂
@@ -1301,10 +1300,7 @@ end
13011300

13021301
"""
13031302
get_optim_functions(
1304-
estim::MovingHorizonEstimator,
1305-
optim::JuMP.GenericModel,
1306-
grad_backend::AbstractADType,
1307-
jac_backend::AbstractADType
1303+
estim::MovingHorizonEstimator, optim::JuMP.GenericModel,
13081304
) -> Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!
13091305
13101306
Return the functions for the nonlinear optimization of [`MovingHorizonEstimator`](@ref).
@@ -1327,10 +1323,7 @@ This method is really intricate and I'm not proud of it. That's because of 3 ele
13271323
Inspired from: [User-defined operators with vector outputs](@extref JuMP User-defined-operators-with-vector-outputs)
13281324
"""
13291325
function get_optim_functions(
1330-
estim::MovingHorizonEstimator,
1331-
optim::JuMP.GenericModel{JNT},
1332-
grad_backend::AbstractADType,
1333-
jac_backend::AbstractADType
1326+
estim::MovingHorizonEstimator, ::JuMP.GenericModel{JNT},
13341327
) where {JNT <: Real}
13351328
# ---------- common cache for Jfunc, gfuncs called with floats ------------------------
13361329
model, con = estim.model, estim.con
@@ -1363,13 +1356,13 @@ function get_optim_functions(
13631356
Cache(g),
13641357
Cache(x̄),
13651358
)
1366-
∇J_prep = prepare_gradient(Jfunc!, grad_backend, Z̃_∇J, ∇J_context...; strict)
1359+
∇J_prep = prepare_gradient(Jfunc!, estim.gradient, Z̃_∇J, ∇J_context...; strict)
13671360
∇J = Vector{JNT}(undef, nZ̃)
13681361
∇Jfunc! = function (∇J::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
13691362
# only the multivariate syntax of JuMP.@operator, univariate is impossible for MHE
13701363
# since Z̃ comprises the arrival state estimate AND the estimated process noise
13711364
Z̃_∇J .= Z̃arg
1372-
gradient!(Jfunc!, ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context...)
1365+
gradient!(Jfunc!, ∇J, ∇J_prep, estim.gradient, Z̃_∇J, ∇J_context...)
13731366
return ∇J
13741367
end
13751368

@@ -1397,17 +1390,17 @@ function get_optim_functions(
13971390
# temporarily enable all the inequality constraints for sparsity detection:
13981391
estim.con.i_g .= true
13991392
estim.Nk[] = He
1400-
∇g_prep = prepare_jacobian(gfunc!, g, jac_backend, Z̃_∇g, ∇g_context...; strict)
1393+
∇g_prep = prepare_jacobian(gfunc!, g, estim.jacobian, Z̃_∇g, ∇g_context...; strict)
14011394
estim.con.i_g .= false
14021395
estim.Nk[] = 0
1403-
∇g = init_diffmat(JNT, jac_backend, ∇g_prep, nZ̃, ng)
1396+
∇g = init_diffmat(JNT, estim.jacobian, ∇g_prep, nZ̃, ng)
14041397
∇gfuncs! = Vector{Function}(undef, ng)
14051398
for i in eachindex(∇gfuncs!)
14061399
∇gfuncs![i] = function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
14071400
# only the multivariate syntax of JuMP.@operator, see above for the explanation
14081401
if isdifferent(Z̃arg, Z̃_∇g)
14091402
Z̃_∇g .= Z̃arg
1410-
jacobian!(gfunc!, g, ∇g, ∇g_prep, jac_backend, Z̃_∇g, ∇g_context...)
1403+
jacobian!(gfunc!, g, ∇g, ∇g_prep, estim.jacobian, Z̃_∇g, ∇g_context...)
14111404
end
14121405
return ∇g_i .= @views ∇g[i, :]
14131406
end

0 commit comments

Comments
 (0)