Add PyGrain MapDataset integration to jax_privacy experimental training#279
Open
copybara-service[bot] wants to merge 1 commit into
Open
Add PyGrain MapDataset integration to jax_privacy experimental training#279copybara-service[bot] wants to merge 1 commit into
copybara-service[bot] wants to merge 1 commit into
Conversation
1e8c4a4 to
e739f72
Compare
Integrates PyGrain as an optional dependency for jax_privacy's experimental training API, enabling DP training on datasets backed by PyGrain MapDataset (e.g., ArrayRecord on disk). Key changes: - New `experimental/_data_loader.py`: Private module that provides a batch iterator for PyGrain `MapDataset` using `BatchSelectionStrategy`. Auto-detects whether to preload: estimates in-memory size from the first element's PyTree structure × dataset length, and preloads if under 1 GiB. Both preload and streaming modes use a `ThreadPoolExecutor` for parallel element loading. - Updated `experimental/training.py`: `DPTrainer.fit()` now detects PyGrain `MapDataset` inputs (by class name string, not import) and routes to the new loader. The module is never imported unless needed. - New `experimental/_data_loader_test.py`: Standalone and end-to-end tests for the data loader integration. PiperOrigin-RevId: 934664099
e739f72 to
291cb23
Compare
|
This PR has been idle for 7 days. Please provide an update or review. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add PyGrain MapDataset integration to jax_privacy experimental training
Integrates PyGrain as an optional dependency for jax_privacy's
experimental training API, enabling DP training on datasets backed
by PyGrain MapDataset (e.g., ArrayRecord on disk).
Key changes:
experimental/_data_loader.py: Private module that provides abatch iterator for PyGrain
MapDatasetusingBatchSelectionStrategy.Auto-detects whether to preload: estimates in-memory size from the
first element's PyTree structure × dataset length, and preloads if
under 1 GiB. Both preload and streaming modes use a
ThreadPoolExecutorfor parallel element loading.experimental/training.py:DPTrainer.fit()now detectsPyGrain
MapDatasetinputs (by class name string, not import) androutes to the new loader. The module is never imported unless needed.
experimental/_data_loader_test.py: Standalone and end-to-endtests for the data loader integration.