Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
Hamed committed Oct 27, 2023
1 parent f531eb6 commit dacee46
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 61 deletions.
118 changes: 63 additions & 55 deletions avalanche/benchmarks/scenarios/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def __init__(self, n_samples, shuffle=True, rng=None):
def _reset_indices(self):
self.indices = torch.arange(self.n_samples).tolist()
if self.shuffle:
self.indices = torch.randperm(self.n_samples,
generator=self.rng).tolist()
self.indices = torch.randperm(self.n_samples, generator=self.rng).tolist()

def __iter__(self):
while True:
Expand All @@ -70,7 +69,7 @@ def __len__(self):
class BoundaryAware(Protocol):
"""Boundary-aware experiences have attributes with task boundary knowledge.
Online streams may have boundary attributes to help training or
Online streams may have boundary attributes to help training or
metrics logging.
Task boundaries denote changes of the underlying data distribution used
Expand All @@ -89,7 +88,7 @@ def is_last_subexp(self) -> bool:

@property
def sub_stream_length(self) -> int:
"""Number of experiences with the same distribution of the current
"""Number of experiences with the same distribution of the current
experience."""
return 0

Expand All @@ -98,7 +97,7 @@ def access_task_boundaries(self) -> bool:
"""True if the model has access to task boundaries.
If the model is boundary-agnostic, task boundaries are available only
for logging by setting the experience in logging mode
for logging by setting the experience in logging mode
`experience.logging()`.
"""
return False
Expand All @@ -121,14 +120,14 @@ def __init__(
sub_stream_length: Optional[int] = None,
access_task_boundaries: bool = False,
):
"""A class representing a continual learning experience in an online
"""A class representing a continual learning experience in an online
setting.
:param current_experience: The index of the current experience.
:type current_experience: int
:param dataset: The dataset containing the experience.
:type dataset: TCLDataset
:param origin_experience: The original experience from which this
:param origin_experience: The original experience from which this
experience was derived.
:type origin_experience: DatasetExperience
:param is_first_subexp: Whether this is the first sub-experience.
Expand Down Expand Up @@ -169,13 +168,13 @@ def __init__(
drop_last: bool = False,
access_task_boundaries: bool = False,
) -> None:
"""Returns a lazy stream generated by splitting an experience into
"""Returns a lazy stream generated by splitting an experience into
smaller ones.
Splits the experience in smaller experiences of size `experience_size`.
Experience decorators (e.g. class attributes) will be stripped from the
experience. You will need to re-apply them to the resulting experiences
Experience decorators (e.g. class attributes) will be stripped from the
experience. You will need to re-apply them to the resulting experiences
if you need them.
:param experience: The experience to split.
Expand Down Expand Up @@ -226,8 +225,7 @@ def __iter__(self) -> Generator[OnlineCLExperience, None, None]:
is_last = True

# check is_last when drop_last=True
if self.drop_last and \
(final_idx + self.experience_size > len(exp_indices)):
if self.drop_last and (final_idx + self.experience_size > len(exp_indices)):
is_last = True

sub_exp_subset = exp_dataset.subset(exp_indices[init_idx:final_idx])
Expand Down Expand Up @@ -295,7 +293,7 @@ def split_online_stream(
`experience_size` instances, then the last experience will be dropped.
Defaults to False. Ignored if `experience_split_strategy` is used.
:param experience_split_strategy: A function that implements a custom
splitting strategy. The function must accept an experience and return
splitting strategy. The function must accept an experience and return
an experience's iterator. Defaults to None, which means
that the standard splitting strategy will be used (which creates
experiences of size `experience_size`).
Expand Down Expand Up @@ -332,7 +330,7 @@ def exps_iter():

def _fixed_size_split(
online_benchmark: "OnlineCLScenario", # TODO: Deprecated
# and unused. Remove.
# and unused. Remove.
experience_size: int,
access_task_boundaries: bool,
shuffle: bool,
Expand All @@ -348,6 +346,7 @@ def _fixed_size_split(

# ========== Continuous linear decay splits


def create_sub_exp_from_multi_exps(
original_stream: Iterable[DatasetExperience[TCLDataset]],
samplers: Iterable[CyclicSampler],
Expand Down Expand Up @@ -407,7 +406,7 @@ def split_continuous_linear_decay_stream(
beta: float,
shuffle: bool,
) -> CLStream[DatasetExperience[TCLDataset]]:
"""Creates a stream of sub-experiences from a list of overlapped
"""Creates a stream of sub-experiences from a list of overlapped
experiences with a linear decay in the overlapping areas.
:param experience_size: The size of each sub-experience.
Expand All @@ -421,41 +420,42 @@ def split_continuous_linear_decay_stream(
:return: A stream of sub-experiences.
"""

def _get_linear_line(start, end, direction="up"):
if direction == "up":
return torch.FloatTensor([(i - start) / (end - start)
for i in range(start, end)])
return torch.FloatTensor([1 - ((i - start) / (end - start))
for i in range(start, end)])
return torch.FloatTensor(
[(i - start) / (end - start) for i in range(start, end)]
)
return torch.FloatTensor(
[1 - ((i - start) / (end - start)) for i in range(start, end)]
)

def _create_task_probs(iters, tasks, task_id, beta=3):

if beta <= 1:
peak_start = int((task_id / tasks) * iters)
peak_end = int(((task_id + 1) / tasks) * iters)
start = peak_start
end = peak_end
else:
start = max(int(((beta * task_id - 1) * iters) / (beta * tasks)),
0)
start = max(int(((beta * task_id - 1) * iters) / (beta * tasks)), 0)
peak_start = int(((beta * task_id + 1) * iters) / (beta * tasks))
peak_end = int(((beta * task_id + (beta - 1))
* iters) / (beta * tasks))
end = min(int(((beta * task_id + (beta + 1))
* iters) / (beta * tasks)), iters)
peak_end = int(((beta * task_id + (beta - 1)) * iters) / (beta * tasks))
end = min(
int(((beta * task_id + (beta + 1)) * iters) / (beta * tasks)), iters
)

probs = torch.zeros(iters, dtype=torch.float)
if task_id == 0:
probs[start:peak_start].add_(1)
else:
probs[start:peak_start] = _get_linear_line(
start, peak_start, direction="up")
start, peak_start, direction="up"
)
probs[peak_start:peak_end].add_(1)
if task_id == tasks - 1:
probs[peak_end:end].add_(1)
else:
probs[peak_end:end] = _get_linear_line(peak_end, end,
direction="down")
probs[peak_end:end] = _get_linear_line(peak_end, end, direction="down")
return probs

# Total number of iterations
Expand All @@ -465,26 +465,28 @@ def _create_task_probs(iters, tasks, task_id, beta=3):
n_experiences = len(original_stream)
tasks_probs_over_iterations = [
_create_task_probs(total_iters, n_experiences, exp_id, beta=beta)
for exp_id in range(n_experiences)]
for exp_id in range(n_experiences)
]

# Normalize probabilities
normalize_probs = torch.zeros_like(tasks_probs_over_iterations[0])
for probs in tasks_probs_over_iterations:
normalize_probs.add_(probs)
for probs in tasks_probs_over_iterations:
probs.div_(normalize_probs)
tasks_probs_over_iterations = torch.cat(
tasks_probs_over_iterations
).view(-1, tasks_probs_over_iterations[0].shape[0])
tasks_probs_over_iterations = torch.cat(tasks_probs_over_iterations).view(
-1, tasks_probs_over_iterations[0].shape[0]
)
tasks_probs_over_iterations_lst = []
for col in range(tasks_probs_over_iterations.shape[1]):
tasks_probs_over_iterations_lst.append(
tasks_probs_over_iterations[:, col])
tasks_probs_over_iterations_lst.append(tasks_probs_over_iterations[:, col])
tasks_probs_over_iterations = tasks_probs_over_iterations_lst

# Random cylic samplers over the datasets of all experiences in the stream
samplers = [iter(CyclicSampler(len(exp.dataset), shuffle=shuffle))
for exp in original_stream]
samplers = [
iter(CyclicSampler(len(exp.dataset), shuffle=shuffle))
for exp in original_stream
]

# The main iterator for the stream
def exps_iter():
Expand All @@ -500,12 +502,14 @@ def exps_iter():
probs=tasks_probs_over_iterations[sub_exp_id]
).sample(n_samples)

yield create_sub_exp_from_multi_exps(original_stream,
samplers,
exp_per_sample_list,
total_iters,
is_first_sub_exp,
is_last_sub_exp)
yield create_sub_exp_from_multi_exps(
original_stream,
samplers,
exp_per_sample_list,
total_iters,
is_first_sub_exp,
is_last_sub_exp,
)

stream_name: str = getattr(original_stream, "name", "train")
return CLStream(
Expand All @@ -517,20 +521,20 @@ def exps_iter():

# ========== Online CL scenario


class OnlineCLScenario(CLScenario):
def __init__(
self,
original_streams: Iterable[CLStream[DatasetExperience[TCLDataset]]],
experiences: Optional[
Union[
DatasetExperience[TCLDataset],
Iterable[DatasetExperience[TCLDataset]]
DatasetExperience[TCLDataset], Iterable[DatasetExperience[TCLDataset]]
]
] = None,
experience_size: int = 10,
stream_split_strategy: Literal[
"fixed_size_split",
"continuous_linear_decay"] = "fixed_size_split",
"fixed_size_split", "continuous_linear_decay"
] = "fixed_size_split",
access_task_boundaries: bool = False,
shuffle: bool = True,
overlap_factor: int = 4,
Expand Down Expand Up @@ -559,8 +563,8 @@ def __init__(
instances in each experience. Defaults to True.
:param overlap_factor: The overlap factor between consecutive
experiences. Defaults to 4.
:param iters_per_virtual_epoch: The number of iterations per virtual
epoch for each experience
:param iters_per_virtual_epoch: The number of iterations per virtual epoch
for each experience
"""
warnings.warn(
"Deprecated. Use `split_online_stream` or similar methods to split"
Expand All @@ -569,16 +573,21 @@ def __init__(

if stream_split_strategy == "fixed_size_split":
split_strat = partial(
_fixed_size_split, self, experience_size,
access_task_boundaries, shuffle
_fixed_size_split,
self,
experience_size,
access_task_boundaries,
shuffle,
)
elif stream_split_strategy == "continuous_linear_decay":
assert access_task_boundaries is False

split_strat = partial(
split_online_stream, experience_size=experience_size,
split_online_stream,
experience_size=experience_size,
iters_per_virtual_epoch=iters_per_virtual_epoch,
beta=overlap_factor, shuffle=True
beta=overlap_factor,
shuffle=True,
)
else:
raise ValueError("Unknown experience split strategy")
Expand All @@ -596,8 +605,7 @@ def __init__(
streams: List[CLStream] = [online_train_stream]
for s in original_streams:
s_wrapped = wrap_stream(
new_name="original_" + s.name,
new_benchmark=self, wrapped_stream=s
new_name="original_" + s.name, new_benchmark=self, wrapped_stream=s
)

streams.append(s_wrapped)
Expand Down
6 changes: 2 additions & 4 deletions examples/online_continuous_linear_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def main(args):
interactive_logger = InteractiveLogger()

eval_plugin = EvaluationPlugin(
accuracy_metrics(
minibatch=True, epoch=True, experience=True, stream=True
),
accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
forgetting_metrics(experience=True),
loggers=[interactive_logger],
Expand All @@ -75,7 +73,7 @@ def main(args):
experience_size=10,
iters_per_virtual_epoch=100,
beta=0.5,
shuffle=True
shuffle=True,
)

# Train the strtegy on the continuous stream
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/scenarios/test_online_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tests.unit_tests_utils import dummy_classification_datasets, get_fast_benchmark
from avalanche.benchmarks.scenarios.online import (
split_online_stream,
split_continuous_linear_decay_stream
split_continuous_linear_decay_stream,
)


Expand Down Expand Up @@ -86,7 +86,7 @@ def test_split_online_stream_continuous_linear_decay(self):
experience_size=10,
iters_per_virtual_epoch=100,
beta=0.5,
shuffle=True
shuffle=True,
)

expected_length = len(bm.train_stream) * iter_per_virt_epoch
Expand Down

0 comments on commit dacee46

Please sign in to comment.