From 6573f6525ebb940d0f3f1068a2cc2e1f836f9a71 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 12 Jan 2022 16:48:48 -0600 Subject: [PATCH 1/2] Use NNlib.conv_bias_act for Conv --- src/layers/conv.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 519618e4be..15fb997f75 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -162,9 +162,13 @@ function (c::Conv)(x::AbstractArray) b = reshape(c.bias, map(_->1, c.stride)..., :, 1) σ = NNlib.fast_act(c.σ, x) cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) - σ.(conv(x, c.weight, cdims) .+ b) + _conv_bias_act(x, c.weight, cdims, b, σ) end +_conv_bias_act(x, w, cdims, b, σ) = NNlib.conv_bias_act(x, w, cdims, b, σ) +_conv_bias_act(x::CuArray, w::CuArray, cdims, b::Zeros, σ) = + _conv_bias_act(x, w, cdims, CUDA.zeros(size(b)...), σ) + _channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups _channels_out(l::Conv) = size(l.weight, ndims(l.weight)) From 8b5b26a008c21bca805a6801076b1a91ac4f1da0 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 12 Jan 2022 18:00:00 -0600 Subject: [PATCH 2/2] Skip conv_bias_act with Flux.Zeros --- src/layers/conv.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 15fb997f75..e9bffca108 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -166,6 +166,7 @@ function (c::Conv)(x::AbstractArray) end _conv_bias_act(x, w, cdims, b, σ) = NNlib.conv_bias_act(x, w, cdims, b, σ) +_conv_bias_act(x, w, cdims, ::Zeros, σ) = σ.(conv(x, w, cdims)) _conv_bias_act(x::CuArray, w::CuArray, cdims, b::Zeros, σ) = _conv_bias_act(x, w, cdims, CUDA.zeros(size(b)...), σ)