-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
162 lines (142 loc) · 6.9 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
Authors: Hui Ren
Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
"""
import argparse
import os
import torch
import yaml
from termcolor import colored
from utils.common_config import get_train_transformations, get_val_transformations, \
get_train_dataset, get_train_dataloader, \
get_val_dataset, get_val_dataloader, \
get_optimizer, get_model
from utils.evaluate_utils import get_predictions, hungarian_evaluate
from PIL import Image
from utils.structure import feature_cluster
import re
from easydict import EasyDict as edict
parser = argparse.ArgumentParser(description='Evaluate models')
# parser.add_argument('--model', help='Location where model is saved')
parser.add_argument("--output_dir", default="tmp/", type=str, help="output_dir")
parser.add_argument('--visualize_prototypes', action='store_true',
help='Show the prototpye for each cluster')
parser.add_argument("--train_db_name", default="cifar_im", type=str, help="cifar_im, iNature_im , imagenet-r_im")
parser.add_argument("--val_db_name", default="cifar_im", type=str, help="cifar_im, iNature_im , imagenet-r_im")
parser.add_argument('--imbalance_ratio', default=0.01, type=float, help='imbalance_ratio for dataset')
parser.add_argument("--num_classes",default=[100],type=int,nargs="+")
parser.add_argument("--backbone", default="dino_vitb16", type=str, help="backbone: resnet18/resnet50/dino_vitb16")
parser.add_argument("--num_workers", default=5, type=int)
parser.add_argument("--model_take", default="select", help="ckpt/select")
parser.add_argument("--no_train", default=False, action="store_true")
parser.add_argument("--no_test", default=False, action="store_true")
parser.add_argument("--no_cluster" ,default=False, action="store_true")
parser.add_argument("--no_selflabel", default=False, action="store_true")
parser.add_argument("--eval_batch_size", default= 1024, type=int)
args = parser.parse_known_args()[0]
# def target_num_count(p):
# p['train_db_name'] = p['val_db_name']
# train_transformations = get_train_transformations(p)
# train_dataset = get_train_dataset(p, train_transformations,
# split='train', to_neighbors_dataset = False,indices=None)
# train_dataloader = get_val_dataloader(p, train_dataset)
# nums = torch.zeros(p['num_classes'][0])
# for batch in train_dataloader:
# target = batch['target']
# for i in range(p['num_classes'][0]):
# nums[i] += torch.sum(target==i)
# print(f"train target num count:{nums}")
# return nums
def eval(dataset, model, head_select, config, setname, save_path=None):
dataloader = get_val_dataloader(config, dataset)
if save_path is not None and os.path.exists(save_path):
predictions = torch.load(save_path)
else:
predictions = get_predictions(config, dataloader, model)
if save_path is not None:
torch.save(predictions,save_path)
# print(f"label distribution: {predictions[0]['targets'].bincount()}")
if "features" in predictions[-1]:
backbone_preds = feature_cluster(predictions[-1]["features"], cluster_num=config["num_classes"][0])
clustering_stats = hungarian_evaluate(0, [{"predictions":backbone_preds, "targets":predictions[-1]['targets']}], compute_confusion_matrix=False)
print(f"backbone cluster result: {clustering_stats}")
select_result=None
for i in range(len(predictions)):
if "predictions" in predictions[i]:
clustering_stats = hungarian_evaluate(i, predictions,
compute_confusion_matrix=False)
print(f"{setname} head {i} result: {clustering_stats}")
if i == head_select:
select_result = clustering_stats
print(f"{setname} result: {select_result}")
def main():
# Read config file
config=edict()
config.update(args.__dict__)
config['setup'] = "cluster"
print(config)
# Get dataset
print(colored('Get validation dataset ...', 'blue'))
transforms = get_val_transformations(config)
dataset_test = get_val_dataset(config, transforms)
dataset_train = get_train_dataset(config, transforms,split="train")
if args.model_take == "ckpt":
cluster_model_path = os.path.join(args.output_dir, "cluster/checkpoint.pth.tar")
elif args.model_take == "select":
cluster_model_path = os.path.join(args.output_dir, "cluster/model.pth.tar")
else:
raise NotImplementedError
selflabel_model_path = os.path.join(args.output_dir, "selflabel/checkpoint.pth.tar")
path = []
if not args.no_cluster and os.path.exists(cluster_model_path):
print(f"cluster model path:{cluster_model_path}")
path.append(("cluster", cluster_model_path))
if not args.no_selflabel and os.path.exists(selflabel_model_path):
print(f"selflabel model path:{selflabel_model_path}")
path.append(("selflabel", selflabel_model_path))
class_num_pattern = r'head.(\d+).'
for name, model_path in path:
print(f"========{name}========")
state_dict = torch.load(model_path, map_location='cpu')
head_num = 1
head_select=0
if "model" in state_dict:
model_state = state_dict['model']
if 'best_loss_head' in state_dict:
head_select = state_dict['best_loss_head']
elif "head" in state_dict:
head_select = state_dict['head']
else:
model_state = state_dict
head_type="linear"
num_classes = []
for key in model_state.keys():
if "head" in key:
tmp = re.findall(class_num_pattern, key)
if tmp is not None and "weight" in key:
head_num = max(head_num, int(tmp[0])+1)
num_classes.append(model_state[key].shape[0])
if "embedding.weight_v" in key:
head_type="cos"
config['num_classes'] = num_classes
config['head_type'] = head_type
config['num_heads'] = head_num
print(f"head type for {name}:{head_type}")
print(f"head num for {name}:{head_num}")
print(f"num classes for {name}:{num_classes}")
model = get_model(config, model_path)
model = torch.nn.DataParallel(model)
if "module" in list(model_state.keys())[0]:
missing = model.load_state_dict(model_state, strict=False)
else:
missing = model.module.load_state_dict(model_state, strict=False)
print(missing)
model.cuda()
if not args.no_train:
eval(dataset_train, model, head_select, config, "trainset",save_path=os.path.join(os.path.dirname(model_path),"output_train.pth"))
if not args.no_test:
eval(dataset_test, model, head_select, config, "test",save_path=os.path.join(os.path.dirname(model_path),"output_test.pth"))
if __name__ == "__main__":
print("EVAL_IM")
main()
print("Complete.")