From 89e67ca41eacce882f89f93517346bb71b540a46 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 5 Jan 2023 20:24:31 -0500 Subject: [PATCH] use NNlib.within_gradient --- Project.toml | 2 +- src/cuda/cudnn.jl | 2 +- src/deprecations.jl | 11 +++++++++++ src/layers/normalise.jl | 11 ++++------- test/layers/normalisation.jl | 15 +++++++++++++++ 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index f0422bc486..b4c3dc6d75 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ ChainRulesCore = "1.12" Functors = "0.3, 0.4" MLUtils = "0.2, 0.3.1, 0.4" MacroTools = "0.5" -NNlib = "0.8.9" +NNlib = "0.8.14" NNlibCUDA = "0.2.4" OneHotArrays = "0.1, 0.2" Optimisers = "0.2.12" diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 9e6bdb53a0..24226ab4b1 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -8,7 +8,7 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, @assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wrong number of channels" return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache=cache, alpha=1, beta=0, eps=BN.ϵ, - training=Flux._isactive(BN))) + training=Flux._isactive(BN, x))) end function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) diff --git a/src/deprecations.jl b/src/deprecations.jl index a763ffd905..8625c458e0 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -86,6 +86,17 @@ Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. @deprecate rng_from_array() default_rng_value() +function istraining() + Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining) + false +end +ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) + +function _isactive(m) + Base.depwarn("_isactive(m) is deprecated, use _isactive(m,x)", :_isactive, force=true) + _isactive(m, 1:0) +end + #= # Valid method in Optimise, old implicit style, is: train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 9f9176e3e9..7ee28f64e9 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -1,8 +1,5 @@ -istraining() = false -ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) - -_isactive(m) = isnothing(m.active) ? istraining() : m.active +_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active _dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) @@ -107,7 +104,7 @@ end trainable(a::Dropout) = (;) function (a::Dropout)(x) - _isactive(a) || return x + _isactive(a, x) || return x return dropout(a.rng, x, a.p; dims=a.dims, active=true) end @@ -162,7 +159,7 @@ AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng) trainable(a::AlphaDropout) = (;) function (a::AlphaDropout)(x::AbstractArray{T}) where T - _isactive(a) || return x + _isactive(a, x) || return x p = a.p iszero(p) && return x isone(p) && return sign.(x) .* T(0) @@ -242,7 +239,7 @@ end function _norm_layer_forward( l, x::AbstractArray{T, N}; reduce_dims, affine_shape, ) where {T, N} - if !_isactive(l) && l.track_stats # testmode with tracked stats + if !_isactive(l, x) && l.track_stats # testmode with tracked stats stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) μ = reshape(l.μ, stats_shape) σ² = reshape(l.σ², stats_shape) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 859d703368..2aa26bb2a7 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -475,5 +475,20 @@ end # This was an error, https://github.com/FluxML/Flux.jl/issues/2122 @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} @test !iszero(bn.μ) + + # Easy case of 2122, gradient with x + x5 = rand(Float32, 5, 3) + bn1 = BatchNorm(5, relu) + bn2 = BatchNorm(5, relu) + g1 = Zygote.gradient(x -> sum(abs2, bn1(x)), x5)[1] + g2 = ForwardDiff.gradient(x -> sum(abs2, bn2(x)), x5) + @test g1 ≈ g2 + + # Harder case? + v1, re1 = Flux.destructure(BatchNorm(5, relu)); + g1 = Zygote.gradient(v -> sum(abs2, re1(v)(x5)), v1)[1] + + v2, re2 = Flux.destructure(BatchNorm(5, relu)); + g2 = ForwardDiff.gradient(v -> sum(abs2, re2(v)(x5)), v2) end