@@ -117,7 +117,8 @@ RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zero
117
117
RNNCell (σ, init (out, in), init (out, out), initb (out), init_state (out,1 ))
118
118
119
119
function (m:: RNNCell{F,A,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {F,A,V,T}
120
- σ, Wi, Wh, b = m. σ, m. Wi, m. Wh, m. b
120
+ Wi, Wh, b = m. Wi, m. Wh, m. b
121
+ σ = NNlib. fast_act (m. σ, x)
121
122
h = σ .(Wi* x .+ Wh* h .+ b)
122
123
return h, reshape_cell_output (h, x)
123
124
end
@@ -224,8 +225,8 @@ function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{Abstr
224
225
b, o = m. b, size (h, 1 )
225
226
g = m. Wi* x .+ m. Wh* h .+ b
226
227
input, forget, cell, output = multigate (g, o, Val (4 ))
227
- c′ = @. σ (forget) * c + σ (input) * tanh (cell)
228
- h′ = @. σ (output) * tanh (c′)
228
+ c′ = @. sigmoid_fast (forget) * c + sigmoid_fast (input) * tanh_fast (cell)
229
+ h′ = @. sigmoid_fast (output) * tanh_fast (c′)
229
230
return (h′, c′), reshape_cell_output (h′, x)
230
231
end
231
232
@@ -309,7 +310,7 @@ function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},O
309
310
Wi, Wh, b, o = m. Wi, m. Wh, m. b, size (h, 1 )
310
311
gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (3 )), multigate (b, o, Val (3 ))
311
312
r, z = _gru_output (gxs, ghs, bs)
312
- h̃ = @. tanh (gxs[3 ] + r * ghs[3 ] + bs[3 ])
313
+ h̃ = @. tanh_fast (gxs[3 ] + r * ghs[3 ] + bs[3 ])
313
314
h′ = @. (1 - z) * h̃ + z * h
314
315
return h′, reshape_cell_output (h′, x)
315
316
end
@@ -387,7 +388,7 @@ function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T}
387
388
Wi, Wh, b, Wh_h̃, o = m. Wi, m. Wh, m. b, m. Wh_h̃, size (h, 1 )
388
389
gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (2 )), multigate (b, o, Val (3 ))
389
390
r, z = _gru_output (gxs, ghs, bs)
390
- h̃ = tanh .(gxs[3 ] .+ (Wh_h̃ * (r .* h)) .+ bs[3 ])
391
+ h̃ = tanh_fast .(gxs[3 ] .+ (Wh_h̃ * (r .* h)) .+ bs[3 ])
391
392
h′ = @. (1 - z) * h̃ + z * h
392
393
return h′, reshape_cell_output (h′, x)
393
394
end
0 commit comments