Skip to content

Add PyGrain MapDataset integration to jax_privacy experimental training#279

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_934664099
Open

Add PyGrain MapDataset integration to jax_privacy experimental training#279
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_934664099

Conversation

@copybara-service

@copybara-service copybara-service Bot commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

Add PyGrain MapDataset integration to jax_privacy experimental training

Integrates PyGrain as an optional dependency for jax_privacy's
experimental training API, enabling DP training on datasets backed
by PyGrain MapDataset (e.g., ArrayRecord on disk).

Key changes:

  • New experimental/_data_loader.py: Private module that provides a
    batch iterator for PyGrain MapDataset using BatchSelectionStrategy.
    Auto-detects whether to preload: estimates in-memory size from the
    first element's PyTree structure × dataset length, and preloads if
    under 1 GiB. Both preload and streaming modes use a
    ThreadPoolExecutor for parallel element loading.
  • Updated experimental/training.py: DPTrainer.fit() now detects
    PyGrain MapDataset inputs (by class name string, not import) and
    routes to the new loader. The module is never imported unless needed.
  • New experimental/_data_loader_test.py: Standalone and end-to-end
    tests for the data loader integration.

@copybara-service copybara-service Bot force-pushed the test_934664099 branch 14 times, most recently from 1e8c4a4 to e739f72 Compare June 19, 2026 14:30
Integrates PyGrain as an optional dependency for jax_privacy's
experimental training API, enabling DP training on datasets backed
by PyGrain MapDataset (e.g., ArrayRecord on disk).

Key changes:
- New `experimental/_data_loader.py`: Private module that provides a
  batch iterator for PyGrain `MapDataset` using `BatchSelectionStrategy`.
  Auto-detects whether to preload: estimates in-memory size from the
  first element's PyTree structure × dataset length, and preloads if
  under 1 GiB.  Both preload and streaming modes use a
  `ThreadPoolExecutor` for parallel element loading.
- Updated `experimental/training.py`: `DPTrainer.fit()` now detects
  PyGrain `MapDataset` inputs (by class name string, not import) and
  routes to the new loader. The module is never imported unless needed.
- New `experimental/_data_loader_test.py`: Standalone and end-to-end
  tests for the data loader integration.

PiperOrigin-RevId: 934664099
@github-actions

Copy link
Copy Markdown

This PR has been idle for 7 days. Please provide an update or review.

@github-actions github-actions Bot added the Stale label Jun 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants