diff --git a/test/stateful_dataloader/test_sampler.py b/test/stateful_dataloader/test_sampler.py index 7b172c8c0..2a215e1bb 100644 --- a/test/stateful_dataloader/test_sampler.py +++ b/test/stateful_dataloader/test_sampler.py @@ -54,7 +54,7 @@ def test_initialization_StatefulDistributedSampler(self): seed=42, drop_last=False, ) - self.assertEqual(sampler.dataset, self.dataset) + self.assertEqual(sampler.dataset_size, len(self.dataset)) self.assertEqual(sampler.num_replicas, 10) self.assertEqual(sampler.rank, 0) self.assertFalse(sampler.shuffle) @@ -232,6 +232,76 @@ def test_seed_replicability(self): self.assertEqual(results1, results2, "Data should be replicable with same seed") self.assertNotEqual(results1, results3, "Data should not be replicable with different seed") + def test_StatefulDistributedSampler_initialization_with_dataset_size(self): + sampler = StatefulDistributedSampler(dataset_size=100, num_replicas=2, rank=0, shuffle=False) + self.assertEqual(sampler.dataset_size, 100) + indices = list(iter(sampler)) + expected_indices = list(range(0, 100, 2)) + self.assertEqual(indices, expected_indices) + + def test_StatefulDistributedSampler_mismatched_dataset_and_dataset_size(self): + dataset = MockDataset(100) + with self.assertRaises(ValueError): + StatefulDistributedSampler(dataset=dataset, dataset_size=50) + + def test_StatefulDistributedSampler_no_dataset_or_dataset_size(self): + with self.assertRaises(ValueError): + StatefulDistributedSampler() + + def test_StatefulDistributedSampler_drop_last_with_dataset_size(self): + dataset_size = 100 + num_replicas = 3 + sampler = StatefulDistributedSampler( + dataset_size=dataset_size, + num_replicas=num_replicas, + rank=0, + drop_last=True, + shuffle=False, + ) + self.assertEqual(sampler.num_samples, 33) + indices = list(iter(sampler)) + self.assertEqual(len(indices), 33) + expected_indices = list(range(0, 99, num_replicas)) + self.assertEqual(indices, expected_indices) + + def test_StatefulDistributedSampler_dataloader_state_dict_with_dataset_size(self): + dataset_size = 100 + sampler = StatefulDistributedSampler(dataset_size=dataset_size, num_replicas=1, rank=0, shuffle=False) + dataset = MockDataset(dataset_size) + dataloader = StatefulDataLoader(dataset, batch_size=10, sampler=sampler) + iter_count = 5 + for i, _ in enumerate(dataloader): + if i == iter_count - 1: + break + state_dict = dataloader.state_dict() + new_sampler = StatefulDistributedSampler(dataset_size=dataset_size, num_replicas=1, rank=0, shuffle=False) + new_dataloader = StatefulDataLoader(MockDataset(dataset_size), batch_size=10, sampler=new_sampler) + new_dataloader.load_state_dict(state_dict) + resumed_data = [] + for data in new_dataloader: + resumed_data.append(data.tolist()) + expected_data = [] + full_dataloader = StatefulDataLoader(MockDataset(dataset_size), batch_size=10, sampler=sampler) + for data in full_dataloader: + expected_data.append(data.tolist()) + self.assertEqual(resumed_data, expected_data[iter_count:]) + + def test_StatefulDistributedSampler_dataset_size_zero(self): + sampler = StatefulDistributedSampler(dataset_size=0, num_replicas=1, rank=0) + self.assertEqual(len(sampler), 0) + indices = list(iter(sampler)) + self.assertEqual(len(indices), 0) + + def test_StatefulDistributedSampler_shuffle_with_dataset_size(self): + dataset_size = 100 + sampler = StatefulDistributedSampler(dataset_size=dataset_size, num_replicas=1, rank=0, shuffle=True, seed=42) + indices = list(iter(sampler)) + self.assertEqual(len(indices), dataset_size) + self.assertEqual(sorted(indices), list(range(dataset_size))) + sampler.set_epoch(1) + indices_epoch_1 = list(iter(sampler)) + self.assertNotEqual(indices, indices_epoch_1) + if __name__ == "__main__": run_tests() diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index cacb1d12c..be2bfe388 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -5,8 +5,11 @@ # LICENSE file in the root directory of this source tree. import itertools +import math from typing import Any, Dict, Iterator, List, Optional, Sized +import torch.distributed as dist + import torch.utils.data.sampler from torch.utils.data import Dataset from torch.utils.data.dataloader import _InfiniteConstantSampler @@ -179,19 +182,66 @@ def __iter__(self): ) -class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler): +class StatefulDistributedSampler(Sampler[int]): _YIELDED = "yielded" def __init__( self, - dataset: Dataset, + dataset: Optional[Dataset] = None, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, + dataset_size: Optional[int] = None, ) -> None: - super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + + # Validate inputs + if dataset is None and dataset_size is None: + raise ValueError("Either dataset or dataset_size must be provided.") + + if dataset_size is not None: + if dataset is not None and (hasattr(dataset, "__len__") and dataset_size != len(dataset)): + raise ValueError( + f"dataset_size must match the length of the dataset. {dataset_size=} and {len(dataset)=}" + ) + self.dataset_size = dataset_size + else: + if dataset is not None and hasattr(dataset, "__len__"): + self.dataset_size = len(dataset) + else: + raise ValueError("Either a dataset with the __len__ method or dataset_size must be provided.") + + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") + + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and self.dataset_size % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (self.dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(self.dataset_size / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + self.yielded = 0 self.next_yielded = None @@ -200,11 +250,52 @@ def __iter__(self): if self.next_yielded is not None: self.yielded = self.next_yielded self.next_yielded = None - it = super().__iter__() + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(self.dataset_size, generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(self.dataset_size)) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + it = iter(indices) + for idx in itertools.islice(it, self.yielded, None): self.yielded += 1 yield idx + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + def state_dict(self) -> Dict[str, Any]: return {self._YIELDED: self.yielded}