Skip to content

Commit c58ba47

Browse files
committed
Cleanup
1 parent afd6f10 commit c58ba47

File tree

8 files changed

+164
-166
lines changed

8 files changed

+164
-166
lines changed

src/convnets/inception.jl

+26-23
Original file line numberDiff line numberDiff line change
@@ -279,18 +279,18 @@ function inceptionv4_c()
279279
end
280280

281281
"""
282-
inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
282+
inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000)
283283
284284
Create an Inceptionv4 model.
285285
([reference](https://arxiv.org/abs/1602.07261))
286286
287287
# Arguments
288288
289289
- `inchannels`: number of input channels.
290-
- `dropout`: rate of dropout in classifier head.
290+
- `drop_rate`: rate of dropout in classifier head.
291291
- `nclasses`: the number of output classes.
292292
"""
293-
function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
293+
function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000)
294294
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)...,
295295
conv_bn((3, 3), 32, 32)...,
296296
conv_bn((3, 3), 32, 64; pad = 1)...,
@@ -313,12 +313,13 @@ function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
313313
inceptionv4_c(),
314314
inceptionv4_c(),
315315
inceptionv4_c())
316-
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses))
316+
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate),
317+
Dense(1536, nclasses))
317318
return Chain(body, head)
318319
end
319320

320321
"""
321-
Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
322+
Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
322323
323324
Creates an Inceptionv4 model.
324325
([reference](https://arxiv.org/abs/1602.07261))
@@ -327,7 +328,7 @@ Creates an Inceptionv4 model.
327328
328329
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
329330
- `inchannels`: number of input channels.
330-
- `dropout`: rate of dropout in classifier head.
331+
- `drop_rate`: rate of dropout in classifier head.
331332
- `nclasses`: the number of output classes.
332333
333334
!!! warning
@@ -338,7 +339,7 @@ struct Inceptionv4
338339
layers::Any
339340
end
340341

341-
function Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
342+
function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
342343
layers = inceptionv4(; inchannels, dropout, nclasses)
343344
pretrain && loadpretrain!(layers, "Inceptionv4")
344345
return Inceptionv4(layers)
@@ -419,18 +420,18 @@ function block8(scale = 1.0f0; activation = identity)
419420
end
420421

421422
"""
422-
inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
423+
inceptionresnetv2(; inchannels = 3, drop_rate =0.0, nclasses = 1000)
423424
424425
Creates an InceptionResNetv2 model.
425426
([reference](https://arxiv.org/abs/1602.07261))
426427
427428
# Arguments
428429
429430
- `inchannels`: number of input channels.
430-
- `dropout`: rate of dropout in classifier head.
431+
- `drop_rate`: rate of dropout in classifier head.
431432
- `nclasses`: the number of output classes.
432433
"""
433-
function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
434+
function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000)
434435
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)...,
435436
conv_bn((3, 3), 32, 32)...,
436437
conv_bn((3, 3), 32, 64; pad = 1)...,
@@ -446,12 +447,13 @@ function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
446447
[block8(0.20f0) for _ in 1:9]...,
447448
block8(; activation = relu),
448449
conv_bn((1, 1), 2080, 1536)...)
449-
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses))
450+
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate),
451+
Dense(1536, nclasses))
450452
return Chain(body, head)
451453
end
452454

453455
"""
454-
InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
456+
InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000)
455457
456458
Creates an InceptionResNetv2 model.
457459
([reference](https://arxiv.org/abs/1602.07261))
@@ -460,7 +462,7 @@ Creates an InceptionResNetv2 model.
460462
461463
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
462464
- `inchannels`: number of input channels.
463-
- `dropout`: rate of dropout in classifier head.
465+
- `drop_rate`: rate of dropout in classifier head.
464466
- `nclasses`: the number of output classes.
465467
466468
!!! warning
@@ -471,9 +473,9 @@ struct InceptionResNetv2
471473
layers::Any
472474
end
473475

474-
function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0,
476+
function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0,
475477
nclasses = 1000)
476-
layers = inceptionresnetv2(; inchannels, dropout, nclasses)
478+
layers = inceptionresnetv2(; inchannels, drop_rate, nclasses)
477479
pretrain && loadpretrain!(layers, "InceptionResNetv2")
478480
return InceptionResNetv2(layers)
479481
end
@@ -533,18 +535,18 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
533535
end
534536

535537
"""
536-
xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
538+
xception(; inchannels = 3, drop_rate =0.0, nclasses = 1000)
537539
538540
Creates an Xception model.
539541
([reference](https://arxiv.org/abs/1610.02357))
540542
541543
# Arguments
542544
543545
- `inchannels`: number of input channels.
544-
- `dropout`: rate of dropout in classifier head.
546+
- `drop_rate`: rate of dropout in classifier head.
545547
- `nclasses`: the number of output classes.
546548
"""
547-
function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
549+
function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000)
548550
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)...,
549551
conv_bn((3, 3), 32, 64; bias = false)...,
550552
xception_block(64, 128, 2; stride = 2, start_with_relu = false),
@@ -554,7 +556,8 @@ function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
554556
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
555557
depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)...,
556558
depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...)
557-
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(2048, nclasses))
559+
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate),
560+
Dense(2048, nclasses))
558561
return Chain(body, head)
559562
end
560563

@@ -563,7 +566,7 @@ struct Xception
563566
end
564567

565568
"""
566-
Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
569+
Xception(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000)
567570
568571
Creates an Xception model.
569572
([reference](https://arxiv.org/abs/1610.02357))
@@ -572,15 +575,15 @@ Creates an Xception model.
572575
573576
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet.
574577
- `inchannels`: number of input channels.
575-
- `dropout`: rate of dropout in classifier head.
578+
- `drop_rate`: rate of dropout in classifier head.
576579
- `nclasses`: the number of output classes.
577580
578581
!!! warning
579582
580583
`Xception` does not currently support pretrained weights.
581584
"""
582-
function Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
583-
layers = xception(; inchannels, dropout, nclasses)
585+
function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
586+
layers = xception(; inchannels, drop_rate, nclasses)
584587
pretrain && loadpretrain!(layers, "xception")
585588
return Xception(layers)
586589
end

src/convnets/resnet.jl

+71-75
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,42 @@
1+
function drop_blocks(drop_prob = 0.0)
2+
return [
3+
identity,
4+
identity,
5+
DropBlock(drop_prob, 5, 0.25),
6+
DropBlock(drop_prob, 3, 1.00),
7+
]
8+
end
9+
10+
function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1,
11+
first_dilation = nothing, norm_layer = BatchNorm)
12+
kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size
13+
first_dilation = kernel_size[1] > 1 ?
14+
(!isnothing(first_dilation) ? first_dilation : dilation) : 1
15+
pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2
16+
return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad,
17+
dilation = first_dilation, bias = false),
18+
norm_layer(out_channels))
19+
end
20+
21+
function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1,
22+
first_dilation = nothing, norm_layer = BatchNorm)
23+
avg_stride = dilation == 1 ? stride : 1
24+
if stride == 1 && dilation == 1
25+
pool = identity
26+
else
27+
pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0
28+
pool = avg_pool_fn((2, 2); stride = avg_stride, pad)
29+
end
30+
return Chain(pool,
31+
Conv((1, 1), in_channels => out_channels; bias = false),
32+
norm_layer(out_channels))
33+
end
34+
135
function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1,
2-
base_width = 64,
3-
reduce_first = 1, dilation = 1, first_dilation = nothing,
4-
act_layer = relu, norm_layer = BatchNorm,
36+
base_width = 64, reduce_first = 1, dilation = 1,
37+
first_dilation = nothing, activation = relu, norm_layer = BatchNorm,
538
drop_block = identity, drop_path = identity)
6-
expansion = 1
39+
expansion = expansion_factor(basicblock)
740
@assert cardinality==1 "BasicBlock only supports cardinality of 1"
841
@assert base_width==64 "BasicBlock does not support changing base width"
942
first_planes = planes ÷ reduce_first
@@ -17,16 +50,16 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina
1750
dilation = dilation, bias = false),
1851
norm_layer(outplanes))
1952
return Chain(Parallel(+, downsample,
20-
Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_path)),
21-
act_layer)
53+
Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)),
54+
activation)
2255
end
56+
expansion_factor(::typeof(basicblock)) = 1
2357

2458
function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1,
25-
base_width = 64,
26-
reduce_first = 1, dilation = 1, first_dilation = nothing,
27-
act_layer = relu, norm_layer = BatchNorm,
59+
base_width = 64, reduce_first = 1, dilation = 1,
60+
first_dilation = nothing, activation = relu, norm_layer = BatchNorm,
2861
drop_block = identity, drop_path = identity)
29-
expansion = 4
62+
expansion = expansion_factor(bottleneck)
3063
width = floor(Int, planes * (base_width / 64)) * cardinality
3164
first_planes = width ÷ reduce_first
3265
outplanes = planes * expansion
@@ -39,62 +72,33 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina
3972
drop_block = drop_block === identity ? identity : drop_block()
4073
conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes))
4174
return Chain(Parallel(+, downsample,
42-
Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_block,
43-
act_layer, conv_bn3, drop_path)),
44-
act_layer)
45-
end
46-
47-
function drop_blocks(drop_prob = 0.0)
48-
return [identity, identity,
49-
drop_prob == 0.0 ? DropBlock(drop_prob, 5, 0.25) : identity,
50-
drop_prob == 0.0 ? DropBlock(drop_prob, 3, 1.00) : identity]
75+
Chain(conv_bn1, drop_block, activation, conv_bn2, drop_block,
76+
activation, conv_bn3, drop_path)),
77+
activation)
5178
end
79+
expansion_factor(::typeof(bottleneck)) = 4
5280

53-
function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1,
54-
first_dilation = nothing, norm_layer = BatchNorm)
55-
kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size
56-
first_dilation = kernel_size[1] > 1 ?
57-
(!isnothing(first_dilation) ? first_dilation : dilation) : 1
58-
pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2
59-
return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad,
60-
dilation = first_dilation, bias = false),
61-
norm_layer(out_channels))
62-
end
63-
64-
function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1,
65-
first_dilation = nothing, norm_layer = BatchNorm)
66-
avg_stride = dilation == 1 ? stride : 1
67-
if stride == 1 && dilation == 1
68-
pool = identity
69-
else
70-
pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0
71-
pool = avg_pool_fn((2, 2); stride = avg_stride, pad)
72-
end
73-
74-
return Chain(pool,
75-
Conv((1, 1), in_channels => out_channels; stride = 1, pad = 0,
76-
bias = false),
77-
norm_layer(out_channels))
78-
end
79-
80-
function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1,
81-
reduce_first = 1, output_stride = 32,
82-
down_kernel_size = 1, avg_down = false, drop_block_rate = 0.0,
83-
drop_path_rate = 0.0, kwargs...)
81+
function make_blocks(block_fn, channels, block_repeats, inplanes;
82+
reduce_first = 1, output_stride = 32, down_kernel_size = 1,
83+
avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0,
84+
kwargs...)
85+
expansion = expansion_factor(block_fn)
8486
kwarg_dict = Dict(kwargs...)
8587
stages = []
8688
net_block_idx = 1
8789
net_stride = 4
8890
dilation = prev_dilation = 1
89-
for (stage_idx, (planes, num_blocks, db)) in enumerate(zip(channels, block_repeats,
90-
drop_blocks(drop_block_rate)))
91+
for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels,
92+
block_repeats,
93+
drop_blocks(drop_block_rate)))
9194
stride = stage_idx == 1 ? 1 : 2
9295
if net_stride >= output_stride
9396
dilation *= stride
9497
stride = 1
9598
else
9699
net_stride *= stride
97100
end
101+
# first block needs to be handled differently for downsampling
98102
downsample = identity
99103
if stride != 1 || inplanes != planes * expansion
100104
downsample = avg_down ?
@@ -106,7 +110,7 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1,
106110
norm_layer = kwarg_dict[:norm_layer])
107111
end
108112
block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation,
109-
:drop_block => db, kwargs...)
113+
:drop_block => drop_block, kwargs...)
110114
blocks = []
111115
for block_idx in 1:num_blocks
112116
downsample = block_idx == 1 ? downsample : identity
@@ -127,15 +131,13 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1,
127131
end
128132

129133
function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride = 32,
130-
expansion = 1,
131134
cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default,
132-
replace_stem_pool = false, reduce_first = 1,
133-
down_kernel_size = (1, 1), avg_down = false, act_layer = relu,
134-
norm_layer = BatchNorm,
135+
replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1),
136+
avg_down = false, activation = relu, norm_layer = BatchNorm,
135137
drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0,
136138
block_kwargs...)
137-
@assert output_stride in (8, 16, 32)
138-
@assert stem_type in [:default, :deep, :deep_tiered]
139+
@assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)"
140+
@assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]"
139141
# Stem
140142
inplanes = stem_type == :deep ? stem_width * 2 : 64
141143
if stem_type == :deep
@@ -145,38 +147,32 @@ function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride
145147
end
146148
conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1,
147149
bias = false),
148-
norm_layer(stem_channels[1]),
149-
act_layer(),
150-
Conv((3, 3), stem_channels[1] => stem_channels[1]; stride = 1,
151-
pad = 1, bias = false),
152-
norm_layer(stem_channels[2]),
153-
act_layer(),
154-
Conv((3, 3), stem_channels[2] => inplanes; stride = 1, pad = 1,
155-
bias = false))
150+
norm_layer(stem_channels[1], activation),
151+
Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1,
152+
bias = false),
153+
norm_layer(stem_channels[2], activation),
154+
Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false))
156155
else
157156
conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false)
158157
end
159-
bn1 = norm_layer(inplanes)
160-
act1 = act_layer
158+
bn1 = norm_layer(inplanes, activation)
161159
# Stem pooling
162160
if replace_stem_pool
163161
stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1,
164162
bias = false),
165-
norm_layer(inplanes),
166-
act_layer)
163+
norm_layer(inplanes, activation))
167164
else
168165
stempool = MaxPool((3, 3); stride = 2, pad = 1)
169166
end
170-
stem = Chain(conv1, bn1, act1, stempool)
171-
167+
stem = Chain(conv1, bn1, stempool)
172168
# Feature Blocks
173169
channels = [64, 128, 256, 512]
174170
stage_blocks = make_blocks(block, channels, layers, inplanes; cardinality, base_width,
175171
output_stride, reduce_first, avg_down,
176-
down_kernel_size, act_layer, norm_layer,
172+
down_kernel_size, activation, norm_layer,
177173
drop_block_rate, drop_path_rate, block_kwargs...)
178-
179174
# Head (Pooling and Classifier)
175+
expansion = expansion_factor(block)
180176
num_features = 512 * expansion
181177
classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten,
182178
Dense(num_features, num_classes))

0 commit comments

Comments
 (0)