Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 86 additions & 11 deletions jax_privacy/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
"""

from collections.abc import Callable
import concurrent.futures
import copy
import dataclasses
import functools
from typing import Protocol, TypeAlias

import chex
Expand All @@ -46,6 +49,10 @@
Params: TypeAlias = chex.ArrayTree
OptState: TypeAlias = chex.ArrayTree
NoiseState: TypeAlias = chex.ArrayTree
PrecompiledFuture: TypeAlias = concurrent.futures.Future[jax.stages.Compiled]

# Shared thread pool for background JIT compilation.
_COMPILE_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=2)


class LossFn(Protocol):
Expand Down Expand Up @@ -154,6 +161,16 @@ class DPTrainer:
)
padding_multiple: int = 32

def init(self, params: Params) -> TrainingState:
"""Initialize a ``TrainingState`` at step 0."""
optimizer = aug_optimizers.as_augmented_optimizer(self.optimizer)
return TrainingState(
step=0,
params=params,
opt_state=optimizer.init(params),
noise_state=self.plan.noise_addition_transform.init(params),
)

def train_step(
self,
state: TrainingState,
Expand Down Expand Up @@ -212,13 +229,73 @@ def train_step(
)
return new_state, aux

def precompile(
self,
dataset: Batch,
params: Params,
*,
rng: np.random.Generator | int | None = None,
) -> dict[int, PrecompiledFuture]:
"""Warm up the JIT cache for ``train_step`` asynchronously.

Mirrors the signature of ``fit`` so the same ``dataset`` and ``params``
(or abstract ``jax.ShapeDtypeStruct`` equivalents) can be passed to
both. Iterates through the batch selection strategy to discover all
distinct padded batch sizes, then lowers and compiles ``train_step``
for each on a background thread pool.

The RNG is deep-copied internally so the caller's generator is never
consumed. Pass the same ``rng`` to ``precompile`` and ``fit`` to
guarantee compiled shapes match the actual training batches.

Args:
dataset: The training dataset (or an abstract version with the correct
first-dimension size), as a PyTree of arrays or
``jax.ShapeDtypeStruct``.
params: Parameter PyTree (concrete or ``jax.ShapeDtypeStruct``).
rng: Random seed or ``numpy.random.Generator``. Should match the ``rng``
passed to ``fit`` for JIT cache hits.

Returns:
A dict mapping padded batch size to a ``concurrent.futures.Future``.
Call ``future.result()`` to block until compilation finishes, or
ignore for fire-and-forget warm-up.
"""
rng = copy.deepcopy(np.random.default_rng(rng))
_ = rng.integers(2**63) # Advance past the loss_rng draw.
n = jax.tree.leaves(dataset)[0].shape[0]

unique_sizes: set[int] = set()
for idx in self.plan.batch_selection_strategy.batch_iterator(n, rng=rng):
padded = batch_selection.pad_to_multiple_of(idx, self.padding_multiple)
unique_sizes.add(len(padded))

state = jax.eval_shape(self.init, params)
key = jax.random.key(0)

jit_step = jax.jit(self.train_step)
futures: dict[int, PrecompiledFuture] = {}

def _resize(size, x):
return jax.ShapeDtypeStruct((size, *x.shape[1:]), x.dtype)

for size in sorted(unique_sizes):
batch = jax.tree.map(functools.partial(_resize, size), dataset)
padding = jax.ShapeDtypeStruct((size,), np.bool_)

lowered = jit_step.lower(state, batch, padding, loss_rng=key)
futures[size] = _COMPILE_POOL.submit(lowered.compile)

return futures

def fit(
self,
dataset: Batch,
params: Params,
*,
callback: CallbackFn | None = None,
rng: np.random.Generator | int | None = None,
precompiled_futures: dict[int, PrecompiledFuture] | None = None,
) -> TrainingState:
"""Runs an end-to-end differentially private training loop.

Expand All @@ -229,38 +306,36 @@ def fit(
``step`` is a Python int.
rng: Optional random seed or ``numpy.random.Generator`` for
reproducibility.
precompiled_futures: Optional dict mapping batch size to
``PrecompiledFuture``, e.g. from ``precompile()``.

Returns:
Final ``TrainingState``.
"""
rng = np.random.default_rng(rng)

if precompiled_futures is None:
precompiled_futures = self.precompile(dataset, params, rng=rng)

loss_rng = jax.random.key(int(rng.integers(2**63)))

num_examples = _validate.batch(dataset)

optimizer = aug_optimizers.as_augmented_optimizer(self.optimizer)

state = TrainingState(
step=0,
params=params,
opt_state=optimizer.init(params),
noise_state=self.plan.noise_addition_transform.init(params),
)
state = self.init(params)

batch_iterator = self.plan.batch_selection_strategy.batch_iterator(
num_examples, rng=rng
)

jit_step = jax.jit(self.train_step)

step = 0
for indices in batch_iterator:
indices = batch_selection.pad_to_multiple_of(
indices, self.padding_multiple
)
batch, is_padding_example = _get_batch(dataset, indices)
step_fn = precompiled_futures.get(indices.size, jax.jit(self.train_step))

state, aux = jit_step(state, batch, is_padding_example, loss_rng=loss_rng)
state, aux = step_fn(state, batch, is_padding_example, loss_rng=loss_rng)
step += 1

del indices, batch, is_padding_example
Expand Down
138 changes: 138 additions & 0 deletions tests/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,5 +334,143 @@ def test_bfloat16_params_preserved(self):
self.assertEqual(state.params.dtype, jnp.bfloat16)


class DPTrainerInitTest(parameterized.TestCase):
"""Tests for DPTrainer.init."""

def test_init_returns_training_state(self):
"""init() should return a TrainingState at step 0."""
params = jnp.array([1.0, 2.0])
plan = _make_plan(iterations=3)
optimizer = optax.sgd(0.01)

trainer = training.DPTrainer(
plan=plan,
loss_fn=_quadratic_loss,
optimizer=optimizer,
)
state = trainer.init(params)

self.assertIsInstance(state, training.TrainingState)
self.assertEqual(int(state.step), 0)
np.testing.assert_array_equal(state.params, params)

def test_init_matches_fit_initial_state(self):
"""init() should produce the same state as the beginning of fit()."""
params = jnp.array([5.0, 5.0])
dataset = np.array([[1.0, 0.0], [0.0, 1.0]])
plan = _make_plan(iterations=1, noise_multiplier=0.0)
optimizer = optax.sgd(0.01)

trainer = training.DPTrainer(
plan=plan,
loss_fn=_quadratic_loss,
optimizer=optimizer,
)

# Run one step manually via init + train_step.
state = trainer.init(params)
batch = jnp.array([[1.0, 0.0], [0.0, 1.0]])
is_padding = jnp.array([False, False])
manual_state, _ = trainer.train_step(
state, batch, is_padding, loss_rng=jax.random.key(0)
)

# Run one step via fit.
fit_state = trainer.fit(dataset, params, rng=0)

self.assertEqual(int(manual_state.step), int(fit_state.step))


class DPTrainerPrecompileTest(parameterized.TestCase):
"""Tests for DPTrainer.precompile."""

def test_precompile_returns_futures(self):
"""precompile() should return a dict of batch_size -> Future."""
params = jnp.array([1.0, 2.0])
dataset = np.array([[0.0, 0.0]] * 10) # 10 examples.
plan = _make_plan(iterations=5)
optimizer = optax.sgd(0.01)

trainer = training.DPTrainer(
plan=plan,
loss_fn=_quadratic_loss,
optimizer=optimizer,
)
futures = trainer.precompile(dataset, params, rng=42)

self.assertIsInstance(futures, dict)
self.assertNotEmpty(futures)
for size, future in futures.items():
self.assertIsInstance(size, int)
self.assertGreater(size, 0)
# Compilation should complete without error.
future.result()

def test_precompile_sizes_are_padded(self):
"""All precompiled sizes should be multiples of padding_multiple."""
params = jnp.array([1.0])
dataset = np.array([[0.0]] * 20) # 20 examples.
plan = _make_plan(iterations=10)
optimizer = optax.sgd(0.01)
padding_multiple = 8

trainer = training.DPTrainer(
plan=plan,
loss_fn=_quadratic_loss,
optimizer=optimizer,
padding_multiple=padding_multiple,
)
futures = trainer.precompile(dataset, params, rng=0)

for size in futures:
self.assertEqual(size % padding_multiple, 0)

# Wait for all compilations.
for future in futures.values():
future.result()

def test_precompile_rng_not_consumed(self):
"""precompile should deep-copy the RNG, not consume the caller's."""
params = jnp.array([1.0])
dataset = np.array([[0.0]] * 5) # 5 examples.
plan = _make_plan(iterations=3)
optimizer = optax.sgd(0.01)

trainer = training.DPTrainer(
plan=plan,
loss_fn=_quadratic_loss,
optimizer=optimizer,
)

rng = np.random.default_rng(42)
state_before = rng.__getstate__()
futures = trainer.precompile(dataset, params, rng=rng)
state_after = rng.__getstate__()

# RNG should not have been consumed.
np.testing.assert_equal(state_before, state_after)

for future in futures.values():
future.result()

def test_precompile_with_shape_dtype_struct(self):
"""precompile() should work with abstract ShapeDtypeStruct inputs."""
abstract_params = jax.ShapeDtypeStruct((3,), jnp.float32)
abstract_dataset = jax.ShapeDtypeStruct((5, 3), jnp.float32)
plan = _make_plan(iterations=3)
optimizer = optax.sgd(0.01)

trainer = training.DPTrainer(
plan=plan,
loss_fn=_quadratic_loss,
optimizer=optimizer,
)
futures = trainer.precompile(abstract_dataset, abstract_params, rng=0)

self.assertNotEmpty(futures)
for future in futures.values():
future.result()


if __name__ == '__main__':
absltest.main()
Loading