Skip to content

Commit

Permalink
a
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Feb 27, 2024
1 parent 4bace17 commit 2b62537
Show file tree
Hide file tree
Showing 25 changed files with 192 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,12 @@ def __len__(self) -> int:
return len(self._benchmark.streams[self._stream])

@overload
def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ...
def __getitem__(self, exp_id: int) -> Optional[Set[int]]:
...

@overload
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ...
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]:
...

def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet:
indexing_collate = _LazyClassesInClassificationExps._slice_collate
Expand Down
18 changes: 9 additions & 9 deletions avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,17 @@ def __init__(
invoking the super constructor) to specialize the experience class.
"""

self.experience_factory: Callable[[TCLStream, int], TDatasetExperience] = (
experience_factory
)
self.experience_factory: Callable[
[TCLStream, int], TDatasetExperience
] = experience_factory

self.stream_factory: Callable[[str, TDatasetScenario], TCLStream] = (
stream_factory
)
self.stream_factory: Callable[
[str, TDatasetScenario], TCLStream
] = stream_factory

self.stream_definitions: Dict[str, StreamDef[TCLDataset]] = (
DatasetScenario._check_stream_definitions(stream_definitions)
)
self.stream_definitions: Dict[
str, StreamDef[TCLDataset]
] = DatasetScenario._check_stream_definitions(stream_definitions)
"""
A structure containing the definition of the streams.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def create_multi_dataset_generic_benchmark(
"complete_test_set_only is True"
)

stream_definitions: Dict[str, Tuple[Iterable[TaskAwareClassificationDataset]]] = (
dict()
)
stream_definitions: Dict[
str, Tuple[Iterable[TaskAwareClassificationDataset]]
] = dict()

for stream_name, dataset_list in input_streams.items():
initial_transform_group = "train"
Expand Down
18 changes: 10 additions & 8 deletions avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def __init__(
now, including the ones of dropped experiences.
"""

self.task_labels_field_sequence: Dict[int, Optional[Sequence[int]]] = (
defaultdict(lambda: None)
)
self.task_labels_field_sequence: Dict[
int, Optional[Sequence[int]]
] = defaultdict(lambda: None)
"""
A dictionary mapping each experience to its `targets_task_labels` field.
Expand All @@ -118,10 +118,12 @@ def __len__(self) -> int:
return self._stream_length

@overload
def __getitem__(self, exp_idx: int) -> TCLDataset: ...
def __getitem__(self, exp_idx: int) -> TCLDataset:
...

@overload
def __getitem__(self, exp_idx: slice) -> Sequence[TCLDataset]: ...
def __getitem__(self, exp_idx: slice) -> Sequence[TCLDataset]:
...

def __getitem__(
self, exp_idx: Union[int, slice]
Expand All @@ -133,9 +135,9 @@ def __getitem__(
:return: The dataset associated to the experience.
"""
# A lot of unuseful lines needed for MyPy -_-
indexing_collate: Callable[[Iterable[TCLDataset]], Sequence[TCLDataset]] = (
lambda x: list(x)
)
indexing_collate: Callable[
[Iterable[TCLDataset]], Sequence[TCLDataset]
] = lambda x: list(x)
result = manage_advanced_indexing(
exp_idx,
self._get_experience_and_load_if_needed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
# used, the user may have defined an amount of classes less than
# the overall amount of classes in the dataset.
if class_id in self.classes_order_original_ids:
self.class_mapping[class_id] = (
self.classes_order_original_ids.index(class_id)
)
self.class_mapping[
class_id
] = self.classes_order_original_ids.index(class_id)
elif self.class_ids_from_zero_in_each_exp:
# Method 2: remap class IDs so that they appear in range [0, N] in
# each experience
Expand Down
6 changes: 4 additions & 2 deletions avalanche/benchmarks/scenarios/detection_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,12 @@ def __len__(self):
return len(self._benchmark.streams[self._stream])

@overload
def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ...
def __getitem__(self, exp_id: int) -> Optional[Set[int]]:
...

@overload
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ...
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]:
...

def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet:
indexing_collate = _LazyClassesInDetectionExps._slice_collate
Expand Down
6 changes: 4 additions & 2 deletions avalanche/benchmarks/scenarios/generic_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,12 @@ def __iter__(self) -> Iterator[TCLExperience]:
yield exp

@overload
def __getitem__(self, item: int) -> TCLExperience: ...
def __getitem__(self, item: int) -> TCLExperience:
...

@overload
def __getitem__(self: TSequenceCLStream, item: slice) -> TSequenceCLStream: ...
def __getitem__(self: TSequenceCLStream, item: slice) -> TSequenceCLStream:
...

@final
def __getitem__(
Expand Down
5 changes: 3 additions & 2 deletions avalanche/benchmarks/scenarios/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,9 @@ def _decorate_stream(obj: CLStream):
"Unsupported object type: must be one of {CLScenario, CLStream}"
)


__all__ = [
"class_incremental_benchmark",
"new_instances_benchmark",
"with_classes_timeline"
]
"with_classes_timeline",
]
35 changes: 22 additions & 13 deletions avalanche/benchmarks/utils/classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def _make_taskaware_classification_dataset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset: ...
) -> TaskAwareSupervisedClassificationDataset:
...


@overload
Expand All @@ -189,7 +190,8 @@ def _make_taskaware_classification_dataset(
task_labels: Union[int, Sequence[int]],
targets: Sequence[TTargetType],
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset: ...
) -> TaskAwareSupervisedClassificationDataset:
...


@overload
Expand All @@ -203,7 +205,8 @@ def _make_taskaware_classification_dataset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareClassificationDataset: ...
) -> TaskAwareClassificationDataset:
...


def _make_taskaware_classification_dataset(
Expand Down Expand Up @@ -383,7 +386,8 @@ def _taskaware_classification_subset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset: ...
) -> TaskAwareSupervisedClassificationDataset:
...


@overload
Expand All @@ -399,7 +403,8 @@ def _taskaware_classification_subset(
task_labels: Union[int, Sequence[int]],
targets: Sequence[TTargetType],
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset: ...
) -> TaskAwareSupervisedClassificationDataset:
...


@overload
Expand All @@ -415,7 +420,8 @@ def _taskaware_classification_subset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareClassificationDataset: ...
) -> TaskAwareClassificationDataset:
...


def _taskaware_classification_subset(
Expand Down Expand Up @@ -613,7 +619,8 @@ def _make_taskaware_tensor_classification_dataset(
task_labels: Union[int, Sequence[int]],
targets: Union[Sequence[TTargetType], int],
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset: ...
) -> TaskAwareSupervisedClassificationDataset:
...


@overload
Expand All @@ -626,9 +633,8 @@ def _make_taskaware_tensor_classification_dataset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Union[Sequence[TTargetType], int]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> Union[
TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset
]: ...
) -> Union[TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset]:
...


def _make_taskaware_tensor_classification_dataset(
Expand Down Expand Up @@ -753,7 +759,8 @@ def _concat_taskaware_classification_datasets(
Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset: ...
) -> TaskAwareSupervisedClassificationDataset:
...


@overload
Expand All @@ -767,7 +774,8 @@ def _concat_taskaware_classification_datasets(
task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]],
targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]],
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset: ...
) -> TaskAwareSupervisedClassificationDataset:
...


@overload
Expand All @@ -783,7 +791,8 @@ def _concat_taskaware_classification_datasets(
Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareClassificationDataset: ...
) -> TaskAwareClassificationDataset:
...


def _concat_taskaware_classification_datasets(
Expand Down
6 changes: 4 additions & 2 deletions avalanche/benchmarks/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,12 @@ def __eq__(self, other: object):
)

@overload
def __getitem__(self, exp_id: int) -> T_co: ...
def __getitem__(self, exp_id: int) -> T_co:
...

@overload
def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset: ...
def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset:
...

def __getitem__(
self: TAvalancheDataset, idx: Union[int, slice]
Expand Down
6 changes: 4 additions & 2 deletions avalanche/benchmarks/utils/data_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ def __iter__(self):
yield self[i]

@overload
def __getitem__(self, item: int) -> T_co: ...
def __getitem__(self, item: int) -> T_co:
...

@overload
def __getitem__(self, item: slice) -> Sequence[T_co]: ...
def __getitem__(self, item: slice) -> Sequence[T_co]:
...

def __getitem__(self, item: Union[int, slice]) -> Union[T_co, Sequence[T_co]]:
return self.data[item]
Expand Down
6 changes: 4 additions & 2 deletions avalanche/benchmarks/utils/dataset_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ class IDataset(Protocol[T_co]):
Note: no __add__ method is defined.
"""

def __getitem__(self, index: int) -> T_co: ...
def __getitem__(self, index: int) -> T_co:
...

def __len__(self) -> int: ...
def __len__(self) -> int:
...


class IDatasetWithTargets(IDataset[T_co], Protocol[T_co, TTargetType_co]):
Expand Down
6 changes: 4 additions & 2 deletions avalanche/benchmarks/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ def __iter__(self) -> Iterator[TData]:
yield el

@overload
def __getitem__(self, item: int) -> TData: ...
def __getitem__(self, item: int) -> TData:
...

@overload
def __getitem__(self: TSliceSequence, item: slice) -> TSliceSequence: ...
def __getitem__(self: TSliceSequence, item: slice) -> TSliceSequence:
...

@final
def __getitem__(
Expand Down
Loading

0 comments on commit 2b62537

Please sign in to comment.