Skip to content

Commit 16b5594

Browse files
authored
Add support for exact iteration (#139)
1 parent 26bf6b2 commit 16b5594

File tree

8 files changed

+60
-20
lines changed

8 files changed

+60
-20
lines changed

src/litdata/streaming/combined.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,27 @@ def __init__(
8080
self._use_streaming_dataloader = False
8181
self._num_samples_yielded: Optional[List[int]] = None
8282
self._current_epoch = 0
83+
self.num_workers = 1
84+
self.batch_size = 1
8385

84-
def __len__(self) -> Optional[int]:
86+
def get_len(self, num_workers: int, batch_size: int) -> Optional[int]:
87+
self.num_workers = num_workers
88+
self.batch_size = batch_size
8589
if self._iterate_over_all:
8690
return self._get_total_length()
8791
return None
8892

93+
def __len__(self) -> Optional[int]:
94+
return self.get_len(1, 1)
95+
8996
# total length of the datasets
9097
def _get_total_length(self) -> int:
91-
return sum(len(d) for d in self._datasets)
98+
return sum(self._get_len(d) for d in self._datasets)
99+
100+
def _get_len(self, d: Any) -> int:
101+
if isinstance(d, StreamingDataset):
102+
return d.get_len(self.num_workers, self.batch_size)
103+
return len(d)
92104

93105
def set_epoch(self, current_epoch: int) -> None:
94106
"""Set the current epoch to the datasets on epoch starts.

src/litdata/streaming/dataloader.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,16 @@ def __iter__(self) -> Any:
615615

616616
self.restore = False
617617

618+
def __len__(self) -> int:
619+
if self._dataset_kind == _DatasetKind.Iterable:
620+
length = self._IterableDataset_len_called = self.dataset.get_len(self.num_workers, self.batch_size)
621+
if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler
622+
from math import ceil
623+
624+
return length // self.batch_size if self.drop_last else ceil(length / self.batch_size)
625+
return length
626+
return len(self._index_sampler)
627+
618628
def state_dict(self) -> Dict[str, Any]:
619629
if isinstance(self.dataset, StreamingDataset):
620630
assert self.batch_size

src/litdata/streaming/dataset.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def __init__(
107107
self.shuffler: Optional[Shuffle] = None
108108
self.serializers = serializers
109109
self._state_dict: Optional[Dict[str, Any]] = None
110+
self.num_workers: Optional[int] = None
111+
self.batch_size: Optional[int] = None
110112

111113
def set_shuffle(self, shuffle: bool) -> None:
112114
self.shuffle = shuffle
@@ -157,10 +159,16 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
157159
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)
158160

159161
def __len__(self) -> int:
162+
return self.get_len(1, 1)
163+
164+
def get_len(self, num_workers: int, batch_size: int) -> int:
165+
self.num_workers = num_workers
166+
self.batch_size = batch_size
167+
worker_env = _WorkerEnv.detect()
160168
if self.shuffler is None:
161-
cache = self._create_cache(worker_env=_WorkerEnv.detect())
169+
cache = self._create_cache(worker_env=worker_env)
162170
self.shuffler = self._create_shuffler(cache)
163-
return self.shuffler.get_len(self.distributed_env, self.current_epoch)
171+
return self.shuffler.get_len(self.distributed_env, num_workers, batch_size, self.current_epoch)
164172

165173
def __iter__(self) -> "StreamingDataset":
166174
# When the StreamingDataset is used within map or optimize, let's refetch the distributed env.
@@ -178,7 +186,7 @@ def __iter__(self) -> "StreamingDataset":
178186
self.current_epoch = state["current_epoch"]
179187

180188
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
181-
self.distributed_env, self.current_epoch
189+
self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch
182190
)
183191
chunks_replica = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
184192
intervals_replica = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
@@ -187,10 +195,6 @@ def __iter__(self) -> "StreamingDataset":
187195
if self._state_dict:
188196
self._resume(chunks_replica, intervals_replica)
189197
else:
190-
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
191-
self.distributed_env, self.current_epoch
192-
)
193-
194198
# Find the chunks shared across multiple ranks.
195199
# For each shared chunk, find the rank to use the chunk last and prevent deletion
196200
# for the other ranks.

src/litdata/streaming/shuffle.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool):
3131
self.drop_last = drop_last
3232

3333
@lru_cache(maxsize=10)
34-
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
35-
_, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks(distributed_env, current_epoch)
34+
def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int:
35+
_, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks(
36+
distributed_env, num_workers, batch_size, current_epoch
37+
)
3638

3739
if self.drop_last:
3840
items_per_process = [
@@ -46,7 +48,9 @@ def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
4648
return sum((interval[-1] - interval[0]) for interval in intervals_per_ranks[distributed_env.global_rank])
4749

4850
@abstractmethod
49-
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
51+
def get_chunks_and_intervals_per_ranks(
52+
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
53+
) -> Any:
5054
pass
5155

5256
@abstractmethod
@@ -59,14 +63,16 @@ class NoShuffle(Shuffle):
5963
is True."""
6064

6165
@lru_cache(maxsize=10)
62-
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
66+
def get_chunks_and_intervals_per_ranks(
67+
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
68+
) -> Any:
6369
# 1. Get the intervals
6470
chunk_intervals = self.cache.get_chunk_intervals()
6571
indexes = range(len(chunk_intervals))
6672

6773
# 2. Compute the items budget of each rank
6874
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
69-
distributed_env, indexes, chunk_intervals, self.drop_last
75+
distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size
7076
)
7177

7278
return chunks_per_ranks, intervals_per_ranks
@@ -94,7 +100,9 @@ class FullShuffle(Shuffle):
94100
"""
95101

96102
@lru_cache(maxsize=10)
97-
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
103+
def get_chunks_and_intervals_per_ranks(
104+
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
105+
) -> Any:
98106
# 1. Get the intervals
99107
chunk_intervals = self.cache.get_chunk_intervals()
100108

@@ -113,7 +121,7 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
113121

114122
# 3. Compute the items budget of each rank
115123
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
116-
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last
124+
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size
117125
)
118126

119127
# For the first epoch, no need of further shuffling
@@ -126,7 +134,7 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
126134
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()
127135

128136
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
129-
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last
137+
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size
130138
)
131139

132140
return chunks_per_ranks, intervals_per_ranks

src/litdata/utilities/shuffle.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def _associate_chunks_and_internals_to_ranks(
4646
indexes: Any,
4747
chunk_intervals: Any,
4848
drop_last: bool,
49+
num_workers: int = 1,
50+
batch_size: int = 1,
4951
) -> Tuple[List[List[int]], List[Any]]:
5052
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
5153
num_items_per_ranks: List[int] = [
@@ -54,6 +56,10 @@ def _associate_chunks_and_internals_to_ranks(
5456
else num_items // distributed_env.world_size
5557
for rank in range(distributed_env.world_size)
5658
]
59+
if drop_last:
60+
ratio = num_workers * batch_size
61+
num_items_per_ranks = [ratio * int(item // ratio) for item in num_items_per_ranks]
62+
5763
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
5864
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
5965

tests/processing/test_data_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,7 @@ def fetch_from_dataset(batch, output_dir):
12081208
f.write("Hello World!")
12091209

12101210

1211-
@pytest.mark.skipif(sys.platform == "win32", reason="skip windows")
1211+
@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="skip windows")
12121212
def test_streaming_dataset_in_map(tmpdir):
12131213
seed_everything(42)
12141214

tests/streaming/test_combined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def test_combined_dataset_with_dataloader_and_one_worker(batch_size):
262262
}
263263

264264

265-
@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow in CI")
265+
@pytest.mark.skipif(sys.platform == "win32", reason="too slow in CI")
266266
def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
267267
data_dir_1 = os.path.join(tmpdir, "data_1")
268268
data_dir_2 = os.path.join(tmpdir, "data_2")

tests/streaming/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression
171171
assert len(process_2_2) == 50 + int(not drop_last)
172172

173173
_, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks(
174-
dataset.distributed_env, dataset.current_epoch
174+
dataset.distributed_env, 1, 1, dataset.current_epoch
175175
)
176176

177177
assert process_1_1 == process_1_2

0 commit comments

Comments
 (0)