Skip to content

Commit 5cae73c

Browse files
deependujhapre-commit-ci[bot]bhimrazy
authored
Fix: Chunks deletion issue (#375)
* not sure if it works. hehe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added test to check behaviour in CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bhimraj Yadav <[email protected]>
1 parent b039b64 commit 5cae73c

File tree

2 files changed

+75
-6
lines changed

2 files changed

+75
-6
lines changed

src/litdata/processing/data_processor.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def run(self) -> None:
456456
try:
457457
self._setup()
458458
self._loop()
459+
self._terminate()
459460
except Exception:
460461
traceback_format = traceback.format_exc()
461462
self.error_queue.put(traceback_format)
@@ -469,6 +470,19 @@ def _setup(self) -> None:
469470
self._start_uploaders()
470471
self._start_remover()
471472

473+
def _terminate(self) -> None:
474+
"""Make sure all the uploaders, downloaders and removers are terminated."""
475+
for uploader in self.uploaders:
476+
if uploader.is_alive():
477+
uploader.join()
478+
479+
for downloader in self.downloaders:
480+
if downloader.is_alive():
481+
downloader.join()
482+
483+
if self.remover and self.remover.is_alive():
484+
self.remover.join()
485+
472486
def _loop(self) -> None:
473487
num_downloader_finished = 0
474488

@@ -795,7 +809,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
795809

796810
chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")]
797811
if chunks and delete_cached_files and output_dir.path is not None:
798-
raise RuntimeError(f"All the chunks should have been deleted. Found {chunks}")
812+
raise RuntimeError(f"All the chunks should have been deleted. Found {chunks} in cache: {cache_dir}")
799813

800814
merge_cache = Cache(cache_dir, chunk_bytes=1)
801815
node_rank = _get_node_rank()
@@ -1110,6 +1124,10 @@ def run(self, data_recipe: DataRecipe) -> None:
11101124

11111125
current_total = new_total
11121126
if current_total == num_items:
1127+
# make sure all processes are terminated
1128+
for w in self.workers:
1129+
if w.is_alive():
1130+
w.join()
11131131
break
11141132

11151133
if _IS_IN_STUDIO and node_rank == 0 and _ENABLE_STATUS:
@@ -1118,17 +1136,13 @@ def run(self, data_recipe: DataRecipe) -> None:
11181136

11191137
# Exit early if all the workers are done.
11201138
# This means there were some kinda of errors.
1139+
# TODO: Check whether this is still required.
11211140
if all(not w.is_alive() for w in self.workers):
11221141
raise RuntimeError("One of the worker has failed")
11231142

11241143
if _TQDM_AVAILABLE:
11251144
pbar.close()
11261145

1127-
# TODO: Check whether this is still required.
1128-
if num_nodes == 1:
1129-
for w in self.workers:
1130-
w.join()
1131-
11321146
print("Workers are finished.")
11331147
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)
11341148

tests/processing/test_functions.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import glob
12
import os
3+
import random
4+
import shutil
25
import sys
6+
from pathlib import Path
37
from unittest import mock
48

59
import cryptography
610
import numpy as np
711
import pytest
12+
import requests
813
from litdata import StreamingDataset, merge_datasets, optimize, walk
914
from litdata.processing.functions import _get_input_dir, _resolve_dir
1015
from litdata.streaming.cache import Cache
@@ -475,3 +480,53 @@ def test_optimize_with_rsa_encryption(tmpdir):
475480
# encryption=rsa,
476481
# mode="overwrite",
477482
# )
483+
484+
485+
def tokenize(filename: str):
486+
with open(filename, encoding="utf-8") as file:
487+
text = file.read()
488+
text = text.strip().split(" ")
489+
word_to_int = {word: random.randint(1, 1000) for word in set(text)} # noqa: S311
490+
tokenized = [word_to_int[word] for word in text]
491+
492+
yield tokenized
493+
494+
495+
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
496+
def test_optimize_race_condition(tmpdir):
497+
# issue: https://github.com/Lightning-AI/litdata/issues/367
498+
# run_commands = [
499+
# "mkdir -p tempdir/custom_texts",
500+
# "curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output tempdir/custom_texts/book1.txt",
501+
# "curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output tempdir/custom_texts/book2.txt",
502+
# ]
503+
shutil.rmtree(f"{tmpdir}/custom_texts", ignore_errors=True)
504+
os.makedirs(f"{tmpdir}/custom_texts", exist_ok=True)
505+
506+
urls = [
507+
"https://www.gutenberg.org/cache/epub/24440/pg24440.txt",
508+
"https://www.gutenberg.org/cache/epub/26393/pg26393.txt",
509+
]
510+
511+
for i, url in enumerate(urls):
512+
print(f"downloading {i+1} file")
513+
with requests.get(url, stream=True, timeout=10) as r:
514+
r.raise_for_status() # Raise an exception for bad status codes
515+
516+
with open(f"{tmpdir}/custom_texts/book{i+1}.txt", "wb") as f:
517+
for chunk in r.iter_content(chunk_size=8192):
518+
f.write(chunk)
519+
520+
print("=" * 100)
521+
522+
train_files = sorted(glob.glob(str(Path(f"{tmpdir}/custom_texts") / "*.txt")))
523+
print("=" * 100)
524+
print(train_files)
525+
print("=" * 100)
526+
optimize(
527+
fn=tokenize,
528+
inputs=train_files,
529+
output_dir=f"{tmpdir}/temp",
530+
num_workers=1,
531+
chunk_bytes="50MB",
532+
)

0 commit comments

Comments
 (0)