Skip to content

Commit 8f0e61f

Browse files
authored
Merge pull request #1646 from vlomonaco/master
task_labels experience attribute type fix: now it is a list, not a set
2 parents 0d6f715 + 12a6a0f commit 8f0e61f

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

avalanche/benchmarks/scenarios/task_aware.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ def with_task_labels(obj):
9696

9797
def _add_task_labels(exp):
9898
tls = exp.dataset.targets_task_labels.uniques
99+
# tls is a set, we need to convert to list to call __getitem__
100+
tls = list(tls)
99101
if len(tls) == 1:
100-
# tls is a set. we need to convert to list to call __getitem__
101-
exp.task_label = list(tls)[0]
102+
exp.task_label = tls[0]
102103
exp.task_labels = tls
103104
return exp
104105

tests/benchmarks/scenarios/test_task_aware.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class TestsTaskAware(unittest.TestCase):
1616
def test_taskaware(self):
17-
"""Common use case: add tas labels to class-incremental benchmark."""
17+
"""Common use case: add task labels to class-incremental benchmark."""
1818
n_classes, n_samples_per_class, n_features = 10, 3, 7
1919

2020
for _ in range(10000):
@@ -58,6 +58,7 @@ def test_taskaware(self):
5858
ci_train = bm_ci.train_stream
5959
for eid, exp in enumerate(bm_ti.train_stream):
6060
assert exp.task_label == eid
61+
assert isinstance(exp.task_labels, list)
6162
assert len(ci_train[eid].dataset) == len(exp.dataset)
6263

6364

0 commit comments

Comments
 (0)