diff --git a/jax_privacy/experimental/training.py b/jax_privacy/experimental/training.py index dd95f96b..8cca4eee 100644 --- a/jax_privacy/experimental/training.py +++ b/jax_privacy/experimental/training.py @@ -135,7 +135,8 @@ class DPTrainer: optimizer: An ``AugmentedGradientTransformation`` or a plain ``optax.GradientTransformation``. padding_multiple: If set, batch sizes are padded to a multiple of this - value. + value. Padding reduces JIT recompilations from variable Poisson batch + sizes, since each unique batch shape triggers a separate XLA compilation. """ plan: execution_plan.DPExecutionPlan @@ -144,7 +145,7 @@ class DPTrainer: aug_optimizers.AugmentedGradientTransformation | optax.GradientTransformation ) - padding_multiple: int = 1 + padding_multiple: int = 32 def train_step( self, diff --git a/tests/experimental/training_test.py b/tests/experimental/training_test.py index 6d18fa4d..aaf0956c 100644 --- a/tests/experimental/training_test.py +++ b/tests/experimental/training_test.py @@ -126,6 +126,16 @@ def test_padding_multiple(self): self.assertEqual(int(state.step), 2) + def test_default_padding_multiple(self): + """Default padding_multiple should be 32 to prevent JIT recompilation.""" + plan = _make_plan(iterations=1) + trainer = training.DPTrainer( + plan=plan, + loss_fn=_quadratic_loss, + optimizer=optax.sgd(0.01), + ) + self.assertEqual(trainer.padding_multiple, 32) + def test_zero_iterations_config_raises(self): """BandMFConfig requires iterations >= 1.""" with self.assertRaises(Exception):