Skip to content

Commit 64cebf2

Browse files
committed
RK4 now support supersample and AD
1 parent 6e92824 commit 64cebf2

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

Diff for: src/model/solver.jl

+27-17
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,34 @@ function RungeKutta(order::Int=4; supersample::Int=1)
3838
return RungeKutta(order, supersample)
3939
end
4040

41-
function get_solver_functions(NT::DataType, ::RungeKutta, f!, h!, Ts, _ , nx, _ , _ )
42-
f! = let fc! = f!, Ts=Ts, nx=nx
43-
# k1::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
44-
k1 = zeros(NT, nx)
45-
k2 = zeros(NT, nx)
46-
k3 = zeros(NT, nx)
47-
k4 = zeros(NT, nx)
41+
function get_solver_functions(NT::DataType, solver::RungeKutta, f!, h!, Ts, _ , nx, _ , _ )
42+
f! = let fc! = f!, Ts=(Ts/solver.supersample), nx=nx
43+
xcur_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx))
44+
k1_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx))
45+
k2_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx))
46+
k3_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx))
47+
k4_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx))
4848
f! = function inner_solver(xnext, x, u, d)
49-
xterm = xnext
50-
@. xterm = x
51-
fc!(k1, xterm, u, d)
52-
@. xterm = x + k1 * Ts/2
53-
fc!(k2, xterm, u, d)
54-
@. xterm = x + k2 * Ts/2
55-
fc!(k3, xterm, u, d)
56-
@. xterm = x + k3 * Ts
57-
fc!(k4, xterm, u, d)
58-
@. xnext = x + (k1 + 2k2 + 2k3 + k4)*Ts/6
49+
x1 = x[begin]
50+
xcur = get_tmp(xcur_cache, x1)
51+
k1 = get_tmp(k1_cache, x1)
52+
k2 = get_tmp(k2_cache, x1)
53+
k3 = get_tmp(k3_cache, x1)
54+
k4 = get_tmp(k4_cache, x1)
55+
xcur .= x
56+
for i=1:solver.supersample
57+
xterm = xnext
58+
@. xterm = xcur
59+
fc!(k1, xterm, u, d)
60+
@. xterm = xcur + k1 * Ts/2
61+
fc!(k2, xterm, u, d)
62+
@. xterm = xcur + k2 * Ts/2
63+
fc!(k3, xterm, u, d)
64+
@. xterm = xcur + k3 * Ts
65+
fc!(k4, xterm, u, d)
66+
@. xnext = xcur + (k1 + 2k2 + 2k3 + k4)*Ts/6
67+
@. xcur = xnext
68+
end
5969
return nothing
6070
end
6171
end

0 commit comments

Comments
 (0)