11import os
22import sys
3+ import threading
4+ from pathlib import Path
35from unittest .mock import Mock , patch
46
57import pytest
@@ -33,6 +35,7 @@ def test_file_metadata():
3335
3436
3537def test_file_indexer_init ():
38+ """Test FileIndexer initialization."""
3639 indexer = FileIndexer ()
3740 assert indexer .max_depth == 5
3841 assert indexer .extensions == []
@@ -206,7 +209,6 @@ def mock_download_fileobj(file_path, file_obj):
206209 mock_downloader .download_fileobj .side_effect = mock_download_fileobj
207210
208211 input_dir = "s3://bucket/dataset"
209-
210212 manager = CacheManager (input_dir = input_dir )
211213
212214 file_path = "s3://bucket/dataset/file.jpg"
@@ -241,13 +243,14 @@ def test_streaming_raw_dataset_getitem_index_error(tmp_path):
241243
242244@pytest .mark .skipif (condition = sys .platform == "win32" , reason = "Not supported on windows" )
243245def test_streaming_raw_dataset_getitems (tmp_path ):
244- """Test batch item access."""
246+ """Test synchronous batch item access."""
245247 test_contents = [b"content1" , b"content2" , b"content3" ]
246248 for i , content in enumerate (test_contents ):
247249 (tmp_path / f"file{ i } .jpg" ).write_bytes (content )
248250
249251 dataset = StreamingRawDataset (input_dir = str (tmp_path ), cache_files = False )
250252
253+ # Mock _download_batch to return test contents
251254 async def mock_download_batch (indices ):
252255 return [test_contents [i ] for i in indices ]
253256
@@ -256,6 +259,66 @@ async def mock_download_batch(indices):
256259 assert items == [test_contents [0 ], test_contents [2 ]]
257260
258261
262+ @pytest .mark .asyncio
263+ @pytest .mark .skipif (condition = sys .platform == "win32" , reason = "Not supported on windows" )
264+ async def test_download_batch (tmp_path ):
265+ """Test asynchronous batch download functionality."""
266+ # Create test files with predefined content
267+ test_contents = {
268+ str (tmp_path / "file0.jpg" ): b"content1" ,
269+ str (tmp_path / "file1.jpg" ): b"content2" ,
270+ str (tmp_path / "file2.jpg" ): b"content3" ,
271+ }
272+ for file_path , content in test_contents .items ():
273+ Path (file_path ).write_bytes (content )
274+
275+ # Initialize the dataset
276+ dataset = StreamingRawDataset (input_dir = str (tmp_path ))
277+
278+ # Find indices for specific files
279+ file0_path = str (tmp_path / "file0.jpg" )
280+ file2_path = str (tmp_path / "file2.jpg" )
281+ indices = [
282+ next (i for i , f in enumerate (dataset .files ) if f .path == file0_path ),
283+ next (i for i , f in enumerate (dataset .files ) if f .path == file2_path ),
284+ ]
285+
286+ # Mock _process_item to return content based on file path
287+ async def mock_process_item (file_path ):
288+ return test_contents [file_path ]
289+
290+ # Patch and test _download_batch
291+ with patch .object (dataset , "_process_item" , side_effect = mock_process_item ):
292+ items = await dataset ._download_batch (indices )
293+ assert items == [test_contents [file0_path ], test_contents [file2_path ]]
294+
295+
296+ @pytest .mark .skipif (condition = sys .platform == "win32" , reason = "Not supported on windows" )
297+ def test_thread_safety (tmp_path ):
298+ """Test thread safety in multi-threaded environments."""
299+ test_contents = [b"content1" , b"content2" , b"content3" ]
300+ for i , content in enumerate (test_contents ):
301+ (tmp_path / f"file{ i } .jpg" ).write_bytes (content )
302+
303+ dataset = StreamingRawDataset (input_dir = str (tmp_path ), cache_files = False )
304+
305+ # Mock _download_batch to return test contents
306+ async def mock_download_batch (indices ):
307+ return [test_contents [i ] for i in indices ]
308+
309+ with patch .object (dataset , "_download_batch" , side_effect = mock_download_batch ):
310+
311+ def worker ():
312+ items = dataset .__getitems__ ([0 , 2 ])
313+ assert items == [test_contents [0 ], test_contents [2 ]]
314+
315+ threads = [threading .Thread (target = worker ) for _ in range (3 )]
316+ for thread in threads :
317+ thread .start ()
318+ for thread in threads :
319+ thread .join ()
320+
321+
259322@pytest .mark .skipif (condition = sys .platform == "win32" , reason = "Not supported on windows" )
260323def test_streaming_raw_dataset_getitems_type_error (tmp_path ):
261324 """Test type error for invalid indices type."""
@@ -311,15 +374,14 @@ def transform(x):
311374@pytest .mark .skipif (condition = sys .platform == "win32" , reason = "Not supported on windows" )
312375def test_streaming_raw_dataset_with_dataloader (tmp_path ):
313376 """Test dataset integration with PyTorch DataLoader."""
314- # Create test files
315377 test_contents = [b"content1" , b"content2" , b"content3" , b"content4" ]
316378 for i , content in enumerate (test_contents ):
317379 (tmp_path / f"file{ i } .jpg" ).write_bytes (content )
318380
319381 dataset = StreamingRawDataset (input_dir = str (tmp_path ))
320382
321- # Mock download to return test content
322- def mock_download_async (file_path ):
383+ # Mock async download to return test content
384+ async def mock_download_async (file_path ):
323385 index = int (file_path .split ("file" )[1 ].split ("." )[0 ])
324386 return test_contents [index ]
325387
@@ -335,7 +397,6 @@ def mock_download_async(file_path):
335397@pytest .mark .skipif (condition = sys .platform == "win32" , reason = "Not supported on windows" )
336398def test_streaming_raw_dataset_no_files_error (tmp_path ):
337399 """Test error when no files are found."""
338- # Create empty directory
339400 empty_dir = tmp_path / "empty"
340401 empty_dir .mkdir ()
341402
0 commit comments