Skip to content

Commit 577e181

Browse files
authored
Fix map() failing to create dataset when input_dir is None (#100)
1 parent 58f7aeb commit 577e181

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/litdata/processing/data_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def run(self, data_recipe: DataRecipe) -> None:
10191019
print("Workers are finished.")
10201020
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)
10211021

1022-
if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO and self.input_dir.path:
1022+
if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO:
10231023
assert self.output_dir.path
10241024
_create_dataset(
10251025
input_dir=self.input_dir.path,

tests/processing/test_data_processor.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import Any, List
77
from unittest import mock
8+
from unittest.mock import ANY, Mock
89

910
import numpy as np
1011
import pytest
@@ -901,7 +902,7 @@ def map_fn_index(index, output_dir):
901902

902903

903904
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
904-
def test_data_processing_map_without_input_dir(monkeypatch, tmpdir):
905+
def test_data_processing_map_without_input_dir_local(monkeypatch, tmpdir):
905906
cache_dir = os.path.join(tmpdir, "cache")
906907
output_dir = os.path.join(tmpdir, "target_dir")
907908
os.makedirs(output_dir, exist_ok=True)
@@ -920,6 +921,46 @@ def test_data_processing_map_without_input_dir(monkeypatch, tmpdir):
920921
assert sorted(os.listdir(output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"]
921922

922923

924+
@pytest.mark.skipif(sys.platform == "win32", reason="Windows not supported")
925+
def test_data_processing_map_without_input_dir_remote(monkeypatch, tmpdir):
926+
cache_dir = os.path.join(tmpdir, "cache")
927+
output_dir = os.path.join("/teamspace", "datasets", "target_dir")
928+
929+
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
930+
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)
931+
932+
create_dataset_mock = Mock()
933+
monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "1")
934+
monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "2")
935+
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "3")
936+
monkeypatch.setattr("litdata.processing.data_processor._IS_IN_STUDIO", True)
937+
monkeypatch.setattr(
938+
"litdata.streaming.resolver._resolve_datasets",
939+
Mock(return_value=Dir(path=tmpdir / "output", url="url")),
940+
)
941+
monkeypatch.setattr("litdata.processing.data_processor._create_dataset", create_dataset_mock)
942+
943+
map(
944+
map_fn_index,
945+
list(range(5)),
946+
output_dir=output_dir,
947+
num_workers=1,
948+
)
949+
950+
create_dataset_mock.assert_called_with(
951+
input_dir=None,
952+
storage_dir=str(tmpdir / "output"),
953+
dataset_type=ANY,
954+
empty=ANY,
955+
size=ANY,
956+
num_bytes=ANY,
957+
data_format=ANY,
958+
compression=ANY,
959+
num_chunks=ANY,
960+
num_bytes_per_chunk=ANY,
961+
)
962+
963+
923964
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
924965
def test_data_processing_map_weights_mismatch(monkeypatch, tmpdir):
925966
cache_dir = os.path.join(tmpdir, "cache")

0 commit comments

Comments
 (0)