55from pathlib import Path
66from typing import Any , List
77from unittest import mock
8+ from unittest .mock import ANY , Mock
89
910import numpy as np
1011import pytest
@@ -901,7 +902,7 @@ def map_fn_index(index, output_dir):
901902
902903
903904@pytest .mark .skipif (condition = not _PIL_AVAILABLE or sys .platform == "win32" , reason = "Requires: ['pil']" )
904- def test_data_processing_map_without_input_dir (monkeypatch , tmpdir ):
905+ def test_data_processing_map_without_input_dir_local (monkeypatch , tmpdir ):
905906 cache_dir = os .path .join (tmpdir , "cache" )
906907 output_dir = os .path .join (tmpdir , "target_dir" )
907908 os .makedirs (output_dir , exist_ok = True )
@@ -920,6 +921,46 @@ def test_data_processing_map_without_input_dir(monkeypatch, tmpdir):
920921 assert sorted (os .listdir (output_dir )) == ["0.JPEG" , "1.JPEG" , "2.JPEG" , "3.JPEG" , "4.JPEG" ]
921922
922923
924+ @pytest .mark .skipif (sys .platform == "win32" , reason = "Windows not supported" )
925+ def test_data_processing_map_without_input_dir_remote (monkeypatch , tmpdir ):
926+ cache_dir = os .path .join (tmpdir , "cache" )
927+ output_dir = os .path .join ("/teamspace" , "datasets" , "target_dir" )
928+
929+ monkeypatch .setenv ("DATA_OPTIMIZER_CACHE_FOLDER" , cache_dir )
930+ monkeypatch .setenv ("DATA_OPTIMIZER_DATA_CACHE_FOLDER" , cache_dir )
931+
932+ create_dataset_mock = Mock ()
933+ monkeypatch .setenv ("LIGHTNING_CLUSTER_ID" , "1" )
934+ monkeypatch .setenv ("LIGHTNING_CLOUD_PROJECT_ID" , "2" )
935+ monkeypatch .setenv ("LIGHTNING_CLOUD_SPACE_ID" , "3" )
936+ monkeypatch .setattr ("litdata.processing.data_processor._IS_IN_STUDIO" , True )
937+ monkeypatch .setattr (
938+ "litdata.streaming.resolver._resolve_datasets" ,
939+ Mock (return_value = Dir (path = tmpdir / "output" , url = "url" )),
940+ )
941+ monkeypatch .setattr ("litdata.processing.data_processor._create_dataset" , create_dataset_mock )
942+
943+ map (
944+ map_fn_index ,
945+ list (range (5 )),
946+ output_dir = output_dir ,
947+ num_workers = 1 ,
948+ )
949+
950+ create_dataset_mock .assert_called_with (
951+ input_dir = None ,
952+ storage_dir = str (tmpdir / "output" ),
953+ dataset_type = ANY ,
954+ empty = ANY ,
955+ size = ANY ,
956+ num_bytes = ANY ,
957+ data_format = ANY ,
958+ compression = ANY ,
959+ num_chunks = ANY ,
960+ num_bytes_per_chunk = ANY ,
961+ )
962+
963+
923964@pytest .mark .skipif (condition = not _PIL_AVAILABLE or sys .platform == "win32" , reason = "Requires: ['pil']" )
924965def test_data_processing_map_weights_mismatch (monkeypatch , tmpdir ):
925966 cache_dir = os .path .join (tmpdir , "cache" )
0 commit comments