Skip to content

Commit 8ca0808

Browse files
committed
fix: don't force ::Real
1 parent ec11357 commit 8ca0808

17 files changed

+63
-63
lines changed

lib/LuxLib/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.3.10"
4+
version = "1.3.11"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

lib/LuxLib/ext/LuxLibTrackerExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector),
9797
Utils.is_tracked(RM, RV, S, B, XT) || continue
9898

9999
@eval Tracker.@grad_from_chainrules LuxLib.Impl.batchnorm_cudnn(
100-
γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool)
100+
γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m, ϵ, training::StaticBool)
101101
end
102102

103103
# Utils extensions

lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ include("batchnorm.jl")
2121
function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, 5}},
2222
γ::Optional{<:CuVector{T}}, β::Optional{<:CuVector{T}},
2323
::Optional{<:CuVector{T}}, rσ²::Optional{<:CuVector{T}},
24-
training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F}
24+
training::StaticBool, σ::F, m, ϵ) where {T <: cuDNNFloat, F}
2525
rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training)
2626
y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1]
2727
return Impl.activation!!(σ, y), safe_vec(rμₙ), safe_vec(rσ²ₙ)

lib/LuxLib/src/api/batchnorm.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ mean and variance.
3737
function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector},
3838
β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
3939
rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity,
40-
momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N}
40+
momentum=0.1f0, epsilon=default_epsilon(x)) where {F, T, N}
4141
σ = select_fastest_activation(act, x, γ, β, rμ, rσ²)
4242
y, rμ, rσ² = batchnorm_impl(
4343
x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²),

lib/LuxLib/src/api/groupnorm.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@doc doc"""
22
groupnorm(x, scale, bias, groups::Int, σ::F=identity,
3-
epsilon::Real=eps(eltype(x)) ^ (5 // 7))
3+
epsilon=eps(eltype(x)) ^ (5 // 7))
44
55
Group Normalization. For details see [1].
66
@@ -30,7 +30,7 @@ The normalized array is returned.
3030
"""
3131
function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
3232
bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity,
33-
epsilon::Real=default_epsilon(x)) where {F, N}
33+
epsilon=default_epsilon(x)) where {F, N}
3434
assert_valid_groupnorm_arguments(x, scale, bias, groups)
3535
return groupnorm_impl(
3636
x, scale, bias, groups, select_fastest_activation(σ, x, scale, bias), epsilon)

lib/LuxLib/src/api/instancenorm.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ mean and variance.
3636
"""
3737
function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector},
3838
β::Optional{<:AbstractVector}, training::TrainingType,
39-
σ::F=identity, epsilon::Real=default_epsilon(x)) where {F}
39+
σ::F=identity, epsilon=default_epsilon(x)) where {F}
4040
# This API is kept for legacy purposes when we didn't support passing running stats
4141
return instancenorm(x, γ, β, nothing, nothing, training, σ, nothing, epsilon)
4242
end
4343

4444
function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector},
4545
β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
4646
rσ²::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity,
47-
momentum::Optional{<:Real}=0.1f0, epsilon::Real=default_epsilon(x)) where {F}
47+
momentum::Optional{<:Real}=0.1f0, epsilon=default_epsilon(x)) where {F}
4848
assert_valid_instancenorm_arguments(x)
4949

5050
y, rμₙ, rσ²ₙ = instancenorm_impl(

lib/LuxLib/src/api/layernorm.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Normalized Array of same size as `x`.
3636
"""
3737
function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray},
3838
bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1),
39-
epsilon::Real=default_epsilon(x)) where {F, xT, N}
39+
epsilon=default_epsilon(x)) where {F, xT, N}
4040
return layernorm_impl(
4141
x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon)
4242
end

lib/LuxLib/src/impl/batchnorm.jl

+13-13
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ CRC.@non_differentiable get_batchnorm_statistics(::Any...)
2727
function batchnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector},
2828
β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
2929
rσ²::Optional{<:AbstractVector}, training::StaticBool, act::F,
30-
momentum::Real, ϵ::Real) where {F, xT, N}
30+
momentum, ϵ) where {F, xT, N}
3131
(μ, σ²), (rμ, rσ²) = compute_batch_statistics(
3232
x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²),
3333
batchnorm_reduce_dims(x), training, momentum)
@@ -37,15 +37,15 @@ end
3737
function batchnorm_affine_normalize(
3838
act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N},
3939
σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
40-
β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N}
40+
β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N}
4141
return batchnorm_affine_normalize(
4242
internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ)
4343
end
4444

4545
function batchnorm_affine_normalize(
4646
::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N},
4747
σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
48-
β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N}
48+
β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N}
4949
return affine_normalize(
5050
act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ)
5151
end
@@ -54,7 +54,7 @@ function batchnorm_affine_normalize(
5454
opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N},
5555
μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N},
5656
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector},
57-
ϵ::Real) where {F, xT, μT, σ²T, N}
57+
ϵ) where {F, xT, μT, σ²T, N}
5858
x′ = reshape(x, :, size(x, N - 1), size(x, N))
5959
return reshape(
6060
batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ),
@@ -64,7 +64,7 @@ end
6464
@stable default_mode="disable" function batchnorm_affine_normalize_internal(
6565
opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 3},
6666
μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
67-
β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT}
67+
β::Optional{<:AbstractVector}, ϵ) where {F, xT}
6868
y = similar(x,
6969
promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
7070
safe_eltype(γ), safe_eltype(β)))
@@ -75,7 +75,7 @@ end
7575
function batchnorm_affine_normalize_internal!(
7676
y::AbstractArray{yT, 3}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 3},
7777
μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
78-
β::Optional{<:AbstractVector}, ϵ::Real,
78+
β::Optional{<:AbstractVector}, ϵ,
7979
γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT}
8080
N = size(y, 2)
8181
γ′ = γ′ === nothing ?
@@ -225,7 +225,7 @@ end
225225
function batchnorm_affine_normalize_internal!(
226226
y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3},
227227
μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
228-
β::Optional{<:AbstractVector}, ϵ::Real,
228+
β::Optional{<:AbstractVector}, ϵ,
229229
γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT}
230230
backend = KA.get_backend(y)
231231
run_ka_kernel(
@@ -278,7 +278,7 @@ function CRC.rrule(
278278
cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm_affine_normalize_internal),
279279
opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{T, N},
280280
μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
281-
β::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N}
281+
β::Optional{<:AbstractVector}, ϵ) where {F, T, N}
282282
y = similar(x,
283283
promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
284284
safe_eltype(γ), safe_eltype(β)))
@@ -304,7 +304,7 @@ end
304304

305305
function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{∂yT, 3},
306306
x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector,
307-
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real,
307+
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ,
308308
γ′::AbstractVector) where {∂yT, xT}
309309
∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²)
310310
∂γ = γ === nothing ? nothing : similar(γ)
@@ -322,7 +322,7 @@ function ∇batchnorm_affine_normalize_cpu!(
322322
∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT},
323323
∂σ²::AbstractVector{∂σ²T}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 3},
324324
x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing,
325-
ϵ::Real, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
325+
ϵ, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
326326
half = eltype(∂σ²)(0.5)
327327

328328
fill!(∂μ, 0)
@@ -361,7 +361,7 @@ function ∇batchnorm_affine_normalize_cpu!(
361361
∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT},
362362
∂σ²::AbstractVector{∂σ²T}, ∂γ::AbstractVector{∂γT},
363363
∂β::AbstractVector{∂βT}, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3},
364-
μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real,
364+
μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ,
365365
γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT}
366366
half = eltype(∂σ²)(0.5)
367367

@@ -406,7 +406,7 @@ end
406406
function ∇batchnorm_affine_normalize(
407407
opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 3},
408408
x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector,
409-
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real,
409+
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ,
410410
γ′::AbstractVector) where {∂yT, xT}
411411
∂x, ∂σ² = similar(x), similar(σ², size(x))
412412
∂γ = γ === nothing ? nothing : similar(γ, size(x))
@@ -425,7 +425,7 @@ function ∇batchnorm_affine_normalize!(
425425
∂x::AbstractArray{∂xT, 3}, ∂σ²::AbstractArray{∂σ²T, 3},
426426
∂γ::Optional{<:AbstractArray{<:Any, 3}}, ::GPUBroadcastOp,
427427
∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector,
428-
σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real,
428+
σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ,
429429
γ′::AbstractVector) where {∂xT, ∂σ²T, ∂yT, xT}
430430
backend = KA.get_backend(∂x)
431431
run_ka_kernel(

lib/LuxLib/src/impl/dropout.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,22 @@ function alpha_dropout(noise::AbstractArray, p, x::AbstractArray, α, A, B)
6262
end
6363

6464
@stable default_mode="disable" function alpha_dropout(
65-
::AbstractInternalArrayOpMode, noise::AbstractArray, p::Real,
66-
x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T}
65+
::AbstractInternalArrayOpMode, noise::AbstractArray, p,
66+
x::AbstractArray{T}, α, A, B) where {T}
6767
A′, B′, α = T(A), T(B), T(α)
6868
return @. muladd(ifelse(noise > p, x, α), A′, B′)
6969
end
7070

7171
@stable default_mode="disable" function alpha_dropout(
72-
opmode::LoopedArrayOp, noise::AbstractArray, p::Real,
73-
x::AbstractArray, α::Real, A::Real, B::Real)
72+
opmode::LoopedArrayOp, noise::AbstractArray, p,
73+
x::AbstractArray, α, A, B)
7474
res = similar(x, promote_type(typeof(p), typeof(α)))
7575
alpha_dropout!(res, opmode, noise, p, x, α, A, B)
7676
return res
7777
end
7878

7979
function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArray,
80-
p::Real, x::AbstractArray, α::Real, A::Real, B::Real)
80+
p, x::AbstractArray, α, A, B)
8181
cond = similar(noise, Bool)
8282
y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x)))
8383
@simd ivdep for I in eachindex(noise, x, y, cond)
@@ -99,7 +99,7 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra
9999
end
100100

101101
function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode,
102-
noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real)
102+
noise::AbstractArray, p, x::AbstractArray, α, A, B)
103103
cond = noise .> p
104104
y = @. ifelse(cond, x, α) * A + B
105105

@@ -114,7 +114,7 @@ end
114114

115115
function alpha_dropout!(
116116
res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T},
117-
p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T}
117+
p, x::AbstractArray{T}, α, A, B) where {T}
118118
@simd ivdep for I in eachindex(noise, x, res)
119119
res[I] = ifelse(noise[I] > p, x[I], α) * A + B
120120
end

0 commit comments

Comments
 (0)