Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: add some static benchmarks #508

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions perf/HNN/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
Reactant = {path = "../.."}
167 changes: 167 additions & 0 deletions perf/HNN/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
using Lux,
Random,
Reactant,
Enzyme,
Zygote,
BenchmarkTools,
LuxCUDA,
DataFrames,
OrderedCollections,
CSV,
Comonicon

struct HamiltonianNN{E,M} <: AbstractLuxWrapperLayer{:model}
model::M

HamiltonianNN{E}(model::M) where {E,M} = new{E,M}(model)
end

function (hnn::HamiltonianNN{false})(x::AbstractArray, ps, st)
model = StatefulLuxLayer{true}(hnn.model, ps, st)
∂x = only(Zygote.gradient(sum ∘ model, x))
n = size(x, ndims(x) - 1) ÷ 2
y = cat(
selectdim(∂x, ndims(∂x) - 1, (n + 1):(2n)),
selectdim(∂x, ndims(∂x) - 1, 1:n);
dims=Val(ndims(∂x) - 1),
)
return y, model.st
end

function (hnn::HamiltonianNN{true})(x::AbstractArray, ps, st)
∂x = similar(x)
model = StatefulLuxLayer{true}(hnn.model, ps, st)
Enzyme.autodiff(Reverse, Const(sum ∘ model), Duplicated(x, ∂x))
n = size(x, ndims(x) - 1) ÷ 2
y = cat(
selectdim(∂x, ndims(∂x) - 1, (n + 1):(2n)),
selectdim(∂x, ndims(∂x) - 1, 1:n);
dims=Val(ndims(∂x) - 1),
)
return y, model.st
end

function loss_fn(model, ps, st, x, y)
pred, _ = model(x, ps, st)
return MSELoss()(pred, y)
end

function ∇zygote_loss_fn(model, ps, st, x, y)
_, dps, _, dx, _ = Zygote.gradient(loss_fn, model, ps, st, x, y)
return dps, dx
end

function ∇enzyme_loss_fn(model, ps, st, x, y)
_, dps, _, dx, _ = Enzyme.gradient(
Reverse, loss_fn, Const(model), ps, Const(st), x, Const(y)
)
return dps, dx
end

function reclaim_fn(backend, reactant)
if backend == "gpu" && !reactant
CUDA.reclaim()
end
GC.gc(true)
return nothing
end

Comonicon.@main function main(; backend::String="gpu")
@assert backend in ("cpu", "gpu")

Reactant.set_default_backend(backend)
filename = joinpath(@__DIR__, "results_$(backend).csv")

@info "Using backend" backend

cdev = cpu_device()
gdev = backend == "gpu" ? gpu_device(; force=true) : cdev
xdev = reactant_device(; force=true)

df = DataFrame(
OrderedDict(
"Kind" => [],
"Fwd Vanilla" => [],
"Fwd Reactant" => [],
"Fwd Reactant SpeedUp" => [],
"Bwd Zygote" => [],
"Bwd Reactant" => [],
"Bwd Reactant SpeedUp" => [],
),
)

mlp = Chain(
Dense(32, 128, gelu),
Dense(128, 128, gelu),
Dense(128, 128, gelu),
Dense(128, 128, gelu),
Dense(128, 1),
)

model_enz = HamiltonianNN{true}(mlp)
model_zyg = HamiltonianNN{false}(mlp)

ps, st = Lux.setup(Random.default_rng(), model_enz)

x = randn(Float32, 32, 1024)
y = randn(Float32, 32, 1024)

x_gdev = gdev(x)
y_gdev = gdev(y)
x_xdev = xdev(x)
y_xdev = xdev(y)

ps_gdev, st_gdev = gdev((ps, st))
ps_xdev, st_xdev = xdev((ps, st))

@info "Compiling Forward Functions"
lfn_compiled = @compile sync = true loss_fn(model_enz, ps_xdev, st_xdev, x_xdev, y_xdev)

@info "Running Forward Benchmarks"

t_gdev = @belapsed CUDA.@sync(loss_fn($model_zyg, $ps_gdev, $st_gdev, $x_gdev, $y_gdev)) setup = (reclaim_fn(
$backend, false
))

t_xdev = @belapsed $lfn_compiled($model_enz, $ps_xdev, $st_xdev, $x_xdev, $y_xdev) setup = (reclaim_fn(
$backend, true
))

@info "Forward Benchmarks" t_gdev t_xdev

@info "Compiling Backward Functions"
grad_fn_compiled = @compile sync = true ∇enzyme_loss_fn(
model_enz, ps_xdev, st_xdev, x_xdev, y_xdev
)

@info "Running Backward Benchmarks"

t_rev_gdev = @belapsed CUDA.@sync(
∇zygote_loss_fn($model_zyg, $ps_gdev, $st_gdev, $x_gdev, $y_gdev)
) setup = (reclaim_fn($backend, false))

t_rev_xdev = @belapsed $grad_fn_compiled(
$model_enz, $ps_xdev, $st_xdev, $x_xdev, $y_xdev
) setup = (reclaim_fn($backend, true))

@info "Backward Benchmarks" t_rev_gdev t_rev_xdev

push!(
df,
[
"HNN",
t_gdev,
t_xdev,
t_gdev / t_xdev,
t_rev_gdev,
t_rev_xdev,
t_rev_gdev / t_rev_xdev,
],
)

display(df)
CSV.write(filename, df)

@info "Results saved to $filename"
return nothing
end
2 changes: 2 additions & 0 deletions perf/HNN/results_cpu.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp
HNN,0.012209751,0.002101077,5.811186834180757,0.173089096,0.004597676,37.64708430955117
2 changes: 2 additions & 0 deletions perf/HNN/results_gpu.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp
HNN,0.000681027,8.4721e-5,8.038467440186022,0.003330234,0.00012123,27.470378619153674
16 changes: 16 additions & 0 deletions perf/KAN/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
KolmogorovArnold = "eec8b66d-f71a-4a43-b228-0fe5d6721cd3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
Reactant = {path = "../.."}
221 changes: 221 additions & 0 deletions perf/KAN/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# This has been adapted from https://github.com/vpuri3/KolmogorovArnold.jl/blob/38616fc66b3c5c1550afa7c718a0629608def19b/examples/eg3.jl

using KolmogorovArnold,
Lux,
Random,
Reactant,
Enzyme,
Zygote,
BenchmarkTools,
LuxCUDA,
DataFrames,
OrderedCollections,
CSV,
Comonicon

function loss_fn(model, ps, st, x, y)
pred, _ = model(x, ps, st)
return MSELoss()(pred, y)
end

function ∇zygote_loss_fn(model, ps, st, x, y)
_, dps, _, dx, _ = Zygote.gradient(loss_fn, model, ps, st, x, y)
return dps, dx
end

function ∇enzyme_loss_fn(model, ps, st, x, y)
_, dps, _, dx, _ = Enzyme.gradient(
Reverse, loss_fn, Const(model), ps, Const(st), x, Const(y)
)
return dps, dx
end

function reclaim_fn(backend, reactant)
if backend == "gpu" && !reactant
CUDA.reclaim()
end
GC.gc(true)
return nothing
end

Comonicon.@main function main(; backend::String="gpu")
@assert backend in ("cpu", "gpu")

Reactant.set_default_backend(backend)
filename = joinpath(@__DIR__, "results_$(backend).csv")

@info "Using backend" backend

cdev = cpu_device()
gdev = backend == "gpu" ? gpu_device(; force=true) : cdev
xdev = reactant_device(; force=true)

df = DataFrame(
OrderedDict(
"Kind" => [],
"Fwd Vanilla" => [],
"Fwd Reactant" => [],
"Fwd Reactant SpeedUp" => [],
"Bwd Zygote" => [],
"Bwd Reactant" => [],
"Bwd Reactant SpeedUp" => [],
),
)

x = randn(Float32, 1, 1024)
x_gdev = gdev(x)
x_xdev = xdev(x)

y_gdev = x_gdev .^ 2
y_xdev = x_xdev .^ 2

wM = 128
wK = 40
G = 10

mlp = Chain(Dense(1, wM, tanh), Dense(wM, wK, tanh), Dense(wK, 1))

basis_func = rbf
normalizer = softsign

kan1 = Chain(
KDense(1, wK, G; use_base_act=true, basis_func, normalizer),
KDense(wK, wK, G; use_base_act=true, basis_func, normalizer),
KDense(wK, 1, G; use_base_act=true, basis_func, normalizer),
)

kan2 = Chain(
KDense(1, wK, G; use_base_act=false, basis_func, normalizer),
KDense(wK, wK, G; use_base_act=false, basis_func, normalizer),
KDense(wK, 1, G; use_base_act=false, basis_func, normalizer),
)

ps_mlp, st_mlp = Lux.setup(Random.default_rng(), mlp)
ps_kan1, st_kan1 = Lux.setup(Random.default_rng(), kan1)
ps_kan2, st_kan2 = Lux.setup(Random.default_rng(), kan2)

ps_mlp_gdev, st_mlp_gdev = gdev((ps_mlp, st_mlp))
ps_kan1_gdev, st_kan1_gdev = gdev((ps_kan1, st_kan1))
ps_kan2_gdev, st_kan2_gdev = gdev((ps_kan2, st_kan2))

ps_mlp_xdev, st_mlp_xdev = xdev((ps_mlp, st_mlp))
ps_kan1_xdev, st_kan1_xdev = xdev((ps_kan1, st_kan1))
ps_kan2_xdev, st_kan2_xdev = xdev((ps_kan2, st_kan2))

@info "Compiling Forward Functions"
lfn_mlp_compiled = @compile sync = true loss_fn(
mlp, ps_mlp_xdev, st_mlp_xdev, x_xdev, y_xdev
)
lfn_kan1_compiled = @compile sync = true loss_fn(
kan1, ps_kan1_xdev, st_kan1_xdev, x_xdev, y_xdev
)
lfn_kan2_compiled = @compile sync = true loss_fn(
kan2, ps_kan2_xdev, st_kan2_xdev, x_xdev, y_xdev
)

@info "Running Forward Benchmarks"

tmlp_gdev = @belapsed CUDA.@sync(
loss_fn($mlp, $ps_mlp_gdev, $st_mlp_gdev, $x_gdev, $y_gdev)
) setup = (reclaim_fn($backend, false))
tkan1_gdev = @belapsed CUDA.@sync(
loss_fn($kan1, $ps_kan1_gdev, $st_kan1_gdev, $x_gdev, $y_gdev)
) setup = (reclaim_fn($backend, false))
tkan2_gdev = @belapsed CUDA.@sync(
loss_fn($kan2, $ps_kan2_gdev, $st_kan2_gdev, $x_gdev, $y_gdev)
) setup = (reclaim_fn($backend, false))

@info "Vanilla Forward Benchmarks" tmlp_gdev tkan1_gdev tkan2_gdev

tmlp_xdev = @belapsed $lfn_mlp_compiled(
$mlp, $ps_mlp_xdev, $st_mlp_xdev, $x_xdev, $y_xdev
) setup = (reclaim_fn($backend, true))
tkan1_xdev = @belapsed $lfn_kan1_compiled(
$kan1, $ps_kan1_xdev, $st_kan1_xdev, $x_xdev, $y_xdev
) setup = (reclaim_fn($backend, true))
tkan2_xdev = @belapsed $lfn_kan2_compiled(
$kan2, $ps_kan2_xdev, $st_kan2_xdev, $x_xdev, $y_xdev
) setup = (reclaim_fn($backend, true))

@info "Reactant Forward Benchmarks" tmlp_xdev tkan1_xdev tkan2_xdev

@info "Compiling Backward Functions"
grad_fn_mlp_compiled = @compile sync = true ∇enzyme_loss_fn(
mlp, ps_mlp_xdev, st_mlp_xdev, x_xdev, y_xdev
)
grad_fn_kan1_compiled = @compile sync = true ∇enzyme_loss_fn(
kan1, ps_kan1_xdev, st_kan1_xdev, x_xdev, y_xdev
)
grad_fn_kan2_compiled = @compile sync = true ∇enzyme_loss_fn(
kan2, ps_kan2_xdev, st_kan2_xdev, x_xdev, y_xdev
)

@info "Running Backward Benchmarks"

tmlp_rev_gdev = @belapsed CUDA.@sync(
∇zygote_loss_fn($mlp, $ps_mlp_gdev, $st_mlp_gdev, $x_gdev, $y_gdev)
) setup = (reclaim_fn($backend, false))
tkan1_rev_gdev = @belapsed CUDA.@sync(
∇zygote_loss_fn($kan1, $ps_kan1_gdev, $st_kan1_gdev, $x_gdev, $y_gdev)
) setup = (reclaim_fn($backend, false))
tkan2_rev_gdev = @belapsed CUDA.@sync(
∇zygote_loss_fn($kan2, $ps_kan2_gdev, $st_kan2_gdev, $x_gdev, $y_gdev)
) setup = (reclaim_fn($backend, false))

@info "Zygote Backward Benchmarks" tmlp_rev_gdev tkan1_rev_gdev tkan2_rev_gdev

tmlp_rev_xdev = @belapsed $grad_fn_mlp_compiled(
$mlp, $ps_mlp_xdev, $st_mlp_xdev, $x_xdev, $y_xdev
) setup = (reclaim_fn($backend, true))
tkan1_rev_xdev = @belapsed $grad_fn_kan1_compiled(
$kan1, $ps_kan1_xdev, $st_kan1_xdev, $x_xdev, $y_xdev
) setup = (reclaim_fn($backend, true))
tkan2_rev_xdev = @belapsed $grad_fn_kan2_compiled(
$kan2, $ps_kan2_xdev, $st_kan2_xdev, $x_xdev, $y_xdev
) setup = (reclaim_fn($backend, true))

@info "Reactant Backward Benchmarks" tmlp_rev_xdev tkan1_rev_xdev tkan2_rev_xdev

push!(
df,
[
"MLP",
tmlp_gdev,
tmlp_xdev,
tmlp_gdev / tmlp_xdev,
tmlp_rev_gdev,
tmlp_rev_xdev,
tmlp_rev_gdev / tmlp_rev_xdev,
],
)
push!(
df,
[
"KAN1",
tkan1_gdev,
tkan1_xdev,
tkan1_gdev / tkan1_xdev,
tkan1_rev_gdev,
tkan1_rev_xdev,
tkan1_rev_gdev / tkan1_rev_xdev,
],
)
push!(
df,
[
"KAN2",
tkan2_gdev,
tkan2_xdev,
tkan2_gdev / tkan2_xdev,
tkan2_rev_gdev,
tkan2_rev_xdev,
tkan2_rev_gdev / tkan2_rev_xdev,
],
)

display(df)
CSV.write(filename, df)

@info "Results saved to $filename"
return nothing
end
4 changes: 4 additions & 0 deletions perf/KAN/results_cpu.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp
MLP,0.000504229,0.000517288,0.9747548754272283,0.001275736,0.00091808,1.389569536423841
KAN1,0.00637513,0.002134405,2.986841766206507,0.015504896,0.004843588,3.201117848999543
KAN2,0.005651889,0.001934863,2.9210796836778625,0.014421515,0.004603253,3.1328964538772905
4 changes: 4 additions & 0 deletions perf/KAN/results_gpu.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp
MLP,0.000161916,7.392e-5,2.190422077922078,0.000502497,8.827e-5,5.692726860768098
KAN1,0.00035248,8.2417e-5,4.276787555965396,0.001132969,0.000111215,10.187195971766398
KAN2,0.000260817,7.7983e-5,3.344536629778285,0.000888154,0.000103864,8.551124547485173
3 changes: 3 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
@@ -196,6 +196,8 @@ function optimization_passes(; no_nan::Bool=false)
"slice_pad<1>",
"dot_reshape_dot<1>",
"concat_const_prop<1>",
"log_const_prop<1>",
"log_plus_one_const_prop<1>",
"concat_fuse<1>",
"pad_reshape_pad<1>",
"pad_pad<1>",
@@ -265,6 +267,7 @@ function optimization_passes(; no_nan::Bool=false)
"if_to_select<1>",
"dynamic_update_slice_const_prop",
"dynamic_gather_op_is_not_dynamic<16>",
"divide_sqrt_to_multiply_rsqrt<16>",
"binary_op_transpose_simplify_add",
"binary_op_transpose_simplify_sub",
"binary_op_transpose_simplify_mul",