Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion armory/datasets/art_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from armory.datasets import generator

from art.data_generators import DataGenerator


Expand All @@ -6,7 +8,9 @@ class WrappedDataGenerator(DataGenerator):
Wrap an ArmoryDataGenerator in the ART interface
"""

def __init__(self, gen):
def __init__(self, gen: generator.ArmoryDataGenerator):
if gen.output_as_dict or len(gen.output_tuple) != 2:
raise ValueError("gen must output (x, y) tuples")
super().__init__(gen.size, gen.batch_size)
self._iterator = gen

Expand Down
43 changes: 33 additions & 10 deletions armory/datasets/config_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ def load_dataset(
preprocessor_name="DEFAULT",
preprocessor_kwargs=None,
shuffle_files=False,
label_key="label", # TODO: make this smarter or more flexible
label_key=None,
index=None,
class_ids=None,
drop_remainder=False,
key_map: dict = None,
use_supervised_keys: bool = True,
):
# All are keyword elements by design
if name is None:
Expand All @@ -35,6 +37,21 @@ def load_dataset(
raise ValueError(
f"class_ids must be a list, int, or None, not {type(class_ids)}"
)
if label_key is None:
if info.supervised_keys is None:
raise ValueError(
"label_key is None and info.supervised_keys is None."
" What label is being filtered on?"
)
elif len(info.supervised_keys) != 2 or not all(
isinstance(k, str) for k in info.supervised_keys
):
raise NotImplementedError(
f"supervised_keys {info.supervised_keys} is not a 2-tuple of str."
" Please specify label_key for filtering."
)
_, label_key = info.supervised_keys

element_filter = filtering.get_filter_by_class(class_ids, label_key=label_key)
if index is None:
index_filter = None
Expand All @@ -56,7 +73,9 @@ def load_dataset(
preprocessor = preprocessing.get(preprocessor_name)

if preprocessor is not None and preprocessor_kwargs is not None:
preprocessing_fn = lambda x: preprocessor(x, **preprocessor_kwargs)
preprocessing_fn = lambda x: preprocessor( # noqa: E731
x, **preprocessor_kwargs
)
else:
preprocessing_fn = preprocessor

Expand All @@ -67,20 +86,24 @@ def load_dataset(
ds_dict,
split=split,
batch_size=batch_size,
framework=framework,
num_batches=num_batches,
epochs=epochs,
drop_remainder=drop_remainder,
num_batches=num_batches,
index_filter=index_filter,
element_filter=element_filter,
element_map=preprocessing_fn,
shuffle_elements=shuffle_elements,
key_map=None,
framework=framework,
)
return wrap_generator(armory_data_generator)

# If key_map is not None, use_supervised_keys is ignored
if key_map is not None:
# ignore use_supervised_keys in this case
armory_data_generator.set_key_map(key_map)
else:
# error if use_supervised_keys and supervised_keys do not exist in info
armory_data_generator.set_key_map(use_supervised_keys=use_supervised_keys)

def wrap_generator(armory_data_generator):
from armory.datasets import art_wrapper
# Let the scenario set the desired tuple directly
# armory_data_generator.as_tuple() # NOTE: This will currently fail for adversarial datasets

return art_wrapper.WrappedDataGenerator(armory_data_generator)
return armory_data_generator
119 changes: 109 additions & 10 deletions armory/datasets/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@

"""

import math
from typing import Tuple

import tensorflow as tf
import tensorflow_datasets as tfds
import math

from armory.datasets import key_mapping


class ArmoryDataGenerator:
Expand Down Expand Up @@ -49,15 +53,14 @@ def __init__(
ds_dict: dict,
split: str = "test",
batch_size: int = 1,
drop_remainder: bool = False,
epochs: int = 1,
num_batches: int = None,
framework: str = "numpy",
shuffle_elements: bool = False,
epochs: int = 1,
drop_remainder: bool = False,
index_filter: callable = None,
element_filter: callable = None,
element_map: callable = None,
key_map=None,
shuffle_elements: bool = False,
framework: str = "numpy",
):
if split not in info.splits:
raise ValueError(f"split {split} not in info.splits {list(info.splits)}")
Expand All @@ -73,9 +76,6 @@ def __init__(
)
if framework not in self.FRAMEWORKS:
raise ValueError(f"framework {framework} not in {self.FRAMEWORKS}")
if key_map is not None:
# TODO: key mapping from dict to tuples, etc.
raise NotImplementedError("key_map argument")

size = info.splits[split].num_examples
batch_size = int(batch_size)
Expand Down Expand Up @@ -129,6 +129,7 @@ def __init__(
raise NotImplementedError(f"framework {framework}")

self._set_params(
info=info,
iterator=iterator,
split=split,
size=size,
Expand All @@ -141,17 +142,115 @@ def __init__(
shuffle_elements=shuffle_elements,
element_filter=element_filter,
element_map=element_map,
key_map=None,
output_as_dict=True,
output_tuple=("x", "y"),
)

def _set_params(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

def set_key_map(self, key_map: dict = None, use_supervised_keys: bool = False):
"""
key_map that maps the keys of the elements dict to other keys
E.g., {"image": "x", "label": "y"}

if use_supervised_keys, it pulls this from info.supervised_keys

If any key is not present in the map, those values are omitted
"""
if key_map is not None and use_supervised_keys:
raise ValueError("Cannot set both key_map and use_supervised_keys")
elif key_map is not None:
key_mapping.check_key_map(key_map)
elif use_supervised_keys:
supervised_keys = self.info.supervised_keys
if supervised_keys is None:
raise ValueError("supervised_keys are None for current dataset info")
elif len(supervised_keys) != 2 or not all(
isinstance(k, str) for k in supervised_keys
):
# NOTE: supervised_keys can be more exotic, though it is rare
# See the SupervisedKeysType in tfds.core.DatasetInfo
# 3-tuples and nested structures are allowed
raise NotImplementedError(
f"supervised_keys {supervised_keys} is not a 2-tuple of str"
)
x, y = supervised_keys
key_map = {x: "x", y: "y"}
else: # key_map is None
pass

self.key_map = key_map

def set_output_tuple(self, output_tuple: Tuple[str]):
"""
key_map - dict that maps element keys to scenario keys such as `x` and `y`
Example: {"image": "x", "label": "y"}
if None, no mapping is done
if "DEFAULT", will use info.supervised_keys if present
output_as_dict - whether to output batches of dicts
if False, output_tuple must be set, as it will output tuples
output_tuple - output batches are tuples based on the given set of keys
if None, will use ("x", "y") as default
"""
if isinstance(output_tuple, str): # prevent "word" -> ("w", "o", "r", "d")
raise ValueError("output_tuple must not be a str")
output_tuple = tuple(output_tuple)
for k in output_tuple:
if not isinstance(k, str):
if isinstance(k, tuple):
# NOTE: nested tuples would enable things like:
# (("x", "x_adv"), ("y", "y_patch_metadata"))
raise NotImplementedError("nested tuples not currently supported")
raise ValueError(f"item {k} in output_tuple is not a str")
self.output_tuple = output_tuple

def as_dict(self):
"""
Sets the return type to dict
"""
self.output_as_dict = True

def as_tuple(self, output_tuple: Tuple[str] = None):
"""
Sets the return type to tuple, according to given output_tuple

If output_tuple is None, it defaults to the existing output_tuple
The default output_tuple at initialization is ("x", "y")
"""
if output_tuple is not None:
self.set_output_tuple(output_tuple)
self.output_as_dict = False

def as_supervised(self):
"""
Convenience function similar to 'as_supervised' in tfds.core.DatasetBuilder
sets key_map and as_tuple to output supervised tuples
"""
self.set_key_map(use_supervised_keys=True)
self.as_tuple(output_tuple=("x", "y"))

def __iter__(self):
return self

def __next__(self):
return next(self.iterator)
element = next(self.iterator)
if self.key_map is not None:
element = {new_k: element[k] for k, new_k in self.key_map.items()}
if not self.output_as_dict:
element = tuple(element[k] for k in self.output_tuple)
return element

def __len__(self):
return self.batches_per_epoch * self.epochs


def wrap_generator(armory_data_generator):
"""
Wrap an ArmoryDataGenerator as an art DataGenerator
"""
from armory.datasets import art_wrapper

return art_wrapper.WrappedDataGenerator(armory_data_generator)
45 changes: 45 additions & 0 deletions armory/datasets/key_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import copy

REGISTERED_KEY_MAPS = {}
DEFAULT = "DEFAULT"


def register(name: str, key_map: dict):
if not isinstance(name, str):
raise ValueError(f"name {name} is not a str")
check_key_map(key_map)
global REGISTERED_KEY_MAPS
REGISTERED_KEY_MAPS[name] = key_map


def list_registered():
return list(REGISTERED_KEY_MAPS)


def get(name):
if name not in REGISTERED_KEY_MAPS:
raise KeyError(f"key_map {name} not registered. Use one of {list_registered()}")
# dicts are malleable, so return a copy
return copy.deepcopy(REGISTERED_KEY_MAPS[name])


def has(name):
return name in REGISTERED_KEY_MAPS


def check_key_map(key_map: dict):
if not isinstance(key_map, dict):
raise ValueError(f"key_map {key_map} must be None or a dict")
for k, v in key_map.items():
if not isinstance(k, str):
raise ValueError(f"key {k} in key_map is not a str")
if not isinstance(v, str):
raise ValueError(f"value {v} in key_map is not a str")
if len(key_map.values()) != len(set(key_map.values())):
raise ValueError("key_map values must be unique")


for name in "mnist", "cifar10":
register(name, {"image": "x", "label": "y"})

register("digit", {"audio": "x", "label": "y"})
6 changes: 5 additions & 1 deletion armory/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ def load(
)

ensure_download_extract(
name, version=version, verify=verify, overwrite=overwrite
name,
version=version,
verify=verify,
overwrite=overwrite,
public=public,
)
elif name in common.armory_builders() or name in common.tfds_builders():
raise ValueError(
Expand Down
Loading