diff --git a/armory/data/adversarial/carla_over_obj_det_custom.py b/armory/data/adversarial/carla_over_obj_det_custom.py new file mode 100644 index 000000000..d0619906a --- /dev/null +++ b/armory/data/adversarial/carla_over_obj_det_custom.py @@ -0,0 +1,63 @@ +""" +Class that handle the loading of a CARLA overhead object detection custom dataset +using PyTorch. +""" + +from pathlib import Path +from typing import Any, List, Tuple + +import numpy as np +from PIL import Image +from torchvision.datasets import CocoDetection + +class CarlaOverObjtDetCustom(CocoDetection): + def __init__( + self, + root: str, + annFile: str, + modalities: List[str] = ["rgb", "foreground_mask", "patch_metadata"], + ): + self.root = Path(root) + self.ann_file = Path(annFile) + + self.images = {} + for modality in modalities: + self.images[modality] = self.root / modality + assert self.images[modality].exists() + + # look for RGB or Depth images to load + self.image_path = None + if "rgb" in self.images: + self.image_path = self.images["rgb"] + elif "depth" in self.images: + self.image_path = self.images["depth"] + + super().__init__(root=self.image_path, annFile=self.ann_file) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + x, y = super().__getitem__(index) + id = self.ids[index] + item_path = Path(self.coco.loadImgs(id)[0]["file_name"]) + + # Set depth perturbation bound based on split + # all images that starts with 1#######.png have patch located off sidewalk/street + if item_path.stem[0] == "1": + max_depth_perturb_meters = 3.0 + else: + max_depth_perturb_meters = 0.03 + + mask = Image.open(self.images["foreground_mask"] / item_path.name) + + patch_metadata = { + "gs_coords": np.load( + self.images["patch_metadata"] / (item_path.stem + "_coords.npy") + ), + "avg_patch_depth": np.load( + self.images["patch_metadata"] / (item_path.stem + "_avg_depth.npy") + ), + "mask": np.array(mask), + "max_depth_perturb_meters": max_depth_perturb_meters, + } + y = (y, patch_metadata) + + return (x, y) diff --git a/armory/data/adversarial_datasets.py b/armory/data/adversarial_datasets.py index 937563aa7..bf8cc6200 100644 --- a/armory/data/adversarial_datasets.py +++ b/armory/data/adversarial_datasets.py @@ -29,6 +29,7 @@ from armory.data.adversarial import imagenet_adversarial as IA # noqa: F401 from armory.data.adversarial import librispeech_adversarial as LA # noqa: F401 from armory.data.adversarial.apricot_metadata import ADV_PATCH_MAGIC_NUMBER_LABEL_ID +from armory.data.adversarial import carla_over_obj_det_custom as coodc # Although these imports are unreferenced in this file, they are required for tfds to know they exist. @@ -874,6 +875,29 @@ def both_fn(batch): ) +def carla_over_obj_det_custom( + epochs: int = 1, + batch_size: int = 1, + dataset_dir: str = None, + ann_file: str = None, + preprocessing_fn: Callable = datasets.coco_image_preprocessing, + label_preprocessing_fn: Callable = datasets.custom_coco_label_preprocessing, + framework: str = "pytorch", + **kwargs, +) -> datasets.ArmoryDataGenerator: + ds = coodc.CarlaOverObjtDetCustom(root=dataset_dir, annFile=ann_file) + generator = datasets.ArmoryDataGenerator( + iter(ds), + size=len(ds), + batch_size=batch_size, + epochs=epochs, + preprocessing_fn=preprocessing_fn, + label_preprocessing_fn=label_preprocessing_fn, + ) + + return generator + + class ClipVideoTrackingLabels: """ Truncate labels for CARLA video tracking, when max_frames is set diff --git a/armory/data/datasets.py b/armory/data/datasets.py index cec605402..fb4a09dd1 100644 --- a/armory/data/datasets.py +++ b/armory/data/datasets.py @@ -1991,6 +1991,74 @@ def coco2017( ) +def coco_image_preprocessing(x): + x = np.array(x) + x = x.astype(np.float32) + normalized_x = (x - np.min(x)) / (np.max(x) - np.min(x)) + + # add a batch dimension + normalized_x = np.expand_dims(normalized_x, axis=0) + + return normalized_x + + +def custom_coco_label_preprocessing(x, y): + y_tmp = y + if isinstance(y, tuple): + y_tmp, y_metadata = y + + boxes = [] + labels = [] + image_id = [] + y_transformed = {} + for label_dict in y_tmp: + # convert bbox format from [x, y, width, height] to [x1, y1, x2, y2] + bbox = label_dict.pop("bbox") + bbox[2] = bbox[0] + bbox[2] + bbox[3] = bbox[1] + bbox[3] + + boxes.append(bbox) + labels.append(label_dict.pop("category_id")) + image_id.append(label_dict.pop("image_id")) + + y_transformed["boxes"] = np.array(boxes, dtype=np.float32) + y_transformed["labels"] = np.array(labels) + y_transformed["image_id"] = np.array(image_id) + + if isinstance(y, tuple): + y = (y_transformed, y_metadata) + else: + y = [y_transformed] + + return y + + +def custom_coco_dataset( + epochs: int = 1, + batch_size: int = 1, + dataset_dir: str = None, + ann_file: str = None, + preprocessing_fn: Callable = coco_image_preprocessing, + label_preprocessing_fn: Callable = custom_coco_label_preprocessing, + framework: str = "pytorch", + **kwargs, +) -> ArmoryDataGenerator: + + from torchvision.datasets import CocoDetection + + ds = CocoDetection(root=dataset_dir, annFile=ann_file) + generator = ArmoryDataGenerator( + iter(ds), + size=len(ds), + batch_size=batch_size, + epochs=epochs, + preprocessing_fn=preprocessing_fn, + label_preprocessing_fn=label_preprocessing_fn, + ) + + return generator + + class So2SatContext: def __init__(self): self.default_type = np.float32