-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess_data.py
133 lines (111 loc) · 3.74 KB
/
preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
"""
Data loading and preprocessing of image files
"""
import numpy as np
import matplotlib.pyplot as plt
import pathlib
from pathlib import Path
import cv2
import os
from PIL import Image
def pre_process(train_path, val_path):
"""Function to preprocess the training and validation data.
Parameters
----------
train_path: str
Path for training data
val_path: str
Path for validation data
Returns
-------
Set of preprocessed data for PlacesCNN model: np array
x_train, y_train, x_test, y_test
"""
random_seed = 42
classes = os.listdir(train_path)
print("Class names: {0} \n Total no of classes: {1}".format(classes, len(classes)))
train_data, train_labels = load_data(classes, train_path)
print("Total no. of train images: {0} \n Total no. of train labels: {1}".format(len(train_data),len(train_labels)))
visualize_images(train_data, train_labels)
val_data, val_labels = load_data(classes, val_path)
print("Total no. of val images: {0} \n Total no. of val labels: {1}".format(len(val_data),len(val_labels)))
visualize_images(val_data, val_labels)
x_train, y_train = process(train_data, train_labels, random_seed)
x_test, y_test = process(val_data, val_labels, random_seed)
print("X_train: {0}, y_train: {1}".format(x_train.shape, y_train.shape))
print("X_test: {0}, y_test: {1}".format(x_test.shape, y_test.shape))
return x_train, y_train, x_test, y_test
def load_data(classes, path):
"""Helper function that loads image data with its respective labels. Images are resized to (227,227,3)
to be used in PlacesCNN network
Parameters
---------
classes: list
Respective image classes
path: str
Can be train/val path
Returns
-------
np array of images and labels
"""
images, labels = [],[]
for i, category in enumerate(classes):
for image_name in os.listdir(path+"/"+category):
img = cv2.imread(path+"/"+category+"/"+image_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_array = Image.fromarray(img, 'RGB')
#Resizing images to 227 x 227
resized_img = img_array.resize((227, 227))
images.append(np.array(resized_img))
labels.append(i)
return np.array(images), np.array(labels)
def visualize_images(images, labels):
"""Helper function to visualize some sample images.
Parameters
---------
images: np array
Train/val images to be visualized.
labels: np array
Respective image labels
Returns
-------
None
"""
plt.figure(1 , figsize = (19 , 10))
n = 0
#Visualizing only 6 random images from the set
for i in range(6):
n = n+1
r = np.random.randint(0 , images.shape[0] , 1)
plt.subplot(3 , 3 , n)
plt.subplots_adjust(hspace = 0.3 , wspace = 0.3)
plt.imshow(images[r[0]])
plt.title('Places Categories : {}'.format(labels[r[0]]))
plt.xticks([])
plt.yticks([])
plt.show()
def process(images, labels, r_seed):
"""Helper function to shuffle, randomize and normalize the images.
Parameters
---------
images: np array
Train/val images to be visualized.
labels: np array
Respective image labels
r_seed: int
Seeding of random number generator
Returns
-------
Preprocessed images: np array
images, labels
"""
n = np.arange(images.shape[0])
np.random.seed(r_seed)
np.random.shuffle(n)
image_shuffled = images[n]
labels_shuffled = labels[n]
images = image_shuffled.astype(np.float32)
labels = labels_shuffled.astype(np.int32)
images = images/255
return images, labels