diff --git a/Project.toml b/Project.toml
index 5f86abc9d1..9cdf88807d 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "Lux"
 uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
 authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
-version = "1.4.2"
+version = "1.4.3"
 
 [deps]
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -83,7 +83,7 @@ Adapt = "4.1"
 ArgCheck = "2.3"
 ArrayInterface = "7.17.1"
 CUDA = "5.3.2"
-ChainRulesCore = "1.24"
+ChainRulesCore = "1.25"
 Compat = "4.16"
 ComponentArrays = "0.15.18"
 ConcreteStructs = "0.2.3"
@@ -106,11 +106,11 @@ MPI = "0.20.19"
 MacroTools = "0.5.13"
 Markdown = "1.10"
 NCCL = "0.1.1"
-NNlib = "0.9.24"
+NNlib = "0.9.26"
 Optimisers = "0.4.1"
 Preferences = "1.4.3"
 Random = "1.10"
-Reactant = "0.2.8"
+Reactant = "0.2.12"
 Reexport = "1.2.2"
 ReverseDiff = "1.15"
 SIMDTypes = "0.1"
diff --git a/docs/Project.toml b/docs/Project.toml
index 3eb44b24ef..0e561d7625 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -33,7 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
 [compat]
 ADTypes = "1.10"
 Adapt = "4"
-ChainRulesCore = "1.24"
+ChainRulesCore = "1.25"
 ComponentArrays = "0.15.18"
 Documenter = "1.4"
 DocumenterVitepress = "0.1.3"
@@ -51,12 +51,12 @@ LuxCore = "1.2"
 LuxLib = "1.3.4"
 LuxTestUtils = "1.5"
 MLDataDevices = "1.6"
-NNlib = "0.9.24"
+NNlib = "0.9.26"
 Optimisers = "0.4.1"
 Pkg = "1.10"
 Printf = "1.10"
 Random = "1.10"
-Reactant = "0.2.8"
+Reactant = "0.2.12"
 StableRNGs = "1"
 StaticArrays = "1"
 WeightInitializers = "1"
diff --git a/docs/make.jl b/docs/make.jl
index fac7081d55..c9f2e98a3c 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -29,7 +29,7 @@ pages = [
             "tutorials/intermediate/1_NeuralODE.md",
             "tutorials/intermediate/2_BayesianNN.md",
             "tutorials/intermediate/3_HyperNet.md",
-            "tutorials/intermediate/4_PINN2DPDE.md"
+            "tutorials/intermediate/4_PINN2DPDE.md",
         ],
         "Advanced" => [
             "tutorials/advanced/1_GravitationalWaveForm.md"
diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl
index 89ce8831fe..6f49b076d0 100644
--- a/ext/LuxReactantExt/LuxReactantExt.jl
+++ b/ext/LuxReactantExt/LuxReactantExt.jl
@@ -2,13 +2,22 @@ module LuxReactantExt
 
 using Enzyme: Enzyme, Const, Duplicated, Active
 using Optimisers: Optimisers
-using Reactant: Reactant, @compile, TracedRArray, TracedRNumber
+using Reactant: Reactant, @compile, AnyTracedRArray, TracedRArray, TracedRNumber
 using Setfield: @set!
 using Static: False
 
-using Lux: Lux, LuxOps, Training
+using Lux: Lux, LuxOps, Training, Utils
 using Lux.Training: TrainingBackendCache, ReactantBackend
 
+Lux.is_extension_loaded(::Val{:Reactant}) = true
+
+Utils.to_rarray(x; kwargs...) = Reactant.to_rarray(x; kwargs...)
+
+function Utils.promote_to(::Type{T}, x::Number) where {T <: Number}
+    x isa Reactant.TracedType && return x
+    return Reactant.ConcreteRNumber{T}(x)
+end
+
 include("patches.jl")
 include("training.jl")
 
diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl
index 8b13789179..f9f4519e0a 100644
--- a/ext/LuxReactantExt/patches.jl
+++ b/ext/LuxReactantExt/patches.jl
@@ -1 +1,4 @@
+Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x))
 
+# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
+Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g
diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl
index 605093e1ea..c35d5cb054 100644
--- a/ext/LuxReactantExt/training.jl
+++ b/ext/LuxReactantExt/training.jl
@@ -1,3 +1,28 @@
+mutable struct StatsAndNewStateWrapper
+    stats::Any
+    st::Any
+end
+
+function wrapped_objective_function(
+        fn::F, model, ps, st, data, cache::StatsAndNewStateWrapper
+) where {F}
+    loss, stₙ, stats = fn(model, ps, st, data)
+    cache.stats = stats
+    cache.st = stₙ
+    return loss
+end
+
+function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
+    stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
+    res = Enzyme.gradient(
+        Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
+        Const(wrapped_objective_function), Const(objective_function),
+        Const(model), ps, Const(st), Const(data), Const(stats_wrapper)
+    )
+    loss, dps = res.val, res.derivs[3]
+    return dps, loss, stats_wrapper.stats, stats_wrapper.st
+end
+
 function Lux.Training.compute_gradients_impl(
         backend::ReactantBackend, objective_function::F,
         data, ts::Training.TrainState) where {F}
@@ -22,18 +47,33 @@ function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data,
     return grads, loss, stats, ts
 end
 
-function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
-    dps = Enzyme.make_zero(ps)
-    _, (loss, stₙ, stats) = Enzyme.autodiff(
-        Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
-        Duplicated(ps, dps), Const(st), Const(data))
-    return dps, loss, stats, stₙ
-end
-
 for inplace in ("!", "")
     fname = Symbol(:single_train_step_impl, inplace)
     internal_fn = Symbol(:compute_gradients_internal_and_step, inplace)
+    apply_gradients_fn = Symbol(:apply_gradients, inplace)
+    update_fn = Symbol(:update, inplace)
+
+    # Ideally users never hit this dispatch but it is still good to have as a fallback
+    @eval function Lux.Training.$(apply_gradients_fn)(
+            ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}}, grads
+    )
+        if hasfield(typeof(ts.cache.extras), :update_function)
+            update_function = ts.cache.extras.update_function
+        else
+            update_function = @compile Optimisers.$(update_fn)(
+                ts.optimizer_state, ts.parameters, grads)
+            @set! ts.cache.extras = merge(ts.cache.extras, (; update_function))
+        end
 
+        opt_state, ps = update_function(ts.optimizer_state, ts.parameters, grads)
+        @set! ts.parameters = ps
+        @set! ts.optimizer_state = opt_state
+        @set! ts.step = ts.step + 1
+        return ts
+    end
+
+    # XXX: Should we add a check to ensure the inputs to this function is same as the one
+    #      used in the compiled function? We can re-trigger the compilation with a warning
     @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
             data, ts::Training.TrainState) where {F}
         compiled_grad_and_step_function = @compile $(internal_fn)(
@@ -68,27 +108,13 @@ for inplace in ("!", "")
 
         return grads, loss, stats, ts
     end
-end
 
-function compute_gradients_internal_and_step(objective_function::F, model, data, ps,
-        st, opt_state) where {F}
-    dps = Enzyme.make_zero(ps)
-    _, (loss, stₙ, stats) = Enzyme.autodiff(
-        Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
-        Const(objective_function), Active, Const(model),
-        Duplicated(ps, dps), Const(st), Const(data))
-    opt_state, ps = Optimisers.update(opt_state, ps, dps)
-    return dps, ps, loss, stats, stₙ, opt_state
-end
-
-function compute_gradients_internal_and_step!(objective_function::F, model, data, ps,
-        st, opt_state) where {F}
-    dps = Enzyme.make_zero(ps)
-    _, (loss, stₙ, stats) = Enzyme.autodiff(
-        Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
-        Const(objective_function), Active, Const(model),
-        Duplicated(ps, dps), Const(st), Const(data))
-    # XXX: Inplace updates not actually inplace
-    opt_state, ps = Optimisers.update!(opt_state, ps, dps)
-    return dps, ps, loss, stats, stₙ, opt_state
+    # XXX: Inplace version not actually inplace
+    @eval function $(internal_fn)(
+            objective_function::F, model, data, ps, st, opt_state) where {F}
+        dps, loss, stats, stₙ = compute_gradients_internal(
+            objective_function, model, data, ps, st)
+        opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
+        return dps, ps, loss, stats, stₙ, opt_state
+    end
 end
diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml
index aa927150be..5c22c0be3f 100644
--- a/lib/LuxLib/Project.toml
+++ b/lib/LuxLib/Project.toml
@@ -1,7 +1,7 @@
 name = "LuxLib"
 uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
 authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
-version = "1.3.10"
+version = "1.3.11"
 
 [deps]
 ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -77,7 +77,7 @@ LuxCore = "1.2"
 MKL = "0.7"
 MLDataDevices = "1.6"
 Markdown = "1.10"
-NNlib = "0.9.24"
+NNlib = "0.9.26"
 Octavian = "0.3.28"
 Preferences = "1.4.3"
 Polyester = "0.7.15"
diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl
index d7b0225937..a7234c5eb4 100644
--- a/lib/LuxLib/ext/LuxLibTrackerExt.jl
+++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl
@@ -97,7 +97,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector),
     Utils.is_tracked(RM, RV, S, B, XT) || continue
 
     @eval Tracker.@grad_from_chainrules LuxLib.Impl.batchnorm_cudnn(
-        γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool)
+        γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m, ϵ, training::StaticBool)
 end
 
 # Utils extensions
diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
index 77e59d3e4b..b35af417bc 100644
--- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
+++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
@@ -21,7 +21,7 @@ include("batchnorm.jl")
 function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, 5}},
         γ::Optional{<:CuVector{T}}, β::Optional{<:CuVector{T}},
         rμ::Optional{<:CuVector{T}}, rσ²::Optional{<:CuVector{T}},
-        training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F}
+        training::StaticBool, σ::F, m, ϵ) where {T <: cuDNNFloat, F}
     rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training)
     y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1]
     return Impl.activation!!(σ, y), safe_vec(rμₙ), safe_vec(rσ²ₙ)
diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl
index 05964f0c6b..bba8a5af27 100644
--- a/lib/LuxLib/src/api/batchnorm.jl
+++ b/lib/LuxLib/src/api/batchnorm.jl
@@ -37,7 +37,7 @@ mean and variance.
 function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector},
         β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
         rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity,
-        momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N}
+        momentum=0.1f0, epsilon=default_epsilon(x)) where {F, T, N}
     σ = select_fastest_activation(act, x, γ, β, rμ, rσ²)
     y, rμ, rσ² = batchnorm_impl(
         x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²),
diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl
index 4e6a7bff86..1053ff9dfe 100644
--- a/lib/LuxLib/src/api/groupnorm.jl
+++ b/lib/LuxLib/src/api/groupnorm.jl
@@ -1,6 +1,6 @@
 @doc doc"""
     groupnorm(x, scale, bias, groups::Int, σ::F=identity,
-        epsilon::Real=eps(eltype(x)) ^ (5 // 7))
+        epsilon=eps(eltype(x)) ^ (5 // 7))
 
 Group Normalization. For details see [1].
 
@@ -30,7 +30,7 @@ The normalized array is returned.
 """
 function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
         bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity,
-        epsilon::Real=default_epsilon(x)) where {F, N}
+        epsilon=default_epsilon(x)) where {F, N}
     assert_valid_groupnorm_arguments(x, scale, bias, groups)
     return groupnorm_impl(
         x, scale, bias, groups, select_fastest_activation(σ, x, scale, bias), epsilon)
diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl
index 1587855242..259e14bb4e 100644
--- a/lib/LuxLib/src/api/instancenorm.jl
+++ b/lib/LuxLib/src/api/instancenorm.jl
@@ -36,7 +36,7 @@ mean and variance.
 """
 function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector},
         β::Optional{<:AbstractVector}, training::TrainingType,
-        σ::F=identity, epsilon::Real=default_epsilon(x)) where {F}
+        σ::F=identity, epsilon=default_epsilon(x)) where {F}
     # This API is kept for legacy purposes when we didn't support passing running stats
     return instancenorm(x, γ, β, nothing, nothing, training, σ, nothing, epsilon)
 end
@@ -44,7 +44,7 @@ end
 function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector},
         β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
         rσ²::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity,
-        momentum::Optional{<:Real}=0.1f0, epsilon::Real=default_epsilon(x)) where {F}
+        momentum::Optional{<:Real}=0.1f0, epsilon=default_epsilon(x)) where {F}
     assert_valid_instancenorm_arguments(x)
 
     y, rμₙ, rσ²ₙ = instancenorm_impl(
diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl
index eb147d30ef..7148fbb0da 100644
--- a/lib/LuxLib/src/api/layernorm.jl
+++ b/lib/LuxLib/src/api/layernorm.jl
@@ -36,7 +36,7 @@ Normalized Array of same size as `x`.
 """
 function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray},
         bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1),
-        epsilon::Real=default_epsilon(x)) where {F, xT, N}
+        epsilon=default_epsilon(x)) where {F, xT, N}
     return layernorm_impl(
         x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon)
 end
diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl
index 995aacf857..d37bee3464 100644
--- a/lib/LuxLib/src/impl/batchnorm.jl
+++ b/lib/LuxLib/src/impl/batchnorm.jl
@@ -27,7 +27,7 @@ CRC.@non_differentiable get_batchnorm_statistics(::Any...)
 function batchnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector},
         β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
         rσ²::Optional{<:AbstractVector}, training::StaticBool, act::F,
-        momentum::Real, ϵ::Real) where {F, xT, N}
+        momentum, ϵ) where {F, xT, N}
     (μ, σ²), (rμ, rσ²) = compute_batch_statistics(
         x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²),
         batchnorm_reduce_dims(x), training, momentum)
@@ -37,7 +37,7 @@ end
 function batchnorm_affine_normalize(
         act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N},
         σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N}
+        β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N}
     return batchnorm_affine_normalize(
         internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ)
 end
@@ -45,7 +45,7 @@ end
 function batchnorm_affine_normalize(
         ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N},
         σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N}
+        β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N}
     return affine_normalize(
         act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ)
 end
@@ -54,7 +54,7 @@ function batchnorm_affine_normalize(
         opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N},
         μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N},
         γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector},
-        ϵ::Real) where {F, xT, μT, σ²T, N}
+        ϵ) where {F, xT, μT, σ²T, N}
     x′ = reshape(x, :, size(x, N - 1), size(x, N))
     return reshape(
         batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ),
@@ -64,7 +64,7 @@ end
 @stable default_mode="disable" function batchnorm_affine_normalize_internal(
         opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 3},
         μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT}
+        β::Optional{<:AbstractVector}, ϵ) where {F, xT}
     y = similar(x,
         promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
             safe_eltype(γ), safe_eltype(β)))
@@ -75,7 +75,7 @@ end
 function batchnorm_affine_normalize_internal!(
         y::AbstractArray{yT, 3}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 3},
         μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real,
+        β::Optional{<:AbstractVector}, ϵ,
         γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT}
     N = size(y, 2)
     γ′ = γ′ === nothing ?
@@ -225,7 +225,7 @@ end
 function batchnorm_affine_normalize_internal!(
         y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3},
         μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real,
+        β::Optional{<:AbstractVector}, ϵ,
         γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT}
     backend = KA.get_backend(y)
     run_ka_kernel(
@@ -278,7 +278,7 @@ function CRC.rrule(
         cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm_affine_normalize_internal),
         opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{T, N},
         μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N}
+        β::Optional{<:AbstractVector}, ϵ) where {F, T, N}
     y = similar(x,
         promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
             safe_eltype(γ), safe_eltype(β)))
@@ -304,7 +304,7 @@ end
 
 function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{∂yT, 3},
         x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector,
-        γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real,
+        γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ,
         γ′::AbstractVector) where {∂yT, xT}
     ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²)
     ∂γ = γ === nothing ? nothing : similar(γ)
@@ -322,7 +322,7 @@ function ∇batchnorm_affine_normalize_cpu!(
         ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT},
         ∂σ²::AbstractVector{∂σ²T}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 3},
         x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing,
-        ϵ::Real, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
+        ϵ, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
     half = eltype(∂σ²)(0.5)
 
     fill!(∂μ, 0)
@@ -361,7 +361,7 @@ function ∇batchnorm_affine_normalize_cpu!(
         ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT},
         ∂σ²::AbstractVector{∂σ²T}, ∂γ::AbstractVector{∂γT},
         ∂β::AbstractVector{∂βT}, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3},
-        μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real,
+        μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ,
         γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT}
     half = eltype(∂σ²)(0.5)
 
@@ -406,7 +406,7 @@ end
 function ∇batchnorm_affine_normalize(
         opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 3},
         x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector,
-        γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real,
+        γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ,
         γ′::AbstractVector) where {∂yT, xT}
     ∂x, ∂σ² = similar(x), similar(σ², size(x))
     ∂γ = γ === nothing ? nothing : similar(γ, size(x))
@@ -425,7 +425,7 @@ function ∇batchnorm_affine_normalize!(
         ∂x::AbstractArray{∂xT, 3}, ∂σ²::AbstractArray{∂σ²T, 3},
         ∂γ::Optional{<:AbstractArray{<:Any, 3}}, ::GPUBroadcastOp,
         ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector,
-        σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real,
+        σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ,
         γ′::AbstractVector) where {∂xT, ∂σ²T, ∂yT, xT}
     backend = KA.get_backend(∂x)
     run_ka_kernel(
diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl
index 5b4248291f..10dda2f69e 100644
--- a/lib/LuxLib/src/impl/dropout.jl
+++ b/lib/LuxLib/src/impl/dropout.jl
@@ -62,22 +62,22 @@ function alpha_dropout(noise::AbstractArray, p, x::AbstractArray, α, A, B)
 end
 
 @stable default_mode="disable" function alpha_dropout(
-        ::AbstractInternalArrayOpMode, noise::AbstractArray, p::Real,
-        x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T}
+        ::AbstractInternalArrayOpMode, noise::AbstractArray, p,
+        x::AbstractArray{T}, α, A, B) where {T}
     A′, B′, α = T(A), T(B), T(α)
     return @. muladd(ifelse(noise > p, x, α), A′, B′)
 end
 
 @stable default_mode="disable" function alpha_dropout(
-        opmode::LoopedArrayOp, noise::AbstractArray, p::Real,
-        x::AbstractArray, α::Real, A::Real, B::Real)
+        opmode::LoopedArrayOp, noise::AbstractArray, p,
+        x::AbstractArray, α, A, B)
     res = similar(x, promote_type(typeof(p), typeof(α)))
     alpha_dropout!(res, opmode, noise, p, x, α, A, B)
     return res
 end
 
 function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArray,
-        p::Real, x::AbstractArray, α::Real, A::Real, B::Real)
+        p, x::AbstractArray, α, A, B)
     cond = similar(noise, Bool)
     y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x)))
     @simd ivdep for I in eachindex(noise, x, y, cond)
@@ -99,7 +99,7 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra
 end
 
 function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode,
-        noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real)
+        noise::AbstractArray, p, x::AbstractArray, α, A, B)
     cond = noise .> p
     y = @. ifelse(cond, x, α) * A + B
 
@@ -114,7 +114,7 @@ end
 
 function alpha_dropout!(
         res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T},
-        p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T}
+        p, x::AbstractArray{T}, α, A, B) where {T}
     @simd ivdep for I in eachindex(noise, x, res)
         res[I] = ifelse(noise[I] > p, x[I], α) * A + B
     end
diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl
index 9a64fd7350..df52b8508b 100644
--- a/lib/LuxLib/src/impl/groupnorm.jl
+++ b/lib/LuxLib/src/impl/groupnorm.jl
@@ -3,7 +3,7 @@ groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 1
 CRC.@non_differentiable groupnorm_reduce_dims(::Any)
 
 function groupnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ::Real) where {F, N, xT}
+        β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ) where {F, N, xT}
     x′ = reshape(x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N))
     (μ, σ²), _ = compute_batch_statistics(
         x′, nothing, nothing, groupnorm_reduce_dims(x), False(), nothing)
@@ -13,7 +13,7 @@ end
 function groupnorm_affine_normalize(
         act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N},
         σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T}
+        β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T}
     return groupnorm_affine_normalize(
         internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ)
 end
@@ -21,7 +21,7 @@ end
 function groupnorm_affine_normalize(
         ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N},
         σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T}
+        β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T}
     return affine_normalize(
         act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ)
 end
@@ -29,7 +29,7 @@ end
 @generated function groupnorm_affine_normalize(
         opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N},
         μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
-        β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T}
+        β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T}
     reshape_calls = if γ != Nothing
         quote
             γ′ = reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1)
@@ -57,7 +57,7 @@ end
         opmode::AbstractInternalArrayOpMode, act::F,
         x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4},
         γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}},
-        ϵ::Real) where {F, xT, μT, σ²T}
+        ϵ) where {F, xT, μT, σ²T}
     y = similar(x,
         promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
             safe_eltype(γ), safe_eltype(β)))
@@ -69,7 +69,7 @@ function groupnorm_affine_normalize_internal!(
         y::AbstractArray{yT, 4}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 4},
         μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4},
         γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}},
-        ϵ::Real) where {F, xT, yT, μT, σ²T}
+        ϵ) where {F, xT, yT, μT, σ²T}
     if unsafe_known(fuse_cpu_activation(act))
         groupnorm_affine_normalize_act_cpu!(y, x, μ, σ², γ, β, ϵ, act)
     else
@@ -82,7 +82,7 @@ end
 function groupnorm_affine_normalize_act_cpu!(
         y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4},
         σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}},
-        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, act::F) where {F, xT, yT, μT, σ²T}
+        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, act::F) where {F, xT, yT, μT, σ²T}
     if size(y, 1) == 1
         groupnorm_affine_normalize_act_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ, act)
     else
@@ -93,7 +93,7 @@ end
 function groupnorm_affine_normalize_act_3d_serial_cpu!(
         y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4},
         σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}},
-        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T}
+        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, σ::F) where {F, xT, yT, μT, σ²T}
     if γ === nothing && β === nothing
         @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3)
             γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ))
@@ -117,7 +117,7 @@ end
 function groupnorm_affine_normalize_act_4d_serial_cpu!(
         y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4},
         σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}},
-        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T}
+        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, σ::F) where {F, xT, yT, μT, σ²T}
     if γ === nothing && β === nothing
         @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3)
             γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ))
@@ -145,7 +145,7 @@ end
 function groupnorm_affine_normalize_cpu!(
         y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4},
         σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}},
-        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T}
+        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T}
     if size(y, 1) == 1
         groupnorm_affine_normalize_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ)
     else
@@ -156,7 +156,7 @@ end
 @inline function groupnorm_affine_normalize_3d_serial_cpu!(
         y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4},
         σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}},
-        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T}
+        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T}
     if γ === nothing && β === nothing
         @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3)
             γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ))
@@ -180,7 +180,7 @@ end
 @inline function groupnorm_affine_normalize_4d_serial_cpu!(
         y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4},
         σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}},
-        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T}
+        β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T}
     if γ === nothing && β === nothing
         @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3)
             γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ))
@@ -209,7 +209,7 @@ function groupnorm_affine_normalize_internal!(
         y::AbstractArray{yT, 4}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 4},
         μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4},
         γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}},
-        ϵ::Real) where {F, xT, yT, μT, σ²T}
+        ϵ) where {F, xT, yT, μT, σ²T}
     backend = KA.get_backend(y)
     run_ka_kernel(
         groupnorm_affine_normalize_kernel!, backend, nothing, size(y),
@@ -240,7 +240,7 @@ function CRC.rrule(
         opmode::AbstractInternalArrayOpMode, f::F,
         x::AbstractArray{T, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4},
         γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}},
-        ϵ::Real) where {F, T, μT, σ²T}
+        ϵ) where {F, T, μT, σ²T}
     y = similar(x,
         promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
             safe_eltype(γ), safe_eltype(β)))
@@ -264,7 +264,7 @@ function ∇groupnorm_affine_normalize(
         opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 4},
         x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4},
         γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}},
-        ϵ::Real) where {∂yT, xT, μT, σ²T}
+        ϵ) where {∂yT, xT, μT, σ²T}
     ∂x, ∂σ² = similar(x), similar(σ², size(x))
     ∂γ = γ === nothing ? nothing : similar(γ, size(x))
 
@@ -281,7 +281,7 @@ end
 function ∇groupnorm_affine_normalize(::LoopedArrayOp, ∂y::AbstractArray{∂yT, 4},
         x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4},
         γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}},
-        ϵ::Real) where {∂yT, xT, μT, σ²T}
+        ϵ) where {∂yT, xT, μT, σ²T}
     ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²)
     ∂γ = γ === nothing ? nothing : similar(γ)
     ∂β = β === nothing ? nothing : similar(β)
@@ -298,7 +298,7 @@ function ∇groupnorm_affine_normalize_cpu!(
         ∂x::AbstractArray{∂xT, 4}, ∂μ::AbstractArray{∂μT, 4}, ∂σ²::AbstractArray{∂σ²T, 4},
         ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4},
         μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, ::Nothing,
-        ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT, μT, σ²T}
+        ϵ) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT, μT, σ²T}
     half = eltype(∂σ²)(0.5)
 
     fill!(∂μ, 0)
@@ -340,7 +340,7 @@ function ∇groupnorm_affine_normalize_cpu!(
         ∂γ::AbstractArray{∂γT, 4}, ∂β::AbstractArray{∂βT, 4}, ∂y::AbstractArray{∂yT, 4},
         x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4},
         γ::AbstractArray{γT, 4},
-        ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT, μT, σ²T, γT}
+        ϵ) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT, μT, σ²T, γT}
     half = eltype(∂σ²)(0.5)
 
     fill!(∂μ, 0)
@@ -391,7 +391,7 @@ function ∇groupnorm_affine_normalize!(
         ∂γ::Optional{<:AbstractArray{<:Any, 4}}, ::GPUBroadcastOp,
         ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4},
         σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}},
-        ϵ::Real) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T}
+        ϵ) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T}
     backend = KA.get_backend(∂x)
     run_ka_kernel(
         ∇groupnorm_affine_normalize_kernel!, backend, nothing, size(∂x),
diff --git a/lib/LuxLib/src/impl/layernorm.jl b/lib/LuxLib/src/impl/layernorm.jl
index 4655972670..6c27d1ac31 100644
--- a/lib/LuxLib/src/impl/layernorm.jl
+++ b/lib/LuxLib/src/impl/layernorm.jl
@@ -1,7 +1,7 @@
 # TODO: For the `dims === nothing` case, we can optimize using a loop vectorization and
 #       kernel abstractions
 function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray},
-        β::Optional{<:AbstractArray}, act::F, dims, epsilon::Real) where {N, F, xT}
+        β::Optional{<:AbstractArray}, act::F, dims, epsilon) where {N, F, xT}
     μ, σ² = mean_var(x; dims=compute_layernorm_dims(x, γ, β, dims), corrected=false)
     γ′, β′ = expand_layernorm_dims(x, γ, β, dims)
     return affine_normalize(act, x, μ, σ², γ′, β′, epsilon)
diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl
index c2c11f12ab..8fa64dda0b 100644
--- a/lib/LuxLib/src/impl/normalization.jl
+++ b/lib/LuxLib/src/impl/normalization.jl
@@ -1,14 +1,14 @@
 # In most cases this implementation should not be preferred. But this is nice to have
 # because it works for arbitrary dimensions
 function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric,
-        ::Nothing, ::Nothing, ϵ::Real) where {F}
+        ::Nothing, ::Nothing, ϵ) where {F}
     γ′ = @. inv(sqrt(σ² + ϵ))
     β′ = @. -μ * γ′
     return @. act(x * γ′ + β′)
 end
 
 function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric,
-        γ::AbstractArray, β::AbstractArray, ϵ::Real) where {F}
+        γ::AbstractArray, β::AbstractArray, ϵ) where {F}
     γ′ = @. γ / sqrt(σ² + ϵ)
     β′ = @. β - μ * γ′
     return @. act(x * γ′ + β′)
@@ -69,7 +69,7 @@ end
 function update_normalization_statistics(
         x::AbstractArray{T, N}, rμ::AbstractArray{rμT, N}, rσ²::AbstractArray{rσ²T, N},
         μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N},
-        momentum::Real, reduce_dims) where {T, N, rμT, rσ²T, μT, σ²T}
+        momentum, reduce_dims) where {T, N, rμT, rσ²T, μT, σ²T}
     if last(reduce_dims) != N
         μ = mean(μ; dims=N)
         σ² = mean(σ²; dims=N)
diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml
index 403bc57fb5..df34c29520 100644
--- a/lib/LuxLib/test/Project.toml
+++ b/lib/LuxLib/test/Project.toml
@@ -49,7 +49,7 @@ LoopVectorization = "0.12.171"
 LuxTestUtils = "1.5"
 MKL = "0.7"
 MLDataDevices = "1.6"
-NNlib = "0.9.21"
+NNlib = "0.9.26"
 Octavian = "0.3.28"
 Pkg = "1.10"
 Random = "1.10"
diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl
index 58b6196c1a..8d30f4285d 100644
--- a/lib/LuxLib/test/normalization/batchnorm_tests.jl
+++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl
@@ -21,7 +21,7 @@ function batchnorm_fallback(
         bias::LuxLib.Optional{<:AbstractVector},
         running_mean::LuxLib.Optional{<:AbstractVector},
         running_var::LuxLib.Optional{<:AbstractVector}, training::Val,
-        σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N}
+        σ::F=identity, momentum=0.1f0, epsilon=1.0f-5) where {F, N}
     y, xm, xv = LuxLib.Impl.normalization(x, LuxLib.Utils.remove_tracking(running_mean),
         LuxLib.Utils.remove_tracking(running_var), scale, bias,
         LuxLib.Impl.batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ)
diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl
index c103595f99..f54a0ebf5e 100644
--- a/lib/LuxLib/test/normalization/groupnorm_tests.jl
+++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl
@@ -16,7 +16,7 @@ end
 function groupnorm_fallback(
         x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector},
         bias::LuxLib.Optional{<:AbstractVector}, groups::Int,
-        σ::F=identity, epsilon::Real=1.0f-5) where {F, N}
+        σ::F=identity, epsilon=1.0f-5) where {F, N}
     sz = size(x)
     x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N])
     y, _, _ = LuxLib.Impl.normalization(x_reshaped, nothing, nothing, scale, bias,
diff --git a/src/Lux.jl b/src/Lux.jl
index 29014cfa6f..64f0af07f1 100644
--- a/src/Lux.jl
+++ b/src/Lux.jl
@@ -44,6 +44,7 @@ include("utils.jl")
 include("extended_ops.jl")
 
 # Training Helpers
+include("helpers/optimizers.jl")
 include("helpers/training.jl")
 
 # Experimental
diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl
index 9a8f575b6b..00c4cae59d 100644
--- a/src/helpers/losses.jl
+++ b/src/helpers/losses.jl
@@ -92,7 +92,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
     return CRC.rrule_via_ad(cfg, fallback_fused_agg, sum, op, x, y)
 end
 
-get_ϵ(::Type{T}, ϵ::Real) where {T} = T(ϵ)
+get_ϵ(::Type{T}, ϵ) where {T} = T(ϵ)
 get_ϵ(::Type{T}, ::Nothing) where {T} = eps(float(T))
 
 get_loss_dims(::AbstractVector) = Colon()
@@ -160,13 +160,13 @@ function msle_loss(x::T1, y::T2, ϵ) where {T1, T2}
 end
 
 label_smoothing(::Nothing, y, ::Type{T}) where {T} = y
-function label_smoothing(label_smoothing::Real, y, ::Type{T}) where {T}
+function label_smoothing(label_smoothing, y, ::Type{T}) where {T}
     label_smoothing = T(label_smoothing)
     return y .* (1 - label_smoothing) .+ label_smoothing ./ size(y, ndims(y) - 1)
 end
 
 label_smoothing_binary(::Nothing, y, ::Type{T}) where {T} = y
-function label_smoothing_binary(label_smoothing::Real, y, ::Type{T}) where {T}
+function label_smoothing_binary(label_smoothing, y, ::Type{T}) where {T}
     label_smoothing = T(label_smoothing)
     return y .* (1 - label_smoothing) .+ label_smoothing ./ 2
 end
@@ -725,7 +725,7 @@ true
 invariant mapping." 2006 IEEE computer society conference on computer vision and pattern
 recognition (CVPR'06). Vol. 2. IEEE, 2006.
 """
-function SiameseContrastiveLoss(; margin::Real=true, agg=mean)
+function SiameseContrastiveLoss(; margin=true, agg=mean)
     @argcheck margin ≥ 0
     return GenericLossFunction(
         Utils.Fix3(LossFunctionImpl.siamese_contrastive_loss, margin); agg)
diff --git a/src/helpers/optimizers.jl b/src/helpers/optimizers.jl
new file mode 100644
index 0000000000..fe0116bb4e
--- /dev/null
+++ b/src/helpers/optimizers.jl
@@ -0,0 +1,184 @@
+# This is mostly an internal implementation detail that users shouldn't need to worry about.
+# We can remove this once https://github.com/FluxML/Optimisers.jl/issues/205 is resolved.
+module ReactantCompatibleOptimisers
+
+using ConcreteStructs: @concrete
+using Optimisers: Optimisers, AbstractRule
+using Setfield: Setfield, @set!
+
+using ..Lux: Lux, Utils
+
+abstract type ReactantCompatibleOptimisersRule <: AbstractRule end
+
+function make_reactant_compatible(opt::AbstractRule)
+    @warn "`make_reactant_compatible` is not defined for $(opt). Returning the original \
+           optimizer. This means adjusting learning rate and other parameters won't \
+           reflect in the generated MLIR." maxlog=1
+    return opt
+end
+make_reactant_compatible(opt::ReactantCompatibleOptimisersRule) = opt
+
+function setfield_if_present(opt, field::Symbol, nt::NamedTuple)
+    if hasfield(typeof(nt), field)
+        opt = Setfield.set(
+            opt, Setfield.PropertyLens{field}(),
+            convert(
+                typeof(getproperty(opt, field)),
+                Utils.to_rarray(getproperty(nt, field); track_numbers=true)
+            )
+        )
+    end
+    return opt
+end
+
+# OptimiserChain
+function make_reactant_compatible(opt::Optimisers.OptimiserChain)
+    return Optimisers.OptimiserChain(make_reactant_compatible.(opt.opts))
+end
+
+# Descent
+@concrete struct ReactantDescent <: ReactantCompatibleOptimisersRule
+    eta
+end
+
+function make_reactant_compatible(opt::Optimisers.Descent)
+    return ReactantDescent(Utils.to_rarray(opt.eta; track_numbers=true))
+end
+
+Optimisers.init(::ReactantDescent, ::AbstractArray) = nothing
+
+function Optimisers.apply!(opt::ReactantDescent, state, x::AbstractArray{T}, dx) where {T}
+    η = T(opt.eta)
+    return state, @. dx * η
+end
+
+function Optimisers._adjust(opt::ReactantDescent, nt::NamedTuple)
+    return setfield_if_present(opt, :eta, nt)
+end
+
+# Momentum
+@concrete struct ReactantMomentum <: ReactantCompatibleOptimisersRule
+    eta
+    rho
+end
+
+function make_reactant_compatible(opt::Optimisers.Momentum)
+    return ReactantMomentum(
+        Utils.to_rarray(opt.eta; track_numbers=true),
+        Utils.to_rarray(opt.rho; track_numbers=true)
+    )
+end
+
+function Optimisers.init(::ReactantMomentum, x::AbstractArray)
+    return Optimisers.init(Optimisers.Momentum(0.0, 0.0), x)
+end
+
+function Optimisers.apply!(opt::ReactantMomentum, mvel, ::AbstractArray{T}, dx) where {T}
+    η, ρ = T(opt.eta), T(opt.rho)
+    @. mvel = ρ * mvel + η * dx
+    return mvel, mvel
+end
+
+function Optimisers._adjust(opt::ReactantMomentum, nt::NamedTuple)
+    opt = setfield_if_present(opt, :eta, nt)
+    opt = setfield_if_present(opt, :rho, nt)
+    return opt
+end
+
+# Adam
+@concrete struct ReactantAdam <: ReactantCompatibleOptimisersRule
+    eta
+    beta
+    epsilon
+end
+
+function make_reactant_compatible(opt::Optimisers.Adam)
+    return ReactantAdam(
+        Utils.to_rarray(opt.eta; track_numbers=true),
+        Utils.to_rarray(opt.beta; track_numbers=true),
+        Utils.to_rarray(opt.epsilon; track_numbers=true)
+    )
+end
+
+function Optimisers.init(opt::ReactantAdam, x::AbstractArray{T}) where {T}
+    return (
+        zero(x),
+        zero(x),
+        (Utils.promote_to(T, opt.beta[1]), Utils.promote_to(T, opt.beta[2]))
+    )
+end
+
+function Optimisers.apply!(o::ReactantAdam, state, ::AbstractArray{T}, dx) where {T}
+    η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) # XXX: See Optimisers._eps
+    mt, vt, βt = state
+
+    @. mt = β[1] * mt + (1 - β[1]) * dx
+    @. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
+    dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η
+
+    return (mt, vt, βt .* β), dx′
+end
+
+function Optimisers._adjust(opt::ReactantAdam, nt::NamedTuple)
+    opt = setfield_if_present(opt, :eta, nt)
+    opt = setfield_if_present(opt, :beta, nt)
+    opt = setfield_if_present(opt, :epsilon, nt)
+    return opt
+end
+
+# AdamW
+@concrete struct ReactantAdamW <: ReactantCompatibleOptimisersRule
+    eta
+    beta
+    lambda
+    epsilon
+    couple::Bool
+end
+
+function make_reactant_compatible(opt::Optimisers.AdamW)
+    return ReactantAdamW(
+        Utils.to_rarray(opt.eta; track_numbers=true),
+        Utils.to_rarray(opt.beta; track_numbers=true),
+        Utils.to_rarray(opt.lambda; track_numbers=true),
+        Utils.to_rarray(opt.epsilon; track_numbers=true),
+        opt.couple
+    )
+end
+
+function Optimisers.init(opt::ReactantAdamW, x::AbstractArray{T}) where {T}
+    return (
+        zero(x),
+        zero(x),
+        (Utils.promote_to(T, opt.beta[1]), Utils.promote_to(T, opt.beta[2]))
+    )
+end
+
+function Optimisers.apply!(o::ReactantAdamW, state, x::AbstractArray{T}, dx) where {T}
+    η, β, ϵ, λ = T(o.eta), T.(o.beta), T(o.epsilon), T(o.lambda) # XXX: See Optimisers._eps
+    mt, vt, βt = state
+
+    # standard Adam update with learning rate eta=1
+    @. mt = β[1] * mt + (1 - β[1]) * dx
+    @. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
+    dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η
+
+    # apply learning rate and weight decay
+    if o.couple
+        dx′′ = @. η * (dx′ + λ * x)
+    else
+        dx′′ = @. η * dx′ + λ * x
+    end
+
+    return (mt, vt, βt .* β), dx′′
+end
+
+function Optimisers._adjust(opt::ReactantAdamW, nt::NamedTuple)
+    opt = setfield_if_present(opt, :eta, nt)
+    opt = setfield_if_present(opt, :beta, nt)
+    opt = setfield_if_present(opt, :lambda, nt)
+    opt = setfield_if_present(opt, :epsilon, nt)
+    opt = setfield_if_present(opt, :couple, nt)
+    return opt
+end
+
+end
diff --git a/src/helpers/size_propagator.jl b/src/helpers/size_propagator.jl
index fc0d12b78a..6e67453406 100644
--- a/src/helpers/size_propagator.jl
+++ b/src/helpers/size_propagator.jl
@@ -152,12 +152,12 @@ end
 function LuxLib.Impl.batchnorm(
         x::AnyNilArray{N}, ::Optional{<:AbstractVector}, ::Optional{<:AbstractVector},
         rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector},
-        ::StaticBool, act::F, ::Real, ::Real) where {N, F}
+        ::StaticBool, act::F, ::Number, ::Number) where {N, F}
     return x, rμ, rσ²
 end
 
 function LuxLib.Impl.groupnorm(x::AnyNilArray{N}, ::Optional{<:AbstractVector},
-        ::Optional{<:AbstractVector}, ::Int, act::F, ::Real) where {N, F}
+        ::Optional{<:AbstractVector}, ::Int, act::F, ::Number) where {N, F}
     return x
 end
 
@@ -168,11 +168,11 @@ function LuxLib.Impl.normalization(x::AnyNilArray, rμ::Optional{<:AbstractVecto
 end
 
 function LuxLib.Impl.affine_normalize(
-        ::F, x::AnyNilArray, ::Numeric, ::Numeric, ::Nothing, ::Nothing, ::Real) where {F}
+        ::F, x::AnyNilArray, ::Numeric, ::Numeric, ::Nothing, ::Nothing, ::Number) where {F}
     return x
 end
 function LuxLib.Impl.affine_normalize(::F, x::AnyNilArray, ::Numeric, ::Numeric,
-        ::AbstractArray, ::AbstractArray, ::Real) where {F}
+        ::AbstractArray, ::AbstractArray, ::Number) where {F}
     return x
 end
 
diff --git a/src/helpers/training.jl b/src/helpers/training.jl
index da2b597a94..c11f74b93f 100644
--- a/src/helpers/training.jl
+++ b/src/helpers/training.jl
@@ -4,14 +4,14 @@ using ADTypes: AbstractADType, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZyg
 using Compat: @compat
 using ConcreteStructs: @concrete
 using FastClosures: @closure
-using Functors: fmap
+using Functors: Functors, fmap
 using Optimisers: Optimisers
 using Setfield: @set!
 using Static: StaticBool, Static, False, True
 
-using ..Lux: Lux, Utils
+using ..Lux: Lux, Utils, ReactantCompatibleOptimisers
 using LuxCore: LuxCore, AbstractLuxLayer
-using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type, get_device, cpu_device
+using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type
 
 """
     TrainState
@@ -63,10 +63,10 @@ Constructor for [`TrainState`](@ref).
 [`TrainState`](@ref) object.
 """
 function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
-    dev = get_device(ps)
-    st_opt = if dev isa ReactantDevice
-        ps_cpu = ps |> cpu_device()
-        Optimisers.setup(optimizer, ps_cpu) |> dev
+    st_opt = if get_device_type(ps) <: ReactantDevice
+        Optimisers.setup(
+            ReactantCompatibleOptimisers.make_reactant_compatible(optimizer), ps
+        )
     else
         Optimisers.setup(optimizer, ps)
     end
diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl
index ac1e040082..91964b0dea 100644
--- a/src/layers/normalize.jl
+++ b/src/layers/normalize.jl
@@ -134,7 +134,9 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple)
 end
 
 function update_batchnorm_state(BN::BatchNorm, st::NamedTuple, stats)
-    has_track_stats(BN) && return merge(st, (; stats.running_mean, stats.running_var))
+    has_track_stats(BN) && return merge(st,
+        (; running_mean=Utils.vec(stats.running_mean),
+            running_var=Utils.vec(stats.running_var)))
     return st
 end
 
@@ -378,14 +380,23 @@ statelength(l::InstanceNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1
 function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple)
     x′ = match_eltype(IN, ps, st, x)
     σ = NNlib.fast_act(IN.activation, x′)
-    y, _ = instancenorm(
+    y, stats = instancenorm(
         x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)),
         safe_getproperty(st, Val(:running_mean)), safe_getproperty(st, Val(:running_var)),
         st.training, σ, convert(unwrapped_eltype(x′), IN.momentum),
         convert(unwrapped_eltype(x′), IN.epsilon))
-    return y, st
+    return y, update_instancenorm_state(IN, st, stats)
 end
 
+function update_instancenorm_state(IN::InstanceNorm, st::NamedTuple, stats)
+    has_track_stats(IN) && return merge(st,
+        (; running_mean=Utils.vec(stats.running_mean),
+            running_var=Utils.vec(stats.running_var)))
+    return st
+end
+
+CRC.@non_differentiable update_instancenorm_state(::Any...)
+
 function Base.show(io::IO, l::InstanceNorm)
     print(io, "InstanceNorm($(l.chs)")
     (l.activation == identity) || print(io, ", $(l.activation)")
diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl
index 943eb947c1..3be3a5e24e 100644
--- a/src/layers/pooling.jl
+++ b/src/layers/pooling.jl
@@ -40,15 +40,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode
 abstract type AbstractPoolOp end
 
 struct MaxPoolOp <: AbstractPoolOp end
+
 (m::MaxPoolOp)(x, pdims) = maxpool(x, pdims)
+function (m::MaxPoolOp)(x, ::GlobalPoolMode)
+    return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf))
+end
 
 struct MeanPoolOp <: AbstractPoolOp end
+
 (m::MeanPoolOp)(x, pdims) = meanpool(x, pdims)
+(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2))
 
 @concrete struct LpPoolOp <: AbstractPoolOp
     p
 end
+
 (m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p)
+(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p)
 
 symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp()
 symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()
diff --git a/src/utils.jl b/src/utils.jl
index 2a6930a2a8..99429f5c14 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -211,6 +211,9 @@ matrix_to_array(x::SMatrix{L, 1, T}, ::AbstractVector) where {L, T} = SVector{L,
 matrix_to_array(x::AbstractMatrix, ::AbstractMatrix) = x
 matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:end]...)
 
+function to_rarray end
+function promote_to end
+
 # This should probably be in WeightInitializers.jl
 calculate_gain(_, __) = 1.0f0
 calculate_gain(::typeof(identity), _) = 1.0f0
@@ -222,7 +225,7 @@ calculate_gain(::typeof(NNlib.tanh_fast), _) = 5.0f0 / 3.0f0
 function calculate_gain(::typeof(NNlib.leakyrelu), ::Nothing)
     return calculate_gain(NNlib.leakyrelu, 0.1f0)
 end
-calculate_gain(::typeof(NNlib.leakyrelu), x::Real) = typeof(x)(√(2 / (1 + x^2)))
+calculate_gain(::typeof(NNlib.leakyrelu), x) = typeof(x)(√(2 / (1 + x^2)))
 calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4
 
 end
diff --git a/test/Project.toml b/test/Project.toml
index 58dd94c2ee..7f9cb93e5c 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -62,7 +62,7 @@ LuxLib = "1.3.4"
 LuxTestUtils = "1.5"
 MLDataDevices = "1.6"
 MLUtils = "0.4.3"
-NNlib = "0.9.24"
+NNlib = "0.9.26"
 Octavian = "0.3.28"
 OneHotArrays = "0.2.5"
 Optimisers = "0.4.1"
diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl
index ba79343140..3c0113b5c0 100644
--- a/test/helpers/loss_tests.jl
+++ b/test/helpers/loss_tests.jl
@@ -46,12 +46,12 @@
 
     @testset "$mode" for (mode, aType, dev, ongpu) in MODES
         x = rand(10) |> aType
-        __f = sum ∘ Broadcast.BroadcastFunction(LuxOps.xlogx)
-        @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()])
+        @test_gradients(sum∘Broadcast.BroadcastFunction(LuxOps.xlogx),
+            x; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()])
 
         y = rand(10) |> aType
-        __f = sum ∘ Broadcast.BroadcastFunction(LuxOps.xlogy)
-        @test_gradients(__f, x, y; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()])
+        @test_gradients(sum∘Broadcast.BroadcastFunction(LuxOps.xlogy),
+            x, y; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()])
     end
 end
 
@@ -79,8 +79,7 @@ end
             @jet loss_mean(ŷ, y)
             @jet loss_sum(ŷ, y)
 
-            __f = Base.Fix2(loss_mean, y)
-            @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
+            @test_gradients(Base.Fix2(loss_mean, y), ŷ; atol=1.0f-3, rtol=1.0f-3)
         end
 
         @testset "MSLE" begin
@@ -93,8 +92,7 @@ end
 
             @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu
 
-            __f = Base.Fix2(MSLELoss(), y)
-            @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
+            @test_gradients(Base.Fix2(MSLELoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3)
         end
     end
 end
@@ -203,9 +201,8 @@ end
 
             @test @inferred(Zygote.gradient(bceloss, σ.(logŷ), y)) isa Any
 
-            __f = Base.Fix2(bceloss, y)
-            σlogŷ = σ.(logŷ)
-            @test_gradients(__f, σlogŷ; atol=1.0f-3, rtol=1.0f-3)
+            @test_gradients(Base.Fix2(bceloss, y), σ.(logŷ); atol=1.0f-3, rtol=1.0f-3,
+                enzyme_set_runtime_activity=true)
         end
 
         @testset "Logit BinaryCrossEntropyLoss" begin
@@ -225,8 +222,8 @@ end
 
             @test @inferred(Zygote.gradient(logitbceloss, logŷ, y)) isa Any
 
-            __f = Base.Fix2(logitbceloss, y)
-            @test_gradients(__f, logŷ; atol=1.0f-3, rtol=1.0f-3)
+            @test_gradients(Base.Fix2(logitbceloss, y), logŷ; atol=1.0f-3, rtol=1.0f-3,
+                enzyme_set_runtime_activity=true)
         end
 
         @testset "BinaryFocalLoss" begin
@@ -248,8 +245,7 @@ end
 
             @test @inferred(Zygote.gradient(BinaryFocalLoss(), ŷ, y)) isa Any broken=ongpu
 
-            __f = Base.Fix2(BinaryFocalLoss(), y)
-            @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
+            @test_gradients(Base.Fix2(BinaryFocalLoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3)
         end
 
         @testset "FocalLoss" begin
diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl
index 6b545e15b6..ec3704d90e 100644
--- a/test/layers/normalize_tests.jl
+++ b/test/layers/normalize_tests.jl
@@ -56,7 +56,7 @@
 
             @jet m(x, ps, Lux.testmode(st))
             @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
-                rtol=1.0f-3, skip_backends=[AutoFiniteDiff()], broken_backends)
+                rtol=1.0f-3, skip_backends=[AutoFiniteDiff()])
 
             # with activation function
             m = BatchNorm(2, sigmoid; affine)
diff --git a/test/reactant/training_tests.jl b/test/reactant/training_tests.jl
index 06c550a3ba..6384cd49a6 100644
--- a/test/reactant/training_tests.jl
+++ b/test/reactant/training_tests.jl
@@ -18,7 +18,9 @@
         @testset "MLP Training: $(version)" for version in (:iip, :oop)
             model = Chain(
                 Dense(2 => 32, gelu),
+                BatchNorm(32),
                 Dense(32 => 32, gelu),
+                BatchNorm(32),
                 Dense(32 => 2)
             )
             ps, st = Lux.setup(StableRNG(1234), model) |> xdev
@@ -43,27 +45,31 @@
                 inference_loss_fn_compiled(xᵢ, yᵢ, model, ps, st)
             end
 
-            train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
+            @testset for opt in (
+                Descent(0.01f0), Momentum(0.01f0), Adam(0.01f0), AdamW(0.01f0)
+            )
+                train_state = Training.TrainState(model, ps, st, opt)
 
-            for epoch in 1:100, (xᵢ, yᵢ) in dataloader
-                grads, loss, stats, train_state = if version === :iip
-                    Training.single_train_step!(
-                        AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
-                elseif version === :oop
-                    Training.single_train_step(
-                        AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
-                else
-                    error("Invalid version: $(version)")
+                for epoch in 1:100, (xᵢ, yᵢ) in dataloader
+                    grads, loss, stats, train_state = if version === :iip
+                        Training.single_train_step!(
+                            AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
+                    elseif version === :oop
+                        Training.single_train_step(
+                            AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
+                    else
+                        error("Invalid version: $(version)")
+                    end
                 end
-            end
 
-            total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
-                inference_loss_fn_compiled(
-                    xᵢ, yᵢ, model, train_state.parameters, train_state.states
-                )
-            end
+                total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
+                    inference_loss_fn_compiled(
+                        xᵢ, yᵢ, model, train_state.parameters, train_state.states
+                    )
+                end
 
-            @test total_final_loss < 100 * total_initial_loss
+                @test total_final_loss < 100 * total_initial_loss
+            end
         end
     end
 end
diff --git a/test/runtests.jl b/test/runtests.jl
index 0f96e8b49f..91db71bcb2 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -127,9 +127,7 @@ const RETESTITEMS_NWORKER_THREADS = parse(
         string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1))))
 
 @testset "Lux.jl Tests" begin
-    for (i, tag) in enumerate(LUX_TEST_GROUP)
-        @info "Running tests for group: [$(i)/$(length(LUX_TEST_GROUP))] $tag"
-
+    @testset "[$(tag)] [$(i)/$(length(LUX_TEST_GROUP))]" for (i, tag) in enumerate(LUX_TEST_GROUP)
         nworkers = (tag == "reactant") || (BACKEND_GROUP == "amdgpu") ? 0 :
                    RETESTITEMS_NWORKERS
 
diff --git a/test/setup_modes.jl b/test/setup_modes.jl
index 1617179a5b..b7c581ccca 100644
--- a/test/setup_modes.jl
+++ b/test/setup_modes.jl
@@ -1,4 +1,4 @@
-using Lux, MLDataDevices
+using Lux, MLDataDevices, Pkg
 
 if !@isdefined(BACKEND_GROUP)
     const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))