diff --git a/armory/datasets/art_wrapper.py b/armory/datasets/art_wrapper.py index 30d8562e0..93c05365c 100644 --- a/armory/datasets/art_wrapper.py +++ b/armory/datasets/art_wrapper.py @@ -1,3 +1,5 @@ +from armory.datasets import generator + from art.data_generators import DataGenerator @@ -6,7 +8,9 @@ class WrappedDataGenerator(DataGenerator): Wrap an ArmoryDataGenerator in the ART interface """ - def __init__(self, gen): + def __init__(self, gen: generator.ArmoryDataGenerator): + if gen.output_as_dict or len(gen.output_tuple) != 2: + raise ValueError("gen must output (x, y) tuples") super().__init__(gen.size, gen.batch_size) self._iterator = gen diff --git a/armory/datasets/config_load.py b/armory/datasets/config_load.py index cf7acbf75..b486d5e2c 100644 --- a/armory/datasets/config_load.py +++ b/armory/datasets/config_load.py @@ -16,10 +16,12 @@ def load_dataset( preprocessor_name="DEFAULT", preprocessor_kwargs=None, shuffle_files=False, - label_key="label", # TODO: make this smarter or more flexible + label_key=None, index=None, class_ids=None, drop_remainder=False, + key_map: dict = None, + use_supervised_keys: bool = True, ): # All are keyword elements by design if name is None: @@ -35,6 +37,21 @@ def load_dataset( raise ValueError( f"class_ids must be a list, int, or None, not {type(class_ids)}" ) + if label_key is None: + if info.supervised_keys is None: + raise ValueError( + "label_key is None and info.supervised_keys is None." + " What label is being filtered on?" + ) + elif len(info.supervised_keys) != 2 or not all( + isinstance(k, str) for k in info.supervised_keys + ): + raise NotImplementedError( + f"supervised_keys {info.supervised_keys} is not a 2-tuple of str." + " Please specify label_key for filtering." + ) + _, label_key = info.supervised_keys + element_filter = filtering.get_filter_by_class(class_ids, label_key=label_key) if index is None: index_filter = None @@ -56,7 +73,9 @@ def load_dataset( preprocessor = preprocessing.get(preprocessor_name) if preprocessor is not None and preprocessor_kwargs is not None: - preprocessing_fn = lambda x: preprocessor(x, **preprocessor_kwargs) + preprocessing_fn = lambda x: preprocessor( # noqa: E731 + x, **preprocessor_kwargs + ) else: preprocessing_fn = preprocessor @@ -67,20 +86,24 @@ def load_dataset( ds_dict, split=split, batch_size=batch_size, - framework=framework, + num_batches=num_batches, epochs=epochs, drop_remainder=drop_remainder, - num_batches=num_batches, index_filter=index_filter, element_filter=element_filter, element_map=preprocessing_fn, shuffle_elements=shuffle_elements, - key_map=None, + framework=framework, ) - return wrap_generator(armory_data_generator) - + # If key_map is not None, use_supervised_keys is ignored + if key_map is not None: + # ignore use_supervised_keys in this case + armory_data_generator.set_key_map(key_map) + else: + # error if use_supervised_keys and supervised_keys do not exist in info + armory_data_generator.set_key_map(use_supervised_keys=use_supervised_keys) -def wrap_generator(armory_data_generator): - from armory.datasets import art_wrapper + # Let the scenario set the desired tuple directly + # armory_data_generator.as_tuple() # NOTE: This will currently fail for adversarial datasets - return art_wrapper.WrappedDataGenerator(armory_data_generator) + return armory_data_generator diff --git a/armory/datasets/generator.py b/armory/datasets/generator.py index 202a612e6..1e1e9a7a7 100644 --- a/armory/datasets/generator.py +++ b/armory/datasets/generator.py @@ -12,9 +12,13 @@ """ +import math +from typing import Tuple + import tensorflow as tf import tensorflow_datasets as tfds -import math + +from armory.datasets import key_mapping class ArmoryDataGenerator: @@ -49,15 +53,14 @@ def __init__( ds_dict: dict, split: str = "test", batch_size: int = 1, - drop_remainder: bool = False, - epochs: int = 1, num_batches: int = None, - framework: str = "numpy", - shuffle_elements: bool = False, + epochs: int = 1, + drop_remainder: bool = False, index_filter: callable = None, element_filter: callable = None, element_map: callable = None, - key_map=None, + shuffle_elements: bool = False, + framework: str = "numpy", ): if split not in info.splits: raise ValueError(f"split {split} not in info.splits {list(info.splits)}") @@ -73,9 +76,6 @@ def __init__( ) if framework not in self.FRAMEWORKS: raise ValueError(f"framework {framework} not in {self.FRAMEWORKS}") - if key_map is not None: - # TODO: key mapping from dict to tuples, etc. - raise NotImplementedError("key_map argument") size = info.splits[split].num_examples batch_size = int(batch_size) @@ -129,6 +129,7 @@ def __init__( raise NotImplementedError(f"framework {framework}") self._set_params( + info=info, iterator=iterator, split=split, size=size, @@ -141,17 +142,115 @@ def __init__( shuffle_elements=shuffle_elements, element_filter=element_filter, element_map=element_map, + key_map=None, + output_as_dict=True, + output_tuple=("x", "y"), ) def _set_params(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) + def set_key_map(self, key_map: dict = None, use_supervised_keys: bool = False): + """ + key_map that maps the keys of the elements dict to other keys + E.g., {"image": "x", "label": "y"} + + if use_supervised_keys, it pulls this from info.supervised_keys + + If any key is not present in the map, those values are omitted + """ + if key_map is not None and use_supervised_keys: + raise ValueError("Cannot set both key_map and use_supervised_keys") + elif key_map is not None: + key_mapping.check_key_map(key_map) + elif use_supervised_keys: + supervised_keys = self.info.supervised_keys + if supervised_keys is None: + raise ValueError("supervised_keys are None for current dataset info") + elif len(supervised_keys) != 2 or not all( + isinstance(k, str) for k in supervised_keys + ): + # NOTE: supervised_keys can be more exotic, though it is rare + # See the SupervisedKeysType in tfds.core.DatasetInfo + # 3-tuples and nested structures are allowed + raise NotImplementedError( + f"supervised_keys {supervised_keys} is not a 2-tuple of str" + ) + x, y = supervised_keys + key_map = {x: "x", y: "y"} + else: # key_map is None + pass + + self.key_map = key_map + + def set_output_tuple(self, output_tuple: Tuple[str]): + """ + key_map - dict that maps element keys to scenario keys such as `x` and `y` + Example: {"image": "x", "label": "y"} + if None, no mapping is done + if "DEFAULT", will use info.supervised_keys if present + output_as_dict - whether to output batches of dicts + if False, output_tuple must be set, as it will output tuples + output_tuple - output batches are tuples based on the given set of keys + if None, will use ("x", "y") as default + """ + if isinstance(output_tuple, str): # prevent "word" -> ("w", "o", "r", "d") + raise ValueError("output_tuple must not be a str") + output_tuple = tuple(output_tuple) + for k in output_tuple: + if not isinstance(k, str): + if isinstance(k, tuple): + # NOTE: nested tuples would enable things like: + # (("x", "x_adv"), ("y", "y_patch_metadata")) + raise NotImplementedError("nested tuples not currently supported") + raise ValueError(f"item {k} in output_tuple is not a str") + self.output_tuple = output_tuple + + def as_dict(self): + """ + Sets the return type to dict + """ + self.output_as_dict = True + + def as_tuple(self, output_tuple: Tuple[str] = None): + """ + Sets the return type to tuple, according to given output_tuple + + If output_tuple is None, it defaults to the existing output_tuple + The default output_tuple at initialization is ("x", "y") + """ + if output_tuple is not None: + self.set_output_tuple(output_tuple) + self.output_as_dict = False + + def as_supervised(self): + """ + Convenience function similar to 'as_supervised' in tfds.core.DatasetBuilder + sets key_map and as_tuple to output supervised tuples + """ + self.set_key_map(use_supervised_keys=True) + self.as_tuple(output_tuple=("x", "y")) + def __iter__(self): return self def __next__(self): - return next(self.iterator) + element = next(self.iterator) + if self.key_map is not None: + element = {new_k: element[k] for k, new_k in self.key_map.items()} + if not self.output_as_dict: + element = tuple(element[k] for k in self.output_tuple) + return element def __len__(self): return self.batches_per_epoch * self.epochs + + +def wrap_generator(armory_data_generator): + """ + Wrap an ArmoryDataGenerator as an art DataGenerator + """ + from armory.datasets import art_wrapper + + return art_wrapper.WrappedDataGenerator(armory_data_generator) diff --git a/armory/datasets/key_mapping.py b/armory/datasets/key_mapping.py new file mode 100644 index 000000000..e25dc73dd --- /dev/null +++ b/armory/datasets/key_mapping.py @@ -0,0 +1,45 @@ +import copy + +REGISTERED_KEY_MAPS = {} +DEFAULT = "DEFAULT" + + +def register(name: str, key_map: dict): + if not isinstance(name, str): + raise ValueError(f"name {name} is not a str") + check_key_map(key_map) + global REGISTERED_KEY_MAPS + REGISTERED_KEY_MAPS[name] = key_map + + +def list_registered(): + return list(REGISTERED_KEY_MAPS) + + +def get(name): + if name not in REGISTERED_KEY_MAPS: + raise KeyError(f"key_map {name} not registered. Use one of {list_registered()}") + # dicts are malleable, so return a copy + return copy.deepcopy(REGISTERED_KEY_MAPS[name]) + + +def has(name): + return name in REGISTERED_KEY_MAPS + + +def check_key_map(key_map: dict): + if not isinstance(key_map, dict): + raise ValueError(f"key_map {key_map} must be None or a dict") + for k, v in key_map.items(): + if not isinstance(k, str): + raise ValueError(f"key {k} in key_map is not a str") + if not isinstance(v, str): + raise ValueError(f"value {v} in key_map is not a str") + if len(key_map.values()) != len(set(key_map.values())): + raise ValueError("key_map values must be unique") + + +for name in "mnist", "cifar10": + register(name, {"image": "x", "label": "y"}) + +register("digit", {"audio": "x", "label": "y"}) diff --git a/armory/datasets/load.py b/armory/datasets/load.py index 67e8a7621..0479f2068 100644 --- a/armory/datasets/load.py +++ b/armory/datasets/load.py @@ -76,7 +76,11 @@ def load( ) ensure_download_extract( - name, version=version, verify=verify, overwrite=overwrite + name, + version=version, + verify=verify, + overwrite=overwrite, + public=public, ) elif name in common.armory_builders() or name in common.tfds_builders(): raise ValueError( diff --git a/armory/datasets/preprocessing.py b/armory/datasets/preprocessing.py index 84325e6ab..885fa9b30 100644 --- a/armory/datasets/preprocessing.py +++ b/armory/datasets/preprocessing.py @@ -1,5 +1,8 @@ """ Standard preprocessing for different datasets + +These modify, in tensorflow, element dicts and should output updated dicts + Ideally, element keys should not be modified here """ @@ -34,27 +37,40 @@ def has(name): return name in REGISTERED_PREPROCESSORS -@register -def supervised_image_classification(element): - return (image_to_canon(element["image"]), element["label"]) +def audio(element, **kwargs): + return { + k: audio_to_canon(v, **kwargs) if k == "audio" else v + for k, v in element.items() + } -mnist = register(supervised_image_classification, "mnist") -cifar10 = register(supervised_image_classification, "cifar10") +def image(element, **kwargs): + return { + k: image_to_canon(v, **kwargs) if k == "image" else v + for k, v in element.items() + } -@register -def digit(element): - return (audio_to_canon(element["audio"]), element["label"]) +def video(element, **kwargs): + return { + k: video_to_canon(v, **kwargs) if k == "video" else v + for k, v in element.items() + } + + +digit = register(audio, "digit") +mnist = register(image, "mnist") +cifar10 = register(image, "cifar10") @register def carla_over_obj_det_dev(element, modality="rgb"): - return carla_over_obj_det_image( - element["image"], modality=modality - ), carla_over_obj_det_dev_label( + out = {} + out["image"] = carla_over_obj_det_image(element["image"], modality=modality) + out["objects"], out["patch_metadata"] = carla_over_obj_det_dev_label( element["image"], element["objects"], element["patch_metadata"] ) + return out def image_to_canon(image, resize=None, target_dtype=tf.float32, input_type="uint8"): @@ -91,14 +107,6 @@ def audio_to_canon(audio, resample=None, target_dtype=tf.float32, input_type="in return audio -# config = { -# "preprocessor": "mnist(max_frames=1)" -# "preprocessor_kwargs": { -# "max_frames": null, -# } -# } - - def video_to_canon( video, resize=None, diff --git a/armory/scenarios/poisoning_witches_brew.py b/armory/scenarios/poisoning_witches_brew.py index 296d14957..bbba56c26 100644 --- a/armory/scenarios/poisoning_witches_brew.py +++ b/armory/scenarios/poisoning_witches_brew.py @@ -354,7 +354,7 @@ def poison_dataset(self): def load_test_dataset(self, test_split_default="test"): # Over-ridden because we need batch_size = 1 for the test set for this attack. if self.config["dataset"].get("test").get("batch_size") != 1: - raise ValueError(f"batch_size must be set to 1 for test set") + raise ValueError("batch_size must be set to 1 for test set") super().load_test_dataset(test_split_default=test_split_default) def load_metrics(self): diff --git a/armory/scenarios/scenario.py b/armory/scenarios/scenario.py index 4a36dcc4d..19067ac93 100644 --- a/armory/scenarios/scenario.py +++ b/armory/scenarios/scenario.py @@ -12,7 +12,7 @@ import armory from armory import Config, paths, metrics -from armory.datasets import config_load +from armory.datasets import config_load, generator from armory.instrument import get_hub, get_probe, del_globals, MetricsLogger from armory.instrument.export import ExportMeter, PredictionMeter from armory.metrics import compute @@ -148,18 +148,26 @@ def load_train_dataset(self, train_split_default="train"): log.info(f"Loading train dataset {name} with kwargs {kwargs}") self.train_dataset = config_load.load_dataset(**kwargs) + self.train_dataset.as_tuple() self.num_train_epochs = kwargs.get("epochs", 1) def fit(self): + log.info("Wrapping ArmoryDataGenerator with ART DataGenerator class") + self.train_art_generator = generator.wrap_generator(self.train_dataset) + if self.defense_type == "Trainer": log.info(f"Training with {type(self.trainer)} Trainer defense...") self.trainer.fit_generator( - self.train_dataset, nb_epochs=self.num_train_epochs, **self.fit_kwargs + self.train_art_generator, + nb_epochs=self.num_train_epochs, + **self.fit_kwargs, ) else: log.info(f"Fitting model {self.model_name}...") self.model.fit_generator( - self.train_dataset, nb_epochs=self.num_train_epochs, **self.fit_kwargs + self.train_art_generator, + nb_epochs=self.num_train_epochs, + **self.fit_kwargs, ) def load_attack(self): @@ -225,6 +233,7 @@ def load_test_dataset(self, test_split_default="test"): log.info(f"Loading test dataset {name} with kwargs {kwargs}") self.test_dataset = config_load.load_dataset(**kwargs) + self.test_dataset.as_tuple() self.i = -1 def load_metrics(self): diff --git a/armory/utils/config_schema.json b/armory/utils/config_schema.json index 56fb9dffa..f64a8190a 100644 --- a/armory/utils/config_schema.json +++ b/armory/utils/config_schema.json @@ -36,6 +36,7 @@ "type": "object" }, "dataset": { + "additionalProperties": false, "properties": { "test": { "$comment": "See load_dataset() kwargs in armory/datasets/config_load.py", @@ -75,7 +76,6 @@ "required": [ "test" ], - "additionalProperties": false, "type": "object" }, "defense": {