-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathselect_ref_images.py
More file actions
186 lines (139 loc) · 7.23 KB
/
select_ref_images.py
File metadata and controls
186 lines (139 loc) · 7.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import shutil
import argparse
import torch
import random
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from model import SiameseNetwork
from config import *
class ImageDataset(Dataset):
"""Custom Dataset to load images and apply transformations."""
def __init__(self, image_paths: list, transform: transforms.Compose):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
img = Image.open(img_path).convert('RGB')
img = self.transform(img)
return img, img_path
def pairwise_dissimilarity_siamese(model: torch.nn.Module, img1: torch.Tensor, img2: torch.Tensor):
"""Compute dissimilarity between two images using the Siamese network."""
model.eval()
with torch.no_grad():
# Compute similarity using the Siamese model (assuming it returns cosine similarity)
similarity = model(img1, img2)
dissimilarity = 1 - similarity # Cosine dissimilarity = 1 - similarity
return dissimilarity
def compute_pairwise_dissimilarity(model: torch.nn.Module, images: list):
"""Compute pairwise dissimilarities between all images in the current set."""
dissimilarity_matrix = np.zeros((len(images), len(images))) # To store dissimilarities
for i in tqdm(range(len(images)), desc="Computing pairwise dissimilarities"):
for j in range(i + 1, len(images)):
img1 = images[i].unsqueeze(0).to(DEVICE)
img2 = images[j].unsqueeze(0).to(DEVICE)
dissimilarity = pairwise_dissimilarity_siamese(model, img1, img2)
dissimilarity_matrix[i, j] = dissimilarity.item()
dissimilarity_matrix[j, i] = dissimilarity.item()
return dissimilarity_matrix
def select_top_dissimilar_images(dissimilarity_matrix, topk):
"""Select K most dissimilar images using the dissimilarity matrix."""
dissimilarity_scores = dissimilarity_matrix.sum(axis=1) # Sum dissimilarities for each image
top_indices = np.argsort(dissimilarity_scores)[-topk:] # Select top K dissimilar images
return top_indices
def iterative_dissimilarity_selection(model: torch.nn.Module, image_paths: list, num_refs: int, topk: int, batch_size: int):
"""Iteratively select K dissimilar images from subsets of N images."""
selected_images = []
while len(image_paths) > 0:
# Select next batch of num_refs images
current_batch_paths = image_paths[:num_refs]
image_paths = image_paths[num_refs:] # Remove these images from the remaining set
# Create dataset and dataloader for batch processing
dataset = ImageDataset(current_batch_paths, transform=BASIC_TRANSFORM)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8)
# Load images into memory and keep track of their paths
images, batch_image_paths = [], []
for img_batch, img_paths in dataloader:
images.append(img_batch)
batch_image_paths.extend(img_paths)
images = torch.cat(images, dim=0)
# Compute pairwise dissimilarity for this batch
dissimilarity_matrix = compute_pairwise_dissimilarity(model, images)
# Select the top K dissimilar images from this batch
top_indices = select_top_dissimilar_images(dissimilarity_matrix, topk)
selected_images += [batch_image_paths[i] for i in top_indices]
return selected_images
def iterative_selection_and_refinement(model: torch.nn.Module, image_paths: list, num_refs: int, topk: int, batch_size: int):
"""Iteratively select N dissimilar images with refinement."""
while len(image_paths) > num_refs:
image_paths = iterative_dissimilarity_selection(
model, image_paths, num_refs, topk, batch_size
)
return image_paths
def save_reference_images(reference_images, output_dir):
"""Save selected reference images into the given output directory, maintaining the class folder structure."""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for class_name, image_paths in reference_images.items():
# Create class sub-directory in the output directory if it doesn't exist
class_output_dir = os.path.join(output_dir, class_name)
os.makedirs(class_output_dir, exist_ok=True)
for image_path in image_paths:
# Copy the image to the new output directory
output_image_path = os.path.join(class_output_dir, os.path.basename(image_path))
shutil.copy(image_path, output_image_path)
print(f"Saved {image_path} to {output_image_path}")
def find_highest_dissimilar_images_per_class(
model: torch.nn.Module,
data_root: str,
num_refs: int,
topk: int,
batch_size: int,
):
"""Find N images from each class that are the most dissimilar (within the class)."""
reference_images = {}
for class_name in CLASS_NAMES:
class_dir = os.path.join(data_root, class_name)
image_paths = [os.path.join(class_dir, f) for f in os.listdir(class_dir) if f.endswith(".png")]
print(f"Processing class {class_name} with {len(image_paths)}...")
num_refs_class = min([num_refs, len(image_paths)])
selected_images = iterative_selection_and_refinement(
model,
image_paths,
num_refs_class,
topk,
batch_size
)
reference_images[class_name] = selected_images
return reference_images
def get_args():
parser = argparse.ArgumentParser(description='Select the most diverse reference images for Siamese similarity model.')
parser.add_argument('-d', '--data-path', required=True, help='Path to the root data folder.')
parser.add_argument('-m', '--model-path', required=True, help='Path to the pre-trained Siamese model.')
parser.add_argument('-o', '--output-path', required=True, help='Path to save reference images.')
parser.add_argument('-n', '--num_refs', default=100, required=False, help='Number of reference images per class.')
parser.add_argument('-k', '--topk', default=100, required=False, help='Number of top dissimilar images per batch.')
parser.add_argument('-b', '--batch-size', default=64, required=False, help='Batch size for DataLoader.')
return parser.parse_args()
def main(args):
random.seed(42)
# Load model
model_path = args.model_path
model = SiameseNetwork(model_name=MODEL_BACKBONE, pretrained=False)
model.load_state_dict(torch.load(model_path))
model.to(DEVICE)
# Find the highest dissimlar images per class
reference_images = find_highest_dissimilar_images_per_class(model, args.data_path, args.num_refs, args.topk, args.batch_size)
# Save the selected reference images to the output directory
save_reference_images(reference_images, args.output_path)
# Output the saved reference images for each class
for class_name, image_paths in reference_images.items():
print(f"Most dissimilar images from {class_name} saved to {args.output_path}/{class_name} | {len(image_paths)}")
if __name__ == "__main__":
args = get_args()
main(args)