Skip to content

Commit e1278a9

Browse files
authored
Try #1837:
2 parents cce7ad0 + 7e4480b commit e1278a9

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

src/layers/basic.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ end
154154
@functor Dense
155155

156156
function (a::Dense)(x::AbstractVecOrMat)
157-
W, b, σ = a.weight, a.bias, a.σ
157+
W, b= a.weight, a.bias
158+
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
158159
return σ.(W*x .+ b)
159160
end
160161

src/layers/conv.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ end
161161
@functor Conv
162162

163163
function (c::Conv)(x::AbstractArray)
164-
σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1)
164+
b = reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1)
165+
σ = NNlib.fast_act(c.σ, x)
165166
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
166167
σ.(conv(x, c.weight, cdims) .+ b)
167168
end
@@ -278,7 +279,8 @@ end
278279
@nograd conv_transpose_dims
279280

280281
function (c::ConvTranspose)(x::AbstractArray)
281-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
282+
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
283+
σ = NNlib.fast_act(c.σ, x)
282284
cdims = conv_transpose_dims(c, x)
283285
σ.(∇conv_data(x, c.weight, cdims) .+ b)
284286
end
@@ -371,7 +373,8 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
371373
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
372374

373375
function (c::DepthwiseConv)(x)
374-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
376+
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
377+
σ = NNlib.fast_act(c.σ, x)
375378
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
376379
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
377380
end
@@ -450,7 +453,8 @@ function crosscor(x, w, ddims::DenseConvDims)
450453
end
451454

452455
function (c::CrossCor)(x::AbstractArray)
453-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
456+
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
457+
σ = NNlib.fast_act(c.σ, x)
454458
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
455459
σ.(crosscor(x, c.weight, cdims) .+ b)
456460
end

src/layers/recurrent.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zero
117117
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
118118

119119
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
120-
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
120+
Wi, Wh, b = m.Wi, m.Wh, m.b
121+
σ = NNlib.fast_act(m.σ, x)
121122
h = σ.(Wi*x .+ Wh*h .+ b)
122123
return h, reshape_cell_output(h, x)
123124
end
@@ -224,8 +225,8 @@ function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{Abstr
224225
b, o = m.b, size(h, 1)
225226
g = m.Wi*x .+ m.Wh*h .+ b
226227
input, forget, cell, output = multigate(g, o, Val(4))
227-
c′ = @. σ(forget) * c + σ(input) * tanh(cell)
228-
h′ = @. σ(output) * tanh(c′)
228+
c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
229+
h′ = @. sigmoid_fast(output) * tanh_fast(c′)
229230
return (h′, c′), reshape_cell_output(h′, x)
230231
end
231232

@@ -309,7 +310,7 @@ function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},O
309310
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
310311
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
311312
r, z = _gru_output(gxs, ghs, bs)
312-
= @. tanh(gxs[3] + r * ghs[3] + bs[3])
313+
= @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
313314
h′ = @. (1 - z) *+ z * h
314315
return h′, reshape_cell_output(h′, x)
315316
end
@@ -387,7 +388,7 @@ function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T}
387388
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
388389
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
389390
r, z = _gru_output(gxs, ghs, bs)
390-
= tanh.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
391+
= tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
391392
h′ = @. (1 - z) *+ z * h
392393
return h′, reshape_cell_output(h′, x)
393394
end

0 commit comments

Comments
 (0)