Skip to content

Commit c98db2a

Browse files
add parallelism tutorial (#465)
1 parent f215d96 commit c98db2a

File tree

3 files changed

+279
-3
lines changed

3 files changed

+279
-3
lines changed

Diff for: docs/make.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ makedocs(
1818
"examples/lotka_volterra.md",
1919
"examples/delay_diffeq.md",
2020
"examples/pde_constrained.md",
21-
],
21+
],
2222
"Neural ODE and SDE Tutorials" => Any[
2323
"examples/neural_ode_sciml.md",
2424
"examples/neural_ode_flux.md",
@@ -27,7 +27,8 @@ makedocs(
2727
"examples/augmented_neural_ode.md",
2828
"examples/collocation.md",
2929
"examples/neural_gde.md",
30-
"examples/normalizing_flows.md"],
30+
"examples/normalizing_flows.md"
31+
],
3132
"Bayesian Estimation Tutorials" => Any[
3233
"examples/turing_bayesian.md",
3334
"examples/BayesianNODE_NUTS.md",
@@ -36,6 +37,7 @@ makedocs(
3637
"FAQ, Tips, and Tricks" => Any[
3738
"examples/local_minima.md",
3839
"examples/multiple_nn.md",
40+
"examples/data_parallel.md",
3941
"examples/second_order_neural.md",
4042
"examples/second_order_adjoints.md",
4143
"examples/minibatch.md",

Diff for: docs/src/examples/data_parallel.md

+274
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# Data-Parallel Multithreaded, Distributed, and Multi-GPU Batching
2+
3+
DiffEqFlux.jl allows for data-parallel batching optimally on one
4+
computer, across an entire compute cluster, and batching along GPUs.
5+
This can be done by parallelizing within an ODE solve or between the
6+
ODE solves. The automatic differentiation tooling is compatible with
7+
the parallelism. The following examples demonstrate training over a few
8+
different modes of parallelism. These examples are not exhaustive.
9+
10+
## Within-ODE Multithreaded and GPU Batching
11+
12+
We end by noting that there is an alternative way of batching which
13+
can be more efficient in some cases like neural ODEs. With a neural
14+
networks, columns are treated independently (by the properties of
15+
matrix multiplication). Thus for example, with `FastChain` we can
16+
define an ODE:
17+
18+
```julia
19+
using DiffEqFlux, OrdinaryDiffEq
20+
21+
dudt = FastChain(FastDense(2,50,tanh),FastDense(50,2))
22+
p = initial_params(dudt)
23+
f(u,p,t) = dudt(u,p)
24+
```
25+
26+
and we can solve this ODE where the initial condition is a vector:
27+
28+
```julia
29+
u0 = Float32[2.; 0.]
30+
prob = ODEProblem(f,u0,(0f0,1f0),p)
31+
solve(prob,Tsit5())
32+
```
33+
34+
or we can solve this ODE where the initial condition is a matrix, where
35+
each column is an independent system:
36+
37+
```julia
38+
u0 = Float32.([0 1 2
39+
0 0 0])
40+
prob = ODEProblem(f,u0,(0f0,1f0),p)
41+
solve(prob,Tsit5())
42+
```
43+
44+
On the CPU this will multithread across the system (due to BLAS) and
45+
on GPUs this will parallelize the operations across the GPU. To GPU
46+
this, you'd simply move the parameters and the initial condition to the
47+
GPU:
48+
49+
```julia
50+
xs = Float32.([0 1 2
51+
0 0 0])
52+
prob = ODEProblem(f,gpu(u0),(0f0,1f0),gpu(p))
53+
solve(prob,Tsit5())
54+
```
55+
56+
This method of parallelism is optimal if all of the operations are
57+
linear algebra operations such as a neural ODE. Thus this method of
58+
parallelism is demonstrated in the [MNIST tutorial](@ref mnist).
59+
60+
However, this method of parallelism has many limitations. First of all,
61+
the ODE function is required to be written in a way that is independent
62+
across the columns. Not all ODEs are written like this, so one needs to
63+
be careful. But additionally, this method is ineffective if the ODE
64+
function has many serial operations, like `u[1]*u[2] - u[3]`. In such
65+
a case, this indexing behavior will dominate the runtime and cause the
66+
parallelism to sometimes even be detrimental.
67+
68+
# Out of ODE Parallelism
69+
70+
Instead of parallelizing within an ODE solve, one can parallelize the
71+
solves to the ODE itself. While this will be less effective on very
72+
large ODEs, like big neural ODE image classifiers, this method be effective
73+
even if the ODE is small or the `f` function is not well-parallelized.
74+
This kind of parallelism is done via the [DifferentialEquations.jl ensemble interface](https://diffeq.sciml.ai/stable/features/ensemble/). The following examples
75+
showcase multithreaded, cluster, and (multi)GPU parallelism through this
76+
interface.
77+
78+
## Multithreaded Batching At a Glance
79+
80+
The following is a full copy-paste example for the multithreading.
81+
Distributed and GPU minibatching are described below.
82+
83+
```julia
84+
using OrdinaryDiffEq, DiffEqSensitivity, DiffEqFlux
85+
pa = [1.0]
86+
u0 = [3.0]
87+
θ = [u0;pa]
88+
89+
function model1(θ,ensemble)
90+
prob = ODEProblem((u, p, t) -> 1.01u .* p, [θ[1]], (0.0, 1.0), [θ[2]])
91+
92+
function prob_func(prob, i, repeat)
93+
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
94+
end
95+
96+
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
97+
sim = solve(ensemble_prob, Tsit5(), ensemble, saveat = 0.1, trajectories = 100)
98+
end
99+
100+
# loss function
101+
loss_serial(θ) = sum(abs2,1.0.-Array(model1(θ,EnsembleSerial())))
102+
loss_threaded(θ) = sum(abs2,1.0.-Array(model1(θ,EnsembleThreads())))
103+
104+
cb = function (θ,l) # callback function to observe training
105+
@show l
106+
false
107+
end
108+
109+
opt = ADAM(0.1)
110+
l1 = loss_serial(θ)
111+
res_serial = DiffEqFlux.sciml_train(loss_serial, θ, opt; cb = cb, maxiters=100)
112+
res_threads = DiffEqFlux.sciml_train(loss_threaded, θ, opt; cb = cb, maxiters=100)
113+
```
114+
115+
## Multithreaded Batching In-Depth
116+
117+
In order to make use of the ensemble interface, we need to build an
118+
`EnsembleProblem`. The `prob_func` is the function for determining
119+
the different `DEProblem`s to solve. This is the place where we can
120+
randomly sample initial conditions or pull initial conditions from
121+
an array of batches in order to perform our study. To do this, we
122+
first define a prototype `DEProblem`. Here we use the following
123+
`ODEProblem` as our base:
124+
125+
```julia
126+
prob = ODEProblem((u, p, t) -> 1.01u .* p, [θ[1]], (0.0, 1.0), [θ[2]])
127+
```
128+
129+
In the `prob_func` we define how to build a new problem based on the
130+
base problem. In this case, we want to change `u0` by a constant, i.e.
131+
`0.5 .+ i/100 .* prob.u0` for different trajectories labelled by `i`.
132+
Thus we use the [remake function from the problem interface](https://diffeq.sciml.ai/stable/basics/problem/#Modification-of-problem-types) to do so:
133+
134+
```julia
135+
function prob_func(prob, i, repeat)
136+
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
137+
end
138+
```
139+
140+
We now build the `EnsembleProblem` with this basis:
141+
142+
```julia
143+
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
144+
```
145+
146+
Now to solve an ensemble problem, we need to choose an ensembling
147+
algorithm and choose the number of trajectories to solve. Here let's
148+
solve this in serial with 100 trajectories. Note that `i` will thus run
149+
from `1:100`.
150+
151+
```julia
152+
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100)
153+
```
154+
155+
and thus running in multithreading would be:
156+
157+
```julia
158+
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), saveat = 0.1, trajectories = 100)
159+
```
160+
161+
This whole mechanism is differentiable, so we then put it in a training
162+
loop and it soars. Note that you need to make sure that [Julia's multithreading](https://docs.julialang.org/en/v1/manual/multi-threading/)
163+
is enabled, which you can do via:
164+
165+
```julia
166+
Threads.nthreads()
167+
```
168+
169+
## Distributed Batching Across a Cluster
170+
171+
Changing to distributed computing is very simple as well. The setup is
172+
all the same, except you utilize `EnsembleDistributed` as the ensembler:
173+
174+
```julia
175+
sim = solve(ensemble_prob, Tsit5(), EnsembleDistributed(), saveat = 0.1, trajectories = 100)
176+
```
177+
178+
Note that for this to work you need to ensure that your processes are
179+
already started. For more information on setting up processes and utilizing
180+
a compute cluster, see [the official distributed documentation](https://docs.julialang.org/en/v1/manual/distributed-computing/). The key feature to recognize is that, due to
181+
the message passing required for cluster compute, one needs to ensure
182+
that all of the required functions are defined on the worker processes.
183+
The following is a full example of a distributed batching setup:
184+
185+
```julia
186+
using Distributed
187+
addprocs(4)
188+
189+
@everywhere begin
190+
using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqFlux
191+
function f(u,p,t)
192+
1.01u .* p
193+
end
194+
end
195+
196+
pa = [1.0]
197+
u0 = [3.0]
198+
θ = [u0;pa]
199+
200+
function model1(θ,ensemble)
201+
prob = ODEProblem(f, [θ[1]], (0.0, 1.0), [θ[2]])
202+
203+
function prob_func(prob, i, repeat)
204+
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
205+
end
206+
207+
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
208+
sim = solve(ensemble_prob, Tsit5(), ensemble, saveat = 0.1, trajectories = 100)
209+
end
210+
211+
cb = function (θ,l) # callback function to observe training
212+
@show l
213+
false
214+
end
215+
216+
opt = ADAM(0.1)
217+
loss_distributed(θ) = sum(abs2,1.0.-Array(model1(θ,EnsembleDistributed())))
218+
l1 = loss_distributed(θ)
219+
res_distributed = DiffEqFlux.sciml_train(loss_distributed, θ, opt; cb = cb, maxiters=100)
220+
```
221+
222+
And note that only `addprocs(4)` needs to be changed in order to make
223+
this demo run across a cluster. For more information on adding processes
224+
to a cluster, check out [ClusterManagers.jl](https://github.com/JuliaParallel/ClusterManagers.jl).
225+
226+
## Minibatching Across GPUs with DiffEqGPU
227+
228+
DiffEqGPU.jl allows for generating code parallelizes an ensemble on
229+
generated CUDA kernels. This method is efficient for sufficiently
230+
small (<100 ODE) problems where the significant computational cost
231+
is due to the large number of batch trajectories that need to be
232+
solved. This kernel-building process adds a few restrictions to the
233+
function, such as requiring it has no boundschecking or allocations.
234+
The following is an example of minibatch ensemble parallelism across
235+
a GPU:
236+
237+
```julia
238+
using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqFlux
239+
function f(du,u,p,t)
240+
@inbounds begin
241+
du[1] = 1.01 * u[1] * p[1] * p[2]
242+
end
243+
end
244+
245+
pa = [1.0]
246+
u0 = [3.0]
247+
θ = [u0;pa]
248+
249+
function model1(θ,ensemble)
250+
prob = ODEProblem(f, [θ[1]], (0.0, 1.0), [θ[2]])
251+
252+
function prob_func(prob, i, repeat)
253+
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
254+
end
255+
256+
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
257+
sim = solve(ensemble_prob, Tsit5(), ensemble, saveat = 0.1, trajectories = 100)
258+
end
259+
260+
cb = function (θ,l) # callback function to observe training
261+
@show l
262+
false
263+
end
264+
265+
opt = ADAM(0.1)
266+
loss_gpu(θ) = sum(abs2,1.0.-Array(model1(θ,EnsembleGPUArray())))
267+
l1 = loss_gpu(θ)
268+
res_gpu = DiffEqFlux.sciml_train(loss_gpu, θ, opt; cb = cb, maxiters=100)
269+
```
270+
271+
## Multi-GPU Batching
272+
273+
DiffEqGPU supports batching across multiple GPUs. See [its README](https://github.com/SciML/DiffEqGPU.jl#setting-up-multi-gpu)
274+
for details on setting it up.

Diff for: docs/src/examples/mnist_neural_ode.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# GPU-based MNIST Neural ODE Classifier
1+
# [GPU-based MNIST Neural ODE Classifier](@id mnist)
22

33
Training a classifier for **MNIST** using a neural ordinary differential equation **NN-ODE**
44
on **GPUs** with **Minibatching**.

0 commit comments

Comments
 (0)