Skip to content

Commit 73131bf

Browse files
committed
Formatting, and some tweaks
1 parent 99eb25a commit 73131bf

File tree

9 files changed

+76
-83
lines changed

9 files changed

+76
-83
lines changed

src/convnets/mobilenet/mobilenetv1.jl

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ function mobilenetv1(width_mult, config;
4545
Dense(inchannels, nclasses)))
4646
end
4747

48+
# Layer configurations for MobileNetv1
4849
const MOBILENETV1_CONFIGS = [
4950
# dw, c, s, r
5051
(false, 32, 2, 1),

src/convnets/mobilenet/mobilenetv2.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747

4848
# Layer configurations for MobileNetv2
4949
const MOBILENETV2_CONFIGS = [
50-
# t, c, n, s, a
50+
# t, c, n, s, a
5151
(1, 16, 1, 1, relu6),
5252
(6, 24, 2, 2, relu6),
5353
(6, 32, 3, 2, relu6),
@@ -57,7 +57,6 @@ const MOBILENETV2_CONFIGS = [
5757
(6, 320, 1, 1, relu6),
5858
]
5959

60-
# Model definition for MobileNetv2
6160
struct MobileNetv2
6261
layers::Any
6362
end

src/convnets/mobilenet/mobilenetv3.jl

+33-34
Original file line numberDiff line numberDiff line change
@@ -52,41 +52,40 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla
5252
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier))
5353
end
5454

55-
# Configurations for small and large mode for MobileNetv3
56-
MOBILENETV3_CONFIGS = Dict(:small => [
57-
# k, t, c, SE, a, s
58-
(3, 1, 16, 4, relu, 2),
59-
(3, 4.5, 24, nothing, relu, 2),
60-
(3, 3.67, 24, nothing, relu, 1),
61-
(5, 4, 40, 4, hardswish, 2),
62-
(5, 6, 40, 4, hardswish, 1),
63-
(5, 6, 40, 4, hardswish, 1),
64-
(5, 3, 48, 4, hardswish, 1),
65-
(5, 3, 48, 4, hardswish, 1),
66-
(5, 6, 96, 4, hardswish, 2),
67-
(5, 6, 96, 4, hardswish, 1),
68-
(5, 6, 96, 4, hardswish, 1),
69-
],
70-
:large => [
71-
# k, t, c, SE, a, s
72-
(3, 1, 16, nothing, relu, 1),
73-
(3, 4, 24, nothing, relu, 2),
74-
(3, 3, 24, nothing, relu, 1),
75-
(5, 3, 40, 4, relu, 2),
76-
(5, 3, 40, 4, relu, 1),
77-
(5, 3, 40, 4, relu, 1),
78-
(3, 6, 80, nothing, hardswish, 2),
79-
(3, 2.5, 80, nothing, hardswish, 1),
80-
(3, 2.3, 80, nothing, hardswish, 1),
81-
(3, 2.3, 80, nothing, hardswish, 1),
82-
(3, 6, 112, 4, hardswish, 1),
83-
(3, 6, 112, 4, hardswish, 1),
84-
(5, 6, 160, 4, hardswish, 2),
85-
(5, 6, 160, 4, hardswish, 1),
86-
(5, 6, 160, 4, hardswish, 1),
87-
])
55+
# Layer configurations for small and large models for MobileNetv3
56+
const MOBILENETV3_CONFIGS = Dict(:small => [
57+
# k, t, c, SE, a, s
58+
(3, 1, 16, 4, relu, 2),
59+
(3, 4.5, 24, nothing, relu, 2),
60+
(3, 3.67, 24, nothing, relu, 1),
61+
(5, 4, 40, 4, hardswish, 2),
62+
(5, 6, 40, 4, hardswish, 1),
63+
(5, 6, 40, 4, hardswish, 1),
64+
(5, 3, 48, 4, hardswish, 1),
65+
(5, 3, 48, 4, hardswish, 1),
66+
(5, 6, 96, 4, hardswish, 2),
67+
(5, 6, 96, 4, hardswish, 1),
68+
(5, 6, 96, 4, hardswish, 1),
69+
],
70+
:large => [
71+
# k, t, c, SE, a, s
72+
(3, 1, 16, nothing, relu, 1),
73+
(3, 4, 24, nothing, relu, 2),
74+
(3, 3, 24, nothing, relu, 1),
75+
(5, 3, 40, 4, relu, 2),
76+
(5, 3, 40, 4, relu, 1),
77+
(5, 3, 40, 4, relu, 1),
78+
(3, 6, 80, nothing, hardswish, 2),
79+
(3, 2.5, 80, nothing, hardswish, 1),
80+
(3, 2.3, 80, nothing, hardswish, 1),
81+
(3, 2.3, 80, nothing, hardswish, 1),
82+
(3, 6, 112, 4, hardswish, 1),
83+
(3, 6, 112, 4, hardswish, 1),
84+
(5, 6, 160, 4, hardswish, 2),
85+
(5, 6, 160, 4, hardswish, 1),
86+
(5, 6, 160, 4, hardswish, 1),
87+
])
8888

89-
# Model definition for MobileNetv3
9089
struct MobileNetv3
9190
layers::Any
9291
end

src/convnets/resnets/core.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer
215215
drop_block = DropBlock(blockschedule[schedule_idx])
216216
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
217217
norm_layer, revnorm, attn_fn, drop_path, drop_block)
218-
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm)
218+
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
219+
revnorm)
219220
# inplanes increases by expansion after each block
220221
inplanes = planes * expansion
221222
return block, downsample
@@ -248,7 +249,8 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer
248249
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
249250
reduction_factor, activation, norm_layer, revnorm,
250251
attn_fn, drop_path, drop_block)
251-
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm)
252+
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
253+
revnorm)
252254
# inplanes increases by expansion after each block
253255
inplanes = planes * expansion
254256
return block, downsample
@@ -298,13 +300,15 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer};
298300
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
299301
activation, norm_layer, revnorm, attn_fn,
300302
drop_block_rate, drop_path_rate,
301-
stride_fn = resnet_stride, planes_fn = resnet_planes,
303+
stride_fn = resnet_stride,
304+
planes_fn = resnet_planes,
302305
downsample_tuple = downsample_opt)
303306
elseif block_type == :bottleneck
304307
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
305308
reduction_factor, activation, norm_layer,
306309
revnorm, attn_fn, drop_block_rate, drop_path_rate,
307-
stride_fn = resnet_stride, planes_fn = resnet_planes,
310+
stride_fn = resnet_stride,
311+
planes_fn = resnet_planes,
308312
downsample_tuple = downsample_opt)
309313
else
310314
# TODO: write better message when we have link to dev docs for resnet

src/layers/Layers.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ include("../utilities.jl")
1616
include("attention.jl")
1717
export MHAttention
1818

19+
include("conv.jl")
20+
export conv_norm, depthwise_sep_conv_norm, invertedresidual
21+
22+
include("drop.jl")
23+
export DropBlock, DropPath
24+
1925
include("embeddings.jl")
2026
export PatchEmbedding, ViPosEmbedding, ClassTokens
2127

@@ -25,19 +31,13 @@ export mlp_block, gated_mlp_block, create_fc, create_classifier
2531
include("normalise.jl")
2632
export prenorm, ChannelLayerNorm
2733

28-
include("conv.jl")
29-
export conv_norm, depthwise_sep_conv_norm, invertedresidual
30-
31-
include("drop.jl")
32-
export DropBlock, DropPath, droppath_rates
33-
34-
include("selayers.jl")
35-
export squeeze_excite, effective_squeeze_excite
34+
include("pool.jl")
35+
export AdaptiveMeanMaxPool
3636

3737
include("scale.jl")
3838
export LayerScale, inputscale
3939

40-
include("pool.jl")
41-
export AdaptiveMeanMaxPool
40+
include("selayers.jl")
41+
export squeeze_excite, effective_squeeze_excite
4242

4343
end

src/layers/attention.jl

+13-26
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,33 @@
11
"""
2-
MHAttention(nheads::Integer, qkv_layer, attn_drop_rate, projection)
2+
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_dropout_rate = 0., proj_dropout_rate = 0.)
33
44
Multi-head self-attention layer.
55
66
# Arguments
77
8-
- `nheads`: Number of heads
9-
- `qkv_layer`: layer to be used for getting the query, key and value
10-
- `attn_drop_rate`: dropout rate after the self-attention layer
11-
- `projection`: projection layer to be used after self-attention
8+
- `planes`: number of input channels
9+
- `nheads`: number of heads
10+
- `qkv_bias`: whether to use bias in the layer to get the query, key and value
11+
- `attn_dropout_rate`: dropout rate after the self-attention layer
12+
- `proj_dropout_rate`: dropout rate after the projection layer
1213
"""
1314
struct MHAttention{P, Q, R}
1415
nheads::Int
1516
qkv_layer::P
16-
attn_drop_rate::Q
17+
attn_drop::Q
1718
projection::R
1819
end
20+
@functor MHAttention
1921

20-
"""
21-
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop_rate = 0., proj_drop_rate = 0.)
22-
23-
Multi-head self-attention layer.
24-
25-
# Arguments
26-
27-
- `planes`: number of input channels
28-
- `nheads`: number of heads
29-
- `qkv_bias`: whether to use bias in the layer to get the query, key and value
30-
- `attn_drop_rate`: dropout rate after the self-attention layer
31-
- `proj_drop_rate`: dropout rate after the projection layer
32-
"""
3322
function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
34-
attn_drop_rate = 0.0, proj_drop_rate = 0.0)
23+
attn_dropout_rate = 0.0, proj_dropout_rate = 0.0)
3524
@assert planes % nheads==0 "planes should be divisible by nheads"
3625
qkv_layer = Dense(planes, planes * 3; bias = qkv_bias)
37-
attn_drop_rate = Dropout(attn_drop_rate)
38-
proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate))
39-
return MHAttention(nheads, qkv_layer, attn_drop_rate, proj)
26+
attn_drop = Dropout(attn_dropout_rate)
27+
proj = Chain(Dense(planes, planes), Dropout(proj_dropout_rate))
28+
return MHAttention(nheads, qkv_layer, attn_drop, proj)
4029
end
4130

42-
@functor MHAttention
43-
4431
function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
4532
nfeatures, seq_len, batch_size = size(x)
4633
x_reshaped = reshape(x, nfeatures, seq_len * batch_size)
@@ -52,7 +39,7 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
5239
seq_len * batch_size)
5340
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
5441
m.nheads, seq_len * batch_size)
55-
attention = m.attn_drop_rate(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
42+
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
5643
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
5744
m.nheads, seq_len * batch_size)
5845
pre_projection = reshape(batched_mul(attention, value_reshaped),

src/layers/conv.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu
5252
return revnorm ? reverse(layers) : layers
5353
end
5454

55-
function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes,
56-
activation = identity; kwargs...)
55+
function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; kwargs...)
5756
inplanes, outplanes = ch
5857
return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...)
5958
end

src/layers/mlp.jl

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ function create_classifier(inplanes, nclasses; pool_layer = AdaptiveMeanPool((1,
6969
"Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used"
7070
end
7171
flatten_in_pool = !use_conv && pool_layer !== identity
72+
if use_conv
73+
@assert pool_layer === identity
74+
"`pool_layer` must be identity if `use_conv` is true"
75+
end
7276
global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer
7377
# Fully-connected layer
7478
fc = use_conv ? Conv((1, 1), inplanes => nclasses) : Dense(inplanes => nclasses)

src/vit-based/vit.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ Transformer as used in the base ViT architecture.
1515
function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.0)
1616
layers = [Chain(SkipConnection(prenorm(planes,
1717
MHAttention(planes, nheads;
18-
attn_drop_rate = dropout_rate,
19-
proj_drop_rate = dropout_rate)), +),
18+
attn_dropout_rate = dropout_rate,
19+
proj_dropout_rate = dropout_rate)), +),
2020
SkipConnection(prenorm(planes,
2121
mlp_block(planes, floor(Int, mlp_ratio * planes);
2222
dropout_rate)), +))
@@ -27,7 +27,7 @@ end
2727
"""
2828
vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16),
2929
embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1,
30-
emb_drop_rate = 0.1, pool = :class, nclasses = 1000)
30+
emb_dropout_rate = 0.1, pool = :class, nclasses = 1000)
3131
3232
Creates a Vision Transformer (ViT) model.
3333
([reference](https://arxiv.org/abs/2010.11929)).
@@ -48,14 +48,14 @@ Creates a Vision Transformer (ViT) model.
4848
"""
4949
function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16),
5050
embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1,
51-
emb_drop_rate = 0.1, pool = :class, nclasses = 1000)
51+
emb_dropout_rate = 0.1, pool = :class, nclasses = 1000)
5252
@assert pool in [:class, :mean]
5353
"Pool type must be either `:class` (class token) or `:mean` (mean pooling)"
5454
npatches = prod(imsize patch_size)
5555
return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes),
5656
ClassTokens(embedplanes),
5757
ViPosEmbedding(embedplanes, npatches + 1),
58-
Dropout(emb_drop_rate),
58+
Dropout(emb_dropout_rate),
5959
transformer_encoder(embedplanes, depth, nheads; mlp_ratio,
6060
dropout_rate),
6161
(pool == :class) ? x -> x[:, 1, :] : seconddimmean),

0 commit comments

Comments
 (0)