From 1b735c66dd7cc69a556c200755da32a480e7a75e Mon Sep 17 00:00:00 2001 From: Dariush Wahdany Date: Mon, 11 Sep 2023 15:29:17 +0200 Subject: [PATCH 1/3] feat: update lightning example to lightning 2.0 --- examples/mnist_lightning.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/mnist_lightning.py b/examples/mnist_lightning.py index 42d93a3d..2155feb0 100644 --- a/examples/mnist_lightning.py +++ b/examples/mnist_lightning.py @@ -99,11 +99,8 @@ def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=0) if self.enable_dp: - data_loader = ( - # soon there will be a fancy way to access train dataloader, - # see https://github.com/PyTorchLightning/pytorch-lightning/issues/10430 - self.trainer._data_connector._train_dataloader_source.dataloader() - ) + self.trainer.fit_loop.setup_data() + dataloader = self.trainer.train_dataloader # transform (model, optimizer, dataloader) to DP-versions if hasattr(self, "dp"): From c314d42ad6eddeffe6b2a5ecbcd1a82c927295ab Mon Sep 17 00:00:00 2001 From: Dariush Wahdany <86673488+lsc64@users.noreply.github.com> Date: Thu, 5 Oct 2023 23:59:31 +0200 Subject: [PATCH 2/3] Update examples/mnist_lightning.py Co-authored-by: Karthik Prasad --- examples/mnist_lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mnist_lightning.py b/examples/mnist_lightning.py index 2155feb0..9313dbcd 100644 --- a/examples/mnist_lightning.py +++ b/examples/mnist_lightning.py @@ -100,7 +100,7 @@ def configure_optimizers(self): if self.enable_dp: self.trainer.fit_loop.setup_data() - dataloader = self.trainer.train_dataloader + data_loader = self.trainer.train_dataloader # transform (model, optimizer, dataloader) to DP-versions if hasattr(self, "dp"): From c35262d778545e468e8526b4acd3e9a91690f954 Mon Sep 17 00:00:00 2001 From: Dariush Wahdany Date: Fri, 22 Mar 2024 18:26:08 +0100 Subject: [PATCH 3/3] Fix BatchMemoryManager length --- opacus/utils/batch_memory_manager.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/opacus/utils/batch_memory_manager.py b/opacus/utils/batch_memory_manager.py index c5d6dcc0..a2e2de62 100644 --- a/opacus/utils/batch_memory_manager.py +++ b/opacus/utils/batch_memory_manager.py @@ -16,12 +16,13 @@ from typing import List import numpy as np +from torch.utils.data import BatchSampler, DataLoader, Sampler + from opacus.optimizers import DPOptimizer from opacus.utils.uniform_sampler import ( DistributedUniformWithReplacementSampler, UniformWithReplacementSampler, ) -from torch.utils.data import BatchSampler, DataLoader, Sampler class BatchSplittingSampler(Sampler[List[int]]): @@ -71,13 +72,17 @@ def __iter__(self): def __len__(self): if isinstance(self.sampler, BatchSampler): return int( - len(self.sampler) * (self.sampler.batch_size / self.max_batch_size) + np.ceil( + len(self.sampler) * (self.sampler.batch_size / self.max_batch_size) + ) ) elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance( self.sampler, DistributedUniformWithReplacementSampler ): expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples - return int(len(self.sampler) * (expected_batch_size / self.max_batch_size)) + return int( + np.ceil(len(self.sampler) * (expected_batch_size / self.max_batch_size)) + ) return len(self.sampler)