-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
126 lines (103 loc) · 4.08 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import sys
import cv2
import numpy as np
from matplotlib import pyplot as plt
from numpy.linalg import norm
def compare_images(original, predicted):
"""Compare the images and return the L2 norm and the saturation difference
Args:
original (image): the original image
predicted (image): the predicted image
"""
# function to calculate the distance between two images
# convert the images to Lab color space
original_lab = cv2.cvtColor(original, cv2.COLOR_BGR2LAB)
predicted_lab = cv2.cvtColor(predicted, cv2.COLOR_BGR2LAB)
L1, a1, b1 = cv2.split(original_lab)
L2, a2, b2 = cv2.split(predicted_lab)
# calculate the L2 norm between a and b
L2_norm = np.sqrt(np.sum((a2 - a1) ** 2)) + np.sqrt(np.sum((b2 - b1) ** 2))
# calculate the saturation differene netween the two images
# convert to HSV and get the saturation channel
sat1 = cv2.cvtColor(original, cv2.COLOR_BGR2HSV)[:, :, 1]
sat2 = cv2.cvtColor(predicted, cv2.COLOR_BGR2HSV)[:, :, 1]
sat_diff = np.abs(np.sum(sat2)-np.sum(sat1))/np.sum(sat1)
return L2_norm, sat_diff
def read_images(path):
""" Read all images in the folder
Args:
path (str): the path to the folder
Returns:
list: a list of images
"""
images = []
for filename in os.listdir(path):
if filename.endswith(".jpg"):
image = cv2.imread(os.path.join(path, filename))
# resize to 256x256 and then center crop to 224x224
image = cv2.resize(image, (256, 256))[16:240, 16:240]
images.append(image)
return images
def compare_folders(original_path, predicted_path):
""" Compare all the images in the folders
Args:
original_path (str): the path to the original images
predicted_path (str): the path to the predicted images
"""
# read the images in the folders
original_images = read_images(original_path)
predicted_images = read_images(predicted_path)
# compare the images and sum the L2 norm and the saturation difference
L2_norm = 0
sat_diff = 0
for original, predicted in zip(original_images, predicted_images):
L2, sat_diff = compare_images(original, predicted)
# print(f"L2 norm: {L2}, saturation difference: {sat_diff}")
L2_norm += L2
sat_diff += sat_diff
return L2_norm/len(original_images), sat_diff/len(original_images)
def plot_images(original_path, predicted_path):
""" Plot the images in the folders
Args:
original_path (str): the path to the original images
predicted_path (str): the path to the predicted images
"""
# read the images in the folders
original_images = read_images(original_path)
# convert to RGB
original_images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
for image in original_images]
predicted_images = read_images(predicted_path)
# convert to RGB
predicted_images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
for image in predicted_images]
# show original image against predicted image using plt
for i in range(1, 20, 2):
plt.subplot(5, 4, i)
plt.imshow(original_images[i // 2])
plt.title(f"Original image {i//2}")
plt.axis('off')
plt.subplot(5, 4, i+1)
plt.imshow(predicted_images[i // 2])
plt.title(f"Predicted image {i//2}")
plt.axis('off')
plt.show()
def main(argv):
""" Evaluate the performance of the model"""
# validate the arguments
if len(argv) != 3:
print("Usage: python3 evaluation.py <original_path> <predicted_path>")
sys.exit(1)
# get the paths to the original and predicted images
original_path = argv[1]
predicted_path = argv[2]
# compare the images in the folders
L2_norm, sat_diff = compare_folders(original_path, predicted_path)
print(f"L2 norm: {L2_norm}, saturation difference: {sat_diff}")
plot_images(original_path, predicted_path)
# close the program when the user presses q
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == "__main__":
main(sys.argv)