forked from HasnainRaz/Fast-SRGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataloader.py
More file actions
123 lines (98 loc) · 4.42 KB
/
dataloader.py
File metadata and controls
123 lines (98 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import tensorflow as tf
import os
from tensorflow.python.ops import array_ops, math_ops
class DataLoader(object):
"""Data Loader for the SR GAN, that prepares a tf data object for training."""
def __init__(self, image_dir, hr_image_size):
"""
Initializes the dataloader.
Args:
image_dir: The path to the directory containing high resolution images.
hr_image_size: Integer, the crop size of the images to train on (High
resolution images will be cropped to this width and height).
Returns:
The dataloader object.
"""
self.image_paths = [os.path.join(image_dir, x) for x in os.listdir(image_dir)]
self.image_size = hr_image_size
def _parse_image(self, image_path):
"""
Function that loads the images given the path.
Args:
image_path: Path to an image file.
Returns:
image: A tf tensor of the loaded image.
"""
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
# Check if image is large enough
if tf.keras.backend.image_data_format() == 'channels_last':
shape = array_ops.shape(image)[:2]
else:
shape = array_ops.shape(image)[1:]
cond = math_ops.reduce_all(shape >= tf.constant(self.image_size))
image = tf.cond(cond, lambda: tf.identity(image),
lambda: tf.image.resize(image, [self.image_size, self.image_size]))
return image
def _random_crop(self, image):
"""
Function that crops the image according a defined width
and height.
Args:
image: A tf tensor of an image.
Returns:
image: A tf tensor of containing the cropped image.
"""
image = tf.image.random_crop(image, [self.image_size, self.image_size, 3])
return image
def _high_low_res_pairs(self, high_res):
"""
Function that generates a low resolution image given the
high resolution image. The downsampling factor is 4x.
Args:
high_res: A tf tensor of the high res image.
Returns:
low_res: A tf tensor of the low res image.
high_res: A tf tensor of the high res image.
"""
low_res = tf.image.resize(high_res,
[self.image_size // 4, self.image_size // 4],
method='bicubic')
return low_res, high_res
def _rescale(self, low_res, high_res):
"""
Function that rescales the pixel values to the -1 to 1 range.
For use with the generator output tanh function.
Args:
low_res: The tf tensor of the low res image.
high_res: The tf tensor of the high res image.
Returns:
low_res: The tf tensor of the low res image, rescaled.
high_res: the tf tensor of the high res image, rescaled.
"""
high_res = high_res * 2.0 - 1.0
return low_res, high_res
def dataset(self, batch_size, threads=4):
"""
Returns a tf dataset object with specified mappings.
Args:
batch_size: Int, The number of elements in a batch returned by the dataset.
threads: Int, CPU threads to use for multi-threaded operation.
Returns:
dataset: A tf dataset object.
"""
# Generate tf dataset from high res image paths.
dataset = tf.data.Dataset.from_tensor_slices(self.image_paths)
# Read the images
dataset = dataset.map(self._parse_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Crop out a piece for training
dataset = dataset.map(self._random_crop, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Generate low resolution by downsampling crop.
dataset = dataset.map(self._high_low_res_pairs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Rescale the values in the input
dataset = dataset.map(self._rescale, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Batch the input, drop remainder to get a defined batch size.
# Prefetch the data for optimal GPU utilization.
dataset = dataset.shuffle(30).batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
return dataset