Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions jax_privacy/accounting.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def truncated_dpsgd_event(
A DpEvent object.
"""
_validate_poisson_args(noise_multiplier, iterations, sampling_prob)
_validate.non_negative(
num_examples=num_examples, truncated_batch_size=truncated_batch_size
)
sampled_gaussian = dp_accounting.TruncatedSubsampledGaussianDpEvent(
dataset_size=num_examples,
sampling_probability=sampling_prob,
Expand Down Expand Up @@ -264,6 +267,7 @@ def amplified_bandmf_event(
A DpEvent object.
"""
_validate_poisson_args(noise_multiplier, iterations, sampling_prob)
_validate.positive(num_bands=num_bands)
rounds = math.ceil(iterations / num_bands)
return dpsgd_event(
noise_multiplier=noise_multiplier,
Expand Down Expand Up @@ -307,6 +311,7 @@ def truncated_amplified_bandmf_event(
A DpEvent object.
"""
_validate_poisson_args(noise_multiplier, iterations, sampling_prob)
_validate.positive(num_bands=num_bands)
return truncated_dpsgd_event(
noise_multiplier=noise_multiplier,
sampling_prob=sampling_prob,
Expand Down
89 changes: 89 additions & 0 deletions tests/accounting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,95 @@ def test_use_zcdp_matches_gaussian_privacy(self):
# as the continuous Gaussian with the same sigma.
self.assertAlmostEqual(eps_continuous, eps_discrete, places=4)

# -- Structural-argument validation ----------------------------------------

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think none of the comments introduced in this CL are necessary, and would suggest removing them all.

From the test names I think it is pretty easy to infer what each test is supposed to do, and the logic can be inferred from reading the file being tested. We generally prefer not to include comments if they only explain something that can be inferred by someone with knowledge of Python.

# num_bands flows into math.ceil(iterations / num_bands); the truncated sizes
# flow into TruncatedSubsampledGaussianDpEvent. These were previously
# unvalidated: num_bands=0 raised an opaque ZeroDivisionError, num_bands<0
# raised a misleading "iterations=..." error, and negative sizes were silently
# accepted and produced a nonsensical event.

@parameterized.parameters(0, -1, -16)
def test_amplified_bandmf_event_rejects_nonpositive_num_bands(
self, num_bands
):
with self.assertRaisesRegex(ValueError, rf"num_bands={num_bands} > 0"):
accounting.amplified_bandmf_event(
1.0, 128, num_bands=num_bands, sampling_prob=0.01
)

@parameterized.parameters(0, -1, -16)
def test_truncated_amplified_bandmf_rejects_nonpositive_num_bands(
self, num_bands
):
with self.assertRaisesRegex(ValueError, rf"num_bands={num_bands} > 0"):
accounting.truncated_amplified_bandmf_event(
1.0,
128,
num_bands=num_bands,
sampling_prob=0.01,
largest_group_size=1000,
truncated_batch_size=16,
)

@parameterized.parameters(
dict(num_examples=-1, truncated_batch_size=16),
dict(num_examples=1000, truncated_batch_size=-1),
dict(num_examples=-5, truncated_batch_size=-3),
)
def test_truncated_dpsgd_event_rejects_negative_sizes(
self, num_examples, truncated_batch_size
):
with self.assertRaisesRegex(ValueError, r">= 0"):
accounting.truncated_dpsgd_event(
1.0,
10,
sampling_prob=0.1,
num_examples=num_examples,
truncated_batch_size=truncated_batch_size,
)

def test_truncated_amplified_bandmf_rejects_negative_sizes(self):
# largest_group_size / truncated_batch_size are forwarded to
# truncated_dpsgd_event and validated there.
with self.assertRaisesRegex(ValueError, r">= 0"):
accounting.truncated_amplified_bandmf_event(
1.0,
128,
num_bands=16,
sampling_prob=0.01,
largest_group_size=-1,
truncated_batch_size=16,
)

def test_amplified_bandmf_event_valid_num_bands_unchanged(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think '_unchanged' here and in the below test names is unnecessary and if anything adds confusion.

# Regression: positive num_bands is unaffected. The mechanism runs
# rounds = ceil(iterations / num_bands) DP-SGD steps.
event = accounting.amplified_bandmf_event(
1.0, 128, num_bands=16, sampling_prob=0.01
)
self.assertIsInstance(event, dp_accounting.dp_event.SelfComposedDpEvent)
self.assertEqual(event.count, 8) # ceil(128 / 16)

def test_truncated_amplified_bandmf_valid_args_unchanged(self):
event = accounting.truncated_amplified_bandmf_event(
1.0,
128,
num_bands=16,
sampling_prob=0.01,
largest_group_size=1000,
truncated_batch_size=16,
)
self.assertIsInstance(event, dp_accounting.dp_event.SelfComposedDpEvent)
self.assertEqual(event.count, 8) # ceil(128 / 16)

def test_truncated_dpsgd_event_valid_args_unchanged(self):
# Regression: non-negative (incl. zero) sizes still build the event.
event = accounting.truncated_dpsgd_event(
1.0, 10, sampling_prob=0.1, num_examples=1000, truncated_batch_size=16
)
self.assertIsInstance(event, dp_accounting.dp_event.SelfComposedDpEvent)
self.assertEqual(event.count, 10)


if __name__ == "__main__":
absltest.main()