1
+ from .earlytrain import EarlyTrain
2
+ from .methods_utils .euclidean import euclidean_dist_pair_np
3
+ from .methods_utils .cossim import cossim_pair_np
4
+ import numpy as np
5
+ import torch
6
+ from .. import nets
7
+ from copy import deepcopy
8
+ from torchvision import transforms
9
+
10
+
11
+ class Cal (EarlyTrain ):
12
+ def __init__ (self , dst_train , args , fraction = 0.5 , random_seed = None , epochs = 200 , specific_model = None ,
13
+ balance = True , metric = "euclidean" , neighbors : int = 10 , pretrain_model : str = "ResNet18" , ** kwargs ):
14
+ super ().__init__ (dst_train , args , fraction , random_seed , epochs , specific_model , ** kwargs )
15
+
16
+ self .balance = balance
17
+
18
+ assert neighbors > 0 and neighbors < 100
19
+ self .neighbors = neighbors
20
+
21
+ if metric == "euclidean" :
22
+ self .metric = euclidean_dist_pair_np
23
+ elif metric == "cossim" :
24
+ self .metric = lambda a , b : - 1. * cossim_pair_np (a , b )
25
+ elif callable (metric ):
26
+ self .metric = metric
27
+ else :
28
+ self .metric = euclidean_dist_pair_np
29
+
30
+ self .pretrain_model = pretrain_model
31
+
32
+ def num_classes_mismatch (self ):
33
+ raise ValueError ("num_classes of pretrain dataset does not match that of the training dataset." )
34
+
35
+ def while_update (self , outputs , loss , targets , epoch , batch_idx , batch_size ):
36
+ if batch_idx % self .args .print_freq == 0 :
37
+ print ('| Epoch [%3d/%3d] Iter[%3d/%3d]\t \t Loss: %.4f' % (
38
+ epoch , self .epochs , batch_idx + 1 , (self .n_pretrain_size // batch_size ) + 1 , loss .item ()))
39
+
40
+ def find_knn (self ):
41
+ """
42
+ Find k-nearest-neighbor data points with the pretrained embedding model
43
+ :return: knn matrix
44
+ """
45
+
46
+ # Initialize pretrained model
47
+ model = nets .__dict__ [self .pretrain_model ](channel = self .args .channel , num_classes = self .args .num_classes ,
48
+ im_size = (224 , 224 ), record_embedding = True , no_grad = True ,
49
+ pretrained = True ).to (self .args .device )
50
+ model .eval ()
51
+
52
+ # Resize dst_train to 224*224
53
+ if self .args .im_size [0 ] != 224 or self .args .im_size [1 ] != 224 :
54
+ dst_train = deepcopy (self .dst_train )
55
+ dst_train .transform = transforms .Compose ([dst_train .transform , transforms .Resize (224 )])
56
+ else :
57
+ dst_train = self .dst_train
58
+
59
+ # Calculate the distance matrix and return knn results
60
+ if self .balance :
61
+ knn = []
62
+ for c in range (self .args .num_classes ):
63
+ class_index = np .arange (self .n_train )[self .dst_train .targets == c ]
64
+
65
+ # Start recording embedding vectors
66
+ embdeddings = []
67
+ batch_loader = torch .utils .data .DataLoader (torch .utils .data .Subset (dst_train , class_index ),
68
+ batch_size = self .args .selection_batch ,
69
+ num_workers = self .args .workers )
70
+ batch_num = len (batch_loader )
71
+ for i , (aa , _ ) in enumerate (batch_loader ):
72
+ if i % self .args .print_freq == 0 :
73
+ print ("| Caculating embeddings for batch [%3d/%3d]" % (i + 1 , batch_num ))
74
+ model (aa .to (self .args .device ))
75
+ embdeddings .append (model .embedding_recorder .embedding .flatten (1 ).cpu ().numpy ())
76
+
77
+ embdeddings = np .concatenate (embdeddings , axis = 0 )
78
+
79
+ knn .append (np .argsort (self .metric (embdeddings ), axis = 1 )[:, 1 :(self .neighbors + 1 )])
80
+ return knn
81
+ else :
82
+ # Start recording embedding vectors
83
+ embdeddings = []
84
+ batch_loader = torch .utils .data .DataLoader (dst_train , batch_size = self .args .selection_batch
85
+ ,num_workers = self .args .workers )
86
+ batch_num = len (batch_loader )
87
+
88
+ for i , (aa , _ ) in enumerate (batch_loader ):
89
+ if i % self .args .print_freq == 0 :
90
+ print ("| Caculating embeddings for batch [%3d/%3d]" % (i + 1 , batch_num ))
91
+ model (aa .to (self .args .device ))
92
+ embdeddings .append (model .embedding_recorder .embedding .flatten (1 ).cpu ().numpy ())
93
+ embdeddings = np .concatenate (embdeddings , axis = 0 )
94
+
95
+ return np .argsort (self .metric (embdeddings ), axis = 1 )[:, 1 :(self .neighbors + 1 )]
96
+
97
+ def calc_kl (self , knn , index = None ):
98
+ self .model .eval ()
99
+ self .model .no_grad = True
100
+ sample_num = self .n_train if index is None else len (index )
101
+ probs = np .zeros ([sample_num , self .args .num_classes ])
102
+
103
+ batch_loader = torch .utils .data .DataLoader (
104
+ self .dst_train if index is None else torch .utils .data .Subset (self .dst_train , index ),
105
+ batch_size = self .args .selection_batch , num_workers = self .args .workers )
106
+ batch_num = len (batch_loader )
107
+
108
+ for i , (inputs , _ ) in enumerate (batch_loader ):
109
+ probs [i * self .args .selection_batch :(i + 1 ) * self .args .selection_batch ] = torch .nn .functional .softmax (
110
+ self .model (inputs .to (self .args .device )), dim = 1 ).detach ().cpu ()
111
+
112
+ s = np .zeros (sample_num )
113
+ for i in range (0 , sample_num , self .args .selection_batch ):
114
+ if i % self .args .print_freq == 0 :
115
+ print ("| Caculating KL-divergence for batch [%3d/%3d]" % (i // self .args .selection_batch + 1 , batch_num ))
116
+ aa = np .expand_dims (probs [i :(i + self .args .selection_batch )], 1 ).repeat (self .neighbors , 1 )
117
+ bb = probs [knn [i :(i + self .args .selection_batch )], :]
118
+ s [i :(i + self .args .selection_batch )] = np .mean (
119
+ np .sum (0.5 * aa * np .log (aa / bb ) + 0.5 * bb * np .log (bb / aa ), axis = 2 ), axis = 1 )
120
+ self .model .no_grad = False
121
+ return s
122
+
123
+ def finish_run (self ):
124
+ scores = []
125
+ if self .balance :
126
+ selection_result = np .array ([], dtype = np .int32 )
127
+ for c , knn in zip (range (self .args .num_classes ), self .knn ):
128
+ class_index = np .arange (self .n_train )[self .dst_train .targets == c ]
129
+ scores .append (self .calc_kl (knn , class_index ))
130
+ selection_result = np .append (selection_result , class_index [np .argsort (
131
+ #self.calc_kl(knn, class_index))[::1][:round(self.fraction * len(class_index))]])
132
+ scores [- 1 ])[::1 ][:round (self .fraction * len (class_index ))]])
133
+ else :
134
+ selection_result = np .argsort (self .calc_kl (self .knn ))[::1 ][:self .coreset_size ]
135
+ return {"indices" : selection_result , "scores" :scores }
136
+
137
+ def select (self , ** kwargs ):
138
+ self .knn = self .find_knn ()
139
+ selection_result = self .run ()
140
+ return selection_result
0 commit comments