Skip to content

Commit fdc8290

Browse files
authored
fix: StreamingRawDataset Async Handling (#661)
* feat(streaming): simplify async handling in StreamingRawDataset by using asyncio.run() * enhance StreamingRawDataset tests with async and thread safety checks
1 parent 68b74f3 commit fdc8290

File tree

2 files changed

+69
-17
lines changed

2 files changed

+69
-17
lines changed

src/litdata/streaming/raw_dataset.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -339,17 +339,8 @@ def __getitem__(self, index: int) -> Any:
339339

340340
def __getitems__(self, indices: list[int]) -> list[Any]:
341341
"""Asynchronously download multiple items by index."""
342-
return self._run_async(self._download_batch(indices))
343-
344-
def _run_async(self, coro: Any) -> Any:
345-
"""Runs a coroutine, attaching to an existing event loop if one is running."""
346-
try:
347-
loop = asyncio.get_event_loop()
348-
except RuntimeError:
349-
loop = asyncio.new_event_loop()
350-
asyncio.set_event_loop(loop)
351-
352-
return loop.run_until_complete(coro)
342+
# asyncio.run() handles loop creation, execution, and teardown cleanly.
343+
return asyncio.run(self._download_batch(indices))
353344

354345
async def _download_batch(self, indices: list[int]) -> list[Any]:
355346
"""Asynchronously download and transform items."""

tests/streaming/test_raw_dataset.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
import sys
3+
import threading
4+
from pathlib import Path
35
from unittest.mock import Mock, patch
46

57
import pytest
@@ -33,6 +35,7 @@ def test_file_metadata():
3335

3436

3537
def test_file_indexer_init():
38+
"""Test FileIndexer initialization."""
3639
indexer = FileIndexer()
3740
assert indexer.max_depth == 5
3841
assert indexer.extensions == []
@@ -206,7 +209,6 @@ def mock_download_fileobj(file_path, file_obj):
206209
mock_downloader.download_fileobj.side_effect = mock_download_fileobj
207210

208211
input_dir = "s3://bucket/dataset"
209-
210212
manager = CacheManager(input_dir=input_dir)
211213

212214
file_path = "s3://bucket/dataset/file.jpg"
@@ -241,13 +243,14 @@ def test_streaming_raw_dataset_getitem_index_error(tmp_path):
241243

242244
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
243245
def test_streaming_raw_dataset_getitems(tmp_path):
244-
"""Test batch item access."""
246+
"""Test synchronous batch item access."""
245247
test_contents = [b"content1", b"content2", b"content3"]
246248
for i, content in enumerate(test_contents):
247249
(tmp_path / f"file{i}.jpg").write_bytes(content)
248250

249251
dataset = StreamingRawDataset(input_dir=str(tmp_path), cache_files=False)
250252

253+
# Mock _download_batch to return test contents
251254
async def mock_download_batch(indices):
252255
return [test_contents[i] for i in indices]
253256

@@ -256,6 +259,66 @@ async def mock_download_batch(indices):
256259
assert items == [test_contents[0], test_contents[2]]
257260

258261

262+
@pytest.mark.asyncio
263+
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
264+
async def test_download_batch(tmp_path):
265+
"""Test asynchronous batch download functionality."""
266+
# Create test files with predefined content
267+
test_contents = {
268+
str(tmp_path / "file0.jpg"): b"content1",
269+
str(tmp_path / "file1.jpg"): b"content2",
270+
str(tmp_path / "file2.jpg"): b"content3",
271+
}
272+
for file_path, content in test_contents.items():
273+
Path(file_path).write_bytes(content)
274+
275+
# Initialize the dataset
276+
dataset = StreamingRawDataset(input_dir=str(tmp_path))
277+
278+
# Find indices for specific files
279+
file0_path = str(tmp_path / "file0.jpg")
280+
file2_path = str(tmp_path / "file2.jpg")
281+
indices = [
282+
next(i for i, f in enumerate(dataset.files) if f.path == file0_path),
283+
next(i for i, f in enumerate(dataset.files) if f.path == file2_path),
284+
]
285+
286+
# Mock _process_item to return content based on file path
287+
async def mock_process_item(file_path):
288+
return test_contents[file_path]
289+
290+
# Patch and test _download_batch
291+
with patch.object(dataset, "_process_item", side_effect=mock_process_item):
292+
items = await dataset._download_batch(indices)
293+
assert items == [test_contents[file0_path], test_contents[file2_path]]
294+
295+
296+
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
297+
def test_thread_safety(tmp_path):
298+
"""Test thread safety in multi-threaded environments."""
299+
test_contents = [b"content1", b"content2", b"content3"]
300+
for i, content in enumerate(test_contents):
301+
(tmp_path / f"file{i}.jpg").write_bytes(content)
302+
303+
dataset = StreamingRawDataset(input_dir=str(tmp_path), cache_files=False)
304+
305+
# Mock _download_batch to return test contents
306+
async def mock_download_batch(indices):
307+
return [test_contents[i] for i in indices]
308+
309+
with patch.object(dataset, "_download_batch", side_effect=mock_download_batch):
310+
311+
def worker():
312+
items = dataset.__getitems__([0, 2])
313+
assert items == [test_contents[0], test_contents[2]]
314+
315+
threads = [threading.Thread(target=worker) for _ in range(3)]
316+
for thread in threads:
317+
thread.start()
318+
for thread in threads:
319+
thread.join()
320+
321+
259322
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
260323
def test_streaming_raw_dataset_getitems_type_error(tmp_path):
261324
"""Test type error for invalid indices type."""
@@ -311,15 +374,14 @@ def transform(x):
311374
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
312375
def test_streaming_raw_dataset_with_dataloader(tmp_path):
313376
"""Test dataset integration with PyTorch DataLoader."""
314-
# Create test files
315377
test_contents = [b"content1", b"content2", b"content3", b"content4"]
316378
for i, content in enumerate(test_contents):
317379
(tmp_path / f"file{i}.jpg").write_bytes(content)
318380

319381
dataset = StreamingRawDataset(input_dir=str(tmp_path))
320382

321-
# Mock download to return test content
322-
def mock_download_async(file_path):
383+
# Mock async download to return test content
384+
async def mock_download_async(file_path):
323385
index = int(file_path.split("file")[1].split(".")[0])
324386
return test_contents[index]
325387

@@ -335,7 +397,6 @@ def mock_download_async(file_path):
335397
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
336398
def test_streaming_raw_dataset_no_files_error(tmp_path):
337399
"""Test error when no files are found."""
338-
# Create empty directory
339400
empty_dir = tmp_path / "empty"
340401
empty_dir.mkdir()
341402

0 commit comments

Comments
 (0)