|
| 1 | +# Data-Parallel Multithreaded, Distributed, and Multi-GPU Batching |
| 2 | + |
| 3 | +DiffEqFlux.jl allows for data-parallel batching optimally on one |
| 4 | +computer, across an entire compute cluster, and batching along GPUs. |
| 5 | +This can be done by parallelizing within an ODE solve or between the |
| 6 | +ODE solves. The automatic differentiation tooling is compatible with |
| 7 | +the parallelism. The following examples demonstrate training over a few |
| 8 | +different modes of parallelism. These examples are not exhaustive. |
| 9 | + |
| 10 | +## Within-ODE Multithreaded and GPU Batching |
| 11 | + |
| 12 | +We end by noting that there is an alternative way of batching which |
| 13 | +can be more efficient in some cases like neural ODEs. With a neural |
| 14 | +networks, columns are treated independently (by the properties of |
| 15 | +matrix multiplication). Thus for example, with `FastChain` we can |
| 16 | +define an ODE: |
| 17 | + |
| 18 | +```julia |
| 19 | +using DiffEqFlux, OrdinaryDiffEq |
| 20 | + |
| 21 | +dudt = FastChain(FastDense(2,50,tanh),FastDense(50,2)) |
| 22 | +p = initial_params(dudt) |
| 23 | +f(u,p,t) = dudt(u,p) |
| 24 | +``` |
| 25 | + |
| 26 | +and we can solve this ODE where the initial condition is a vector: |
| 27 | + |
| 28 | +```julia |
| 29 | +u0 = Float32[2.; 0.] |
| 30 | +prob = ODEProblem(f,u0,(0f0,1f0),p) |
| 31 | +solve(prob,Tsit5()) |
| 32 | +``` |
| 33 | + |
| 34 | +or we can solve this ODE where the initial condition is a matrix, where |
| 35 | +each column is an independent system: |
| 36 | + |
| 37 | +```julia |
| 38 | +u0 = Float32.([0 1 2 |
| 39 | + 0 0 0]) |
| 40 | +prob = ODEProblem(f,u0,(0f0,1f0),p) |
| 41 | +solve(prob,Tsit5()) |
| 42 | +``` |
| 43 | + |
| 44 | +On the CPU this will multithread across the system (due to BLAS) and |
| 45 | +on GPUs this will parallelize the operations across the GPU. To GPU |
| 46 | +this, you'd simply move the parameters and the initial condition to the |
| 47 | +GPU: |
| 48 | + |
| 49 | +```julia |
| 50 | +xs = Float32.([0 1 2 |
| 51 | + 0 0 0]) |
| 52 | +prob = ODEProblem(f,gpu(u0),(0f0,1f0),gpu(p)) |
| 53 | +solve(prob,Tsit5()) |
| 54 | +``` |
| 55 | + |
| 56 | +This method of parallelism is optimal if all of the operations are |
| 57 | +linear algebra operations such as a neural ODE. Thus this method of |
| 58 | +parallelism is demonstrated in the [MNIST tutorial](@ref mnist). |
| 59 | + |
| 60 | +However, this method of parallelism has many limitations. First of all, |
| 61 | +the ODE function is required to be written in a way that is independent |
| 62 | +across the columns. Not all ODEs are written like this, so one needs to |
| 63 | +be careful. But additionally, this method is ineffective if the ODE |
| 64 | +function has many serial operations, like `u[1]*u[2] - u[3]`. In such |
| 65 | +a case, this indexing behavior will dominate the runtime and cause the |
| 66 | +parallelism to sometimes even be detrimental. |
| 67 | + |
| 68 | +# Out of ODE Parallelism |
| 69 | + |
| 70 | +Instead of parallelizing within an ODE solve, one can parallelize the |
| 71 | +solves to the ODE itself. While this will be less effective on very |
| 72 | +large ODEs, like big neural ODE image classifiers, this method be effective |
| 73 | +even if the ODE is small or the `f` function is not well-parallelized. |
| 74 | +This kind of parallelism is done via the [DifferentialEquations.jl ensemble interface](https://diffeq.sciml.ai/stable/features/ensemble/). The following examples |
| 75 | +showcase multithreaded, cluster, and (multi)GPU parallelism through this |
| 76 | +interface. |
| 77 | + |
| 78 | +## Multithreaded Batching At a Glance |
| 79 | + |
| 80 | +The following is a full copy-paste example for the multithreading. |
| 81 | +Distributed and GPU minibatching are described below. |
| 82 | + |
| 83 | +```julia |
| 84 | +using OrdinaryDiffEq, DiffEqSensitivity, DiffEqFlux |
| 85 | +pa = [1.0] |
| 86 | +u0 = [3.0] |
| 87 | +θ = [u0;pa] |
| 88 | + |
| 89 | +function model1(θ,ensemble) |
| 90 | + prob = ODEProblem((u, p, t) -> 1.01u .* p, [θ[1]], (0.0, 1.0), [θ[2]]) |
| 91 | + |
| 92 | + function prob_func(prob, i, repeat) |
| 93 | + remake(prob, u0 = 0.5 .+ i/100 .* prob.u0) |
| 94 | + end |
| 95 | + |
| 96 | + ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) |
| 97 | + sim = solve(ensemble_prob, Tsit5(), ensemble, saveat = 0.1, trajectories = 100) |
| 98 | +end |
| 99 | + |
| 100 | +# loss function |
| 101 | +loss_serial(θ) = sum(abs2,1.0.-Array(model1(θ,EnsembleSerial()))) |
| 102 | +loss_threaded(θ) = sum(abs2,1.0.-Array(model1(θ,EnsembleThreads()))) |
| 103 | + |
| 104 | +cb = function (θ,l) # callback function to observe training |
| 105 | + @show l |
| 106 | + false |
| 107 | +end |
| 108 | + |
| 109 | +opt = ADAM(0.1) |
| 110 | +l1 = loss_serial(θ) |
| 111 | +res_serial = DiffEqFlux.sciml_train(loss_serial, θ, opt; cb = cb, maxiters=100) |
| 112 | +res_threads = DiffEqFlux.sciml_train(loss_threaded, θ, opt; cb = cb, maxiters=100) |
| 113 | +``` |
| 114 | + |
| 115 | +## Multithreaded Batching In-Depth |
| 116 | + |
| 117 | +In order to make use of the ensemble interface, we need to build an |
| 118 | +`EnsembleProblem`. The `prob_func` is the function for determining |
| 119 | +the different `DEProblem`s to solve. This is the place where we can |
| 120 | +randomly sample initial conditions or pull initial conditions from |
| 121 | +an array of batches in order to perform our study. To do this, we |
| 122 | +first define a prototype `DEProblem`. Here we use the following |
| 123 | +`ODEProblem` as our base: |
| 124 | + |
| 125 | +```julia |
| 126 | +prob = ODEProblem((u, p, t) -> 1.01u .* p, [θ[1]], (0.0, 1.0), [θ[2]]) |
| 127 | +``` |
| 128 | + |
| 129 | +In the `prob_func` we define how to build a new problem based on the |
| 130 | +base problem. In this case, we want to change `u0` by a constant, i.e. |
| 131 | +`0.5 .+ i/100 .* prob.u0` for different trajectories labelled by `i`. |
| 132 | +Thus we use the [remake function from the problem interface](https://diffeq.sciml.ai/stable/basics/problem/#Modification-of-problem-types) to do so: |
| 133 | + |
| 134 | +```julia |
| 135 | +function prob_func(prob, i, repeat) |
| 136 | + remake(prob, u0 = 0.5 .+ i/100 .* prob.u0) |
| 137 | +end |
| 138 | +``` |
| 139 | + |
| 140 | +We now build the `EnsembleProblem` with this basis: |
| 141 | + |
| 142 | +```julia |
| 143 | +ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) |
| 144 | +``` |
| 145 | + |
| 146 | +Now to solve an ensemble problem, we need to choose an ensembling |
| 147 | +algorithm and choose the number of trajectories to solve. Here let's |
| 148 | +solve this in serial with 100 trajectories. Note that `i` will thus run |
| 149 | +from `1:100`. |
| 150 | + |
| 151 | +```julia |
| 152 | +sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100) |
| 153 | +``` |
| 154 | + |
| 155 | +and thus running in multithreading would be: |
| 156 | + |
| 157 | +```julia |
| 158 | +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), saveat = 0.1, trajectories = 100) |
| 159 | +``` |
| 160 | + |
| 161 | +This whole mechanism is differentiable, so we then put it in a training |
| 162 | +loop and it soars. Note that you need to make sure that [Julia's multithreading](https://docs.julialang.org/en/v1/manual/multi-threading/) |
| 163 | +is enabled, which you can do via: |
| 164 | + |
| 165 | +```julia |
| 166 | +Threads.nthreads() |
| 167 | +``` |
| 168 | + |
| 169 | +## Distributed Batching Across a Cluster |
| 170 | + |
| 171 | +Changing to distributed computing is very simple as well. The setup is |
| 172 | +all the same, except you utilize `EnsembleDistributed` as the ensembler: |
| 173 | + |
| 174 | +```julia |
| 175 | +sim = solve(ensemble_prob, Tsit5(), EnsembleDistributed(), saveat = 0.1, trajectories = 100) |
| 176 | +``` |
| 177 | + |
| 178 | +Note that for this to work you need to ensure that your processes are |
| 179 | +already started. For more information on setting up processes and utilizing |
| 180 | +a compute cluster, see [the official distributed documentation](https://docs.julialang.org/en/v1/manual/distributed-computing/). The key feature to recognize is that, due to |
| 181 | +the message passing required for cluster compute, one needs to ensure |
| 182 | +that all of the required functions are defined on the worker processes. |
| 183 | +The following is a full example of a distributed batching setup: |
| 184 | + |
| 185 | +```julia |
| 186 | +using Distributed |
| 187 | +addprocs(4) |
| 188 | + |
| 189 | +@everywhere begin |
| 190 | + using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqFlux |
| 191 | + function f(u,p,t) |
| 192 | + 1.01u .* p |
| 193 | + end |
| 194 | +end |
| 195 | + |
| 196 | +pa = [1.0] |
| 197 | +u0 = [3.0] |
| 198 | +θ = [u0;pa] |
| 199 | + |
| 200 | +function model1(θ,ensemble) |
| 201 | + prob = ODEProblem(f, [θ[1]], (0.0, 1.0), [θ[2]]) |
| 202 | + |
| 203 | + function prob_func(prob, i, repeat) |
| 204 | + remake(prob, u0 = 0.5 .+ i/100 .* prob.u0) |
| 205 | + end |
| 206 | + |
| 207 | + ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) |
| 208 | + sim = solve(ensemble_prob, Tsit5(), ensemble, saveat = 0.1, trajectories = 100) |
| 209 | +end |
| 210 | + |
| 211 | +cb = function (θ,l) # callback function to observe training |
| 212 | + @show l |
| 213 | + false |
| 214 | +end |
| 215 | + |
| 216 | +opt = ADAM(0.1) |
| 217 | +loss_distributed(θ) = sum(abs2,1.0.-Array(model1(θ,EnsembleDistributed()))) |
| 218 | +l1 = loss_distributed(θ) |
| 219 | +res_distributed = DiffEqFlux.sciml_train(loss_distributed, θ, opt; cb = cb, maxiters=100) |
| 220 | +``` |
| 221 | + |
| 222 | +And note that only `addprocs(4)` needs to be changed in order to make |
| 223 | +this demo run across a cluster. For more information on adding processes |
| 224 | +to a cluster, check out [ClusterManagers.jl](https://github.com/JuliaParallel/ClusterManagers.jl). |
| 225 | + |
| 226 | +## Minibatching Across GPUs with DiffEqGPU |
| 227 | + |
| 228 | +DiffEqGPU.jl allows for generating code parallelizes an ensemble on |
| 229 | +generated CUDA kernels. This method is efficient for sufficiently |
| 230 | +small (<100 ODE) problems where the significant computational cost |
| 231 | +is due to the large number of batch trajectories that need to be |
| 232 | +solved. This kernel-building process adds a few restrictions to the |
| 233 | +function, such as requiring it has no boundschecking or allocations. |
| 234 | +The following is an example of minibatch ensemble parallelism across |
| 235 | +a GPU: |
| 236 | + |
| 237 | +```julia |
| 238 | +using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqFlux |
| 239 | +function f(du,u,p,t) |
| 240 | + @inbounds begin |
| 241 | + du[1] = 1.01 * u[1] * p[1] * p[2] |
| 242 | + end |
| 243 | +end |
| 244 | + |
| 245 | +pa = [1.0] |
| 246 | +u0 = [3.0] |
| 247 | +θ = [u0;pa] |
| 248 | + |
| 249 | +function model1(θ,ensemble) |
| 250 | + prob = ODEProblem(f, [θ[1]], (0.0, 1.0), [θ[2]]) |
| 251 | + |
| 252 | + function prob_func(prob, i, repeat) |
| 253 | + remake(prob, u0 = 0.5 .+ i/100 .* prob.u0) |
| 254 | + end |
| 255 | + |
| 256 | + ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) |
| 257 | + sim = solve(ensemble_prob, Tsit5(), ensemble, saveat = 0.1, trajectories = 100) |
| 258 | +end |
| 259 | + |
| 260 | +cb = function (θ,l) # callback function to observe training |
| 261 | + @show l |
| 262 | + false |
| 263 | +end |
| 264 | + |
| 265 | +opt = ADAM(0.1) |
| 266 | +loss_gpu(θ) = sum(abs2,1.0.-Array(model1(θ,EnsembleGPUArray()))) |
| 267 | +l1 = loss_gpu(θ) |
| 268 | +res_gpu = DiffEqFlux.sciml_train(loss_gpu, θ, opt; cb = cb, maxiters=100) |
| 269 | +``` |
| 270 | + |
| 271 | +## Multi-GPU Batching |
| 272 | + |
| 273 | +DiffEqGPU supports batching across multiple GPUs. See [its README](https://github.com/SciML/DiffEqGPU.jl#setting-up-multi-gpu) |
| 274 | +for details on setting it up. |
0 commit comments