diff --git a/jax_privacy/training.py b/jax_privacy/training.py index f8ca8b3a..37f8c804 100644 --- a/jax_privacy/training.py +++ b/jax_privacy/training.py @@ -21,9 +21,13 @@ """ from collections.abc import Callable +import concurrent.futures +import copy import dataclasses +import functools from typing import Protocol, TypeAlias +from absl import logging import chex import jax import jax_privacy @@ -34,6 +38,7 @@ import numpy as np import optax + # Re-export key symbols so users can access them via jax_privacy.training. BandMFConfig = execution_plan.BandMFConfig DPExecutionPlan = execution_plan.DPExecutionPlan @@ -43,9 +48,14 @@ Aux: TypeAlias = chex.ArrayTree PerExampleAux: TypeAlias = jax_privacy.clipping.AuxiliaryOutput Batch: TypeAlias = chex.ArrayTree +Dataset: TypeAlias = chex.ArrayTree Params: TypeAlias = chex.ArrayTree OptState: TypeAlias = chex.ArrayTree NoiseState: TypeAlias = chex.ArrayTree +PrecompiledFuture: TypeAlias = concurrent.futures.Future[jax.stages.Compiled] + +# Shared thread pool for background ahead-of-time compilation. +_COMPILE_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=1) class LossFn(Protocol): @@ -100,7 +110,7 @@ class TrainingState: CallbackFn: TypeAlias = Callable[[int, TrainingState, PerExampleAux], None] -def _get_batch(dataset: Batch, indices: np.ndarray) -> tuple[Batch, np.ndarray]: +def _get_batch(dataset: Batch, indices: np.ndarray) -> tuple[Batch, jax.Array]: """Retrieves a batch from a PyTree dataset, zeroing padding examples. Args: @@ -115,11 +125,11 @@ def _get_batch(dataset: Batch, indices: np.ndarray) -> tuple[Batch, np.ndarray]: """ is_padding = indices == -1 - def _index_and_zero(x: np.ndarray) -> np.ndarray: + def _index_and_zero(x): mask = np.expand_dims(is_padding, tuple(range(1, x.ndim))) - return np.where(mask, 0, x[indices]) + return jax.device_put(np.where(mask, 0, x[indices])) - return jax.tree.map(_index_and_zero, dataset), is_padding + return jax.tree.map(_index_and_zero, dataset), jax.device_put(is_padding) @dataclasses.dataclass(frozen=True, kw_only=True) @@ -154,13 +164,23 @@ class DPTrainer: ) padding_multiple: int = 32 + def init(self, params: Params) -> TrainingState: + """Initialize a ``TrainingState`` at step 0.""" + optimizer = aug_optimizers.as_augmented_optimizer(self.optimizer) + return TrainingState( + step=0, + params=params, + opt_state=optimizer.init(params), + noise_state=self.plan.noise_addition_transform.init(params), + ) + + @jax.jit(static_argnames=["self"], donate_argnames=["state"]) def train_step( self, state: TrainingState, batch: Batch, is_padding_example: jax.Array, - *, - loss_rng: jax.Array, + prng_key: jax.Array, ) -> tuple[TrainingState, PerExampleAux]: """Executes a single DP training step. @@ -172,8 +192,8 @@ def train_step( batch: A PyTree of arrays representing the current mini-batch. is_padding_example: A boolean array indicating which examples in ``batch`` are padding (and should be ignored). - loss_rng: Base PRNG key; a step-specific key is derived via - ``jax.random.fold_in(loss_rng, state.step)``. + prng_key: Base PRNG key; a step-specific key is derived via + ``jax.random.fold_in(prng_key, state.step)``. Returns: A tuple ``(new_state, aux)`` where ``new_state`` is the updated @@ -191,9 +211,9 @@ def train_step( prng_argnum=2, ) - rng = jax.random.fold_in(loss_rng, state.step) + loss_prng = jax.random.fold_in(prng_key, state.step) clipped_grad_sum, aux = grad_fn( - state.params, batch, rng, is_padding_example=is_padding_example + state.params, batch, loss_prng, is_padding_example=is_padding_example ) dp_grad, new_noise_state = self.plan.noise_addition_transform.update( @@ -212,55 +232,99 @@ def train_step( ) return new_state, aux + def _precompile( + self, + dataset: Dataset, + params: Params, + *, + rng_or_seed: np.random.Generator | int | None = None, + ) -> dict[int, PrecompiledFuture]: + """[ADVANCED] Warm up the JIT cache for ``train_step`` asynchronously.""" + # With the same rng passed to _precompile and fit, the exact same + # batches will be sampled in this dry-run as in the actual training loop, + # guaranteeing JIT cache hits. + rng = copy.deepcopy(np.random.default_rng(rng_or_seed)) + seed = rng.integers(2**63) + n = _validate.batch(dataset) + + state = jax.eval_shape(self.init, params) + key = jax.eval_shape(lambda x: x, jax.random.key(seed)) + + futures: dict[int, PrecompiledFuture] = {} + + def _resize(size, x): + return jax.ShapeDtypeStruct((size, *x.shape[1:]), x.dtype) + + for idx in self.plan.batch_selection_strategy.batch_iterator(n, rng=rng): + padded = batch_selection.pad_to_multiple_of(idx, self.padding_multiple) + batch_size = padded.size + batch = jax.tree.map(functools.partial(_resize, batch_size), dataset) + padding = jax.ShapeDtypeStruct((batch_size,), np.bool_) + + lowered = self.train_step.lower(self, state, batch, padding, key) + logging.info("AOT-compiling train_step for batch size %d", batch_size) + futures[batch_size] = _COMPILE_POOL.submit(lowered.compile) + + return futures + def fit( self, - dataset: Batch, + dataset: Dataset, params: Params, *, callback: CallbackFn | None = None, - rng: np.random.Generator | int | None = None, + rng_or_seed: np.random.Generator | int | None = None, + precompile: bool = True, ) -> TrainingState: """Runs an end-to-end differentially private training loop. Args: - dataset: The training dataset, as a PyTree of arrays. + dataset: The training dataset, as a PyTree of arrays where the first axis + of each leaf is the batch / example dimension. params: Initial parameter PyTree. callback: Called after each step as ``callback(step, state, aux)``. ``step`` is a Python int. - rng: Optional random seed or ``numpy.random.Generator`` for - reproducibility. + rng_or_seed: Optional random seed or ``numpy.random.Generator``, used for + sampling batches (impacting privacy) and initializing the loss PRNG key + (potentially impacting utility). Does not influence the noise addition + transform, which is configured via the DPExecutionPlan. + precompile: A boolean indicating whether to asyncronously precompile + ``train_step`` for the batch sizes encountered, instead of just-in-time + compiling on the fly, which can idle accelerators during training. Can + also directly pass in a dict of precompiled futures, e.g. from + :py:meth:`precompile()`. Returns: Final ``TrainingState``. """ - rng = np.random.default_rng(rng) - loss_rng = jax.random.key(int(rng.integers(2**63))) - - num_examples = _validate.batch(dataset) + futures: dict[int, PrecompiledFuture] = {} + if precompile: + futures = self._precompile(dataset, params, rng_or_seed=rng_or_seed) - optimizer = aug_optimizers.as_augmented_optimizer(self.optimizer) + # We need tight alignement between how rng is used here and in precompile(). + rng = np.random.default_rng(rng_or_seed) + prng_key = jax.random.key(int(rng.integers(2**63))) - state = TrainingState( - step=0, - params=params, - opt_state=optimizer.init(params), - noise_state=self.plan.noise_addition_transform.init(params), - ) + num_examples = _validate.batch(dataset) + state = self.init(jax.tree.map(jax.numpy.copy, params)) batch_iterator = self.plan.batch_selection_strategy.batch_iterator( num_examples, rng=rng ) - jit_step = jax.jit(self.train_step) - step = 0 for indices in batch_iterator: indices = batch_selection.pad_to_multiple_of( indices, self.padding_multiple ) batch, is_padding_example = _get_batch(dataset, indices) + if indices.size in futures: + step_fn = futures[indices.size].result() + else: + logging.info("JIT-compiling train_step for batch size %d", indices.size) + step_fn = self.train_step - state, aux = jit_step(state, batch, is_padding_example, loss_rng=loss_rng) + state, aux = step_fn(state, batch, is_padding_example, prng_key) step += 1 del indices, batch, is_padding_example diff --git a/tests/training_test.py b/tests/training_test.py index da347e09..a7338c8b 100644 --- a/tests/training_test.py +++ b/tests/training_test.py @@ -13,10 +13,12 @@ # limitations under the License. +import dataclasses from absl.testing import absltest from absl.testing import parameterized import jax import jax.numpy as jnp +from jax_privacy import batch_selection from jax_privacy import execution_plan from jax_privacy import training import numpy as np @@ -63,7 +65,7 @@ def test_basic_training_runs(self): loss_fn=_quadratic_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertIsInstance(state, training.TrainingState) self.assertEqual(int(state.step), 3) @@ -80,7 +82,7 @@ def test_params_change_after_training(self): loss_fn=_quadratic_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=42) + state = trainer.fit(dataset, params, rng_or_seed=42) self.assertFalse(jnp.allclose(state.params, params)) @@ -104,7 +106,12 @@ def callback(step, state, aux): loss_fn=_quadratic_loss, optimizer=optimizer, ) - trainer.fit(dataset, params, callback=callback, rng=0) + trainer.fit( + dataset, + params, + callback=callback, + rng_or_seed=0, + ) self.assertLen(callback_log, iterations) self.assertEqual([s for s, _ in callback_log], [1, 2, 3]) @@ -122,7 +129,7 @@ def test_padding_multiple(self): optimizer=optimizer, padding_multiple=4, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertEqual(int(state.step), 2) @@ -143,7 +150,7 @@ def test_single_iteration(self): loss_fn=_quadratic_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertEqual(int(state.step), 1) @@ -169,7 +176,7 @@ def counting_loss(params, batch, prng): loss_fn=counting_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertEqual(int(state.step), 3) self.assertLess(trace_count[0], 3 * 2) @@ -188,18 +195,16 @@ def test_train_step_callable_directly(self): state = training.TrainingState( step=0, - params=params, + params=jnp.copy(params), opt_state=optimizer.init(params), noise_state=plan.noise_addition_transform.init(params), ) batch = jnp.array([[1.0, 0.0], [0.0, 1.0]]) is_padding = jnp.array([False, False]) - loss_rng = jax.random.key(0) + prng_key = jax.random.key(0) - new_state, _ = trainer.train_step( - state, batch, is_padding, loss_rng=loss_rng - ) + new_state, _ = trainer.train_step(state, batch, is_padding, prng_key) self.assertEqual(int(new_state.step), 1) self.assertFalse(jnp.allclose(new_state.params, params)) @@ -225,10 +230,10 @@ def test_train_step_jit_compilable(self): batch = jnp.array([[1.0], [0.0]]) is_padding = jnp.array([False, False]) - loss_rng = jax.random.key(0) + prng_key = jax.random.key(0) - jit_step = jax.jit(trainer.train_step) - new_state, _ = jit_step(state, batch, is_padding, loss_rng=loss_rng) + # train_step is already @jax.jit decorated; call it directly. + new_state, _ = trainer.train_step(state, batch, is_padding, prng_key) self.assertEqual(int(new_state.step), 1) @@ -248,7 +253,7 @@ def test_epsilon_zero_high_noise(self): loss_fn=_quadratic_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertEqual(int(state.step), 2) self.assertTrue(jnp.all(jnp.isfinite(state.params))) @@ -265,7 +270,7 @@ def test_epsilon_inf_no_noise(self): loss_fn=_quadratic_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertEqual(int(state.step), 3) self.assertLess( @@ -285,7 +290,7 @@ def test_single_example_dataset(self): loss_fn=_quadratic_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertEqual(int(state.step), 2) self.assertTrue(jnp.all(jnp.isfinite(state.params))) @@ -308,7 +313,7 @@ def dict_loss(params, batch, prng): loss_fn=dict_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertEqual(int(state.step), 2) @@ -329,10 +334,176 @@ def test_bfloat16_params_preserved(self): loss_fn=_quadratic_loss, optimizer=optimizer, ) - state = trainer.fit(dataset, params, rng=0) + state = trainer.fit(dataset, params, rng_or_seed=0) self.assertEqual(state.params.dtype, jnp.bfloat16) +class DPTrainerInitTest(parameterized.TestCase): + """Tests for DPTrainer.init.""" + + def test_init_returns_training_state(self): + """init() should return a TrainingState at step 0.""" + params = jnp.array([1.0, 2.0]) + plan = _make_plan(iterations=3) + optimizer = optax.sgd(0.01) + + trainer = training.DPTrainer( + plan=plan, + loss_fn=_quadratic_loss, + optimizer=optimizer, + ) + state = trainer.init(params) + + self.assertIsInstance(state, training.TrainingState) + self.assertEqual(int(state.step), 0) + np.testing.assert_array_equal(state.params, params) + + +class DPTrainerPrecompileTest(parameterized.TestCase): + """Tests for DPTrainer.precompile.""" + + def test_precompile_returns_futures(self): + """precompile() should return a dict of batch_size -> Future.""" + params = jnp.array([1.0, 2.0]) + dataset = np.array([[0.0, 0.0]] * 10) # 10 examples. + plan = _make_plan(iterations=5) + optimizer = optax.sgd(0.01) + + trainer = training.DPTrainer( + plan=plan, + loss_fn=_quadratic_loss, + optimizer=optimizer, + ) + futures = trainer._precompile(dataset, params, rng_or_seed=42) + + self.assertIsInstance(futures, dict) + self.assertNotEmpty(futures) + for size, future in futures.items(): + self.assertIsInstance(size, int) + self.assertGreater(size, 0) + # Compilation should complete without error. + future.result() + + def test_precompile_sizes_are_padded(self): + """All precompiled sizes should be multiples of padding_multiple.""" + params = jnp.array([1.0]) + dataset = np.array([[0.0]] * 20) # 20 examples. + plan = _make_plan(iterations=10) + optimizer = optax.sgd(0.01) + padding_multiple = 8 + + trainer = training.DPTrainer( + plan=plan, + loss_fn=_quadratic_loss, + optimizer=optimizer, + padding_multiple=padding_multiple, + ) + futures = trainer._precompile(dataset, params, rng_or_seed=0) + + for size in futures: + self.assertEqual(size % padding_multiple, 0) + + # Wait for all compilations. + for future in futures.values(): + future.result() + + def test_precompile_rng_not_consumed(self): + """precompile should deep-copy the RNG, not consume the caller's.""" + params = jnp.array([1.0]) + dataset = np.array([[0.0]] * 5) # 5 examples. + plan = _make_plan(iterations=3) + optimizer = optax.sgd(0.01) + + trainer = training.DPTrainer( + plan=plan, + loss_fn=_quadratic_loss, + optimizer=optimizer, + ) + + rng = np.random.default_rng(42) + state_before = rng.__getstate__() + futures = trainer._precompile(dataset, params, rng_or_seed=rng) + state_after = rng.__getstate__() + + # RNG should not have been consumed. + np.testing.assert_equal(state_before, state_after) + + for future in futures.values(): + future.result() + + def test_precompile_with_shape_dtype_struct(self): + """precompile() should work with abstract ShapeDtypeStruct inputs.""" + params = jax.ShapeDtypeStruct((3,), jnp.float32) + dataset = jax.ShapeDtypeStruct((5, 3), jnp.float32) + plan = _make_plan(iterations=3) + optimizer = optax.sgd(0.01) + + trainer = training.DPTrainer( + plan=plan, + loss_fn=_quadratic_loss, + optimizer=optimizer, + ) + futures = trainer._precompile(dataset, params, rng_or_seed=0) + + self.assertNotEmpty(futures) + for future in futures.values(): + future.result() + + def test_fit_precompile_aot_compiles_all_sizes(self): + """precompile=True should AOT-compile once per unique batch size.""" + trace_count = [0] + + def loss_fn(params, batch, _): + trace_count[0] += 1 + return jnp.mean((params - batch) ** 2), {} + + params = jnp.array([1.0]) + dataset = np.array([[i] for i in range(50)]) + + plan = dataclasses.replace( + _make_plan(iterations=5), + batch_selection_strategy=batch_selection.CyclicPoissonSampling(0.5, 5), + ) + trainer = training.DPTrainer( + plan=plan, loss_fn=loss_fn, optimizer=optax.sgd(1), padding_multiple=1 + ) + + with self.assertLogs(level='INFO') as logs: + trainer.fit(dataset, params, rng_or_seed=0, precompile=True) + for log in logs.output: + self.assertIn('AOT-compiling train_step for batch size', log) + self.assertNotIn('JIT-compiling train_step for batch size', log) + self.assertEqual(trace_count[0], 5) + self.assertLen(logs.output, 5) + + def test_fit_no_precompile_jit_compiles_all_sizes(self): + """precompile=False should JIT-compile once per unique batch size.""" + trace_count = [0] + + def loss_fn(params, batch, _): + trace_count[0] += 1 + return jnp.mean((params - batch) ** 2), {} + + params = jnp.array([1.0]) + dataset = np.array([[i] for i in range(50)]) + + plan = dataclasses.replace( + _make_plan(iterations=5), + batch_selection_strategy=batch_selection.CyclicPoissonSampling(0.5, 5), + ) + trainer = training.DPTrainer( + plan=plan, loss_fn=loss_fn, optimizer=optax.sgd(1), padding_multiple=1 + ) + + with self.assertLogs(level='INFO') as logs: + trainer.fit(dataset, params, rng_or_seed=0, precompile=False) + for log in logs.output: + self.assertNotIn('AOT-compiling train_step for batch size', log) + self.assertIn('JIT-compiling train_step for batch size', log) + self.assertEqual(trace_count[0], 5) + self.assertLen(logs.output, 5) + + if __name__ == '__main__': absltest.main()