-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.cotching.py
More file actions
executable file
·90 lines (78 loc) · 2.77 KB
/
train.cotching.py
File metadata and controls
executable file
·90 lines (78 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
from lz import *
from config import conf
import argparse
from pathlib import Path
torch.backends.cudnn.benchmark = True
def log_conf(conf):
conf2 = {k: v for k, v in conf.items() if not isinstance(v, (dict, np.ndarray))}
logging.info(f'training conf is {conf2}')
from exargs import parser
if __name__ == '__main__':
args = parser.parse_args()
if args.work_path:
conf.work_path = Path(args.work_path)
conf.model_path = conf.work_path / 'models'
conf.log_path = conf.work_path / 'log'
conf.save_path = conf.work_path / 'save'
else:
args.work_path = conf.work_path
conf.update(args.__dict__)
if conf.local_rank is not None:
torch.cuda.set_device(conf.local_rank)
torch.distributed.init_process_group(backend='nccl',
init_method="env://")
if torch.distributed.get_rank() != 0:
set_stream_logger(logging.WARNING)
from Learner import *
learner = face_cotching(conf, )
# learner = face_cotching_head(conf, )
ress = {}
for p in [
# 'mbfc.retina.cl.arc.cotch.cont',
# 'mbfc.cotch.mual.1e-3',
]:
learner.load_state(
resume_path=Path(f'work_space/{p}/models/'),
load_optimizer=False,
load_head=True, # todo note !!!
load_imp=False,
latest=True,
load_model2=True,
)
# res = learner.validate_ori(conf)
# ress[p] = res
# logging.warning(f'{p} res: {res}')
print(ress)
# learner.calc_img_feas(out='work_space/casia.r50.arc.h5')
# exit(0)
# learner.init_lr()
# conf.tri_wei = 0
# log_conf(conf)
# learner.train(conf, 1, name='xent')
learner.init_lr()
log_conf(conf)
# learner.warmup(conf, conf.warmup)
# learner.train(conf, conf.epochs)
# learner.train_dist(conf, conf.epochs)
# learner.train_simple(conf, conf.epochs)
# learner.train_cotching(conf, conf.epochs)
# learner.train_cotching_accbs(conf, conf.epochs)
if args.train_mode == 'mual':
learner.train_mual(conf, conf.epochs)
else:
learner.train_cotching_accbs_v2(conf, conf.epochs)
# learner.train_ghm(conf, conf.epochs)
# learner.train_with_wei(conf, conf.epochs)
# learner.train_use_test(conf, conf.epochs)
# res = learner.validate_ori(conf)
from tools.test_ijbc3 import test_ijbc3
res = test_ijbc3(conf, learner)
# log_lrs, losses = learner.find_lr(conf,
# # final_value=100,
# num=200,
# bloding_scale=1000)
# best_lr = 10 ** (log_lrs[np.argmin(losses)])
# print(best_lr)
# conf.lr = best_lr
# learner.push2redis()