Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions armory/data/adversarial/carla_over_obj_det_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Class that handle the loading of a CARLA overhead object detection custom dataset
using PyTorch.
"""

from pathlib import Path
from typing import Any, List, Tuple

import numpy as np
from PIL import Image
from torchvision.datasets import CocoDetection

class CarlaOverObjtDetCustom(CocoDetection):
def __init__(
self,
root: str,
annFile: str,
modalities: List[str] = ["rgb", "foreground_mask", "patch_metadata"],
):
self.root = Path(root)
self.ann_file = Path(annFile)

self.images = {}
for modality in modalities:
self.images[modality] = self.root / modality
assert self.images[modality].exists()

# look for RGB or Depth images to load
self.image_path = None
if "rgb" in self.images:
self.image_path = self.images["rgb"]
elif "depth" in self.images:
self.image_path = self.images["depth"]

super().__init__(root=self.image_path, annFile=self.ann_file)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
x, y = super().__getitem__(index)
id = self.ids[index]
item_path = Path(self.coco.loadImgs(id)[0]["file_name"])

# Set depth perturbation bound based on split
# all images that starts with 1#######.png have patch located off sidewalk/street
if item_path.stem[0] == "1":
max_depth_perturb_meters = 3.0
else:
max_depth_perturb_meters = 0.03

mask = Image.open(self.images["foreground_mask"] / item_path.name)

patch_metadata = {
"gs_coords": np.load(
self.images["patch_metadata"] / (item_path.stem + "_coords.npy")
),
"avg_patch_depth": np.load(
self.images["patch_metadata"] / (item_path.stem + "_avg_depth.npy")
),
"mask": np.array(mask),
"max_depth_perturb_meters": max_depth_perturb_meters,
}
y = (y, patch_metadata)

return (x, y)
24 changes: 24 additions & 0 deletions armory/data/adversarial_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from armory.data.adversarial import imagenet_adversarial as IA # noqa: F401
from armory.data.adversarial import librispeech_adversarial as LA # noqa: F401
from armory.data.adversarial.apricot_metadata import ADV_PATCH_MAGIC_NUMBER_LABEL_ID
from armory.data.adversarial import carla_over_obj_det_custom as coodc

# Although these imports are unreferenced in this file, they are required for tfds to know they exist.

Expand Down Expand Up @@ -874,6 +875,29 @@ def both_fn(batch):
)


def carla_over_obj_det_custom(
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
ann_file: str = None,
preprocessing_fn: Callable = datasets.coco_image_preprocessing,
label_preprocessing_fn: Callable = datasets.custom_coco_label_preprocessing,
framework: str = "pytorch",
**kwargs,
) -> datasets.ArmoryDataGenerator:
ds = coodc.CarlaOverObjtDetCustom(root=dataset_dir, annFile=ann_file)
generator = datasets.ArmoryDataGenerator(
iter(ds),
size=len(ds),
batch_size=batch_size,
epochs=epochs,
preprocessing_fn=preprocessing_fn,
label_preprocessing_fn=label_preprocessing_fn,
)

return generator


class ClipVideoTrackingLabels:
"""
Truncate labels for CARLA video tracking, when max_frames is set
Expand Down
68 changes: 68 additions & 0 deletions armory/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,74 @@ def coco2017(
)


def coco_image_preprocessing(x):
x = np.array(x)
x = x.astype(np.float32)
normalized_x = (x - np.min(x)) / (np.max(x) - np.min(x))

# add a batch dimension
normalized_x = np.expand_dims(normalized_x, axis=0)

return normalized_x


def custom_coco_label_preprocessing(x, y):
y_tmp = y
if isinstance(y, tuple):
y_tmp, y_metadata = y

boxes = []
labels = []
image_id = []
y_transformed = {}
for label_dict in y_tmp:
# convert bbox format from [x, y, width, height] to [x1, y1, x2, y2]
bbox = label_dict.pop("bbox")
bbox[2] = bbox[0] + bbox[2]
bbox[3] = bbox[1] + bbox[3]

boxes.append(bbox)
labels.append(label_dict.pop("category_id"))
image_id.append(label_dict.pop("image_id"))

y_transformed["boxes"] = np.array(boxes, dtype=np.float32)
y_transformed["labels"] = np.array(labels)
y_transformed["image_id"] = np.array(image_id)

if isinstance(y, tuple):
y = (y_transformed, y_metadata)
else:
y = [y_transformed]

return y


def custom_coco_dataset(
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
ann_file: str = None,
preprocessing_fn: Callable = coco_image_preprocessing,
label_preprocessing_fn: Callable = custom_coco_label_preprocessing,
framework: str = "pytorch",
**kwargs,
) -> ArmoryDataGenerator:

from torchvision.datasets import CocoDetection

ds = CocoDetection(root=dataset_dir, annFile=ann_file)
generator = ArmoryDataGenerator(
iter(ds),
size=len(ds),
batch_size=batch_size,
epochs=epochs,
preprocessing_fn=preprocessing_fn,
label_preprocessing_fn=label_preprocessing_fn,
)

return generator


class So2SatContext:
def __init__(self):
self.default_type = np.float32
Expand Down