Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions jax_privacy/experimental/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ class DPTrainer:
optimizer: An ``AugmentedGradientTransformation`` or a plain
``optax.GradientTransformation``.
padding_multiple: If set, batch sizes are padded to a multiple of this
value.
value. Padding reduces JIT recompilations from variable Poisson batch
sizes, since each unique batch shape triggers a separate XLA compilation.
"""

plan: execution_plan.DPExecutionPlan
Expand All @@ -144,7 +145,7 @@ class DPTrainer:
aug_optimizers.AugmentedGradientTransformation
| optax.GradientTransformation
)
padding_multiple: int = 1
padding_multiple: int = 32

def train_step(
self,
Expand Down
10 changes: 10 additions & 0 deletions tests/experimental/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ def test_padding_multiple(self):

self.assertEqual(int(state.step), 2)

def test_default_padding_multiple(self):
"""Default padding_multiple should be 32 to prevent JIT recompilation."""
plan = _make_plan(iterations=1)
trainer = training.DPTrainer(
plan=plan,
loss_fn=_quadratic_loss,
optimizer=optax.sgd(0.01),
)
self.assertEqual(trainer.padding_multiple, 32)

def test_zero_iterations_config_raises(self):
"""BandMFConfig requires iterations >= 1."""
with self.assertRaises(Exception):
Expand Down
Loading