@@ -27,7 +27,7 @@ CRC.@non_differentiable get_batchnorm_statistics(::Any...)
27
27
function batchnorm (x:: AbstractArray{xT, N} , γ:: Optional{<:AbstractVector} ,
28
28
β:: Optional{<:AbstractVector} , rμ:: Optional{<:AbstractVector} ,
29
29
rσ²:: Optional{<:AbstractVector} , training:: StaticBool , act:: F ,
30
- momentum:: Real , ϵ:: Real ) where {F, xT, N}
30
+ momentum, ϵ) where {F, xT, N}
31
31
(μ, σ²), (rμ, rσ²) = compute_batch_statistics (
32
32
x, reshape_norm_dims (x, rμ), reshape_norm_dims (x, rσ²),
33
33
batchnorm_reduce_dims (x), training, momentum)
37
37
function batchnorm_affine_normalize (
38
38
act:: F , x:: AbstractArray{xT, N} , μ:: AbstractArray{μT, N} ,
39
39
σ²:: AbstractArray{σ²T, N} , γ:: Optional{<:AbstractVector} ,
40
- β:: Optional{<:AbstractVector} , ϵ:: Real ) where {F, xT, μT, σ²T, N}
40
+ β:: Optional{<:AbstractVector} , ϵ) where {F, xT, μT, σ²T, N}
41
41
return batchnorm_affine_normalize (
42
42
internal_operation_mode ((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ)
43
43
end
44
44
45
45
function batchnorm_affine_normalize (
46
46
:: GenericBroadcastOp , act:: F , x:: AbstractArray{xT, N} , μ:: AbstractArray{μT, N} ,
47
47
σ²:: AbstractArray{σ²T, N} , γ:: Optional{<:AbstractVector} ,
48
- β:: Optional{<:AbstractVector} , ϵ:: Real ) where {F, xT, μT, σ²T, N}
48
+ β:: Optional{<:AbstractVector} , ϵ) where {F, xT, μT, σ²T, N}
49
49
return affine_normalize (
50
50
act, x, μ, σ², reshape_norm_dims (x, γ), reshape_norm_dims (x, β), ϵ)
51
51
end
@@ -54,7 +54,7 @@ function batchnorm_affine_normalize(
54
54
opmode:: AbstractInternalArrayOpMode , act:: F , x:: AbstractArray{xT, N} ,
55
55
μ:: AbstractArray{μT, N} , σ²:: AbstractArray{σ²T, N} ,
56
56
γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} ,
57
- ϵ:: Real ) where {F, xT, μT, σ²T, N}
57
+ ϵ) where {F, xT, μT, σ²T, N}
58
58
x′ = reshape (x, :, size (x, N - 1 ), size (x, N))
59
59
return reshape (
60
60
batchnorm_affine_normalize_internal (opmode, act, x′, vec (μ), vec (σ²), γ, β, ϵ),
64
64
@stable default_mode= " disable" function batchnorm_affine_normalize_internal (
65
65
opmode:: AbstractInternalArrayOpMode , act:: F , x:: AbstractArray{xT, 3} ,
66
66
μ:: AbstractVector , σ²:: AbstractVector , γ:: Optional{<:AbstractVector} ,
67
- β:: Optional{<:AbstractVector} , ϵ:: Real ) where {F, xT}
67
+ β:: Optional{<:AbstractVector} , ϵ) where {F, xT}
68
68
y = similar (x,
69
69
promote_type (safe_eltype (x), safe_eltype (μ), safe_eltype (σ²),
70
70
safe_eltype (γ), safe_eltype (β)))
75
75
function batchnorm_affine_normalize_internal! (
76
76
y:: AbstractArray{yT, 3} , opmode:: LoopedArrayOp , act:: F , x:: AbstractArray{xT, 3} ,
77
77
μ:: AbstractVector , σ²:: AbstractVector , γ:: Optional{<:AbstractVector} ,
78
- β:: Optional{<:AbstractVector} , ϵ:: Real ,
78
+ β:: Optional{<:AbstractVector} , ϵ,
79
79
γ′:: Optional{<:AbstractVector} = nothing ) where {F, xT, yT}
80
80
N = size (y, 2 )
81
81
γ′ = γ′ === nothing ?
225
225
function batchnorm_affine_normalize_internal! (
226
226
y:: AbstractArray{yT, 3} , :: GPUBroadcastOp , act:: F , x:: AbstractArray{xT, 3} ,
227
227
μ:: AbstractVector , σ²:: AbstractVector , γ:: Optional{<:AbstractVector} ,
228
- β:: Optional{<:AbstractVector} , ϵ:: Real ,
228
+ β:: Optional{<:AbstractVector} , ϵ,
229
229
γ′:: Optional{<:AbstractVector} = nothing ) where {F, xT, yT}
230
230
backend = KA. get_backend (y)
231
231
run_ka_kernel (
@@ -278,7 +278,7 @@ function CRC.rrule(
278
278
cfg:: RuleConfig{>:HasReverseMode} , :: typeof (batchnorm_affine_normalize_internal),
279
279
opmode:: AbstractInternalArrayOpMode , act:: F , x:: AbstractArray{T, N} ,
280
280
μ:: AbstractVector , σ²:: AbstractVector , γ:: Optional{<:AbstractVector} ,
281
- β:: Optional{<:AbstractVector} , ϵ:: Real ) where {F, T, N}
281
+ β:: Optional{<:AbstractVector} , ϵ) where {F, T, N}
282
282
y = similar (x,
283
283
promote_type (safe_eltype (x), safe_eltype (μ), safe_eltype (σ²),
284
284
safe_eltype (γ), safe_eltype (β)))
304
304
305
305
function ∇batchnorm_affine_normalize (opmode:: LoopedArrayOp , ∂y:: AbstractArray{∂yT, 3} ,
306
306
x:: AbstractArray{xT, 3} , μ:: AbstractVector , σ²:: AbstractVector ,
307
- γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} , ϵ:: Real ,
307
+ γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} , ϵ,
308
308
γ′:: AbstractVector ) where {∂yT, xT}
309
309
∂x, ∂μ, ∂σ² = similar (x), similar (μ), similar (σ²)
310
310
∂γ = γ === nothing ? nothing : similar (γ)
@@ -322,7 +322,7 @@ function ∇batchnorm_affine_normalize_cpu!(
322
322
∂x:: AbstractArray{∂xT, 3} , ∂μ:: AbstractVector{∂μT} ,
323
323
∂σ²:: AbstractVector{∂σ²T} , :: Nothing , :: Nothing , ∂y:: AbstractArray{∂yT, 3} ,
324
324
x:: AbstractArray{xT, 3} , μ:: AbstractVector , σ²:: AbstractVector , :: Nothing ,
325
- ϵ:: Real , γ′:: AbstractVector ) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
325
+ ϵ, γ′:: AbstractVector ) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
326
326
half = eltype (∂σ²)(0.5 )
327
327
328
328
fill! (∂μ, 0 )
@@ -361,7 +361,7 @@ function ∇batchnorm_affine_normalize_cpu!(
361
361
∂x:: AbstractArray{∂xT, 3} , ∂μ:: AbstractVector{∂μT} ,
362
362
∂σ²:: AbstractVector{∂σ²T} , ∂γ:: AbstractVector{∂γT} ,
363
363
∂β:: AbstractVector{∂βT} , ∂y:: AbstractArray{∂yT, 3} , x:: AbstractArray{xT, 3} ,
364
- μ:: AbstractVector , σ²:: AbstractVector , γ:: AbstractVector , ϵ:: Real ,
364
+ μ:: AbstractVector , σ²:: AbstractVector , γ:: AbstractVector , ϵ,
365
365
γ′:: AbstractVector ) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT}
366
366
half = eltype (∂σ²)(0.5 )
367
367
406
406
function ∇batchnorm_affine_normalize (
407
407
opmode:: AbstractInternalArrayOpMode , ∂y:: AbstractArray{∂yT, 3} ,
408
408
x:: AbstractArray{xT, 3} , μ:: AbstractVector , σ²:: AbstractVector ,
409
- γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} , ϵ:: Real ,
409
+ γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} , ϵ,
410
410
γ′:: AbstractVector ) where {∂yT, xT}
411
411
∂x, ∂σ² = similar (x), similar (σ², size (x))
412
412
∂γ = γ === nothing ? nothing : similar (γ, size (x))
@@ -425,7 +425,7 @@ function ∇batchnorm_affine_normalize!(
425
425
∂x:: AbstractArray{∂xT, 3} , ∂σ²:: AbstractArray{∂σ²T, 3} ,
426
426
∂γ:: Optional{<:AbstractArray{<:Any, 3}} , :: GPUBroadcastOp ,
427
427
∂y:: AbstractArray{∂yT, 3} , x:: AbstractArray{xT, 3} , μ:: AbstractVector ,
428
- σ²:: AbstractVector , γ:: Optional{<:AbstractVector} , ϵ:: Real ,
428
+ σ²:: AbstractVector , γ:: Optional{<:AbstractVector} , ϵ,
429
429
γ′:: AbstractVector ) where {∂xT, ∂σ²T, ∂yT, xT}
430
430
backend = KA. get_backend (∂x)
431
431
run_ka_kernel (
0 commit comments