Skip to content

Commit 478688f

Browse files
committed
added: DI backend as NonLinModel parameter in the struct
1 parent 0a2c63f commit 478688f

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

src/model/linearization.jl

+12-11
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
"""
2-
get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p) -> linfunc!
2+
get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p, backend) -> linfunc!
33
44
Return the `linfunc!` function that computes the Jacobians of `f!` and `h!` functions.
55
66
The function has the following signature:
77
```
8-
linfunc!(xnext, y, A, Bu, C, Bd, Dd, x, u, d, cst_x, cst_u, cst_d) -> nothing
8+
linfunc!(xnext, y, A, Bu, C, Bd, Dd, backend, x, u, d, cst_x, cst_u, cst_d) -> nothing
99
```
10-
and it should modifies in-place all the arguments before `x`. The `cst_x`, `cst_u`, `cst_d`
11-
and are `DifferentiationInterface.Constant` objects with the linearization points.
10+
and it should modifies in-place all the arguments before `backend`. The `backend` argument
11+
is an `AbstractADType` backend from `DifferentiationInterface`. The `cst_x`, `cst_u` and
12+
`cst_d` are `DifferentiationInterface.Constant` objects with the linearization points.
1213
"""
13-
function get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p)
14+
function get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p, backend)
1415
f_x!(xnext, x, u, d) = f!(xnext, x, u, d, p)
1516
f_u!(xnext, u, x, d) = f!(xnext, x, u, d, p)
1617
f_d!(xnext, d, x, u) = f!(xnext, x, u, d, p)
1718
h_x!(y, x, d) = h!(y, x, d, p)
1819
h_d!(y, d, x) = h!(y, x, d, p)
19-
backend = AutoForwardDiff()
2020
strict = Val(true)
2121
xnext = zeros(NT, nx)
2222
y = zeros(NT, ny)
@@ -31,7 +31,7 @@ function get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p)
3131
Bd_prep = prepare_jacobian(f_d!, xnext, backend, d, cst_x, cst_u; strict)
3232
C_prep = prepare_jacobian(h_x!, y, backend, x, cst_d ; strict)
3333
Dd_prep = prepare_jacobian(h_d!, y, backend, d, cst_x ; strict)
34-
function linfunc!(xnext, y, A, Bu, C, Bd, Dd, x, u, d, cst_x, cst_u, cst_d)
34+
function linfunc!(xnext, y, A, Bu, C, Bd, Dd, backend, x, u, d, cst_x, cst_u, cst_d)
3535
# all the arguments before `x` are mutated in this function
3636
jacobian!(f_x!, xnext, A, A_prep, backend, x, cst_u, cst_d)
3737
jacobian!(f_u!, xnext, Bu, Bu_prep, backend, u, cst_x, cst_d)
@@ -183,10 +183,11 @@ end
183183
function linearize_core!(linmodel::LinModel, model::SimModel, x0, u0, d0)
184184
xnext0, y0 = linmodel.buffer.x, linmodel.buffer.y
185185
A, Bu, C, Bd, Dd = linmodel.A, linmodel.Bu, linmodel.C, linmodel.Bd, linmodel.Dd
186-
cx = Constant(x0)
187-
cu = Constant(u0)
188-
cd = Constant(d0)
189-
model.linfunc!(xnext0, y0, A, Bu, C, Bd, Dd, x0, u0, d0, cx, cu, cd)
186+
cst_x = Constant(x0)
187+
cst_u = Constant(u0)
188+
cst_d = Constant(d0)
189+
backend = model.jacobian
190+
model.linfunc!(xnext0, y0, A, Bu, C, Bd, Dd, backend, x0, u0, d0, cst_x, cst_u, cst_d)
190191
return nothing
191192
end
192193

src/model/nonlinmodel.jl

+25-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
struct NonLinModel{
2-
NT<:Real, F<:Function, H<:Function, PT<:Any, DS<:DiffSolver, LF<:Function
2+
NT<:Real,
3+
F<:Function,
4+
H<:Function,
5+
PT<:Any,
6+
DS<:DiffSolver,
7+
JB<:AbstractADType,
8+
LF<:Function
39
} <: SimModel{NT}
410
x0::Vector{NT}
511
f!::F
@@ -21,11 +27,20 @@ struct NonLinModel{
2127
yname::Vector{String}
2228
dname::Vector{String}
2329
xname::Vector{String}
30+
jacobian::JB
2431
linfunc!::LF
2532
buffer::SimModelBuffer{NT}
2633
function NonLinModel{NT}(
27-
f!::F, h!::H, Ts, nu, nx, ny, nd, p::PT, solver::DS, linfunc!::LF
28-
) where {NT<:Real, F<:Function, H<:Function, PT<:Any, DS<:DiffSolver, LF<:Function}
34+
f!::F, h!::H, Ts, nu, nx, ny, nd, p::PT, solver::DS, jacobian::JB, linfunc!::LF
35+
) where {
36+
NT<:Real,
37+
F<:Function,
38+
H<:Function,
39+
PT<:Any,
40+
DS<:DiffSolver,
41+
JB<:AbstractADType,
42+
LF<:Function
43+
}
2944
Ts > 0 || error("Sampling time Ts must be positive")
3045
uop = zeros(NT, nu)
3146
yop = zeros(NT, ny)
@@ -39,7 +54,7 @@ struct NonLinModel{
3954
x0 = zeros(NT, nx)
4055
t = zeros(NT, 1)
4156
buffer = SimModelBuffer{NT}(nu, nx, ny, nd)
42-
return new{NT, F, H, PT, DS, LF}(
57+
return new{NT, F, H, PT, DS, JB, LF}(
4358
x0,
4459
f!, h!,
4560
p,
@@ -48,7 +63,7 @@ struct NonLinModel{
4863
nu, nx, ny, nd,
4964
uop, yop, dop, xop, fop,
5065
uname, yname, dname, xname,
51-
linfunc!,
66+
jacobian, linfunc!,
5267
buffer
5368
)
5469
end
@@ -143,20 +158,20 @@ NonLinModel with a sample time Ts = 2.0 s, empty solver and:
143158
"""
144159
function NonLinModel{NT}(
145160
f::Function, h::Function, Ts::Real, nu::Int, nx::Int, ny::Int, nd::Int=0;
146-
p=NT[], solver=RungeKutta(4)
161+
p=NT[], solver=RungeKutta(4), jacobian=AutoForwardDiff()
147162
) where {NT<:Real}
148163
isnothing(solver) && (solver=EmptySolver())
149164
f!, h! = get_mutating_functions(NT, f, h)
150165
f!, h! = get_solver_functions(NT, solver, f!, h!, Ts, nu, nx, ny, nd)
151-
linfunc! = get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p)
152-
return NonLinModel{NT}(f!, h!, Ts, nu, nx, ny, nd, p, solver, linfunc!)
166+
linfunc! = get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p, jacobian)
167+
return NonLinModel{NT}(f!, h!, Ts, nu, nx, ny, nd, p, solver, jacobian, linfunc!)
153168
end
154169

155170
function NonLinModel(
156171
f::Function, h::Function, Ts::Real, nu::Int, nx::Int, ny::Int, nd::Int=0;
157-
p=Float64[], solver=RungeKutta(4)
172+
p=Float64[], solver=RungeKutta(4), jacobian=AutoForwardDiff()
158173
)
159-
return NonLinModel{Float64}(f, h, Ts, nu, nx, ny, nd; p, solver)
174+
return NonLinModel{Float64}(f, h, Ts, nu, nx, ny, nd; p, solver, jacobian)
160175
end
161176

162177
"Get the mutating functions `f!` and `h!` from the provided functions in argument."

0 commit comments

Comments
 (0)