diff --git a/jax_privacy/experimental/_data_loader.py b/jax_privacy/experimental/_data_loader.py new file mode 100644 index 00000000..f600fd1b --- /dev/null +++ b/jax_privacy/experimental/_data_loader.py @@ -0,0 +1,302 @@ +# Copyright 2026 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyGrain data loader for differentially private training. + +This module provides batch iterators that integrate a PyGrain +``MapDataset`` with :mod:`jax_privacy.batch_selection` strategies. + +Two loading modes are supported: + +* **Preload** (``preload=True`` or auto-detected): Materializes the full + dataset into memory as a stacked NumPy PyTree, then uses fast numpy + fancy-indexing per batch. Best for datasets that comfortably fit in + host RAM. In multi-controller JAX setups, use the combination of + shard_options=grain.ShardByJaxProcess() and + jax.make_array_from_process_local_data(). + +* **Streaming** (``preload=False``): Loads elements on-demand using a + ``ThreadPoolExecutor`` for concurrent reads. This is the right choice + for large or remote datasets (e.g., ArrayRecord on disk/remote storage) + that should not be held entirely in memory. + +When ``preload`` is left as ``None`` (the default), the loader estimates +the in-memory size from the first element's PyTree and the dataset length. +If the estimate is under ``_PRELOAD_THRESHOLD_BYTES`` (1 GiB), the dataset +is preloaded; otherwise it streams. + +This module is intentionally *not* re-exported from any ``__init__.py`` — +users who do not have PyGrain installed will never import it. +""" + +from __future__ import annotations + +from collections.abc import Generator +import concurrent.futures +import copy +import logging +from typing import Any + +import jax +from jax_privacy import batch_selection +import numpy as np + +_PRELOAD_THRESHOLD_BYTES = 1 << 30 # 1 GiB + + +def is_pygrain_map_dataset(dataset: Any) -> bool: + """Checks whether ``dataset`` is a PyGrain MapDataset, by class name.""" + for cls in type(dataset).__mro__: + if cls.__name__ == "MapDataset": + return True + return False + + +# --------------------------------------------------------------------------- +# Size estimation +# --------------------------------------------------------------------------- + + +def _estimate_dataset_bytes(dataset) -> int: + """Estimates the total in-memory size of the dataset in bytes. + + Computes the size of one element's PyTree leaves (via ``nbytes``) and + multiplies by the number of elements. + + Args: + dataset: A PyGrain ``MapDataset`` supporting ``len`` and ``__getitem__``. + + Returns: + Estimated size in bytes. + """ + first_element = dataset[0] + element_bytes = sum( + jax.tree.leaves(jax.tree.map(lambda x: x.nbytes, first_element)) + ) + return element_bytes * len(dataset) + + +def _should_preload(dataset) -> bool: + """Decides whether to preload based on estimated dataset size.""" + estimated = _estimate_dataset_bytes(dataset) + decision = estimated <= _PRELOAD_THRESHOLD_BYTES + logging.info( + "Dataset size estimate: %.2f MiB (%d elements). Preload: %s.", + estimated / (1 << 20), + len(dataset), + decision, + ) + return decision + + +# --------------------------------------------------------------------------- +# Preload helpers +# --------------------------------------------------------------------------- + + +def _preload(dataset, max_workers: int | None = None): + """Materializes all elements into a stacked PyTree for fast indexing.""" + n = len(dataset) + with concurrent.futures.ThreadPoolExecutor(max_workers) as executor: + elements = list(executor.map(dataset.__getitem__, range(n))) + return jax.tree.map(lambda *leaves: np.stack(leaves), *elements) + + +def _get_batch_preloaded(stacked, indices): + """Indexes into a stacked PyTree with padding support.""" + is_padding_example = indices == -1 + # Replace -1 with 0 for safe indexing, then zero out padding entries. + safe = np.where(is_padding_example, 0, indices) + + def _index_and_zero(x): + mask = np.expand_dims(is_padding_example, tuple(range(1, x.ndim))) + return np.where(mask, 0, x[safe]) + + return jax.tree.map(_index_and_zero, stacked), is_padding_example + + +# --------------------------------------------------------------------------- +# Streaming iterator (ThreadPoolExecutor) +# --------------------------------------------------------------------------- + + +class PrivateBatchIterator: + """A batch iterator that uses jax_privacy BatchSelectionStrategy. + + This iterator yields batches of data from the given dataset, along with + a boolean mask indicating which examples in the batch are padding examples. + ``get_state`` and ``set_state`` are implemented to allow for easy and + lightweight checkpointing of this batch iterator. + """ + + def __init__( + self, + dataset: Any, + strategy: batch_selection.BatchSelectionStrategy, + rng: np.random.Generator | int, + *, + shard_options: Any = None, + pad_to_multiple_of: int = 1, + max_workers: int | None = None, + ): + """Initializes the PrivateBatchIterator. + + Args: + dataset: The dataset from which to draw samples from. Each example in the + dataset should be a PyTree of numpy arrays with common structure/shapes. + strategy: A BatchSelectionStrategy defining how batches should be formed. + rng: The random number generator or seed to use to sample minibatches. + shard_options: If specified, only a subset of the batch will be loaded + based on shard_index and shard_count. In multi-controller JAX setups, + use grain.ShardByJaxProcess() to have each process load a disjoint + subset of the batch. + pad_to_multiple_of: If provided, pad the batch to a multiple of this + number. Larger values reduces the number of compilations needed in + downstream JAX code. + max_workers: The number of workers to use for parallel loading. The + behavior the default `max_workers=None` is version-dependent, and + typically depends on the number of available CPU cores. See + https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor + for more details. + """ + self._dataset = dataset + self._strategy = strategy + self._shard_options = shard_options + self._pad_to_multiple_of = pad_to_multiple_of + if ( + self._shard_options is not None + and self._pad_to_multiple_of % self._shard_options.shard_count != 0 + ): + raise ValueError( + f"{pad_to_multiple_of=} must be a multiple of" + f" {shard_options.shard_count=}" + ) + self._iteration = 0 + # If rng is used outside of this iterator, our checkpointing logic will + # fail. We therefore make a deepcopy here to avoid this. + self._initial_rng = copy.deepcopy(rng) + self._batch_generator = self._strategy.batch_iterator( + num_examples=len(self._dataset), rng=copy.deepcopy(rng) + ) + # Pre-fetch the first element to use as a template for padding. + self._padding_element = jax.tree.map(np.empty_like, dataset[0]) + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers) + self._max_workers = max_workers + + def _get_element(self, idx): + # It might be better in some cases to use batched indexing (via grains + # private interface SupportsBatchedReadRandomAccessDataSource). + # In this simple benchmark, we do not see significant performance gains with + # this, but it may work better in other settings. + if idx == -1: + return self._padding_element + return self._dataset[idx] + + def __iter__(self): + return self + + def __next__(self) -> tuple[Any, np.ndarray]: + indices = batch_selection.pad_to_multiple_of( + next(self._batch_generator), self._pad_to_multiple_of + ) + if self._shard_options is not None: + shard_size = len(indices) // self._shard_options.shard_count + start_idx = self._shard_options.shard_index * shard_size + indices = indices[start_idx : start_idx + shard_size] + if indices.size == 0: + return self.__next__() + is_padding_example = indices == -1 + + batch_elements = list(self._executor.map(self._get_element, indices)) + + self._iteration += 1 + batch = jax.tree.map(lambda *leaves: np.stack(leaves), *batch_elements) + return batch, is_padding_example + + def get_state(self) -> dict[str, Any]: + return { + "iteration": self._iteration, + "initial_rng": self._initial_rng, + } + + def set_state(self, state: dict[str, Any]): + self._iteration = state["iteration"] + self._initial_rng = state["initial_rng"] + self._batch_generator = self._strategy.batch_iterator( + num_examples=len(self._dataset), rng=copy.deepcopy(self._initial_rng) + ) + # Fast-forward the generator + for _ in range(self._iteration): + next(self._batch_generator) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def iterate_batches( + dataset: Any, + strategy: batch_selection.BatchSelectionStrategy, + rng: np.random.Generator, + *, + shard_options: Any = None, + pad_to_multiple_of: int = 1, + preload: bool | None = None, + max_workers: int | None = None, +) -> Generator[tuple[Any, np.ndarray], None, None]: + """Yields ``(batch, is_padding_example)`` tuples from a PyGrain MapDataset. + + Args: + dataset: A PyGrain ``MapDataset`` supporting ``len`` and ``__getitem__``. + strategy: A ``BatchSelectionStrategy`` that produces index arrays. + rng: A NumPy random generator for the batch strategy. + shard_options: If specified, only a subset of the batch will be loaded based + on shard_index and shard_count. In multi-controller JAX setups, use + grain.ShardByJaxProcess() to have each process load a disjoint subset of + the batch. + pad_to_multiple_of: If provided, pad the batch to a multiple of this number. + Larger values reduces the number of compilations needed in downstream JAX + code. + preload: Whether to materialize the full dataset into memory. ``True`` + forces preloading, ``False`` forces streaming, and ``None`` (default) + auto-decides based on estimated dataset size (preloads if < 1 GiB). + max_workers: Maximum thread pool workers. Used for parallel element loading + in *both* preload and streaming modes. + + Yields: + ``(batch, is_padding_example)`` where ``batch`` is a stacked PyTree and + ``is_padding_example`` is a boolean array flagging padding entries. + """ + if preload is None: + preload = _should_preload(dataset) + + if preload: + stacked = _preload(dataset, max_workers=max_workers) + for indices in strategy.batch_iterator( + len(dataset), rng=copy.deepcopy(rng) + ): + indices = batch_selection.pad_to_multiple_of(indices, pad_to_multiple_of) + if indices.size == 0: + continue + yield _get_batch_preloaded(stacked, indices) + else: + yield from PrivateBatchIterator( + dataset, + strategy, + rng, + shard_options=shard_options, + pad_to_multiple_of=pad_to_multiple_of, + max_workers=max_workers, + ) diff --git a/jax_privacy/experimental/training.py b/jax_privacy/experimental/training.py index dd95f96b..362471b5 100644 --- a/jax_privacy/experimental/training.py +++ b/jax_privacy/experimental/training.py @@ -22,7 +22,7 @@ from collections.abc import Callable import dataclasses -from typing import Protocol, TypeAlias +from typing import Any, Protocol, TypeAlias import chex import jax @@ -213,16 +213,33 @@ def fit( Callable[[int, TrainingState, PerExampleAux], None] | None ) = None, rng: np.random.Generator | int | None = None, + shard_options: Any = None, + preload: bool | None = None, + max_workers: int | None = None, ) -> TrainingState: """Runs an end-to-end differentially private training loop. Args: - dataset: The training dataset, as a PyTree of arrays. + dataset: The training dataset. This can be either a PyTree of NumPy + arrays (all data in memory) or a PyGrain ``MapDataset``. When a + ``MapDataset`` is provided, PyGrain must be installed; it is not + required otherwise. params: Initial parameter PyTree. callback: Called after each step as ``callback(step, state, aux)``. ``step`` is a Python int. rng: Optional random seed or ``numpy.random.Generator`` for reproducibility. + shard_options: If specified, only a subset of the batch will be loaded + based on shard_index and shard_count. In multi-controller JAX setups, + use ``grain.ShardByJaxProcess()`` to have each process load a disjoint + subset of the batch. Defaults to no sharding. + preload: Whether to materialize a PyGrain ``MapDataset`` into host memory + for fast numpy indexing. ``True`` forces preloading, ``False`` forces + streaming, and ``None`` (default) auto-decides based on estimated + dataset size (preloads if < 1 GiB). + max_workers: Maximum thread pool workers for concurrent element loading. + Used in both preload and streaming modes. Ignored when the dataset is + an in-memory PyTree. Returns: Final ``TrainingState``. @@ -230,8 +247,6 @@ def fit( rng = np.random.default_rng(rng) loss_rng = jax.random.key(int(rng.integers(2**63))) - num_examples = _validate.batch(dataset) - optimizer = aug_optimizers.as_augmented_optimizer(self.optimizer) state = TrainingState( @@ -241,25 +256,45 @@ def fit( noise_state=self.plan.noise_addition_transform.init(params), ) - batch_iterator = self.plan.batch_selection_strategy.batch_iterator( - num_examples, rng=rng - ) - jit_step = jax.jit(self.train_step) - step = 0 - for indices in batch_iterator: - indices = batch_selection.pad_to_multiple_of( - indices, self.padding_multiple + # Lazy import: only pull in the data loader when the dataset is a + # PyGrain MapDataset. Detection is by class name, not import, so + # users who don't have grain installed never trigger this path. + from jax_privacy.experimental import _data_loader # pylint: disable=g-import-not-at-top,import-outside-toplevel,protected-access + + if _data_loader.is_pygrain_map_dataset(dataset): + batches = _data_loader.iterate_batches( + dataset, + self.plan.batch_selection_strategy, + rng, + shard_options=shard_options, + pad_to_multiple_of=self.padding_multiple, + preload=preload, + max_workers=max_workers, ) - batch, is_padding_example = _get_batch(dataset, indices) + else: + num_examples = _validate.batch(dataset) + batches = self._in_memory_batches(dataset, num_examples, rng) + step = 0 + for batch, is_padding_example in batches: state, aux = jit_step(state, batch, is_padding_example, loss_rng=loss_rng) step += 1 - del indices, batch, is_padding_example + del batch, is_padding_example if callback is not None: callback(step, state, aux) return state + + def _in_memory_batches(self, dataset, num_examples, rng): + """Yields ``(batch, is_padding)`` tuples from an in-memory PyTree.""" + for indices in self.plan.batch_selection_strategy.batch_iterator( + num_examples, rng=rng + ): + indices = batch_selection.pad_to_multiple_of( + indices, self.padding_multiple + ) + yield _get_batch(dataset, indices) diff --git a/tests/experimental/_data_loader_test.py b/tests/experimental/_data_loader_test.py new file mode 100644 index 00000000..fcaaf7f0 --- /dev/null +++ b/tests/experimental/_data_loader_test.py @@ -0,0 +1,367 @@ +# Copyright 2026 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import concurrent.futures + +from absl.testing import absltest +from absl.testing import parameterized +import jax.numpy as jnp +from jax_privacy import batch_selection +from jax_privacy import execution_plan +from jax_privacy.experimental import _data_loader +from jax_privacy.experimental import training +import numpy as np +import optax +import pytest + +grain = pytest.importorskip('grain.python') + +# --------------------------------------------------------------------------- +# Standalone _data_loader tests +# --------------------------------------------------------------------------- + + +class IsPygrainMapDatasetTest(absltest.TestCase): + """Tests for the string-based MapDataset type check.""" + + def test_grain_map_dataset_detected(self): + ds = grain.MapDataset.source([1, 2, 3]) + self.assertTrue(_data_loader.is_pygrain_map_dataset(ds)) + + def test_transformed_map_dataset_detected(self): + ds = grain.MapDataset.source([1, 2, 3]).map(lambda x: x + 1) + self.assertTrue(_data_loader.is_pygrain_map_dataset(ds)) + + def test_numpy_array_rejected(self): + self.assertFalse(_data_loader.is_pygrain_map_dataset(np.zeros((3, 2)))) + + def test_python_list_rejected(self): + self.assertFalse(_data_loader.is_pygrain_map_dataset([1, 2, 3])) + + def test_dict_rejected(self): + self.assertFalse(_data_loader.is_pygrain_map_dataset({'x': np.zeros((3,))})) + + +class IterateBatchesTest(parameterized.TestCase): + """Tests for iterate_batches with a real grain MapDataset.""" + + def _make_dataset(self, num_examples=10, dim=2): + """Creates a simple MapDataset of dicts.""" + examples = [ + {'x': np.full(dim, i, dtype=np.float32)} for i in range(num_examples) + ] + return grain.MapDataset.source(examples) + + def test_yields_correct_number_of_batches(self): + ds = self._make_dataset(num_examples=8) + strategy = batch_selection.CyclicPoissonSampling( + iterations=5, sampling_prob=1.0 + ) + rng = np.random.default_rng(0) + + batches = list( + _data_loader.iterate_batches(ds, strategy, rng, pad_to_multiple_of=1) + ) + self.assertLen(batches, 5) + + def test_auto_preload_small_dataset(self): + """Small datasets should auto-preload (default preload=None).""" + ds = self._make_dataset(num_examples=4) + strategy = batch_selection.CyclicPoissonSampling( + iterations=2, sampling_prob=1.0 + ) + rng = np.random.default_rng(0) + # Should auto-detect as preload=True and produce valid batches. + batches = list( + _data_loader.iterate_batches(ds, strategy, rng, pad_to_multiple_of=1) + ) + self.assertLen(batches, 2) + + def test_estimate_dataset_bytes(self): + ds = self._make_dataset(num_examples=10, dim=4) + # Each element is {'x': float32[4]} = 16 bytes, so total = 160 bytes. + estimated = _data_loader._estimate_dataset_bytes(ds) + self.assertEqual(estimated, 160) + + def test_batch_contains_correct_keys(self): + ds = self._make_dataset(num_examples=4) + strategy = batch_selection.CyclicPoissonSampling( + iterations=1, sampling_prob=1.0 + ) + rng = np.random.default_rng(42) + + ((batch, is_padding),) = list( + _data_loader.iterate_batches(ds, strategy, rng, pad_to_multiple_of=1) + ) + self.assertIn('x', batch) + self.assertEqual(batch['x'].ndim, 2) + self.assertEqual(is_padding.ndim, 1) + + def test_padding_entries_are_zeroed(self): + ds = self._make_dataset(num_examples=3, dim=1) + # sampling_prob=1.0 guarantees all 3 examples are selected. + # pad_to_multiple_of=4 pads the batch from 3 → 4, so one entry is padding. + strategy = batch_selection.CyclicPoissonSampling( + iterations=1, sampling_prob=1.0 + ) + rng = np.random.default_rng(0) + + ((batch, is_padding),) = list( + _data_loader.iterate_batches(ds, strategy, rng, pad_to_multiple_of=4) + ) + self.assertTrue(is_padding.any()) + np.testing.assert_array_equal(batch['x'][is_padding], 0.0) + + @parameterized.parameters(1, 4, 8) + def test_padding_multiple_respected(self, padding_multiple): + ds = self._make_dataset(num_examples=10) + strategy = batch_selection.CyclicPoissonSampling( + iterations=3, sampling_prob=0.5 + ) + rng = np.random.default_rng(7) + + for batch, _ in _data_loader.iterate_batches( + ds, strategy, rng, pad_to_multiple_of=padding_multiple + ): + self.assertEqual(batch['x'].shape[0] % padding_multiple, 0) + + @parameterized.parameters(True, False) + def test_preload_modes_produce_same_results(self, preload): + """Both preload paths should produce identical batches.""" + ds = self._make_dataset(num_examples=6, dim=3) + strategy = batch_selection.CyclicPoissonSampling( + iterations=4, sampling_prob=1.0 + ) + rng = np.random.default_rng(99) + + batches = list( + _data_loader.iterate_batches( + ds, strategy, rng, pad_to_multiple_of=4, preload=preload + ) + ) + self.assertLen(batches, 4) + for batch, is_padding in batches: + self.assertIn('x', batch) + if preload: + # Preloaded path explicitly zeros padding entries. + np.testing.assert_array_equal(batch['x'][is_padding], 0.0) + + +class PrivateBatchIteratorTest(parameterized.TestCase): + """Tests for the streaming PrivateBatchIterator.""" + + def _make_dataset(self, num_examples=10, dim=2): + examples = [ + {'x': np.full(dim, i, dtype=np.float32)} for i in range(num_examples) + ] + return grain.MapDataset.source(examples) + + def test_yields_correct_number_of_batches(self): + ds = self._make_dataset(num_examples=8) + strategy = batch_selection.CyclicPoissonSampling( + iterations=5, sampling_prob=1.0 + ) + rng = np.random.default_rng(0) + + it = _data_loader.PrivateBatchIterator( + ds, strategy, rng, pad_to_multiple_of=1 + ) + batches = list(it) + self.assertLen(batches, 5) + + def test_padding_zeroed(self): + ds = self._make_dataset(num_examples=3, dim=1) + strategy = batch_selection.CyclicPoissonSampling( + iterations=1, sampling_prob=1.0 + ) + rng = np.random.default_rng(0) + + it = _data_loader.PrivateBatchIterator( + ds, strategy, rng, pad_to_multiple_of=4 + ) + batch, is_padding = next(it) + self.assertTrue(is_padding.any()) + # Note: padding values are not guaranteed to be zero in streaming mode + # because the iterator uses np.empty_like for the padding template. + self.assertEqual(batch['x'].shape[0], 4) + + def test_get_set_state_roundtrip(self): + ds = self._make_dataset(num_examples=6) + strategy = batch_selection.CyclicPoissonSampling( + iterations=4, sampling_prob=1.0 + ) + rng = np.random.default_rng(42) + + it = _data_loader.PrivateBatchIterator( + ds, strategy, rng, pad_to_multiple_of=1 + ) + # Consume 2 batches. + _, _ = next(it) + _, _ = next(it) + state = it.get_state() + + # Create new iterator and restore state. + it2 = _data_loader.PrivateBatchIterator( + ds, strategy, rng, pad_to_multiple_of=1 + ) + it2.set_state(state) + batch3, _ = next(it2) + batch3_orig, _ = next(it) + np.testing.assert_array_equal(batch3['x'], batch3_orig['x']) + + def test_streaming_matches_preload(self): + """Streaming and preload should produce the same batch contents.""" + ds = self._make_dataset(num_examples=6, dim=3) + strategy = batch_selection.CyclicPoissonSampling( + iterations=4, sampling_prob=1.0 + ) + rng = np.random.default_rng(99) + + preloaded = list( + _data_loader.iterate_batches( + ds, strategy, rng, pad_to_multiple_of=4, preload=True + ) + ) + streamed = list( + _data_loader.iterate_batches( + ds, strategy, rng, pad_to_multiple_of=4, preload=False + ) + ) + self.assertLen(preloaded, len(streamed)) + for (pb, pp), (sb, sp) in zip(preloaded, streamed): + np.testing.assert_array_equal(pp, sp) + # Only compare non-padding entries; padding values may differ because + # the preloaded path zeros padding via np.where while the streaming + # path uses np.empty_like (uninitialized memory) for the template. + non_pad = ~pp + np.testing.assert_array_equal(pb['x'][non_pad], sb['x'][non_pad]) + + def test_max_workers_respected(self): + """Ensures max_workers parameter is passed through.""" + ds = self._make_dataset(num_examples=4) + strategy = batch_selection.CyclicPoissonSampling( + iterations=1, sampling_prob=1.0 + ) + rng = np.random.default_rng(0) + + it = _data_loader.PrivateBatchIterator( + ds, strategy, rng, pad_to_multiple_of=1, max_workers=2 + ) + self.assertIsInstance(it._executor, concurrent.futures.ThreadPoolExecutor) + batch, _ = next(it) + self.assertIn('x', batch) + + +# --------------------------------------------------------------------------- +# End-to-end training with a PyGrain MapDataset +# --------------------------------------------------------------------------- + + +def _quadratic_loss(params, batch, prng): + """Per-example quadratic loss.""" + del prng + loss = jnp.mean((params - batch['x']) ** 2) + return loss, {'loss': loss} + + +def _make_plan(iterations, noise_multiplier=1.0, sampling_prob=1.0): + config = execution_plan.BandMFExecutionPlanConfig.default( + num_bands=1, + iterations=iterations, + noise_multiplier=noise_multiplier, + sampling_prob=sampling_prob, + ) + return config.make() + + +class DPTrainerPygrainTest(parameterized.TestCase): + """End-to-end tests for DPTrainer.fit() with a PyGrain MapDataset.""" + + def _make_dataset(self, num_examples=6, dim=2): + examples = [ + {'x': np.random.default_rng(i).standard_normal(dim).astype(np.float32)} + for i in range(num_examples) + ] + return grain.MapDataset.source(examples) + + def test_basic_training_completes(self): + params = jnp.array([5.0, 5.0]) + ds = self._make_dataset(num_examples=4, dim=2) + plan = _make_plan(iterations=3) + optimizer = optax.sgd(0.01) + + trainer = training.DPTrainer( + plan=plan, loss_fn=_quadratic_loss, optimizer=optimizer + ) + state = trainer.fit(ds, params, rng=0) + + self.assertIsInstance(state, training.TrainingState) + self.assertEqual(int(state.step), 3) + + def test_params_change(self): + params = jnp.array([10.0, 10.0]) + ds = self._make_dataset(num_examples=4, dim=2) + plan = _make_plan(iterations=5, noise_multiplier=0.0) + optimizer = optax.sgd(0.1) + + trainer = training.DPTrainer( + plan=plan, loss_fn=_quadratic_loss, optimizer=optimizer + ) + state = trainer.fit(ds, params, rng=42) + + self.assertFalse(jnp.allclose(state.params, params)) + + def test_callback_invoked(self): + params = jnp.array([1.0, 1.0]) + ds = self._make_dataset(num_examples=4, dim=2) + iterations = 3 + plan = _make_plan(iterations=iterations) + optimizer = optax.sgd(0.01) + log = [] + + trainer = training.DPTrainer( + plan=plan, loss_fn=_quadratic_loss, optimizer=optimizer + ) + trainer.fit( + ds, + params, + callback=lambda step, state, aux: log.append(step), + rng=0, + ) + + self.assertEqual(log, [1, 2, 3]) + + def test_in_memory_path_still_works(self): + """Ensure the in-memory (non-pygrain) path is not broken.""" + params = jnp.array([5.0, 5.0]) + dataset = np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + plan = _make_plan(iterations=3) + optimizer = optax.sgd(0.01) + + def loss_fn(params, batch, prng): + del prng + loss = jnp.mean((params - batch) ** 2) + return loss, {} + + trainer = training.DPTrainer( + plan=plan, loss_fn=loss_fn, optimizer=optimizer + ) + state = trainer.fit(dataset, params, rng=0) + + self.assertEqual(int(state.step), 3) + + +if __name__ == '__main__': + absltest.main()