diff --git a/examples/2-deep-kernel-learning/Project.toml b/examples/2-deep-kernel-learning/Project.toml index 94f146e4..9315aa23 100644 --- a/examples/2-deep-kernel-learning/Project.toml +++ b/examples/2-deep-kernel-learning/Project.toml @@ -12,7 +12,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractGPs = "0.3,0.4,0.5" Distributions = "0.25" -Flux = "0.12, 0.13, 0.14" +Flux = "0.15, 0.16" KernelFunctions = "0.10" Literate = "2" MLDataUtils = "0.5" diff --git a/examples/2-deep-kernel-learning/script.jl b/examples/2-deep-kernel-learning/script.jl index 9ce11164..2ae3b3fe 100644 --- a/examples/2-deep-kernel-learning/script.jl +++ b/examples/2-deep-kernel-learning/script.jl @@ -71,13 +71,14 @@ plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Predicti # ## Training nmax = 200 opt = Flux.Adam(0.1) +state = Flux.setup(opt, ps) anim = Animation() for i in 1:nmax grads = gradient(ps) do loss(y_train) end - Flux.Optimise.update!(opt, ps, grads) + Flux.Optimise.update!(state, ps, grads) if i % 10 == 0 L = loss(y_train)