Skip to content

Commit bc934aa

Browse files
committed
Fixup and tweaks
1 parent 3b00364 commit bc934aa

File tree

8 files changed

+49
-34
lines changed

8 files changed

+49
-34
lines changed

src/convnets/inception/googlenet.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Create an Inception-v1 model (commonly referred to as GoogLeNet)
3636
3737
- `nclasses`: the number of output classes
3838
"""
39-
function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000)
39+
function googlenet(; dropout_rate = 0.4, inchannels::Integer = 3, nclasses::Integer = 1000)
4040
backbone = Chain(Conv((7, 7), inchannels => 64; stride = 2, pad = 3),
4141
MaxPool((3, 3); stride = 2, pad = 1),
4242
Conv((1, 1), 64 => 64),
@@ -53,7 +53,7 @@ function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000)
5353
MaxPool((3, 3); stride = 2, pad = 1),
5454
_inceptionblock(832, 256, 160, 320, 32, 128, 128),
5555
_inceptionblock(832, 384, 192, 384, 48, 128, 128))
56-
return Chain(backbone, create_classifier(1024, nclasses; dropout_rate = 0.4))
56+
return Chain(backbone, create_classifier(1024, nclasses; dropout_rate))
5757
end
5858

5959
"""

src/convnets/inception/inceptionresnetv2.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0,
9696
end
9797

9898
"""
99-
InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000)
99+
InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3,
100+
nclasses::Integer = 1000)
100101
101102
Creates an InceptionResNetv2 model.
102103
([reference](https://arxiv.org/abs/1602.07261))
@@ -118,9 +119,8 @@ end
118119
@functor InceptionResNetv2
119120

120121
function InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3,
121-
dropout_rate = 0.0,
122122
nclasses::Integer = 1000)
123-
layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses)
123+
layers = inceptionresnetv2(; inchannels, nclasses)
124124
if pretrain
125125
loadpretrain!(layers, "InceptionResNetv2")
126126
end

src/convnets/inception/inceptionv3.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
135135
136136
- `nclasses`: the number of output classes
137137
"""
138-
function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000)
138+
function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000)
139139
backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
140140
conv_norm((3, 3), 32, 32)...,
141141
conv_norm((3, 3), 32, 64; pad = 1)...,
@@ -154,7 +154,7 @@ function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000)
154154
inceptionv3_d(768),
155155
inceptionv3_e(1280),
156156
inceptionv3_e(2048))
157-
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate = 0.2))
157+
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate))
158158
end
159159

160160
"""

src/convnets/inception/xception.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integ
6666
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
6767
depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)...,
6868
depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...)
69-
classifier = create_classifier(2048, nclasses; dropout_rate)
70-
return Chain(backbone, classifier)
69+
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate))
7170
end
7271

7372
"""

src/convnets/resnets/core.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
116116
end
117117

118118
# Shortcut configurations for the ResNet models
119-
const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
119+
const SHORTCUT_DICT = Dict(:A => (downsample_identity, downsample_identity),
120120
:B => (downsample_conv, downsample_identity),
121121
:C => (downsample_conv, downsample_conv),
122122
:D => (downsample_pool, downsample_identity))
@@ -288,8 +288,8 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte
288288
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
289289
backbone = Chain(stem, stage_blocks)
290290
# Build the classifier head
291-
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
292-
classifier = classifier_fn(nfeaturemaps)
291+
# nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
292+
classifier = classifier_fn(2048)
293293
return Chain(backbone, classifier)
294294
end
295295

@@ -343,7 +343,7 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer},
343343
connection$activation, classifier_fn)
344344
end
345345
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
346-
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...)
346+
return resnet(block_fn, block_repeats, SHORTCUT_DICT[downsample_opt]; kwargs...)
347347
end
348348

349349
# block-layer configurations for ResNet-like models

src/convnets/resnets/res2net.jl

+31-18
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,32 @@ Creates a bottleneck block as described in the Res2Net paper.
2222
"""
2323
function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
2424
cardinality::Integer = 1, base_width::Integer = 26,
25-
scale::Integer = 4, activation = relu, norm_layer = BatchNorm,
26-
revnorm::Bool = false, attn_fn = planes -> identity)
25+
scale::Integer = 4, activation = relu, is_first::Bool = false,
26+
norm_layer = BatchNorm, revnorm::Bool = false,
27+
attn_fn = planes -> identity)
2728
width = fld(planes * base_width, 64) * cardinality
2829
outplanes = planes * 4
29-
is_first = stride > 1
3030
pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity
3131
conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride,
3232
pad = 1, groups = cardinality, bias = false)...)
33-
for _ in 1:(max(1, scale - 1))]
33+
for _ in 1:max(1, scale - 1)]
3434
reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) :
35-
Parallel(cat_channels, identity, PairwiseFusion(+, conv_bns...))
35+
Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...)))
3636
tuplify(x) = is_first ? tuple(x...) : tuple(x[1], tuple(x[2:end]...))
37-
return Chain(conv_norm((1, 1), inplanes => width * scale, activation;
38-
norm_layer, revnorm, bias = false)...,
39-
chunk$(; size = width, dims = 3),
40-
tuplify, reslayer,
41-
conv_norm((1, 1), width * scale => outplanes, activation;
42-
norm_layer, revnorm, bias = false)...,
43-
attn_fn(outplanes))
37+
layers = [conv_norm((1, 1), inplanes => width * scale, activation;
38+
norm_layer, revnorm, bias = false)...,
39+
chunk$(; size = width, dims = 3), tuplify, reslayer,
40+
conv_norm((1, 1), width * scale => outplanes, activation;
41+
norm_layer, revnorm, bias = false)...,
42+
attn_fn(outplanes)]
43+
return Chain(filter(!=(identity), layers)...)
4444
end
4545

4646
function bottle2neck_builder(block_repeats::AbstractVector{<:Integer};
4747
inplanes::Integer = 64, cardinality::Integer = 1,
4848
base_width::Integer = 26, scale::Integer = 4,
4949
expansion::Integer = 4, norm_layer = BatchNorm,
50-
revnorm::Bool = false, activation = relu,
50+
revnorm::Bool = false, activation = relu,
5151
attn_fn = planes -> identity,
5252
stride_fn = resnet_stride, planes_fn = resnet_planes,
5353
downsample_tuple = (downsample_conv, downsample_identity))
@@ -63,8 +63,9 @@ function bottle2neck_builder(block_repeats::AbstractVector{<:Integer};
6363
stride = stride_fn(stage_idx, block_idx)
6464
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
6565
downsample_tuple[1] : downsample_tuple[2]
66+
is_first = (stride > 1 || downsample_fn != downsample_tuple[2]) ? true : false
6667
block = bottle2neck(inplanes, planes; stride, cardinality, base_width, scale,
67-
activation, norm_layer, revnorm, attn_fn)
68+
activation, is_first, norm_layer, revnorm, attn_fn)
6869
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
6970
revnorm)
7071
return block, downsample
@@ -92,19 +93,25 @@ Creates a Res2Net model with the specified depth, scale, and base width.
9293
struct Res2Net
9394
layers::Any
9495
end
96+
@functor Res2Net
9597

9698
function Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
9799
base_width::Integer = 26, inchannels::Integer = 3,
98100
nclasses::Integer = 1000)
99101
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
100-
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2], :C; base_width, scale,
102+
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2]; base_width, scale,
101103
inchannels, nclasses)
102104
if pretrain
103105
loadpretrain!(layers, string("Res2Net", depth, "_", base_width, "x", scale))
104106
end
105-
return ResNet(layers)
107+
return Res2Net(layers)
106108
end
107109

110+
(m::Res2Net)(x) = m.layers(x)
111+
112+
backbone(m::Res2Net) = m.layers[1]
113+
classifier(m::Res2Net) = m.layers[2]
114+
108115
"""
109116
Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
110117
base_width::Integer = 4, cardinality::Integer = 8,
@@ -126,17 +133,23 @@ Creates a Res2NeXt model with the specified depth, scale, base width and cardina
126133
struct Res2NeXt
127134
layers::Any
128135
end
136+
@functor Res2NeXt
129137

130138
function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
131139
base_width::Integer = 4, cardinality::Integer = 8,
132140
inchannels::Integer = 3, nclasses::Integer = 1000)
133141
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
134-
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2], :C; base_width, scale,
142+
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2]; base_width, scale,
135143
cardinality, inchannels, nclasses)
136144
if pretrain
137145
loadpretrain!(layers,
138146
string("Res2NeXt", depth, "_", base_width, "x", cardinality,
139147
"x", scale))
140148
end
141-
return ResNet(layers)
149+
return Res2NeXt(layers)
142150
end
151+
152+
(m::Res2NeXt)(x) = m.layers(x)
153+
154+
backbone(m::Res2NeXt) = m.layers[1]
155+
classifier(m::Res2NeXt) = m.layers[2]

src/utilities.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ Concatenate `x` and `y` (and any `z`s) along the channel dimension (third dimens
3838
Equivalent to `cat(x, y, zs...; dims=3)`.
3939
Convenient reduction operator for use with `Parallel`.
4040
"""
41-
cat_channels(xy...) = cat(xy...; dims = Val(3))
41+
cat_channels(xs::AbstractArray...) = cat(xs...; dims = Val(3))
42+
cat_channels(x::AbstractArray, y::Tuple) = cat_channels(x, y...)
43+
cat_channels(x::Tuple, y::AbstractArray...) = cat_channels(x..., y...)
44+
cat_channels(x::Tuple) = cat_channels(x...)
4245

4346
"""
4447
swapdims(perm)

test/convnets.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ end
126126
@testset for (base_width, scale) in [(26, 4), (48, 2), (14, 8), (26, 6), (26, 8)]
127127
m = Res2Net(50; base_width, scale)
128128
@test size(m(x_224)) == (1000, 1)
129-
if (Res2Net, depth, cardinality, base_width) in PRETRAINED_MODELS
129+
if (Res2Net, depth, base_width, scale) in PRETRAINED_MODELS
130130
@test acctest(Res2Net(depth, pretrain = true))
131131
else
132132
@test_throws ArgumentError Res2Net(depth, pretrain = true)
@@ -137,7 +137,7 @@ end
137137
@testset for (base_width, scale) in [(26, 4)]
138138
m = Res2Net(101; base_width, scale)
139139
@test size(m(x_224)) == (1000, 1)
140-
if (Res2Net, depth, cardinality, base_width) in PRETRAINED_MODELS
140+
if (Res2Net, depth, base_width, scale) in PRETRAINED_MODELS
141141
@test acctest(Res2Net(depth, pretrain = true))
142142
else
143143
@test_throws ArgumentError Res2Net(depth, pretrain = true)

0 commit comments

Comments
 (0)