Skip to content

Commit e7a14c5

Browse files
authored
Merge pull request #24 from pzimbrod/23-Implement-DeepONet
🆕 initial DeepONet implementation
2 parents ac0fd84 + 7f6de3f commit e7a14c5

File tree

8 files changed

+289
-10
lines changed

8 files changed

+289
-10
lines changed

README.md

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
A Package that provides Layers for the learning of (nonlinear) operators in order to solve parametric PDEs.
1414

15-
For now, this package contains the Fourier Neural Operator originally proposed by Li et al.
15+
For now, this package contains the Fourier Neural Operator originally proposed by Li et al [1] as well as the DeepONet conceived by Lu et al [2].
1616

17-
I decided to implement this method in Julia because coding up a layer using PyTorch in Python is rather cumbersome in comparison and Julia as a whole simply runs at comparable or faster speed than Python. Please do check out the [original work](https://github.com/zongyi-li/fourier_neural_operator) at GitHub as well.
17+
I decided to implement this method in Julia because coding up a layer using PyTorch in Python is rather cumbersome in comparison and Julia as a whole simply runs at comparable or faster speed than Python.
1818

1919
The implementation of the layers is influenced heavily by the basic layers provided in the [Flux.jl](https://github.com/FluxML/Flux.jl) package.
2020

@@ -28,6 +28,8 @@ pkg> add OperatorLearning
2828

2929
## Usage/Examples
3030

31+
### Fourier Layer
32+
3133
The basic workflow is more or less in line with the layer architectures that `Flux` provides, i.e. you construct individual layers, chain them if desired and pass the inputs as arguments to the layers.
3234

3335
The Fourier Layer performs a linear transform as well as convolution (linear transform in fourier space), adds them and passes it through the activation.
@@ -47,11 +49,34 @@ model = FourierLayer(101, 101, 100, 16, σ)
4749
model = FourierLayer(101, 101, 100, 16, σ; bias_fourier=false)
4850
```
4951

50-
To see a full implementation, check the Burgers equation example at `examples/burgers.jl`.
52+
To see a full implementation, check the Burgers equation example at `examples/burgers_FNO.jl`.
5153
Compared to the original implementation by [Li et al.](https://github.com/zongyi-li/fourier_neural_operator/blob/master/fourier_1d.py) using PyTorch, this version written in Julia clocks in about 20 - 25% faster when running on a NVIDIA RTX A5000 GPU.
5254

5355
If you'd like to replicate the example, you need to get the dataset for learning the Burgers equation. You can get it [here](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-) or alternatively use the provided [scripts](https://github.com/zongyi-li/fourier_neural_operator/tree/master/data_generation/burgers).
5456

57+
### DeepONet
58+
59+
The `DeepONet` function basically sets up two separate Flux `Chain` structs and transforms the two input arrays into one via einsum/dot product.
60+
61+
You can either set up a "vanilla" DeepONet via the constructor function which sets up `Dense` layers for you or, if you feel fancy, pass two Chains directly to the function so you can use other architectures such as CNN or RNN as well.
62+
The former takes two tuples that describe each architecture. E.g. `(32,64,72)` sets up a DNN with 32 neurons in the first, 64 in the second and 72 in the last layer.
63+
64+
```julia
65+
using OperatorLearning
66+
using Flux
67+
68+
# Create a DeepONet with branch 32 -> 64 -> 72 and sigmoid activation
69+
# and trunk 24 -> 64 -> 72 and tanh activation without biases
70+
model = DeepONet((32,64,72), (24,64,72), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)
71+
72+
# Alternatively, set up your own nets altogether and pass them to DeepONet
73+
branch = Chain(Dense(2,128),Dense(128,64),Dense(64,72))
74+
trunk = Chain(Dense(1,24),Dense(24,72))
75+
model = DeepONet(branch,trunk)
76+
```
77+
78+
For usage, check the Burgers equation example at `examples/burgers_DeepONet.jl`.
79+
5580
## License
5681

5782
[MIT](https://choosealicense.com/licenses/mit/)
@@ -60,7 +85,7 @@ If you'd like to replicate the example, you need to get the dataset for learning
6085

6186
- [x] 1D Fourier Layer
6287
- [ ] 2D / 3D Fourier Layer
63-
- [ ] DeepONet
88+
- [x] DeepONet
6489
- [ ] Physics informed Loss
6590

6691
## Contributing
@@ -69,4 +94,6 @@ Contributions are always welcome! Please submit a PR if you'd like to participat
6994

7095
## References
7196

72-
- Li et al., 2020 [arXiv:2010.08895](https://arxiv.org/abs/2010.08895)
97+
[1] Z. Li et al., „Fourier Neural Operator for Parametric Partial Differential Equations“, [arXiv:2010.08895](https://arxiv.org/abs/2010.08895) [cs, math], May 2021
98+
99+
[2] L. Lu, P. Jin, and G. E. Karniadakis, „DeepONet: Learning nonlinear operators for identifying differential equations based on the universal approximation theorem of operators“, [arXiv:1910.03193](http://arxiv.org/abs/1910.03193) [cs, stat], Apr. 2020

examples/burgers_DeepONet.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using Flux: length, reshape, train!, throttle, @epochs
2+
using OperatorLearning, Flux, MAT
3+
4+
device = cpu;
5+
6+
#=
7+
We would like to implement and train a DeepONet that infers the solution
8+
u(x) of the burgers equation on a grid of 1024 points at time one based
9+
on the initial condition a(x) = u(x,0)
10+
=#
11+
12+
# Read the data from MAT file and store it in a dict
13+
# key "a" is the IC
14+
# key "u" is the desired solution at time 1
15+
vars = matread("burgers_data_R10.mat") |> device
16+
17+
# For trial purposes, we might want to train with different resolutions
18+
# So we sample only every n-th element
19+
subsample = 2^3;
20+
21+
# create the x training array, according to our desired grid size
22+
xtrain = vars["a"][1:1000, 1:subsample:end]' |> device;
23+
# create the x test array
24+
xtest = vars["a"][end-99:end, 1:subsample:end]' |> device;
25+
26+
# Create the y training array
27+
ytrain = vars["u"][1:1000, 1:subsample:end] |> device;
28+
# Create the y test array
29+
ytest = vars["u"][end-99:end, 1:subsample:end] |> device;
30+
31+
# The data is missing grid data, so we create it
32+
# `collect` converts data type `range` into an array
33+
grid = collect(range(0, 1, length=1024))' |> device
34+
35+
# Pass the data to the Flux DataLoader and give it a batch of 20
36+
#train_loader = Flux.Data.DataLoader((xtrain, ytrain), batchsize=20, shuffle=true) |> device
37+
#test_loader = Flux.Data.DataLoader((xtest, ytest), batchsize=20, shuffle=false) |> device
38+
39+
# Create the DeepONet:
40+
# IC is given on grid of 1024 points, and we solve for a fixed time t in one
41+
# spatial dimension x, making the branch input of size 1024 and trunk size 1
42+
# We choose GeLU activation for both subnets
43+
model = DeepONet((1024,1024,1024),(1,1024,1024),gelu,gelu) |> device
44+
45+
# We use the ADAM optimizer for training
46+
learning_rate = 0.001
47+
opt = ADAM(learning_rate)
48+
49+
# Specify the model parameters
50+
parameters = params(model)
51+
52+
# The loss function
53+
# We can't use the "vanilla" implementation of the mse here since we have
54+
# two distinct inputs to our DeepONet, so we wrap them into a tuple
55+
loss(xtrain,ytrain,sensor) = Flux.Losses.mse(model(xtrain,sensor),ytrain)
56+
57+
# Define a callback function that gives some output during training
58+
evalcb() = @show(loss(xtest,ytest,grid))
59+
# Print the callback only every 5 seconds
60+
throttled_cb = throttle(evalcb, 5)
61+
62+
# Do the training loop
63+
Flux.@epochs 500 train!(loss, parameters, [(xtrain,ytrain,grid)], opt, cb = evalcb)

examples/burgers.jl renamed to examples/burgers_FNO.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Flux: length, reshape, train!, @epochs
1+
using Flux: length, reshape, train!, throttle, @epochs
22
using OperatorLearning, Flux, MAT
33

44
device = gpu;
@@ -74,10 +74,12 @@ parameters = params(model)
7474
loss(x,y) = Flux.Losses.mse(model(x),y)
7575

7676
# Define a callback function that gives some output during training
77-
evalcb() = @show(loss(x,y))
77+
evalcb() = @show(loss(xtest,ytest))
78+
# Print the callback only every 5 seconds,
79+
throttled_cb = throttle(evalcb, 5)
7880

7981
# Do the training loop
80-
Flux.@epochs 500 train!(loss, parameters, train_loader, opt, cb = evalcb)
82+
Flux.@epochs 500 train!(loss, parameters, train_loader, opt, cb = throttled_cb)
8183

8284
# Accuracy metrics
8385
val_loader = Flux.Data.DataLoader((xtest, ytest), batchsize=1, shuffle=false) |> device
@@ -86,4 +88,4 @@ loss = 0.0 |> device
8688
for (x,y) in val_loader
8789
= model(x)
8890
loss += Flux.Losses.mse(ŷ,y)
89-
end
91+
end

src/DeepONet.jl

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
`DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
3+
act_branch = identity, act_trunk = identity;
4+
init_branch = Flux.glorot_uniform,
5+
init_trunk = Flux.glorot_uniform,
6+
bias_branch=true, bias_trunk=true)`
7+
`DeepONet(branch_net::Flux.Chain, trunk_net::Flux.Chain)`
8+
9+
Create an (unstacked) DeepONet architecture as proposed by Lu et al.
10+
arXiv:1910.03193
11+
12+
The model works as follows:
13+
14+
x --- branch --
15+
|
16+
-⊠--u-
17+
|
18+
y --- trunk ---
19+
20+
Where `x` represents the input function, discretely evaluated at its respective sensors. So the ipnut is of shape [m] for one instance or [m x b] for a training set.
21+
`y` are the probing locations for the operator to be trained. It has shape [N x n] for N different variables in the PDE (i.e. spatial and temporal coordinates) with each n distinct evaluation points.
22+
`u` is the solution of the queried instance of the PDE, given by the specific choice of parameters.
23+
24+
Both inputs `x` and `y` are multiplied together via dot product Σᵢ bᵢⱼ tᵢₖ.
25+
26+
You can set up this architecture in two ways:
27+
28+
1. By Specifying the architecture and all its parameters as given above. This always creates `Dense` layers for the branch and trunk net and corresponds to the DeepONet proposed by Lu et al.
29+
30+
2. By passing two architectures in the form of two Chain structs directly. Do this if you want more flexibility and e.g. use an RNN or CNN instead of simple `Dense` layers.
31+
32+
Strictly speaking, DeepONet does not imply either of the branch or trunk net to be a simple DNN. Usually though, this is the case which is why it's treated as the default case here.
33+
34+
# Example
35+
36+
Consider a transient 1D advection problem ∂ₜu + u ⋅ ∇u = 0, with an IC u(x,0) = g(x).
37+
We are given several (b = 200) instances of the IC, discretized at 50 points each and want to query the solution for 100 different locations and times [0;1].
38+
39+
That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100]. So the input for the branch net is 50 and 100 for the trunk net.
40+
41+
# Usage
42+
43+
```julia
44+
julia> model = DeepONet((32,64,72), (24,64,72))
45+
DeepONet with
46+
branch net: (Chain(Dense(32, 64), Dense(64, 72)))
47+
Trunk net: (Chain(Dense(24, 64), Dense(64, 72)))
48+
49+
julia> model = DeepONet((32,64,72), (24,64,72), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)
50+
DeepONet with
51+
branch net: (Chain(Dense(32, 64, σ), Dense(64, 72, σ)))
52+
Trunk net: (Chain(Dense(24, 64, tanh; bias=false), Dense(64, 72, tanh; bias=false)))
53+
54+
julia> branch = Chain(Dense(2,128),Dense(128,64),Dense(64,72))
55+
Chain(
56+
Dense(2, 128), # 384 parameters
57+
Dense(128, 64), # 8_256 parameters
58+
Dense(64, 72), # 4_680 parameters
59+
) # Total: 6 arrays, 13_320 parameters, 52.406 KiB.
60+
61+
julia> trunk = Chain(Dense(1,24),Dense(24,72))
62+
Chain(
63+
Dense(1, 24), # 48 parameters
64+
Dense(24, 72), # 1_800 parameters
65+
) # Total: 4 arrays, 1_848 parameters, 7.469 KiB.
66+
67+
julia> model = DeepONet(branch,trunk)
68+
DeepONet with
69+
branch net: (Chain(Dense(2, 128), Dense(128, 64), Dense(64, 72)))
70+
Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
71+
```
72+
"""
73+
struct DeepONet
74+
branch_net::Flux.Chain
75+
trunk_net::Flux.Chain
76+
end
77+
78+
# Declare the function that assigns Weights and biases to the layer
79+
function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
80+
act_branch = identity, act_trunk = identity;
81+
init_branch = Flux.glorot_uniform,
82+
init_trunk = Flux.glorot_uniform,
83+
bias_branch=true, bias_trunk=true)
84+
85+
@assert architecture_branch[end] == architecture_trunk[end] "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
86+
87+
# To construct the subnets we use the helper function in subnets.jl
88+
# Initialize the branch net
89+
branch_net = construct_subnet(architecture_branch, act_branch;
90+
init=init_branch, bias=bias_branch)
91+
# Initialize the trunk net
92+
trunk_net = construct_subnet(architecture_trunk, act_trunk;
93+
init=init_trunk, bias=bias_trunk)
94+
95+
return DeepONet(branch_net, trunk_net)
96+
end
97+
98+
Flux.@functor DeepONet
99+
100+
#= The actual layer that does stuff
101+
x is the input function, evaluated at m locations (or m x b in case of batches)
102+
y is the array of sensors, i.e. the variables of the output function
103+
with shape (N x n) - N different variables with each n evaluation points =#
104+
function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
105+
# Assign the parameters
106+
branch, trunk = a.branch_net, a.trunk_net
107+
108+
#= Dot product needs a dim to contract
109+
However, we perform the transformations by the NNs always in the first dim
110+
so we need to adjust (i.e. transpose) one of the inputs,
111+
which we do on the branch input here =#
112+
return branch(x)' * trunk(y)
113+
end
114+
115+
# Sensors stay the same and shouldn't be batched
116+
(a::DeepONet)(x::AbstractArray, y::AbstractArray) =
117+
throw(ArgumentError("Sensor locations fed to trunk net can't be batched."))
118+
119+
# Print nicely
120+
function Base.show(io::IO, l::DeepONet)
121+
print(io, "DeepONet with\nbranch net: (",l.branch_net)
122+
print(io, ")\n")
123+
print(io, "Trunk net: (", l.trunk_net)
124+
print(io, ")\n")
125+
end

src/OperatorLearning.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ using Random: AbstractRNG
1010
using Flux: nfan, glorot_uniform, batch
1111
using OMEinsum
1212

13-
export FourierLayer
13+
export FourierLayer, DeepONet
1414

1515
include("FourierLayer.jl")
16+
include("DeepONet.jl")
1617
include("ComplexWeights.jl")
1718
include("batched.jl")
19+
include("subnets.jl")
1820

1921
end # module

src/subnets.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
Construct a Chain of `Dense` layers from a given tuple of integers.
3+
4+
Input:
5+
A tuple (m,n,o,p) of integer type numbers that each describe the width of the i-th Dense layer to Construct
6+
7+
Output:
8+
A `Flux` Chain with length of the input tuple and individual width given by the tuple elements
9+
10+
# Example
11+
12+
```julia
13+
julia> model = OperatorLearning.construct_subnet((2,128,64,32,1))
14+
Chain(
15+
Dense(2, 128), # 384 parameters
16+
Dense(128, 64), # 8_256 parameters
17+
Dense(64, 32), # 2_080 parameters
18+
Dense(32, 1), # 33 parameters
19+
) # Total: 8 arrays, 10_753 parameters, 42.504 KiB.
20+
21+
julia> model([2,1])
22+
1-element Vector{Float32}:
23+
-0.7630446
24+
```
25+
"""
26+
function construct_subnet(architecture::Tuple, σ = identity;
27+
init=Flux.glorot_uniform, bias=true)
28+
# First, create an array that contains all Dense layers independently
29+
# Given n-element architecture constructs n-1 layers
30+
layers = Array{Flux.Dense}(undef, length(architecture)-1)
31+
@inbounds for i 2:length(architecture)
32+
layers[i-1] = Flux.Dense(architecture[i-1], architecture[i], σ;
33+
init=init, bias=bias)
34+
end
35+
36+
# Concatenate the layers to a string, chain them and parse them into
37+
# the Flux Chain constructor syntax
38+
return Meta.parse("Chain("*join(layers,",")*")") |> eval
39+
end

test/deeponet.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using Test, Random, Flux
2+
3+
@testset "DeepONet" begin
4+
@testset "dimensions" begin
5+
# Test the proper construction
6+
# Branch net
7+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).branch_net.layers[end].weight) == (72,64)
8+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).branch_net.layers[end].bias) == (72,)
9+
# Trunk net
10+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).trunk_net.layers[end].weight) == (72,48)
11+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).trunk_net.layers[end].bias) == (72,)
12+
end
13+
14+
# Accept only Int as architecture parameters
15+
@test_throws MethodError DeepONet((32.5,64,72), (24,48,72), σ, tanh)
16+
@test_throws MethodError DeepONet((32,64,72), (24.1,48,72))
17+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Random.seed!(0)
88
include("fourierlayer.jl")
99
end
1010

11+
@testset "DeepONet" begin
12+
include("deeponet.jl")
13+
end
14+
1115
@testset "Weights" begin
1216
include("complexweights.jl")
1317
end

0 commit comments

Comments
 (0)