Skip to content

Commit adb2ca8

Browse files
theabhirathdarsnack
andcommitted
Hardcode large ResNet model dict for block configs
Also misc. cleanup Co-Authored-By: Kyle Daruwalla <[email protected]>
1 parent cf42dc7 commit adb2ca8

File tree

5 files changed

+20
-12
lines changed

5 files changed

+20
-12
lines changed

src/convnets/resnets/core.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
116116
end
117117

118118
# Shortcut configurations for the ResNet models
119-
const SHORTCUT_DICT = Dict(:A => (downsample_identity, downsample_identity),
119+
const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity),
120120
:B => (downsample_conv, downsample_identity),
121121
:C => (downsample_conv, downsample_conv),
122122
:D => (downsample_pool, downsample_identity))
@@ -342,7 +342,7 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer},
342342
connection$activation, classifier_fn)
343343
end
344344
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
345-
return resnet(block_fn, block_repeats, SHORTCUT_DICT[downsample_opt]; kwargs...)
345+
return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...)
346346
end
347347

348348
# block-layer configurations for ResNet-like models
@@ -351,3 +351,7 @@ const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]),
351351
50 => (:bottleneck, [3, 4, 6, 3]),
352352
101 => (:bottleneck, [3, 4, 23, 3]),
353353
152 => (:bottleneck, [3, 8, 36, 3]))
354+
355+
const LRESNET_CONFIGS = Dict(50 => (:bottleneck, [3, 4, 6, 3]),
356+
101 => (:bottleneck, [3, 4, 23, 3]),
357+
152 => (:bottleneck, [3, 8, 36, 3]))

src/convnets/resnets/res2net.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
3333
for _ in 1:max(1, scale - 1)]
3434
reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) :
3535
Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...)))
36-
tuplify(x) = is_first ? tuple(x...) : tuple(x[1], tuple(x[2:end]...))
36+
if is_first
37+
tuplify(x) = tuple(x...)
38+
else
39+
tuplify(x) = tuple(x[1], tuple(x[2:end]...))
40+
end
3741
layers = [conv_norm((1, 1), inplanes => width * scale, activation;
3842
norm_layer, revnorm, bias = false)...,
3943
chunk$(; size = width, dims = 3), tuplify, reslayer,
@@ -138,8 +142,8 @@ end
138142
function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
139143
base_width::Integer = 4, cardinality::Integer = 8,
140144
inchannels::Integer = 3, nclasses::Integer = 1000)
141-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
142-
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2]; base_width, scale,
145+
_checkconfig(depth, keys(LRESNET_CONFIGS))
146+
layers = resnet(:bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale,
143147
cardinality, inchannels, nclasses)
144148
if pretrain
145149
loadpretrain!(layers,
@@ -152,4 +156,4 @@ end
152156
(m::Res2NeXt)(x) = m.layers(x)
153157

154158
backbone(m::Res2NeXt) = m.layers[1]
155-
classifier(m::Res2NeXt) = m.layers[2]
159+
classifier(m::Res2NeXt) = m.layers[2]

src/convnets/resnets/resnet.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ end
5757

5858
function WideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3,
5959
nclasses::Integer = 1000)
60-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
61-
layers = resnet(RESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses)
60+
_checkconfig(depth, keys(LRESNET_CONFIGS))
61+
layers = resnet(LRESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses)
6262
if pretrain
6363
loadpretrain!(layers, string("WideResNet", depth))
6464
end

src/convnets/resnets/resnext.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ end
2929

3030
function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
3131
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
32-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
33-
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
32+
_checkconfig(depth, keys(LRESNET_CONFIGS))
33+
layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
3434
if pretrain
3535
loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d"))
3636
end

src/convnets/resnets/seresnet.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ end
6969

7070
function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
7171
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
72-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
73-
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
72+
_checkconfig(depth, keys(LRESNET_CONFIGS))
73+
layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
7474
attn_fn = squeeze_excite)
7575
if pretrain
7676
loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width))

0 commit comments

Comments
 (0)