1717import shutil
1818import sys
1919import tempfile
20- from copy import deepcopy
2120from dataclasses import dataclass
2221from datetime import datetime
2322from time import time
2423from typing import Any , Dict , List , Optional , Union
2524
2625import numpy as np
2726import torch
28- from torch .utils .data import IterableDataset
27+ from torch .utils .data import IterableDataset , get_worker_info
2928
3029from lightning .data .streaming import Cache
3130from lightning .data .streaming .constants import (
@@ -56,7 +55,7 @@ def __init__(
5655 drop_last : bool = False ,
5756 seed : int = 42 ,
5857 serializers : Optional [Dict [str , Serializer ]] = None ,
59- checkpoint_interval : int = 60 * 5 ,
58+ checkpoint_interval : Optional [ int ] = None ,
6059 ) -> None :
6160 """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
6261
@@ -93,15 +92,19 @@ def __init__(
9392 self .worker_intervals : List [List [int ]] = []
9493 self .current_indexes : List [int ] = []
9594 self .chunk_index = 0
95+ self .num_chunks : Optional [int ] = None
9696 self .global_index = 0
9797 self .index = 0
9898 self .has_triggered_download = False
9999 self .min_items_per_replica : Optional [int ] = None
100- self .current_epoch = 0
100+ self .current_epoch = 1
101101 self .random_state = None
102102 self .shuffler : Optional [Shuffle ] = None
103103 self .serializers = serializers
104- self .checkpoint_interval = checkpoint_interval
104+ if sys .platform == "win32" :
105+ if checkpoint_interval is not None :
106+ raise ValueError ("The argument `checkpoint_interval` isn't suported on Windows." )
107+ self .checkpoint_interval = checkpoint_interval or 60
105108 self ._state_dict : Optional [Dict [str , Dict [str , Any ]]] = None
106109
107110 def _create_cache (self , worker_env : _WorkerEnv ) -> Cache :
@@ -170,14 +173,16 @@ def __iter__(self) -> "StreamingDataset":
170173 self .worker_chunks .append (chunk_index )
171174 self .worker_intervals .append (chunk_interval )
172175
176+ self .num_chunks = len (self .worker_chunks )
177+
173178 # Handle restart
174179 if self ._state_dict :
175180 state = self ._state_dict [str (self .cache .rank )]
176181
177182 # re-generate indexes
178183 interval = self .worker_intervals [self .chunk_index ]
179184 current_indexes = np .arange (interval [0 ], interval [1 ])
180- current_indexes = self .shuffler (current_indexes , self .current_epoch , self .chunk_index )
185+ current_indexes = self .shuffler (current_indexes , self .num_chunks , self . current_epoch , self .chunk_index )
181186 self .current_indexes = current_indexes [state ["index" ] :]
182187
183188 # Bump the chunk_index
@@ -210,21 +215,22 @@ def __next__(self) -> Any:
210215
211216 # Lazily re-populate the interval to reduce memory usage.
212217 if len (self .current_indexes ) == 0 :
213- if self .chunk_index == len ( self .worker_intervals ) :
218+ if self .chunk_index == self .num_chunks :
214219 self .current_epoch += 1
215220 raise StopIteration
216221
217222 # reset index
218223 self .index = 0
219224
220225 # Checkpoint when reaching a new chunk
221- self .checkpoint (self .chunk_index )
226+ self ._checkpoint (self .chunk_index )
222227
223228 interval = self .worker_intervals [self .chunk_index ]
224229 current_indexes = np .arange (interval [0 ], interval [1 ])
225230
226231 assert self .shuffler is not None
227- self .current_indexes = self .shuffler (current_indexes , self .current_epoch , self .chunk_index )
232+ assert self .num_chunks is not None
233+ self .current_indexes = self .shuffler (current_indexes , self .num_chunks , self .current_epoch , self .chunk_index )
228234
229235 self .chunk_index += 1
230236
@@ -238,7 +244,7 @@ def __next__(self) -> Any:
238244 chunk_index = self .worker_chunks [self .chunk_index - 1 ],
239245 # We provide the chunks indexes only one the first
240246 chunk_indexes = None if self .has_triggered_download else self .worker_chunks ,
241- last_index = (self .chunk_index - 1 ) == len (self .worker_intervals ) and len (self .current_indexes ) == 1 ,
247+ is_last_index = (self .chunk_index - 1 ) == len (self .worker_intervals ) and len (self .current_indexes ) == 1 ,
242248 )
243249 )
244250
@@ -247,14 +253,16 @@ def __next__(self) -> Any:
247253 self .index += 1
248254
249255 # Checkpoint based on time
250- if (self .last_time - time ()) > self .checkpoint_interval :
251- self .checkpoint (self .chunk_index - 1 )
256+ if self . checkpoint_interval and (self .last_time - time ()) > self .checkpoint_interval :
257+ self ._checkpoint (self .chunk_index - 1 )
252258
253259 return data
254260
255- def checkpoint (self , chunk_index : int ) -> None :
256- # Checkpointing isn't supported for windows
257- if sys .platform == "win32" :
261+ def _checkpoint (self , chunk_index : int ) -> None :
262+ if self .checkpoint_interval is None :
263+ return
264+
265+ if not _is_in_dataloader_worker ():
258266 return
259267
260268 assert self .cache
@@ -284,55 +292,29 @@ def checkpoint(self, chunk_index: int) -> None:
284292 f ,
285293 )
286294
287- # 3. Move the file to avoid corrupted read from the main thread.
288- now = datetime .now ().strftime (_TIME_FORMAT )
289- checkpoint_path = os .path .join (self .cache .checkpoint_rank_dir , f"checkpoint-{ now } .json" )
290-
291295 # 4. Move the file to its target position
292- shutil .move (tmp_checkpoint_path , checkpoint_path )
296+ shutil .move (tmp_checkpoint_path , os . path . join ( self . cache . checkpoint_rank_dir , "checkpoint.json" ) )
293297
294298 self .last_time = time ()
295299
296300 def state_dict (self ) -> Dict [str , Any ]:
301+ if _is_in_dataloader_worker ():
302+ raise RuntimeError ("The method `state_dict` should only be called in the main process." )
303+
297304 if self .cache is None :
298305 self .worker_env = _WorkerEnv .detect ()
299306 self .cache = self ._create_cache (worker_env = self .worker_env )
300307
301308 state_dict : Dict [str , Any ] = {}
302- worker_env = _WorkerEnv .detect ()
303- if worker_env .world_size == 1 :
304- # 1. Check whether the checkpoint_dir exists
305- if not os .path .exists (self .cache .checkpoint_dir ):
306- return state_dict
307-
308- # 2. Iterate through the workers and read the latest checkpoint
309- for worker_idx in os .listdir (self .cache .checkpoint_dir ):
310- checkpoints = os .listdir (os .path .join (self .cache .checkpoint_dir , str (worker_idx )))
311- checkpoints = sorted (checkpoints , key = _string_to_datetime )
312-
313- # Load the latest checkpoint for this worker
314- checkpoint_path = os .path .join (self .cache .checkpoint_dir , str (worker_idx ), checkpoints [- 1 ])
315- with open (checkpoint_path ) as f :
316- state_dict [worker_idx ] = json .load (f )
317-
318- _state_dict = deepcopy (state_dict )
319-
320- if self .distributed_env .world_size > 1 :
321- # TODO: Move this to fabric.
322- num_devices = torch .cuda .device_count () or 1
323- node_ranks = []
324- for index in range (self .distributed_env .world_size ):
325- node_rank = index // num_devices
326- if node_rank in node_ranks :
327- continue
328- state = {}
329- obj = [_state_dict ]
330- torch .distributed .broadcast_object_list (obj , index , group = _group .WORLD )
331- state = obj [0 ]
332- state_dict .update (** state )
333- node_ranks .append (node_rank )
334- else :
335- raise NotImplementedError ("The `state_dict` should be called on the main thread." )
309+
310+ # 1. Check whether the checkpoint_dir exists
311+ if not os .path .exists (self .cache .checkpoint_dir ):
312+ return state_dict
313+
314+ state_dict = _load_state_dict_from_checkpoint_dir (self .cache .checkpoint_dir )
315+
316+ if self .distributed_env .world_size > 1 :
317+ return _collect_distributed_state_dict (state_dict , self .distributed_env .world_size )
336318 return state_dict
337319
338320 def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
@@ -413,6 +395,42 @@ def _string_to_datetime(item: str) -> datetime:
413395 return datetime .strptime (item .split ("checkpoint-" )[1 ].split (".json" )[0 ], _TIME_FORMAT )
414396
415397
398+ def _load_state_dict_from_checkpoint_dir (checkpoint_dir : str ) -> Dict [str , Any ]:
399+ state_dict : Dict [str , Any ] = {}
400+ if not os .path .exists (checkpoint_dir ):
401+ return state_dict
402+ for worker_idx in os .listdir (checkpoint_dir ):
403+ checkpoint_filepath = os .path .join (checkpoint_dir , str (worker_idx ), "checkpoint.json" )
404+ if not os .path .exists (checkpoint_filepath ):
405+ state_dict [worker_idx ] = {}
406+ else :
407+ with open (checkpoint_filepath ) as f :
408+ state_dict [worker_idx ] = json .load (f )
409+ return state_dict
410+
411+
412+ def _collect_distributed_state_dict (state_dict : Dict [str , Any ], world_size : int ) -> Dict [str , Any ]:
413+ state_dict_out : Dict [str , Any ] = {}
414+ # TODO: Move this to fabric to support all accelerators
415+ num_devices = torch .cuda .device_count () or 1
416+ node_ranks = []
417+ for index in range (world_size ):
418+ node_rank = index // num_devices
419+ if node_rank in node_ranks :
420+ continue
421+ state = {}
422+ obj = [state_dict ]
423+ torch .distributed .broadcast_object_list (obj , index , group = _group .WORLD )
424+ state = obj [0 ]
425+ state_dict_out .update (** state )
426+ node_ranks .append (node_rank )
427+ return state_dict_out
428+
429+
430+ def _is_in_dataloader_worker () -> bool :
431+ return get_worker_info () is not None
432+
433+
416434@dataclass
417435class RemoteDir :
418436 """Holds a remote URL to a directory and a cache directory where the data will be downloaded."""
0 commit comments