@@ -84,9 +84,10 @@ class SlicingSpec:
84
84
# When is set to true, one of the slices is the whole dataset.
85
85
entire_dataset : bool = True
86
86
87
- # Used in classification tasks for slicing by classes. It is assumed that
88
- # classes are integers 0, 1, ... number of classes. When true one slice per
89
- # each class is generated.
87
+ # Used in classification tasks for slicing by classes. When true one slice per
88
+ # each class is generated. Classes can either be
89
+ # - integers 0, 1, ..., (for single label) or
90
+ # - an array of integers (for multi-label).
90
91
by_class : Union [bool , Iterable [int ], int ] = False
91
92
92
93
# if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%,
@@ -238,8 +239,10 @@ class AttackInputData:
238
239
probs_train : Optional [np .ndarray ] = None
239
240
probs_test : Optional [np .ndarray ] = None
240
241
241
- # Contains ground-truth classes. Classes are assumed to be integers starting
242
- # from 0.
242
+ # Contains ground-truth classes. For single-label classification, classes are
243
+ # assumed to be integers starting from 0. For multi-label classification,
244
+ # label is assumed to be multi-hot, i.e., labels is a binary array of shape
245
+ # (num_examples, num_classes).
243
246
labels_train : Optional [np .ndarray ] = None
244
247
labels_test : Optional [np .ndarray ] = None
245
248
@@ -290,7 +293,9 @@ def num_classes(self):
290
293
raise ValueError (
291
294
'Can\' t identify the number of classes as no labels were provided. '
292
295
'Please set labels_train and labels_test' )
293
- return int (max (np .max (self .labels_train ), np .max (self .labels_test ))) + 1
296
+ if not self .multilabel_data :
297
+ return int (max (np .max (self .labels_train ), np .max (self .labels_test ))) + 1
298
+ return self .labels_train .shape [1 ]
294
299
295
300
@property
296
301
def logits_or_probs_train (self ):
@@ -586,6 +591,8 @@ def validate(self):
586
591
_is_array_two_dimensional (self .entropy_test , 'entropy_test' )
587
592
_is_array_two_dimensional (self .labels_train , 'labels_train' )
588
593
_is_array_two_dimensional (self .labels_test , 'labels_test' )
594
+ self .is_multihot_labels (self .labels_train , 'labels_train' )
595
+ self .is_multihot_labels (self .labels_test , 'labels_test' )
589
596
else :
590
597
_is_array_one_dimensional (self .loss_train , 'loss_train' )
591
598
_is_array_one_dimensional (self .loss_test , 'loss_test' )
0 commit comments