diff --git a/perf/HNN/Project.toml b/perf/HNN/Project.toml new file mode 100644 index 0000000000..de982f7dc8 --- /dev/null +++ b/perf/HNN/Project.toml @@ -0,0 +1,15 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[sources] +Reactant = {path = "../.."} diff --git a/perf/HNN/main.jl b/perf/HNN/main.jl new file mode 100644 index 0000000000..96b92c0f81 --- /dev/null +++ b/perf/HNN/main.jl @@ -0,0 +1,167 @@ +using Lux, + Random, + Reactant, + Enzyme, + Zygote, + BenchmarkTools, + LuxCUDA, + DataFrames, + OrderedCollections, + CSV, + Comonicon + +struct HamiltonianNN{E,M} <: AbstractLuxWrapperLayer{:model} + model::M + + HamiltonianNN{E}(model::M) where {E,M} = new{E,M}(model) +end + +function (hnn::HamiltonianNN{false})(x::AbstractArray, ps, st) + model = StatefulLuxLayer{true}(hnn.model, ps, st) + ∂x = only(Zygote.gradient(sum ∘ model, x)) + n = size(x, ndims(x) - 1) ÷ 2 + y = cat( + selectdim(∂x, ndims(∂x) - 1, (n + 1):(2n)), + selectdim(∂x, ndims(∂x) - 1, 1:n); + dims=Val(ndims(∂x) - 1), + ) + return y, model.st +end + +function (hnn::HamiltonianNN{true})(x::AbstractArray, ps, st) + ∂x = similar(x) + model = StatefulLuxLayer{true}(hnn.model, ps, st) + Enzyme.autodiff(Reverse, Const(sum ∘ model), Duplicated(x, ∂x)) + n = size(x, ndims(x) - 1) ÷ 2 + y = cat( + selectdim(∂x, ndims(∂x) - 1, (n + 1):(2n)), + selectdim(∂x, ndims(∂x) - 1, 1:n); + dims=Val(ndims(∂x) - 1), + ) + return y, model.st +end + +function loss_fn(model, ps, st, x, y) + pred, _ = model(x, ps, st) + return MSELoss()(pred, y) +end + +function ∇zygote_loss_fn(model, ps, st, x, y) + _, dps, _, dx, _ = Zygote.gradient(loss_fn, model, ps, st, x, y) + return dps, dx +end + +function ∇enzyme_loss_fn(model, ps, st, x, y) + _, dps, _, dx, _ = Enzyme.gradient( + Reverse, loss_fn, Const(model), ps, Const(st), x, Const(y) + ) + return dps, dx +end + +function reclaim_fn(backend, reactant) + if backend == "gpu" && !reactant + CUDA.reclaim() + end + GC.gc(true) + return nothing +end + +Comonicon.@main function main(; backend::String="gpu") + @assert backend in ("cpu", "gpu") + + Reactant.set_default_backend(backend) + filename = joinpath(@__DIR__, "results_$(backend).csv") + + @info "Using backend" backend + + cdev = cpu_device() + gdev = backend == "gpu" ? gpu_device(; force=true) : cdev + xdev = reactant_device(; force=true) + + df = DataFrame( + OrderedDict( + "Kind" => [], + "Fwd Vanilla" => [], + "Fwd Reactant" => [], + "Fwd Reactant SpeedUp" => [], + "Bwd Zygote" => [], + "Bwd Reactant" => [], + "Bwd Reactant SpeedUp" => [], + ), + ) + + mlp = Chain( + Dense(32, 128, gelu), + Dense(128, 128, gelu), + Dense(128, 128, gelu), + Dense(128, 128, gelu), + Dense(128, 1), + ) + + model_enz = HamiltonianNN{true}(mlp) + model_zyg = HamiltonianNN{false}(mlp) + + ps, st = Lux.setup(Random.default_rng(), model_enz) + + x = randn(Float32, 32, 1024) + y = randn(Float32, 32, 1024) + + x_gdev = gdev(x) + y_gdev = gdev(y) + x_xdev = xdev(x) + y_xdev = xdev(y) + + ps_gdev, st_gdev = gdev((ps, st)) + ps_xdev, st_xdev = xdev((ps, st)) + + @info "Compiling Forward Functions" + lfn_compiled = @compile sync = true loss_fn(model_enz, ps_xdev, st_xdev, x_xdev, y_xdev) + + @info "Running Forward Benchmarks" + + t_gdev = @belapsed CUDA.@sync(loss_fn($model_zyg, $ps_gdev, $st_gdev, $x_gdev, $y_gdev)) setup = (reclaim_fn( + $backend, false + )) + + t_xdev = @belapsed $lfn_compiled($model_enz, $ps_xdev, $st_xdev, $x_xdev, $y_xdev) setup = (reclaim_fn( + $backend, true + )) + + @info "Forward Benchmarks" t_gdev t_xdev + + @info "Compiling Backward Functions" + grad_fn_compiled = @compile sync = true ∇enzyme_loss_fn( + model_enz, ps_xdev, st_xdev, x_xdev, y_xdev + ) + + @info "Running Backward Benchmarks" + + t_rev_gdev = @belapsed CUDA.@sync( + ∇zygote_loss_fn($model_zyg, $ps_gdev, $st_gdev, $x_gdev, $y_gdev) + ) setup = (reclaim_fn($backend, false)) + + t_rev_xdev = @belapsed $grad_fn_compiled( + $model_enz, $ps_xdev, $st_xdev, $x_xdev, $y_xdev + ) setup = (reclaim_fn($backend, true)) + + @info "Backward Benchmarks" t_rev_gdev t_rev_xdev + + push!( + df, + [ + "HNN", + t_gdev, + t_xdev, + t_gdev / t_xdev, + t_rev_gdev, + t_rev_xdev, + t_rev_gdev / t_rev_xdev, + ], + ) + + display(df) + CSV.write(filename, df) + + @info "Results saved to $filename" + return nothing +end diff --git a/perf/HNN/results_cpu.csv b/perf/HNN/results_cpu.csv new file mode 100644 index 0000000000..b8dfa433df --- /dev/null +++ b/perf/HNN/results_cpu.csv @@ -0,0 +1,2 @@ +Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp +HNN,0.012209751,0.002101077,5.811186834180757,0.173089096,0.004597676,37.64708430955117 diff --git a/perf/HNN/results_gpu.csv b/perf/HNN/results_gpu.csv new file mode 100644 index 0000000000..3bcd1c5e70 --- /dev/null +++ b/perf/HNN/results_gpu.csv @@ -0,0 +1,2 @@ +Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp +HNN,0.000681027,8.4721e-5,8.038467440186022,0.003330234,0.00012123,27.470378619153674 diff --git a/perf/KAN/Project.toml b/perf/KAN/Project.toml new file mode 100644 index 0000000000..f06a075e4b --- /dev/null +++ b/perf/KAN/Project.toml @@ -0,0 +1,16 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +KolmogorovArnold = "eec8b66d-f71a-4a43-b228-0fe5d6721cd3" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[sources] +Reactant = {path = "../.."} diff --git a/perf/KAN/main.jl b/perf/KAN/main.jl new file mode 100644 index 0000000000..7837352e37 --- /dev/null +++ b/perf/KAN/main.jl @@ -0,0 +1,221 @@ +# This has been adapted from https://github.com/vpuri3/KolmogorovArnold.jl/blob/38616fc66b3c5c1550afa7c718a0629608def19b/examples/eg3.jl + +using KolmogorovArnold, + Lux, + Random, + Reactant, + Enzyme, + Zygote, + BenchmarkTools, + LuxCUDA, + DataFrames, + OrderedCollections, + CSV, + Comonicon + +function loss_fn(model, ps, st, x, y) + pred, _ = model(x, ps, st) + return MSELoss()(pred, y) +end + +function ∇zygote_loss_fn(model, ps, st, x, y) + _, dps, _, dx, _ = Zygote.gradient(loss_fn, model, ps, st, x, y) + return dps, dx +end + +function ∇enzyme_loss_fn(model, ps, st, x, y) + _, dps, _, dx, _ = Enzyme.gradient( + Reverse, loss_fn, Const(model), ps, Const(st), x, Const(y) + ) + return dps, dx +end + +function reclaim_fn(backend, reactant) + if backend == "gpu" && !reactant + CUDA.reclaim() + end + GC.gc(true) + return nothing +end + +Comonicon.@main function main(; backend::String="gpu") + @assert backend in ("cpu", "gpu") + + Reactant.set_default_backend(backend) + filename = joinpath(@__DIR__, "results_$(backend).csv") + + @info "Using backend" backend + + cdev = cpu_device() + gdev = backend == "gpu" ? gpu_device(; force=true) : cdev + xdev = reactant_device(; force=true) + + df = DataFrame( + OrderedDict( + "Kind" => [], + "Fwd Vanilla" => [], + "Fwd Reactant" => [], + "Fwd Reactant SpeedUp" => [], + "Bwd Zygote" => [], + "Bwd Reactant" => [], + "Bwd Reactant SpeedUp" => [], + ), + ) + + x = randn(Float32, 1, 1024) + x_gdev = gdev(x) + x_xdev = xdev(x) + + y_gdev = x_gdev .^ 2 + y_xdev = x_xdev .^ 2 + + wM = 128 + wK = 40 + G = 10 + + mlp = Chain(Dense(1, wM, tanh), Dense(wM, wK, tanh), Dense(wK, 1)) + + basis_func = rbf + normalizer = softsign + + kan1 = Chain( + KDense(1, wK, G; use_base_act=true, basis_func, normalizer), + KDense(wK, wK, G; use_base_act=true, basis_func, normalizer), + KDense(wK, 1, G; use_base_act=true, basis_func, normalizer), + ) + + kan2 = Chain( + KDense(1, wK, G; use_base_act=false, basis_func, normalizer), + KDense(wK, wK, G; use_base_act=false, basis_func, normalizer), + KDense(wK, 1, G; use_base_act=false, basis_func, normalizer), + ) + + ps_mlp, st_mlp = Lux.setup(Random.default_rng(), mlp) + ps_kan1, st_kan1 = Lux.setup(Random.default_rng(), kan1) + ps_kan2, st_kan2 = Lux.setup(Random.default_rng(), kan2) + + ps_mlp_gdev, st_mlp_gdev = gdev((ps_mlp, st_mlp)) + ps_kan1_gdev, st_kan1_gdev = gdev((ps_kan1, st_kan1)) + ps_kan2_gdev, st_kan2_gdev = gdev((ps_kan2, st_kan2)) + + ps_mlp_xdev, st_mlp_xdev = xdev((ps_mlp, st_mlp)) + ps_kan1_xdev, st_kan1_xdev = xdev((ps_kan1, st_kan1)) + ps_kan2_xdev, st_kan2_xdev = xdev((ps_kan2, st_kan2)) + + @info "Compiling Forward Functions" + lfn_mlp_compiled = @compile sync = true loss_fn( + mlp, ps_mlp_xdev, st_mlp_xdev, x_xdev, y_xdev + ) + lfn_kan1_compiled = @compile sync = true loss_fn( + kan1, ps_kan1_xdev, st_kan1_xdev, x_xdev, y_xdev + ) + lfn_kan2_compiled = @compile sync = true loss_fn( + kan2, ps_kan2_xdev, st_kan2_xdev, x_xdev, y_xdev + ) + + @info "Running Forward Benchmarks" + + tmlp_gdev = @belapsed CUDA.@sync( + loss_fn($mlp, $ps_mlp_gdev, $st_mlp_gdev, $x_gdev, $y_gdev) + ) setup = (reclaim_fn($backend, false)) + tkan1_gdev = @belapsed CUDA.@sync( + loss_fn($kan1, $ps_kan1_gdev, $st_kan1_gdev, $x_gdev, $y_gdev) + ) setup = (reclaim_fn($backend, false)) + tkan2_gdev = @belapsed CUDA.@sync( + loss_fn($kan2, $ps_kan2_gdev, $st_kan2_gdev, $x_gdev, $y_gdev) + ) setup = (reclaim_fn($backend, false)) + + @info "Vanilla Forward Benchmarks" tmlp_gdev tkan1_gdev tkan2_gdev + + tmlp_xdev = @belapsed $lfn_mlp_compiled( + $mlp, $ps_mlp_xdev, $st_mlp_xdev, $x_xdev, $y_xdev + ) setup = (reclaim_fn($backend, true)) + tkan1_xdev = @belapsed $lfn_kan1_compiled( + $kan1, $ps_kan1_xdev, $st_kan1_xdev, $x_xdev, $y_xdev + ) setup = (reclaim_fn($backend, true)) + tkan2_xdev = @belapsed $lfn_kan2_compiled( + $kan2, $ps_kan2_xdev, $st_kan2_xdev, $x_xdev, $y_xdev + ) setup = (reclaim_fn($backend, true)) + + @info "Reactant Forward Benchmarks" tmlp_xdev tkan1_xdev tkan2_xdev + + @info "Compiling Backward Functions" + grad_fn_mlp_compiled = @compile sync = true ∇enzyme_loss_fn( + mlp, ps_mlp_xdev, st_mlp_xdev, x_xdev, y_xdev + ) + grad_fn_kan1_compiled = @compile sync = true ∇enzyme_loss_fn( + kan1, ps_kan1_xdev, st_kan1_xdev, x_xdev, y_xdev + ) + grad_fn_kan2_compiled = @compile sync = true ∇enzyme_loss_fn( + kan2, ps_kan2_xdev, st_kan2_xdev, x_xdev, y_xdev + ) + + @info "Running Backward Benchmarks" + + tmlp_rev_gdev = @belapsed CUDA.@sync( + ∇zygote_loss_fn($mlp, $ps_mlp_gdev, $st_mlp_gdev, $x_gdev, $y_gdev) + ) setup = (reclaim_fn($backend, false)) + tkan1_rev_gdev = @belapsed CUDA.@sync( + ∇zygote_loss_fn($kan1, $ps_kan1_gdev, $st_kan1_gdev, $x_gdev, $y_gdev) + ) setup = (reclaim_fn($backend, false)) + tkan2_rev_gdev = @belapsed CUDA.@sync( + ∇zygote_loss_fn($kan2, $ps_kan2_gdev, $st_kan2_gdev, $x_gdev, $y_gdev) + ) setup = (reclaim_fn($backend, false)) + + @info "Zygote Backward Benchmarks" tmlp_rev_gdev tkan1_rev_gdev tkan2_rev_gdev + + tmlp_rev_xdev = @belapsed $grad_fn_mlp_compiled( + $mlp, $ps_mlp_xdev, $st_mlp_xdev, $x_xdev, $y_xdev + ) setup = (reclaim_fn($backend, true)) + tkan1_rev_xdev = @belapsed $grad_fn_kan1_compiled( + $kan1, $ps_kan1_xdev, $st_kan1_xdev, $x_xdev, $y_xdev + ) setup = (reclaim_fn($backend, true)) + tkan2_rev_xdev = @belapsed $grad_fn_kan2_compiled( + $kan2, $ps_kan2_xdev, $st_kan2_xdev, $x_xdev, $y_xdev + ) setup = (reclaim_fn($backend, true)) + + @info "Reactant Backward Benchmarks" tmlp_rev_xdev tkan1_rev_xdev tkan2_rev_xdev + + push!( + df, + [ + "MLP", + tmlp_gdev, + tmlp_xdev, + tmlp_gdev / tmlp_xdev, + tmlp_rev_gdev, + tmlp_rev_xdev, + tmlp_rev_gdev / tmlp_rev_xdev, + ], + ) + push!( + df, + [ + "KAN1", + tkan1_gdev, + tkan1_xdev, + tkan1_gdev / tkan1_xdev, + tkan1_rev_gdev, + tkan1_rev_xdev, + tkan1_rev_gdev / tkan1_rev_xdev, + ], + ) + push!( + df, + [ + "KAN2", + tkan2_gdev, + tkan2_xdev, + tkan2_gdev / tkan2_xdev, + tkan2_rev_gdev, + tkan2_rev_xdev, + tkan2_rev_gdev / tkan2_rev_xdev, + ], + ) + + display(df) + CSV.write(filename, df) + + @info "Results saved to $filename" + return nothing +end diff --git a/perf/KAN/results_cpu.csv b/perf/KAN/results_cpu.csv new file mode 100644 index 0000000000..88ab0d2387 --- /dev/null +++ b/perf/KAN/results_cpu.csv @@ -0,0 +1,4 @@ +Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp +MLP,0.000504229,0.000517288,0.9747548754272283,0.001275736,0.00091808,1.389569536423841 +KAN1,0.00637513,0.002134405,2.986841766206507,0.015504896,0.004843588,3.201117848999543 +KAN2,0.005651889,0.001934863,2.9210796836778625,0.014421515,0.004603253,3.1328964538772905 diff --git a/perf/KAN/results_gpu.csv b/perf/KAN/results_gpu.csv new file mode 100644 index 0000000000..a592f1f856 --- /dev/null +++ b/perf/KAN/results_gpu.csv @@ -0,0 +1,4 @@ +Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp +MLP,0.000161916,7.392e-5,2.190422077922078,0.000502497,8.827e-5,5.692726860768098 +KAN1,0.00035248,8.2417e-5,4.276787555965396,0.001132969,0.000111215,10.187195971766398 +KAN2,0.000260817,7.7983e-5,3.344536629778285,0.000888154,0.000103864,8.551124547485173 diff --git a/src/Compiler.jl b/src/Compiler.jl index 7bc4f29fa7..f99fb556f5 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -196,6 +196,8 @@ function optimization_passes(; no_nan::Bool=false) "slice_pad<1>", "dot_reshape_dot<1>", "concat_const_prop<1>", + "log_const_prop<1>", + "log_plus_one_const_prop<1>", "concat_fuse<1>", "pad_reshape_pad<1>", "pad_pad<1>", @@ -265,6 +267,7 @@ function optimization_passes(; no_nan::Bool=false) "if_to_select<1>", "dynamic_update_slice_const_prop", "dynamic_gather_op_is_not_dynamic<16>", + "divide_sqrt_to_multiply_rsqrt<16>", "binary_op_transpose_simplify_add", "binary_op_transpose_simplify_sub", "binary_op_transpose_simplify_mul",