17
17
import shutil
18
18
import sys
19
19
import tempfile
20
- from copy import deepcopy
21
20
from dataclasses import dataclass
22
21
from datetime import datetime
23
22
from time import time
24
23
from typing import Any , Dict , List , Optional , Union
25
24
26
25
import numpy as np
27
26
import torch
28
- from torch .utils .data import IterableDataset
27
+ from torch .utils .data import IterableDataset , get_worker_info
29
28
30
29
from lightning .data .streaming import Cache
31
30
from lightning .data .streaming .constants import (
@@ -56,7 +55,7 @@ def __init__(
56
55
drop_last : bool = False ,
57
56
seed : int = 42 ,
58
57
serializers : Optional [Dict [str , Serializer ]] = None ,
59
- checkpoint_interval : int = 60 * 5 ,
58
+ checkpoint_interval : Optional [ int ] = None ,
60
59
) -> None :
61
60
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
62
61
@@ -93,15 +92,19 @@ def __init__(
93
92
self .worker_intervals : List [List [int ]] = []
94
93
self .current_indexes : List [int ] = []
95
94
self .chunk_index = 0
95
+ self .num_chunks : Optional [int ] = None
96
96
self .global_index = 0
97
97
self .index = 0
98
98
self .has_triggered_download = False
99
99
self .min_items_per_replica : Optional [int ] = None
100
- self .current_epoch = 0
100
+ self .current_epoch = 1
101
101
self .random_state = None
102
102
self .shuffler : Optional [Shuffle ] = None
103
103
self .serializers = serializers
104
- self .checkpoint_interval = checkpoint_interval
104
+ if sys .platform == "win32" :
105
+ if checkpoint_interval is not None :
106
+ raise ValueError ("The argument `checkpoint_interval` isn't suported on Windows." )
107
+ self .checkpoint_interval = checkpoint_interval or 60
105
108
self ._state_dict : Optional [Dict [str , Dict [str , Any ]]] = None
106
109
107
110
def _create_cache (self , worker_env : _WorkerEnv ) -> Cache :
@@ -170,14 +173,16 @@ def __iter__(self) -> "StreamingDataset":
170
173
self .worker_chunks .append (chunk_index )
171
174
self .worker_intervals .append (chunk_interval )
172
175
176
+ self .num_chunks = len (self .worker_chunks )
177
+
173
178
# Handle restart
174
179
if self ._state_dict :
175
180
state = self ._state_dict [str (self .cache .rank )]
176
181
177
182
# re-generate indexes
178
183
interval = self .worker_intervals [self .chunk_index ]
179
184
current_indexes = np .arange (interval [0 ], interval [1 ])
180
- current_indexes = self .shuffler (current_indexes , self .current_epoch , self .chunk_index )
185
+ current_indexes = self .shuffler (current_indexes , self .num_chunks , self . current_epoch , self .chunk_index )
181
186
self .current_indexes = current_indexes [state ["index" ] :]
182
187
183
188
# Bump the chunk_index
@@ -210,21 +215,22 @@ def __next__(self) -> Any:
210
215
211
216
# Lazily re-populate the interval to reduce memory usage.
212
217
if len (self .current_indexes ) == 0 :
213
- if self .chunk_index == len ( self .worker_intervals ) :
218
+ if self .chunk_index == self .num_chunks :
214
219
self .current_epoch += 1
215
220
raise StopIteration
216
221
217
222
# reset index
218
223
self .index = 0
219
224
220
225
# Checkpoint when reaching a new chunk
221
- self .checkpoint (self .chunk_index )
226
+ self ._checkpoint (self .chunk_index )
222
227
223
228
interval = self .worker_intervals [self .chunk_index ]
224
229
current_indexes = np .arange (interval [0 ], interval [1 ])
225
230
226
231
assert self .shuffler is not None
227
- self .current_indexes = self .shuffler (current_indexes , self .current_epoch , self .chunk_index )
232
+ assert self .num_chunks is not None
233
+ self .current_indexes = self .shuffler (current_indexes , self .num_chunks , self .current_epoch , self .chunk_index )
228
234
229
235
self .chunk_index += 1
230
236
@@ -238,7 +244,7 @@ def __next__(self) -> Any:
238
244
chunk_index = self .worker_chunks [self .chunk_index - 1 ],
239
245
# We provide the chunks indexes only one the first
240
246
chunk_indexes = None if self .has_triggered_download else self .worker_chunks ,
241
- last_index = (self .chunk_index - 1 ) == len (self .worker_intervals ) and len (self .current_indexes ) == 1 ,
247
+ is_last_index = (self .chunk_index - 1 ) == len (self .worker_intervals ) and len (self .current_indexes ) == 1 ,
242
248
)
243
249
)
244
250
@@ -247,14 +253,16 @@ def __next__(self) -> Any:
247
253
self .index += 1
248
254
249
255
# Checkpoint based on time
250
- if (self .last_time - time ()) > self .checkpoint_interval :
251
- self .checkpoint (self .chunk_index - 1 )
256
+ if self . checkpoint_interval and (self .last_time - time ()) > self .checkpoint_interval :
257
+ self ._checkpoint (self .chunk_index - 1 )
252
258
253
259
return data
254
260
255
- def checkpoint (self , chunk_index : int ) -> None :
256
- # Checkpointing isn't supported for windows
257
- if sys .platform == "win32" :
261
+ def _checkpoint (self , chunk_index : int ) -> None :
262
+ if self .checkpoint_interval is None :
263
+ return
264
+
265
+ if not _is_in_dataloader_worker ():
258
266
return
259
267
260
268
assert self .cache
@@ -284,55 +292,29 @@ def checkpoint(self, chunk_index: int) -> None:
284
292
f ,
285
293
)
286
294
287
- # 3. Move the file to avoid corrupted read from the main thread.
288
- now = datetime .now ().strftime (_TIME_FORMAT )
289
- checkpoint_path = os .path .join (self .cache .checkpoint_rank_dir , f"checkpoint-{ now } .json" )
290
-
291
295
# 4. Move the file to its target position
292
- shutil .move (tmp_checkpoint_path , checkpoint_path )
296
+ shutil .move (tmp_checkpoint_path , os . path . join ( self . cache . checkpoint_rank_dir , "checkpoint.json" ) )
293
297
294
298
self .last_time = time ()
295
299
296
300
def state_dict (self ) -> Dict [str , Any ]:
301
+ if _is_in_dataloader_worker ():
302
+ raise RuntimeError ("The method `state_dict` should only be called in the main process." )
303
+
297
304
if self .cache is None :
298
305
self .worker_env = _WorkerEnv .detect ()
299
306
self .cache = self ._create_cache (worker_env = self .worker_env )
300
307
301
308
state_dict : Dict [str , Any ] = {}
302
- worker_env = _WorkerEnv .detect ()
303
- if worker_env .world_size == 1 :
304
- # 1. Check whether the checkpoint_dir exists
305
- if not os .path .exists (self .cache .checkpoint_dir ):
306
- return state_dict
307
-
308
- # 2. Iterate through the workers and read the latest checkpoint
309
- for worker_idx in os .listdir (self .cache .checkpoint_dir ):
310
- checkpoints = os .listdir (os .path .join (self .cache .checkpoint_dir , str (worker_idx )))
311
- checkpoints = sorted (checkpoints , key = _string_to_datetime )
312
-
313
- # Load the latest checkpoint for this worker
314
- checkpoint_path = os .path .join (self .cache .checkpoint_dir , str (worker_idx ), checkpoints [- 1 ])
315
- with open (checkpoint_path ) as f :
316
- state_dict [worker_idx ] = json .load (f )
317
-
318
- _state_dict = deepcopy (state_dict )
319
-
320
- if self .distributed_env .world_size > 1 :
321
- # TODO: Move this to fabric.
322
- num_devices = torch .cuda .device_count () or 1
323
- node_ranks = []
324
- for index in range (self .distributed_env .world_size ):
325
- node_rank = index // num_devices
326
- if node_rank in node_ranks :
327
- continue
328
- state = {}
329
- obj = [_state_dict ]
330
- torch .distributed .broadcast_object_list (obj , index , group = _group .WORLD )
331
- state = obj [0 ]
332
- state_dict .update (** state )
333
- node_ranks .append (node_rank )
334
- else :
335
- raise NotImplementedError ("The `state_dict` should be called on the main thread." )
309
+
310
+ # 1. Check whether the checkpoint_dir exists
311
+ if not os .path .exists (self .cache .checkpoint_dir ):
312
+ return state_dict
313
+
314
+ state_dict = _load_state_dict_from_checkpoint_dir (self .cache .checkpoint_dir )
315
+
316
+ if self .distributed_env .world_size > 1 :
317
+ return _collect_distributed_state_dict (state_dict , self .distributed_env .world_size )
336
318
return state_dict
337
319
338
320
def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
@@ -413,6 +395,42 @@ def _string_to_datetime(item: str) -> datetime:
413
395
return datetime .strptime (item .split ("checkpoint-" )[1 ].split (".json" )[0 ], _TIME_FORMAT )
414
396
415
397
398
+ def _load_state_dict_from_checkpoint_dir (checkpoint_dir : str ) -> Dict [str , Any ]:
399
+ state_dict : Dict [str , Any ] = {}
400
+ if not os .path .exists (checkpoint_dir ):
401
+ return state_dict
402
+ for worker_idx in os .listdir (checkpoint_dir ):
403
+ checkpoint_filepath = os .path .join (checkpoint_dir , str (worker_idx ), "checkpoint.json" )
404
+ if not os .path .exists (checkpoint_filepath ):
405
+ state_dict [worker_idx ] = {}
406
+ else :
407
+ with open (checkpoint_filepath ) as f :
408
+ state_dict [worker_idx ] = json .load (f )
409
+ return state_dict
410
+
411
+
412
+ def _collect_distributed_state_dict (state_dict : Dict [str , Any ], world_size : int ) -> Dict [str , Any ]:
413
+ state_dict_out : Dict [str , Any ] = {}
414
+ # TODO: Move this to fabric to support all accelerators
415
+ num_devices = torch .cuda .device_count () or 1
416
+ node_ranks = []
417
+ for index in range (world_size ):
418
+ node_rank = index // num_devices
419
+ if node_rank in node_ranks :
420
+ continue
421
+ state = {}
422
+ obj = [state_dict ]
423
+ torch .distributed .broadcast_object_list (obj , index , group = _group .WORLD )
424
+ state = obj [0 ]
425
+ state_dict_out .update (** state )
426
+ node_ranks .append (node_rank )
427
+ return state_dict_out
428
+
429
+
430
+ def _is_in_dataloader_worker () -> bool :
431
+ return get_worker_info () is not None
432
+
433
+
416
434
@dataclass
417
435
class RemoteDir :
418
436
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""
0 commit comments