Skip to content

Commit a6da1e3

Browse files
tchatonthomas
andauthored
Add fault tolerance Streaming Dataset 2/n (#19052)
Co-authored-by: thomas <[email protected]>
1 parent bf54a1d commit a6da1e3

File tree

8 files changed

+111
-84
lines changed

8 files changed

+111
-84
lines changed

src/lightning/data/streaming/cache.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from lightning.data.streaming.sampler import ChunkedIndex
2727
from lightning.data.streaming.serializers import Serializer
2828
from lightning.data.streaming.writer import BinaryWriter
29-
from lightning.data.utilities.env import _DistributedEnv
29+
from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv
3030
from lightning.data.utilities.format import _convert_bytes_to_int
3131

3232
logger = logging.Logger(__name__)
@@ -93,10 +93,15 @@ def __init__(
9393
)
9494
self._is_done = False
9595
self._distributed_env = _DistributedEnv.detect()
96+
self._rank: Optional[int] = None
9697

9798
@property
9899
def rank(self) -> int:
99-
return self._reader.rank
100+
"""Returns the rank of the Cache."""
101+
if self._rank is None:
102+
self._worker_env = _WorkerEnv.detect()
103+
self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank
104+
return self._rank
100105

101106
@property
102107
def filled(self) -> bool:
@@ -109,16 +114,16 @@ def filled(self) -> bool:
109114
@property
110115
def checkpoint_dir(self) -> str:
111116
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
112-
if not os.path.exists(checkpoint_dir):
113-
os.makedirs(checkpoint_dir, exist_ok=True)
114-
return checkpoint_dir
117+
return self._try_create(checkpoint_dir)
115118

116119
@property
117120
def checkpoint_rank_dir(self) -> str:
118-
checkpoint_rank_dir = os.path.join(self.checkpoint_dir, str(self.rank))
119-
if not os.path.exists(checkpoint_rank_dir):
120-
os.makedirs(checkpoint_rank_dir, exist_ok=True)
121-
return checkpoint_rank_dir
121+
checkpoint_rank_dir = os.path.join(self._cache_dir, "checkpoints", str(self.rank))
122+
return self._try_create(checkpoint_rank_dir)
123+
124+
def _try_create(self, path: str) -> str:
125+
os.makedirs(path, exist_ok=True)
126+
return path
122127

123128
def __setitem__(self, index: int, data: Any) -> None:
124129
"""Store an item in the writer."""

src/lightning/data/streaming/dataset.py

Lines changed: 72 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717
import shutil
1818
import sys
1919
import tempfile
20-
from copy import deepcopy
2120
from dataclasses import dataclass
2221
from datetime import datetime
2322
from time import time
2423
from typing import Any, Dict, List, Optional, Union
2524

2625
import numpy as np
2726
import torch
28-
from torch.utils.data import IterableDataset
27+
from torch.utils.data import IterableDataset, get_worker_info
2928

3029
from lightning.data.streaming import Cache
3130
from lightning.data.streaming.constants import (
@@ -56,7 +55,7 @@ def __init__(
5655
drop_last: bool = False,
5756
seed: int = 42,
5857
serializers: Optional[Dict[str, Serializer]] = None,
59-
checkpoint_interval: int = 60 * 5,
58+
checkpoint_interval: Optional[int] = None,
6059
) -> None:
6160
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
6261
@@ -93,15 +92,19 @@ def __init__(
9392
self.worker_intervals: List[List[int]] = []
9493
self.current_indexes: List[int] = []
9594
self.chunk_index = 0
95+
self.num_chunks: Optional[int] = None
9696
self.global_index = 0
9797
self.index = 0
9898
self.has_triggered_download = False
9999
self.min_items_per_replica: Optional[int] = None
100-
self.current_epoch = 0
100+
self.current_epoch = 1
101101
self.random_state = None
102102
self.shuffler: Optional[Shuffle] = None
103103
self.serializers = serializers
104-
self.checkpoint_interval = checkpoint_interval
104+
if sys.platform == "win32":
105+
if checkpoint_interval is not None:
106+
raise ValueError("The argument `checkpoint_interval` isn't suported on Windows.")
107+
self.checkpoint_interval = checkpoint_interval or 60
105108
self._state_dict: Optional[Dict[str, Dict[str, Any]]] = None
106109

107110
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
@@ -170,14 +173,16 @@ def __iter__(self) -> "StreamingDataset":
170173
self.worker_chunks.append(chunk_index)
171174
self.worker_intervals.append(chunk_interval)
172175

176+
self.num_chunks = len(self.worker_chunks)
177+
173178
# Handle restart
174179
if self._state_dict:
175180
state = self._state_dict[str(self.cache.rank)]
176181

177182
# re-generate indexes
178183
interval = self.worker_intervals[self.chunk_index]
179184
current_indexes = np.arange(interval[0], interval[1])
180-
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
185+
current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index)
181186
self.current_indexes = current_indexes[state["index"] :]
182187

183188
# Bump the chunk_index
@@ -210,21 +215,22 @@ def __next__(self) -> Any:
210215

211216
# Lazily re-populate the interval to reduce memory usage.
212217
if len(self.current_indexes) == 0:
213-
if self.chunk_index == len(self.worker_intervals):
218+
if self.chunk_index == self.num_chunks:
214219
self.current_epoch += 1
215220
raise StopIteration
216221

217222
# reset index
218223
self.index = 0
219224

220225
# Checkpoint when reaching a new chunk
221-
self.checkpoint(self.chunk_index)
226+
self._checkpoint(self.chunk_index)
222227

223228
interval = self.worker_intervals[self.chunk_index]
224229
current_indexes = np.arange(interval[0], interval[1])
225230

226231
assert self.shuffler is not None
227-
self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
232+
assert self.num_chunks is not None
233+
self.current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index)
228234

229235
self.chunk_index += 1
230236

@@ -238,7 +244,7 @@ def __next__(self) -> Any:
238244
chunk_index=self.worker_chunks[self.chunk_index - 1],
239245
# We provide the chunks indexes only one the first
240246
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
241-
last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
247+
is_last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
242248
)
243249
)
244250

@@ -247,14 +253,16 @@ def __next__(self) -> Any:
247253
self.index += 1
248254

249255
# Checkpoint based on time
250-
if (self.last_time - time()) > self.checkpoint_interval:
251-
self.checkpoint(self.chunk_index - 1)
256+
if self.checkpoint_interval and (self.last_time - time()) > self.checkpoint_interval:
257+
self._checkpoint(self.chunk_index - 1)
252258

253259
return data
254260

255-
def checkpoint(self, chunk_index: int) -> None:
256-
# Checkpointing isn't supported for windows
257-
if sys.platform == "win32":
261+
def _checkpoint(self, chunk_index: int) -> None:
262+
if self.checkpoint_interval is None:
263+
return
264+
265+
if not _is_in_dataloader_worker():
258266
return
259267

260268
assert self.cache
@@ -284,55 +292,29 @@ def checkpoint(self, chunk_index: int) -> None:
284292
f,
285293
)
286294

287-
# 3. Move the file to avoid corrupted read from the main thread.
288-
now = datetime.now().strftime(_TIME_FORMAT)
289-
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json")
290-
291295
# 4. Move the file to its target position
292-
shutil.move(tmp_checkpoint_path, checkpoint_path)
296+
shutil.move(tmp_checkpoint_path, os.path.join(self.cache.checkpoint_rank_dir, "checkpoint.json"))
293297

294298
self.last_time = time()
295299

296300
def state_dict(self) -> Dict[str, Any]:
301+
if _is_in_dataloader_worker():
302+
raise RuntimeError("The method `state_dict` should only be called in the main process.")
303+
297304
if self.cache is None:
298305
self.worker_env = _WorkerEnv.detect()
299306
self.cache = self._create_cache(worker_env=self.worker_env)
300307

301308
state_dict: Dict[str, Any] = {}
302-
worker_env = _WorkerEnv.detect()
303-
if worker_env.world_size == 1:
304-
# 1. Check whether the checkpoint_dir exists
305-
if not os.path.exists(self.cache.checkpoint_dir):
306-
return state_dict
307-
308-
# 2. Iterate through the workers and read the latest checkpoint
309-
for worker_idx in os.listdir(self.cache.checkpoint_dir):
310-
checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx)))
311-
checkpoints = sorted(checkpoints, key=_string_to_datetime)
312-
313-
# Load the latest checkpoint for this worker
314-
checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1])
315-
with open(checkpoint_path) as f:
316-
state_dict[worker_idx] = json.load(f)
317-
318-
_state_dict = deepcopy(state_dict)
319-
320-
if self.distributed_env.world_size > 1:
321-
# TODO: Move this to fabric.
322-
num_devices = torch.cuda.device_count() or 1
323-
node_ranks = []
324-
for index in range(self.distributed_env.world_size):
325-
node_rank = index // num_devices
326-
if node_rank in node_ranks:
327-
continue
328-
state = {}
329-
obj = [_state_dict]
330-
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
331-
state = obj[0]
332-
state_dict.update(**state)
333-
node_ranks.append(node_rank)
334-
else:
335-
raise NotImplementedError("The `state_dict` should be called on the main thread.")
309+
310+
# 1. Check whether the checkpoint_dir exists
311+
if not os.path.exists(self.cache.checkpoint_dir):
312+
return state_dict
313+
314+
state_dict = _load_state_dict_from_checkpoint_dir(self.cache.checkpoint_dir)
315+
316+
if self.distributed_env.world_size > 1:
317+
return _collect_distributed_state_dict(state_dict, self.distributed_env.world_size)
336318
return state_dict
337319

338320
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
@@ -413,6 +395,42 @@ def _string_to_datetime(item: str) -> datetime:
413395
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)
414396

415397

398+
def _load_state_dict_from_checkpoint_dir(checkpoint_dir: str) -> Dict[str, Any]:
399+
state_dict: Dict[str, Any] = {}
400+
if not os.path.exists(checkpoint_dir):
401+
return state_dict
402+
for worker_idx in os.listdir(checkpoint_dir):
403+
checkpoint_filepath = os.path.join(checkpoint_dir, str(worker_idx), "checkpoint.json")
404+
if not os.path.exists(checkpoint_filepath):
405+
state_dict[worker_idx] = {}
406+
else:
407+
with open(checkpoint_filepath) as f:
408+
state_dict[worker_idx] = json.load(f)
409+
return state_dict
410+
411+
412+
def _collect_distributed_state_dict(state_dict: Dict[str, Any], world_size: int) -> Dict[str, Any]:
413+
state_dict_out: Dict[str, Any] = {}
414+
# TODO: Move this to fabric to support all accelerators
415+
num_devices = torch.cuda.device_count() or 1
416+
node_ranks = []
417+
for index in range(world_size):
418+
node_rank = index // num_devices
419+
if node_rank in node_ranks:
420+
continue
421+
state = {}
422+
obj = [state_dict]
423+
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
424+
state = obj[0]
425+
state_dict_out.update(**state)
426+
node_ranks.append(node_rank)
427+
return state_dict_out
428+
429+
430+
def _is_in_dataloader_worker() -> bool:
431+
return get_worker_info() is not None
432+
433+
416434
@dataclass
417435
class RemoteDir:
418436
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""

src/lightning/data/streaming/reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def read(self, index: ChunkedIndex) -> Any:
228228
chunk_filepath, begin, _ = self.config[index]
229229
item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)
230230

231-
if index.last_index and self._prepare_thread:
231+
if index.is_last_index and self._prepare_thread:
232232
self._prepare_thread.stop()
233233
self._prepare_thread = None
234234

src/lightning/data/streaming/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class ChunkedIndex:
2525
index: int
2626
chunk_index: int
2727
chunk_indexes: Optional[List[int]] = None
28-
last_index: bool = False
28+
is_last_index: bool = False
2929

3030

3131
class CacheBatchSampler:

src/lightning/data/streaming/serializers.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,10 @@ class TensorSerializer(Serializer):
149149

150150
def __init__(self) -> None:
151151
super().__init__()
152-
self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
152+
self._dtype_to_indices = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
153153

154154
def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]:
155-
dtype_indice = self._dtype_to_indice[item.dtype]
155+
dtype_indice = self._dtype_to_indices[item.dtype]
156156
data = [np.uint32(dtype_indice).tobytes()]
157157
data.append(np.uint32(len(item.shape)).tobytes())
158158
for dim in item.shape:
@@ -182,14 +182,14 @@ class NoHeaderTensorSerializer(Serializer):
182182

183183
def __init__(self) -> None:
184184
super().__init__()
185-
self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
185+
self._dtype_to_indices = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
186186
self._dtype: Optional[torch.dtype] = None
187187

188188
def setup(self, data_format: str) -> None:
189189
self._dtype = _TORCH_DTYPES_MAPPING[int(data_format.split(":")[1])]
190190

191191
def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]:
192-
dtype_indice = self._dtype_to_indice[item.dtype]
192+
dtype_indice = self._dtype_to_indices[item.dtype]
193193
return item.numpy().tobytes(order="C"), f"no_header_tensor:{dtype_indice}"
194194

195195
def deserialize(self, data: bytes) -> torch.Tensor:
@@ -205,10 +205,10 @@ class NumpySerializer(Serializer):
205205

206206
def __init__(self) -> None:
207207
super().__init__()
208-
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
208+
self._dtype_to_indices = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
209209

210210
def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
211-
dtype_indice = self._dtype_to_indice[item.dtype]
211+
dtype_indice = self._dtype_to_indices[item.dtype]
212212
data = [np.uint32(dtype_indice).tobytes()]
213213
data.append(np.uint32(len(item.shape)).tobytes())
214214
for dim in item.shape:
@@ -221,8 +221,12 @@ def deserialize(self, data: bytes) -> np.ndarray:
221221
dtype = _NUMPY_DTYPES_MAPPING[dtype_indice]
222222
shape_size = np.frombuffer(data[4:8], np.uint32).item()
223223
shape = []
224+
# deserialize the shape header
225+
# Note: The start position of the shape value: 8 (dtype + shape length) + 4 * shape_idx
224226
for shape_idx in range(shape_size):
225227
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
228+
229+
# deserialize the numpy array bytes
226230
tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
227231
if tensor.shape == shape:
228232
return tensor
@@ -237,14 +241,14 @@ class NoHeaderNumpySerializer(Serializer):
237241

238242
def __init__(self) -> None:
239243
super().__init__()
240-
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
244+
self._dtype_to_indices = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
241245
self._dtype: Optional[np.dtype] = None
242246

243247
def setup(self, data_format: str) -> None:
244248
self._dtype = _NUMPY_DTYPES_MAPPING[int(data_format.split(":")[1])]
245249

246250
def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
247-
dtype_indice: int = self._dtype_to_indice[item.dtype]
251+
dtype_indice: int = self._dtype_to_indices[item.dtype]
248252
return item.tobytes(order="C"), f"no_header_numpy:{dtype_indice}"
249253

250254
def deserialize(self, data: bytes) -> np.ndarray:

0 commit comments

Comments
 (0)