-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
155 lines (140 loc) · 6.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from trainer import Trainer
from utils.dataset import FaceDataset, FaceDataloader
from arcface import ArcFaceModel
from utils.losses import get_loss
from utils.optimizers import SAM, Lamb
from adan_pytorch import Adan
import json
import torch
from PIL import ImageFile
from tqdm import tqdm
import argparse
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/arcface.json', help='path to the config file')
parser.add_argument('--phase', type=str, default='train', help='train, test')
parser.add_argument('--device', type=str, default='0', help='train, test')
return parser.parse_args()
def train(args):
with open(args.config, "r") as jsonfile:
config = json.load(jsonfile)['train']
dataloader = FaceDataloader(root_dir=config['root_dir'],
val_size = 0.2,
random_seed = 0,
batch_size_train=config['batch_size_train'],
batch_size_val=config['batch_size_val'],
save_label_dict=True)
train_loader, val_loader = dataloader.get_dataloaders(num_worker=config['num_worker'])
num_classes = dataloader.num_classes
device = torch.device("cuda:"+args.device if torch.cuda.is_available() else "cpu")
print("Loss: ", config["loss"])
print("Device: ", device)
print("Number of classes: {num_classes}".format(num_classes=num_classes))
# Get the path of pretrained backbone
if config['use_pretrained']:
pretrained_backbone_path = config['pretrained_backbone_path']
else:
pretrained_backbone_path = None
# init model and train it
loss_function = get_loss(config['loss'])
model = ArcFaceModel(backbone_name=config['backbone'],
input_size=[112,112],
num_classes=num_classes,
use_pretrained=config['use_pretrained'],
pretrained_backbone_path=pretrained_backbone_path,
freeze=config['freeze_model'],
type_of_freeze='body_only')
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of trainable parameters: ', pytorch_total_params)
n_epochs = config['n_epochs']
# Get optimizer
optim_name = config["optimizer"].lower()
use_sam_optim = False
if optim_name == "sam":
optimizer = SAM(model.parameters(),
lr=config['learning_rate'],
momentum=config['sam_optim']['momentum'],
rho=config['sam_optim']['rho'],
adaptive=config['sam_optim']['adaptive'])
use_sam_optim = True
elif optim_name == "lamb":
optimizer = Lamb(model.parameters(), lr=config['learning_rate'], weight_decay=1e-5)
elif optim_name == "adan":
adan_config = config["adan_optim"]
betas = (adan_config["beta1"], adan_config["beta2"], adan_config["beta3"])
optimizer = Adan(model.parameters(),
lr=config['learning_rate'],
betas = betas,
weight_decay=adan_config["weight_decay"])
else: # default optimizer is Adam
optimizer = torch.optim.Adam(model.parameters(),
lr=config['learning_rate'],
weight_decay=1e-5)
# use learning rate scheduler
if config['use_lr_scheduler']:
scheduler_config = config['scheduler']
else:
scheduler_config = None
# Initialize Trainer and train model
trainer = Trainer(model=model,
n_epochs=n_epochs,
optimizer=optimizer,
loss_function=loss_function,
device=device,
train_loader=train_loader,
val_loader=val_loader)
trained_model = trainer.train(use_sam_optim=use_sam_optim,
verbose=config['verbose'],
scheduler_config=scheduler_config)
# Save the best model
if config['save_model']:
trainer.save_trained_model(trained_model = trained_model,
prefix = config['prefix'],
backbone_name = config['backbone'],
num_classes = num_classes,
split_modules=True)
def test(args):
'''
trained model must have fully connected layer
'''
device = torch.device("cuda:"+args.device if torch.cuda.is_available() else "cpu")
with open(args.config, "r") as jsonfile:
config = json.load(jsonfile)['test']
train_set = FaceDataset(root_dir=config['trainset_path'])
test_set = FaceDataset(root_dir=config['testset_path'])
test_loader = torch.torch.utils.data.DataLoader(test_set,
batch_size = config['batch_size'],
shuffle = False,
num_workers = config['num_worker'],
drop_last=False)
model = ArcFaceModel(backbone_name=config['backbone'],
input_size=[112,112],
num_classes=train_set.num_classes)
model.load_state_dict(torch.load(config['pretrained_model_path']))
model.to(device)
model.eval()
print("Model: ", config['pretrained_model_path'])
print("Device: ", device)
print("Test dataset: ", config["testset_path"])
print("Number of classes: {num_classes}".format(num_classes=test_set.num_classes))
acc = []
for _, (images, labels) in tqdm(enumerate(test_loader)):
images = images.to(device)
labels = labels.to(device)
with torch.no_grad():
logits = model(images)
y_probs = torch.softmax(logits, dim = 1)
correct = (torch.argmax(y_probs, dim = 1) == labels).type(torch.FloatTensor)
batch_accuracy = correct.mean()
acc.append(batch_accuracy)
test_accuracy = sum(acc)/len(acc)
print("Accuracy on test set: ", test_accuracy.item())
if __name__ == '__main__':
args = get_args()
if args.phase == 'train':
train(args)
elif args.phase == 'test':
test(args)
else:
print("phase "+args.phase+' is not available')