diff --git a/jax_privacy/accounting.py b/jax_privacy/accounting.py index d1e2d414..baf68dcc 100644 --- a/jax_privacy/accounting.py +++ b/jax_privacy/accounting.py @@ -225,6 +225,9 @@ def truncated_dpsgd_event( A DpEvent object. """ _validate_poisson_args(noise_multiplier, iterations, sampling_prob) + _validate.non_negative( + num_examples=num_examples, truncated_batch_size=truncated_batch_size + ) sampled_gaussian = dp_accounting.TruncatedSubsampledGaussianDpEvent( dataset_size=num_examples, sampling_probability=sampling_prob, @@ -264,6 +267,7 @@ def amplified_bandmf_event( A DpEvent object. """ _validate_poisson_args(noise_multiplier, iterations, sampling_prob) + _validate.positive(num_bands=num_bands) rounds = math.ceil(iterations / num_bands) return dpsgd_event( noise_multiplier=noise_multiplier, @@ -307,6 +311,7 @@ def truncated_amplified_bandmf_event( A DpEvent object. """ _validate_poisson_args(noise_multiplier, iterations, sampling_prob) + _validate.positive(num_bands=num_bands) return truncated_dpsgd_event( noise_multiplier=noise_multiplier, sampling_prob=sampling_prob, diff --git a/tests/accounting_test.py b/tests/accounting_test.py index dc4ce2fa..08c9c409 100644 --- a/tests/accounting_test.py +++ b/tests/accounting_test.py @@ -233,6 +233,95 @@ def test_use_zcdp_matches_gaussian_privacy(self): # as the continuous Gaussian with the same sigma. self.assertAlmostEqual(eps_continuous, eps_discrete, places=4) + # -- Structural-argument validation ---------------------------------------- + # num_bands flows into math.ceil(iterations / num_bands); the truncated sizes + # flow into TruncatedSubsampledGaussianDpEvent. These were previously + # unvalidated: num_bands=0 raised an opaque ZeroDivisionError, num_bands<0 + # raised a misleading "iterations=..." error, and negative sizes were silently + # accepted and produced a nonsensical event. + + @parameterized.parameters(0, -1, -16) + def test_amplified_bandmf_event_rejects_nonpositive_num_bands( + self, num_bands + ): + with self.assertRaisesRegex(ValueError, rf"num_bands={num_bands} > 0"): + accounting.amplified_bandmf_event( + 1.0, 128, num_bands=num_bands, sampling_prob=0.01 + ) + + @parameterized.parameters(0, -1, -16) + def test_truncated_amplified_bandmf_rejects_nonpositive_num_bands( + self, num_bands + ): + with self.assertRaisesRegex(ValueError, rf"num_bands={num_bands} > 0"): + accounting.truncated_amplified_bandmf_event( + 1.0, + 128, + num_bands=num_bands, + sampling_prob=0.01, + largest_group_size=1000, + truncated_batch_size=16, + ) + + @parameterized.parameters( + dict(num_examples=-1, truncated_batch_size=16), + dict(num_examples=1000, truncated_batch_size=-1), + dict(num_examples=-5, truncated_batch_size=-3), + ) + def test_truncated_dpsgd_event_rejects_negative_sizes( + self, num_examples, truncated_batch_size + ): + with self.assertRaisesRegex(ValueError, r">= 0"): + accounting.truncated_dpsgd_event( + 1.0, + 10, + sampling_prob=0.1, + num_examples=num_examples, + truncated_batch_size=truncated_batch_size, + ) + + def test_truncated_amplified_bandmf_rejects_negative_sizes(self): + # largest_group_size / truncated_batch_size are forwarded to + # truncated_dpsgd_event and validated there. + with self.assertRaisesRegex(ValueError, r">= 0"): + accounting.truncated_amplified_bandmf_event( + 1.0, + 128, + num_bands=16, + sampling_prob=0.01, + largest_group_size=-1, + truncated_batch_size=16, + ) + + def test_amplified_bandmf_event_valid_num_bands_unchanged(self): + # Regression: positive num_bands is unaffected. The mechanism runs + # rounds = ceil(iterations / num_bands) DP-SGD steps. + event = accounting.amplified_bandmf_event( + 1.0, 128, num_bands=16, sampling_prob=0.01 + ) + self.assertIsInstance(event, dp_accounting.dp_event.SelfComposedDpEvent) + self.assertEqual(event.count, 8) # ceil(128 / 16) + + def test_truncated_amplified_bandmf_valid_args_unchanged(self): + event = accounting.truncated_amplified_bandmf_event( + 1.0, + 128, + num_bands=16, + sampling_prob=0.01, + largest_group_size=1000, + truncated_batch_size=16, + ) + self.assertIsInstance(event, dp_accounting.dp_event.SelfComposedDpEvent) + self.assertEqual(event.count, 8) # ceil(128 / 16) + + def test_truncated_dpsgd_event_valid_args_unchanged(self): + # Regression: non-negative (incl. zero) sizes still build the event. + event = accounting.truncated_dpsgd_event( + 1.0, 10, sampling_prob=0.1, num_examples=1000, truncated_batch_size=16 + ) + self.assertIsInstance(event, dp_accounting.dp_event.SelfComposedDpEvent) + self.assertEqual(event.count, 10) + if __name__ == "__main__": absltest.main()