-
Notifications
You must be signed in to change notification settings - Fork 14
/
dataloader.py
173 lines (143 loc) · 6.51 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import cv2
import numpy as np
import torch
import torchvision
import opencv_transforms.functional as FF
from torchvision import datasets
from PIL import Image
def color_cluster(img, nclusters=9):
"""
Apply K-means clustering to the input image
Args:
img: Numpy array which has shape of (H, W, C)
nclusters: # of clusters (default = 9)
Returns:
color_palette: list of 3D numpy arrays which have same shape of that of input image
e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4]
and each component is (256, 256, 3) numpy array.
Note:
K-means clustering algorithm is quite computaionally intensive.
Thus, before extracting dominant colors, the input images are resized to x0.25 size.
"""
img_size = img.shape
small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
sample = small_img.reshape((-1, 3))
sample = np.float32(sample)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
flags = cv2.KMEANS_PP_CENTERS
_, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags)
centers = np.uint8(centers)
color_palette = []
for i in range(0, nclusters):
dominant_color = np.zeros(img_size, dtype='uint8')
dominant_color[:,:,:] = centers[i]
color_palette.append(dominant_color)
return color_palette
class PairImageFolder(datasets.ImageFolder):
"""
A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
This class works properly for paired image in form of [sketch, color_image]
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
sketch_net: The network to convert color image to sketch image
ncluster: Number of clusters when extracting color palette.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
Getitem:
img_edge: Edge image
img: Color Image
color_palette: Extracted color paltette
"""
def __init__(self, root, transform, sketch_net, ncluster):
super(PairImageFolder, self).__init__(root, transform)
self.ncluster = ncluster
self.sketch_net = sketch_net
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def __getitem__(self, index):
path, label = self.imgs[index]
img = self.loader(path)
img = np.asarray(img)
img = img[:, 0:512, :]
img = self.transform(img)
color_palette = color_cluster(img, nclusters=self.ncluster)
img = self.make_tensor(img)
with torch.no_grad():
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy()
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
img_edge = FF.to_tensor(img_edge)
for i in range(0, len(color_palette)):
color = color_palette[i]
color_palette[i] = self.make_tensor(color)
return img_edge, img, color_palette
def make_tensor(self, img):
img = FF.to_tensor(img)
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
return img
class GetImageFolder(datasets.ImageFolder):
"""
A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
sketch_net: The network to convert color image to sketch image
ncluster: Number of clusters when extracting color palette.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
Getitem:
img_edge: Edge image
img: Color Image
color_palette: Extracted color paltette
"""
def __init__(self, root, transform, sketch_net, ncluster):
super(GetImageFolder, self).__init__(root, transform)
self.ncluster = ncluster
self.sketch_net = sketch_net
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def __getitem__(self, index):
path, label = self.imgs[index]
img = self.loader(path)
img = np.asarray(img)
img = self.transform(img)
color_palette = color_cluster(img, nclusters=self.ncluster)
img = self.make_tensor(img)
with torch.no_grad():
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy()
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
img_edge = FF.to_tensor(img_edge)
for i in range(0, len(color_palette)):
color = color_palette[i]
color_palette[i] = self.make_tensor(color)
return img_edge, img, color_palette
def make_tensor(self, img):
img = FF.to_tensor(img)
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
return img