Skip to content

Commit 76d5b7e

Browse files
committed
Use the functions directly instead of Symbols in resnet
Also fix `tuplify`
1 parent adb2ca8 commit 76d5b7e

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

src/convnets/resnets/core.jl

+12-12
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte
292292
return Chain(backbone, classifier_fn(nfeaturemaps))
293293
end
294294

295-
function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer},
295+
function resnet(block_type, block_repeats::AbstractVector{<:Integer},
296296
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);
297297
cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64,
298298
reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256),
@@ -304,7 +304,7 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer},
304304
# Build stem
305305
stem = stem_fn(; inchannels)
306306
# Block builder
307-
if block_type === :basicblock
307+
if block_type == basicblock
308308
@assert cardinality==1 "Cardinality must be 1 for `basicblock`"
309309
@assert base_width==64 "Base width must be 64 for `basicblock`"
310310
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
@@ -314,15 +314,15 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer},
314314
planes_fn = resnet_planes,
315315
downsample_tuple = downsample_opt,
316316
kwargs...)
317-
elseif block_type === :bottleneck
317+
elseif block_type == bottleneck
318318
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
319319
reduction_factor, activation, norm_layer, revnorm,
320320
attn_fn, drop_block_rate, drop_path_rate,
321321
stride_fn = resnet_stride,
322322
planes_fn = resnet_planes,
323323
downsample_tuple = downsample_opt,
324324
kwargs...)
325-
elseif block_type === :bottle2neck
325+
elseif block_type == bottle2neck
326326
@assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`"
327327
@assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`"
328328
@assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`"
@@ -346,12 +346,12 @@ function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
346346
end
347347

348348
# block-layer configurations for ResNet-like models
349-
const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]),
350-
34 => (:basicblock, [3, 4, 6, 3]),
351-
50 => (:bottleneck, [3, 4, 6, 3]),
352-
101 => (:bottleneck, [3, 4, 23, 3]),
353-
152 => (:bottleneck, [3, 8, 36, 3]))
349+
const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]),
350+
34 => (basicblock, [3, 4, 6, 3]),
351+
50 => (bottleneck, [3, 4, 6, 3]),
352+
101 => (bottleneck, [3, 4, 23, 3]),
353+
152 => (bottleneck, [3, 8, 36, 3]))
354354

355-
const LRESNET_CONFIGS = Dict(50 => (:bottleneck, [3, 4, 6, 3]),
356-
101 => (:bottleneck, [3, 4, 23, 3]),
357-
152 => (:bottleneck, [3, 8, 36, 3]))
355+
const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]),
356+
101 => (bottleneck, [3, 4, 23, 3]),
357+
152 => (bottleneck, [3, 8, 36, 3]))

src/convnets/resnets/res2net.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
3333
for _ in 1:max(1, scale - 1)]
3434
reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) :
3535
Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...)))
36-
if is_first
37-
tuplify(x) = tuple(x...)
36+
tuplify = if is_first
37+
x -> tuple(x...)
3838
else
39-
tuplify(x) = tuple(x[1], tuple(x[2:end]...))
39+
x -> tuple(x[1], tuple(x[2:end]...))
4040
end
4141
layers = [conv_norm((1, 1), inplanes => width * scale, activation;
4242
norm_layer, revnorm, bias = false)...,
@@ -102,8 +102,8 @@ end
102102
function Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
103103
base_width::Integer = 26, inchannels::Integer = 3,
104104
nclasses::Integer = 1000)
105-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
106-
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2]; base_width, scale,
105+
_checkconfig(depth, keys(LRESNET_CONFIGS))
106+
layers = resnet(bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale,
107107
inchannels, nclasses)
108108
if pretrain
109109
loadpretrain!(layers, string("Res2Net", depth, "_", base_width, "x", scale))
@@ -143,7 +143,7 @@ function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
143143
base_width::Integer = 4, cardinality::Integer = 8,
144144
inchannels::Integer = 3, nclasses::Integer = 1000)
145145
_checkconfig(depth, keys(LRESNET_CONFIGS))
146-
layers = resnet(:bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale,
146+
layers = resnet(bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale,
147147
cardinality, inchannels, nclasses)
148148
if pretrain
149149
loadpretrain!(layers,

test/convnets.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
end
3434

3535
@testset "resnet" begin
36-
@testset for block_fn in [:basicblock, :bottleneck]
36+
@testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck]
3737
layer_list = [
3838
[2, 2, 2, 2],
3939
[3, 4, 6, 3],

0 commit comments

Comments
 (0)