Skip to content

Commit c850df5

Browse files
authored
Allow ForwardDiff in BatchNorm's track_stats (#2127)
* allow ForwardDiff in BatchNorm's track_stats * second test * add comments * Update test/layers/normalisation.jl
1 parent 815deaa commit c850df5

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

src/Flux.jl

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owne
1111

1212
using Zygote, ChainRulesCore
1313
using Zygote: Params, @adjoint, gradient, pullback, @nograd
14+
using Zygote.ForwardDiff: value
1415
export gradient
1516

1617
# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)

src/layers/normalise.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,9 @@ function _track_stats!(
275275
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
276276
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
277277

278-
bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
279-
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
278+
# ForwardDiff.value removes Dual, was an error, issue #2122
279+
bn.μ .= value.(res_mtm .* bn.μ .+ mtm .* μnew)
280+
bn.σ² .= value.(res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new)
280281
return nothing
281282
end
282283

test/layers/normalisation.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Flux, Test, Statistics
2-
using Zygote: pullback
2+
using Zygote: pullback, ForwardDiff
33

44
evalwgrad(f, x...) = pullback(f, x...)[1]
55

@@ -462,4 +462,18 @@ end
462462
@testset "second derivatives" begin
463463
m1 = Dropout(0.5)
464464
@test Zygote.hessian_reverse(summ1, [1.0,2.0,3.0]) == zeros(3, 3)
465+
466+
m2 = Chain(BatchNorm(3), sum)
467+
@test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6)
468+
end
469+
470+
@testset "ForwardDiff" begin
471+
bn = BatchNorm(3)
472+
@test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32}
473+
# iszero(bn.μ) # is true. But ideally would not be, if Flux would automatically choose trainmode
474+
Flux.trainmode!(bn)
475+
# This was an error, https://github.com/FluxML/Flux.jl/issues/2122
476+
@test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32}
477+
@test !iszero(bn.μ)
465478
end
479+

0 commit comments

Comments
 (0)