diff --git a/Augmentor/ImageUtilities.py b/Augmentor/ImageUtilities.py index 53719be..e0010e8 100644 --- a/Augmentor/ImageUtilities.py +++ b/Augmentor/ImageUtilities.py @@ -27,7 +27,7 @@ class AugmentorImage(object): Each image that is found by Augmentor during the initialisation of a Pipeline object is contained with a new AugmentorImage object. """ - def __init__(self, image_path, output_directory): + def __init__(self, image_path, output_directory,ground_truth_output_directory): """ To initialise an AugmentorImage object for any image, the image's file path is required, as well as that image's output directory, @@ -53,6 +53,7 @@ def __init__(self, image_path, output_directory): # Now we call the setters that we require. self.image_path = image_path self.output_directory = output_directory + self.ground_truth_output_directory = ground_truth_output_directory def __str__(self): return """ @@ -199,9 +200,11 @@ def extract_paths_and_extensions(image_path): return file_name, extension, root_path -def scan(source_directory, output_directory): +def scan(source_directory, output_directory,ground_truth_output_directory): abs_output_directory = os.path.abspath(output_directory) + abs_ground_truth_output_directory = os.path.join(ground_truth_output_directory) + files_and_directories = glob.glob(os.path.join(os.path.abspath(source_directory), '*')) directory_count = 0 @@ -226,7 +229,7 @@ def scan(source_directory, output_directory): parent_directory_name = os.path.basename(os.path.abspath(source_directory)) for image_path in scan_directory(source_directory): - a = AugmentorImage(image_path=image_path, output_directory=abs_output_directory) + a = AugmentorImage(image_path=image_path, output_directory=abs_output_directory,ground_truth_output_directory = abs_ground_truth_output_directory) a.class_label = parent_directory_name a.class_label_int = label_counter a.categorical_label = [label_counter] diff --git a/Augmentor/Pipeline.py b/Augmentor/Pipeline.py index 0d913f9..7ea23c9 100644 --- a/Augmentor/Pipeline.py +++ b/Augmentor/Pipeline.py @@ -43,7 +43,7 @@ class Pipeline(object): _valid_formats = ["PNG", "BMP", "GIF", "JPEG"] _legal_filters = ["NEAREST", "BICUBIC", "ANTIALIAS", "BILINEAR"] - def __init__(self, source_directory=None, output_directory="output", save_format=None): + def __init__(self, source_directory=None, output_directory="output",ground_truth_output_directory='gt',save_format=None): """ Create a new Pipeline object pointing to a directory containing your original image dataset. @@ -85,7 +85,7 @@ def __init__(self, source_directory=None, output_directory="output", save_format self._populate(source_directory=source_directory, output_directory=output_directory, ground_truth_directory=None, - ground_truth_output_directory=output_directory) + ground_truth_output_directory=ground_truth_output_directory) def __call__(self, augmentor_image): """ @@ -138,14 +138,23 @@ def _populate(self, source_directory, output_directory, ground_truth_directory, raise IOError("The ground truth source directory you specified does not exist.") # Get absolute path for output - abs_output_directory = os.path.join(source_directory, output_directory) + if output_directory == 'output': + abs_output_directory = os.path.join(source_directory, output_directory) + else: + abs_output_directory = output_directory + + # Get absolute path for gt + if ground_truth_output_directory == 'gt': + abs_ground_truth_output_directory = os.path.join(source_directory,ground_truth_output_directory) + else: + abs_ground_truth_output_directory = ground_truth_output_directory # Scan the directory that user supplied. - self.augmentor_images, self.class_labels = scan(source_directory, abs_output_directory) + self.augmentor_images, self.class_labels = scan(source_directory, abs_output_directory,abs_ground_truth_output_directory) - self._check_images(abs_output_directory) + self._check_images(abs_output_directory,abs_ground_truth_output_directory) - def _check_images(self, abs_output_directory): + def _check_images(self, abs_output_directory,abs_ground_truth_output_directory): """ Private method. Used to check and get the dimensions of all of the images :param abs_output_directory: the absolute path of the output directory @@ -158,6 +167,14 @@ def _check_images(self, abs_output_directory): os.makedirs(abs_output_directory) except IOError: print("Insufficient rights to read or write output directory (%s)" % abs_output_directory) + + # Check for ground truth directory + if not os.path.exists(abs_ground_truth_output_directory): + try: + os.makedirs(abs_ground_truth_output_directory) + except IOError: + print("Insufficient rights to read or write output directory (%s)" % abs_output_directory) + else: for class_label in self.class_labels: if not os.path.exists(os.path.join(abs_output_directory, str(class_label[0]))): @@ -221,13 +238,20 @@ def _execute(self, augmentor_image, save_to_disk=True, multi_threaded=True): # image = image.convert("RGB") for i in range(len(images)): if i == 0: - save_name = augmentor_image.class_label + "_original_" + os.path.basename(augmentor_image.image_path) + "_" + file_name \ + #save_name = augmentor_image.class_label + "_original_" + os.path.basename(augmentor_image.image_path) + "_" + file_name \ + # + "." + (self.save_format if self.save_format else augmentor_image.file_format) + + save_name = os.path.basename(augmentor_image.image_path)[:-4] + "_" + file_name \ + "." + (self.save_format if self.save_format else augmentor_image.file_format) + images[i].save(os.path.join(augmentor_image.output_directory, save_name)) else: - save_name = "_groundtruth_(" + str(i) + ")_" + augmentor_image.class_label + "_" + os.path.basename(augmentor_image.image_path) + "_" + file_name \ + # save_name = "_groundtruth_(" + str(i) + ")_" + augmentor_image.class_label + "_" + os.path.basename(augmentor_image.image_path) + "_" + file_name \ + # + "." + (self.save_format if self.save_format else augmentor_image.file_format) + save_name = os.path.basename(augmentor_image.image_path)[:-4] + "_" + file_name \ + "." + (self.save_format if self.save_format else augmentor_image.file_format) - images[i].save(os.path.join(augmentor_image.output_directory, save_name)) + + images[i].save(os.path.join(augmentor_image.ground_truth_output_directory, save_name)) except IOError as e: print("Error writing %s, %s. Change save_format to PNG?" % (file_name, e.message)) print("You can change the save format using the set_save_format(save_format) function.")