Skip to content

Commit 398912a

Browse files
libi01libi01
libi01
authored and
libi01
committed
🎉 initial commit
1 parent dc6a442 commit 398912a

16 files changed

+1830
-2
lines changed

README.md

+74-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,74 @@
1-
# DCQ
2-
Dynamic Class Queue for Large Scale Face Recognition in the Wild
1+
## Introduction
2+
3+
DCQ (Dynamic Class Queue) is a state-of-the-art face recognition method for training million-IDs datasets.
4+
5+
This repo is the official implementation for CVPR 2021 paper: Dynamic Class Queue for Large Scale Face Recognition In the Wild.
6+
[**[paper]**](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Dynamic_Class_Queue_for_Large_Scale_Face_Recognition_in_the_CVPR_2021_paper.pdf)
7+
8+
## News
9+
10+
**`2021-08-03`**: Initial code release.
11+
12+
## Quick Start
13+
14+
### Prerequisite
15+
16+
Install PaddlePaddle 2.1
17+
18+
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html
19+
20+
Download the MS1MV2 dataset and common test benchmarks via BaiduYun
21+
22+
url: https://pan.baidu.com/s/1PYY3h-jEVURWwQYvLzE3Mg
23+
password: m2m8
24+
25+
```bash
26+
# untar file
27+
cat xaa xab xac xad xae | tar xf -
28+
```
29+
30+
### Training
31+
32+
```bash
33+
# Train iresnet50
34+
bash train_scripts/train_dcq_ir50_q8192_ms1mv2.sh
35+
36+
# Train iresnet100
37+
bash train_scripts/train_dcq_ir100_q8192_ms1mv2.sh
38+
```
39+
40+
### Evaluation
41+
42+
```bash
43+
data_root=./DCQ_train_test_data/common_test_benchmarks
44+
model=Logs/dcq_ires50_q8192_ms1mv2
45+
epoch=19
46+
python eval/eval_verification.py $model $epoch --save-prefix lfw --filelist '$data_root/lfw.filelist' --label-path '$data_root/lfw_label.npy'
47+
python eval/eval_verification.py $model $epoch --save-prefix cplfw --filelist '$data_root/cplfw.filelist' --label-path '$data_root/cplfw_label.npy'
48+
python eval/eval_verification.py $model $epoch --save-prefix agedb_30 --filelist '$data_root/agedb_30.filelist' --label-path '$data_root/agedb_30_label.npy'
49+
```
50+
51+
## Contributing
52+
53+
Main contributors:
54+
55+
- Bi Li
56+
- Jianwei Li
57+
- Nan Peng
58+
59+
## Credit
60+
61+
This code is largely based on [**moco**](https://github.com/facebookresearch/moco) and [**face.evoLVe**](https://github.com/ZhaoJ9014/face.evoLVe.PyTorch).
62+
63+
## Citation
64+
65+
```
66+
@InProceedings{Li_2021_CVPR,
67+
author = {Li, Bi and Xi, Teng and Zhang, Gang and Feng, Haocheng and Han, Junyu and Liu, Jingtuo and Ding, Errui and Liu, Wenyu},
68+
title = {Dynamic Class Queue for Large Scale Face Recognition in the Wild},
69+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
70+
month = {June},
71+
year = {2021},
72+
pages = {3763-3772}
73+
}
74+
```

backbone/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .iresnet import *

backbone/iresnet.py

+229
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Improved ResNet backbone"""
16+
17+
from collections import namedtuple
18+
19+
import paddle
20+
import paddle.nn as nn
21+
from paddle.nn import functional as F
22+
from paddle.fluid.initializer import Constant
23+
from paddle.framework import get_default_dtype
24+
from paddle.nn import (
25+
Linear, Conv2D, BatchNorm1D, BatchNorm2D, ReLU,
26+
Sigmoid, Dropout, MaxPool2D, AdaptiveAvgPool2D, Sequential,
27+
Layer, Flatten)
28+
29+
__all__ = ['iresnet34', 'iresnet50', 'iresnet100', 'iresnet50_se', 'iresnet100_se']
30+
31+
32+
class PReLU(Layer):
33+
def __init__(self, num_parameters=1, init=0.25, weight_attr=None,
34+
name=None):
35+
super(PReLU, self).__init__()
36+
self._num_parameters = num_parameters
37+
self._init = init
38+
self._weight_attr = weight_attr
39+
self._name = name
40+
41+
self.weight = self.create_parameter(
42+
attr=self._weight_attr,
43+
shape=[self._num_parameters],
44+
dtype=get_default_dtype(),
45+
is_bias=False,
46+
default_initializer=Constant(self._init))
47+
48+
def forward(self, x):
49+
return F.prelu(x, self.weight)
50+
51+
def extra_repr(self):
52+
name_str = ', name={}'.format(self._name) if self._name else ''
53+
return 'num_parameters={}, init={}, dtype={}{}'.format(
54+
self._num_parameters, self._init, self._dtype, name_str)
55+
56+
57+
def l2_norm(input, axis=1):
58+
norm = paddle.norm(input, 2, axis, True)
59+
output = paddle.divide(input, norm)
60+
return output
61+
62+
63+
class SEModule(Layer):
64+
def __init__(self, channels, reduction):
65+
super(SEModule, self).__init__()
66+
self.avg_pool = AdaptiveAvgPool2D(1)
67+
self.fc1 = Conv2D(
68+
channels, channels // reduction, kernel_size=1, padding=0, bias_attr=False)
69+
self.relu = ReLU()
70+
self.fc2 = Conv2D(
71+
channels // reduction, channels, kernel_size=1, padding=0, bias_attr=False)
72+
self.sigmoid = Sigmoid()
73+
74+
def forward(self, x):
75+
module_input = x
76+
x = self.avg_pool(x)
77+
x = self.fc1(x)
78+
x = self.relu(x)
79+
x = self.fc2(x)
80+
x = self.sigmoid(x)
81+
return module_input * x
82+
83+
84+
class BottleneckIR(Layer):
85+
def __init__(self, in_channel, depth, stride):
86+
super(BottleneckIR, self).__init__()
87+
if in_channel == depth:
88+
self.shortcut_layer = MaxPool2D(1, stride)
89+
else:
90+
self.shortcut_layer = Sequential(
91+
Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False),
92+
BatchNorm2D(depth))
93+
self.res_layer = Sequential(
94+
BatchNorm2D(in_channel),
95+
Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False), PReLU(depth),
96+
Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False), BatchNorm2D(depth))
97+
98+
def forward(self, x):
99+
shortcut = self.shortcut_layer(x)
100+
res = self.res_layer(x)
101+
return res + shortcut
102+
103+
104+
class BottleneckIRSE(Layer):
105+
def __init__(self, in_channel, depth, stride):
106+
super(BottleneckIRSE, self).__init__()
107+
if in_channel == depth:
108+
self.shortcut_layer = MaxPool2D(1, stride)
109+
else:
110+
self.shortcut_layer = Sequential(
111+
Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False),
112+
BatchNorm2D(depth))
113+
self.res_layer = Sequential(
114+
BatchNorm2D(in_channel),
115+
Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False),
116+
PReLU(depth),
117+
Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False),
118+
BatchNorm2D(depth),
119+
SEModule(depth, 16)
120+
)
121+
122+
def forward(self, x):
123+
shortcut = self.shortcut_layer(x)
124+
res = self.res_layer(x)
125+
return res + shortcut
126+
127+
128+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
129+
'''A named tuple describing a ResNet block.'''
130+
131+
132+
def get_block(in_channel, depth, num_units, stride=2):
133+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
134+
135+
136+
def get_blocks(num_layers):
137+
if num_layers == 50:
138+
blocks = [
139+
get_block(in_channel=64, depth=64, num_units=3),
140+
get_block(in_channel=64, depth=128, num_units=4),
141+
get_block(in_channel=128, depth=256, num_units=14),
142+
get_block(in_channel=256, depth=512, num_units=3)
143+
]
144+
elif num_layers == 34:
145+
blocks = [
146+
get_block(in_channel=64, depth=64, num_units=3),
147+
get_block(in_channel=64, depth=128, num_units=4),
148+
get_block(in_channel=128, depth=256, num_units=6),
149+
get_block(in_channel=256, depth=512, num_units=3)
150+
]
151+
elif num_layers == 100:
152+
blocks = [
153+
get_block(in_channel=64, depth=64, num_units=3),
154+
get_block(in_channel=64, depth=128, num_units=13),
155+
get_block(in_channel=128, depth=256, num_units=30),
156+
get_block(in_channel=256, depth=512, num_units=3)
157+
]
158+
else:
159+
raise NotImplementedError
160+
161+
return blocks
162+
163+
164+
class Backbone(Layer):
165+
def __init__(self, input_size, num_layers, out_dim, mode='ir'):
166+
super(Backbone, self).__init__()
167+
assert input_size[0] in [112, 224], "input_size should be [112, 112] or [224, 224]"
168+
assert num_layers in [34, 50, 100, 152], "num_layers should be 50, 100 or 152"
169+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
170+
blocks = get_blocks(num_layers)
171+
if mode == 'ir':
172+
unit_module = BottleneckIR
173+
elif mode == 'ir_se':
174+
unit_module = BottleneckIRSE
175+
self.input_layer = Sequential(Conv2D(3, 64, (3, 3), 1, 1, bias_attr=False),
176+
BatchNorm2D(64),
177+
PReLU(64))
178+
if input_size[0] == 112:
179+
self.output_layer = Sequential(BatchNorm2D(512),
180+
Dropout(),
181+
Flatten(),
182+
Linear(512 * 7 * 7, out_dim),
183+
BatchNorm1D(out_dim))
184+
else:
185+
self.output_layer = Sequential(BatchNorm2D(512),
186+
Dropout(),
187+
Flatten(),
188+
Linear(512 * 14 * 14, out_dim),
189+
BatchNorm1D(out_dim))
190+
191+
modules = []
192+
for block in blocks:
193+
for bottleneck in block:
194+
modules.append(
195+
unit_module(bottleneck.in_channel,
196+
bottleneck.depth,
197+
bottleneck.stride))
198+
self.body = Sequential(*modules)
199+
200+
def forward(self, x):
201+
x = self.input_layer(x)
202+
x = self.body(x)
203+
x = self.output_layer(x)
204+
return x
205+
206+
207+
def iresnet34(num_classes, input_size=[112, 112], **kwargs):
208+
model = Backbone(input_size, 34, num_classes, 'ir')
209+
return model
210+
211+
212+
def iresnet50(num_classes, input_size=[112, 112], **kwargs):
213+
model = Backbone(input_size, 50, num_classes, 'ir')
214+
return model
215+
216+
217+
def iresnet100(num_classes, input_size=[112, 112], **kwargs):
218+
model = Backbone(input_size, 100, num_classes, 'ir')
219+
return model
220+
221+
222+
def iresnet50_se(num_classes, input_size=[112, 112], **kwargs):
223+
model = Backbone(input_size, 50, num_classes, 'ir_se')
224+
return model
225+
226+
227+
def iresnet100_se(num_classes, input_size=[112, 112], **kwargs):
228+
model = Backbone(input_size, 100, num_classes, 'ir_se')
229+
return model

data_proc/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)