Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, estimator, **kwargs):
kwargs.pop("hsv_upper_bound", [255, 255, 255])
) # [255, 255, 255] means unbounded above

self.perturbations = []

super().__init__(estimator=estimator, **kwargs)

def create_initial_image(self, size, hsv_lower_bound, hsv_upper_bound):
Expand Down Expand Up @@ -449,6 +451,7 @@ def generate(self, x, y=None, y_patch_metadata=None):

num_imgs = x.shape[0]
attacked_images = []
self.perturbations = []
for i in range(num_imgs):
# Adversarial patch attack, when used for object detection, requires ground truth
y_gt = dict()
Expand Down Expand Up @@ -556,6 +559,11 @@ def generate(self, x, y=None, y_patch_metadata=None):

patch, _ = super().generate(np.expand_dims(x[i], axis=0), y=[y_gt])

# Extract perturbation image
perturbation = self._patch.detach().cpu().numpy()
perturbation = np.transpose(perturbation, (1, 2, 0))
self.perturbations.append(perturbation)

# Patch image
x_tensor = torch.tensor(np.expand_dims(x[i], axis=0)).to(
self.estimator.device
Expand Down Expand Up @@ -583,4 +591,5 @@ def generate(self, x, y=None, y_patch_metadata=None):

attacked_images.append(patched_image)

self.perturbations = np.array(self.perturbations)
return np.array(attacked_images)
18 changes: 17 additions & 1 deletion armory/scenarios/carla_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Scenario Contributor: MITRE Corporation
"""

from armory.instrument.export import ObjectDetectionExporter
from armory.instrument.export import ExportMeter, ObjectDetectionExporter
from armory.logs import log
from armory.scenarios.object_detection import ObjectDetectionTask

Expand All @@ -22,6 +22,17 @@ def load_dataset(self):
if self.config["dataset"]["batch_size"] != 1:
raise ValueError("batch_size must be 1 for evaluation.")
super().load_dataset(eval_split_default="dev")

def load_export_meters(self):
super().load_export_meters()

export_meter = ExportMeter(
"perturbation_exporter",
self.sample_exporter,
f"scenario.perturbation",
max_batches=self.num_export_batches,
)
self.hub.connect_meter(export_meter, use_default_writers=False)

def load_metrics(self):
super().load_metrics()
Expand Down Expand Up @@ -72,12 +83,17 @@ def run_attack(self):
**self.generate_kwargs,
)

if hasattr(self.attack, "perturbations"):
perturbations = self.attack.perturbations

# Ensure that input sample isn't overwritten by model
self.hub.set_context(stage="adversarial")
x_adv.flags.writeable = False
y_pred_adv = self.model.predict(x_adv, **self.predict_kwargs)

self.probe.update(x_adv=x_adv, y_pred_adv=y_pred_adv)
if hasattr(self.attack, "perturbations"):
self.probe.update(perturbation=perturbations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prior conditional could be combined with this one. For example:

if hasattr(self.attack, "perturbations"):
    self.probe.update(perturbation=self.attack.perturbations)

if self.targeted:
self.probe.update(y_target=y_target)

Expand Down