Skip to content

Commit e362f51

Browse files
shs037tensorflower-gardener
authored andcommitted
Supports slicing for multi-label data.
PiperOrigin-RevId: 523846333
1 parent d5e41e2 commit e362f51

File tree

3 files changed

+37
-12
lines changed

3 files changed

+37
-12
lines changed

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ class SlicingSpec:
8484
# When is set to true, one of the slices is the whole dataset.
8585
entire_dataset: bool = True
8686

87-
# Used in classification tasks for slicing by classes. It is assumed that
88-
# classes are integers 0, 1, ... number of classes. When true one slice per
89-
# each class is generated.
87+
# Used in classification tasks for slicing by classes. When true one slice per
88+
# each class is generated. Classes can either be
89+
# - integers 0, 1, ..., (for single label) or
90+
# - an array of integers (for multi-label).
9091
by_class: Union[bool, Iterable[int], int] = False
9192

9293
# if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%,
@@ -238,8 +239,10 @@ class AttackInputData:
238239
probs_train: Optional[np.ndarray] = None
239240
probs_test: Optional[np.ndarray] = None
240241

241-
# Contains ground-truth classes. Classes are assumed to be integers starting
242-
# from 0.
242+
# Contains ground-truth classes. For single-label classification, classes are
243+
# assumed to be integers starting from 0. For multi-label classification,
244+
# label is assumed to be multi-hot, i.e., labels is a binary array of shape
245+
# (num_examples, num_classes).
243246
labels_train: Optional[np.ndarray] = None
244247
labels_test: Optional[np.ndarray] = None
245248

@@ -290,7 +293,9 @@ def num_classes(self):
290293
raise ValueError(
291294
'Can\'t identify the number of classes as no labels were provided. '
292295
'Please set labels_train and labels_test')
293-
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
296+
if not self.multilabel_data:
297+
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
298+
return self.labels_train.shape[1]
294299

295300
@property
296301
def logits_or_probs_train(self):
@@ -586,6 +591,8 @@ def validate(self):
586591
_is_array_two_dimensional(self.entropy_test, 'entropy_test')
587592
_is_array_two_dimensional(self.labels_train, 'labels_train')
588593
_is_array_two_dimensional(self.labels_test, 'labels_test')
594+
self.is_multihot_labels(self.labels_train, 'labels_train')
595+
self.is_multihot_labels(self.labels_test, 'labels_test')
589596
else:
590597
_is_array_one_dimensional(self.loss_train, 'loss_train')
591598
_is_array_one_dimensional(self.loss_test, 'loss_test')

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,18 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
7272

7373

7474
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
75-
if data.is_multilabel_data():
76-
raise ValueError("Slicing by class not supported for multilabel data.")
77-
idx_train = data.labels_train == class_value
78-
idx_test = data.labels_test == class_value
75+
"""Gets the indices (boolean) for examples belonging to the given class."""
76+
if not data.is_multilabel_data():
77+
idx_train = data.labels_train == class_value
78+
idx_test = data.labels_test == class_value
79+
else:
80+
if class_value >= data.num_classes:
81+
raise ValueError(
82+
f"class_value ({class_value}) is larger than the number of classes"
83+
" (data.num_classes)."
84+
)
85+
idx_train = data.labels_train[:, class_value].astype(bool)
86+
idx_test = data.labels_test[:, class_value].astype(bool)
7987
return _slice_data_by_indices(data, idx_train, idx_test)
8088

8189

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,20 @@ def test_slice_entire_dataset(self):
358358
expected.slice_spec = entire_dataset_slice
359359
self.assertTrue(_are_all_fields_equal(output, self.input_data))
360360

361-
def test_slice_by_class_fails(self):
361+
def test_slice_by_class(self):
362362
class_index = 1
363363
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
364-
self.assertRaises(ValueError, get_slice, self.input_data, class_slice)
364+
output = get_slice(self.input_data, class_slice)
365+
expected_indices_train = np.array([0, 2, 3])
366+
expected_indices_test = np.array([1, 2])
367+
368+
np.testing.assert_array_equal(
369+
output.logits_train,
370+
self.input_data.logits_train[expected_indices_train],
371+
)
372+
np.testing.assert_array_equal(
373+
output.logits_test, self.input_data.logits_test[expected_indices_test]
374+
)
365375

366376
@mock.patch('logging.Logger.info', wraps=logging.Logger)
367377
def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):

0 commit comments

Comments
 (0)