Skip to content

Commit b9aee60

Browse files
committed
perf: add HNN (2nd order AD) benchmarks
[skip ci]
1 parent 88acd71 commit b9aee60

File tree

5 files changed

+187
-0
lines changed

5 files changed

+187
-0
lines changed

perf/HNN/Project.toml

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[deps]
2+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
4+
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
5+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
6+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
7+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
8+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
9+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
10+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
12+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
13+
14+
[sources]
15+
Reactant = {path = "../.."}

perf/HNN/main.jl

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
using Lux,
2+
Random,
3+
Reactant,
4+
Enzyme,
5+
Zygote,
6+
BenchmarkTools,
7+
LuxCUDA,
8+
DataFrames,
9+
OrderedCollections,
10+
CSV,
11+
Comonicon
12+
13+
struct HamiltonianNN{E,M} <: AbstractLuxWrapperLayer{:model}
14+
model::M
15+
16+
HamiltonianNN{E}(model::M) where {E,M} = new{E,M}(model)
17+
end
18+
19+
function (hnn::HamiltonianNN{false})(x::AbstractArray, ps, st)
20+
model = StatefulLuxLayer{true}(hnn.model, ps, st)
21+
∂x = only(Zygote.gradient(sum model, x))
22+
n = size(x, ndims(x) - 1) ÷ 2
23+
y = cat(
24+
selectdim(∂x, ndims(∂x) - 1, (n + 1):(2n)),
25+
selectdim(∂x, ndims(∂x) - 1, 1:n);
26+
dims=Val(ndims(∂x) - 1),
27+
)
28+
return y, model.st
29+
end
30+
31+
function (hnn::HamiltonianNN{true})(x::AbstractArray, ps, st)
32+
∂x = similar(x)
33+
model = StatefulLuxLayer{true}(hnn.model, ps, st)
34+
Enzyme.autodiff(Reverse, Const(sum model), Duplicated(x, ∂x))
35+
n = size(x, ndims(x) - 1) ÷ 2
36+
y = cat(
37+
selectdim(∂x, ndims(∂x) - 1, (n + 1):(2n)),
38+
selectdim(∂x, ndims(∂x) - 1, 1:n);
39+
dims=Val(ndims(∂x) - 1),
40+
)
41+
return y, model.st
42+
end
43+
44+
function loss_fn(model, ps, st, x, y)
45+
pred, _ = model(x, ps, st)
46+
return MSELoss()(pred, y)
47+
end
48+
49+
function ∇zygote_loss_fn(model, ps, st, x, y)
50+
_, dps, _, dx, _ = Zygote.gradient(loss_fn, model, ps, st, x, y)
51+
return dps, dx
52+
end
53+
54+
function ∇enzyme_loss_fn(model, ps, st, x, y)
55+
_, dps, _, dx, _ = Enzyme.gradient(
56+
Reverse, loss_fn, Const(model), ps, Const(st), x, Const(y)
57+
)
58+
return dps, dx
59+
end
60+
61+
function reclaim_fn(backend, reactant)
62+
if backend == "gpu" && !reactant
63+
CUDA.reclaim()
64+
end
65+
GC.gc(true)
66+
return nothing
67+
end
68+
69+
Comonicon.@main function main(; backend::String="gpu")
70+
@assert backend in ("cpu", "gpu")
71+
72+
Reactant.set_default_backend(backend)
73+
filename = joinpath(@__DIR__, "results_$(backend).csv")
74+
75+
@info "Using backend" backend
76+
77+
cdev = cpu_device()
78+
gdev = backend == "gpu" ? gpu_device(; force=true) : cdev
79+
xdev = reactant_device(; force=true)
80+
81+
df = DataFrame(
82+
OrderedDict(
83+
"Kind" => [],
84+
"Fwd Vanilla" => [],
85+
"Fwd Reactant" => [],
86+
"Fwd Reactant SpeedUp" => [],
87+
"Bwd Zygote" => [],
88+
"Bwd Reactant" => [],
89+
"Bwd Reactant SpeedUp" => [],
90+
),
91+
)
92+
93+
mlp = Chain(
94+
Dense(32, 128, gelu),
95+
Dense(128, 128, gelu),
96+
Dense(128, 128, gelu),
97+
Dense(128, 128, gelu),
98+
Dense(128, 1),
99+
)
100+
101+
model_enz = HamiltonianNN{true}(mlp)
102+
model_zyg = HamiltonianNN{false}(mlp)
103+
104+
ps, st = Lux.setup(Random.default_rng(), model_enz)
105+
106+
x = randn(Float32, 32, 1024)
107+
y = randn(Float32, 32, 1024)
108+
109+
x_gdev = gdev(x)
110+
y_gdev = gdev(y)
111+
x_xdev = xdev(x)
112+
y_xdev = xdev(y)
113+
114+
ps_gdev, st_gdev = gdev((ps, st))
115+
ps_xdev, st_xdev = xdev((ps, st))
116+
117+
@info "Compiling Forward Functions"
118+
lfn_compiled = @compile sync = true loss_fn(model_enz, ps_xdev, st_xdev, x_xdev, y_xdev)
119+
120+
@info "Running Forward Benchmarks"
121+
122+
t_gdev = @belapsed CUDA.@sync(loss_fn($model_zyg, $ps_gdev, $st_gdev, $x_gdev, $y_gdev)) setup = (reclaim_fn(
123+
$backend, false
124+
))
125+
126+
t_xdev = @belapsed $lfn_compiled($model_enz, $ps_xdev, $st_xdev, $x_xdev, $y_xdev) setup = (reclaim_fn(
127+
$backend, true
128+
))
129+
130+
@info "Forward Benchmarks" t_gdev t_xdev
131+
132+
@info "Compiling Backward Functions"
133+
grad_fn_compiled = @compile sync = true ∇enzyme_loss_fn(
134+
model_enz, ps_xdev, st_xdev, x_xdev, y_xdev
135+
)
136+
137+
@info "Running Backward Benchmarks"
138+
139+
t_rev_gdev = @belapsed CUDA.@sync(
140+
∇zygote_loss_fn($model_zyg, $ps_gdev, $st_gdev, $x_gdev, $y_gdev)
141+
) setup = (reclaim_fn($backend, false))
142+
143+
t_rev_xdev = @belapsed $grad_fn_compiled(
144+
$model_enz, $ps_xdev, $st_xdev, $x_xdev, $y_xdev
145+
) setup = (reclaim_fn($backend, true))
146+
147+
@info "Backward Benchmarks" t_rev_gdev t_rev_xdev
148+
149+
push!(
150+
df,
151+
[
152+
"HNN",
153+
t_gdev,
154+
t_xdev,
155+
t_gdev / t_xdev,
156+
t_rev_gdev,
157+
t_rev_xdev,
158+
t_rev_gdev / t_rev_xdev,
159+
],
160+
)
161+
162+
display(df)
163+
CSV.write(filename, df)
164+
165+
@info "Results saved to $filename"
166+
return nothing
167+
end

perf/HNN/results_cpu.csv

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp
2+
HNN,0.012209751,0.002101077,5.811186834180757,0.173089096,0.004597676,37.64708430955117

perf/HNN/results_gpu.csv

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp
2+
HNN,0.000681027,8.4721e-5,8.038467440186022,0.003330234,0.00012123,27.470378619153674

perf/KAN/main.jl

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ function reclaim_fn(backend, reactant)
3535
CUDA.reclaim()
3636
end
3737
GC.gc(true)
38+
return nothing
3839
end
3940

4041
Comonicon.@main function main(; backend::String="gpu")

0 commit comments

Comments
 (0)