Skip to content

Commit 19d9eab

Browse files
authored
Enable map over inputs without files input (#19285)
1 parent 4996965 commit 19d9eab

File tree

4 files changed

+80
-13
lines changed

4 files changed

+80
-13
lines changed

Diff for: src/lightning/data/streaming/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# This is required for full pytree serialization / deserialization support
2727
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
2828
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
29-
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.57")
29+
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.58")
3030
_BOTO3_AVAILABLE = RequirementCache("boto3")
3131

3232
# DON'T CHANGE ORDER

Diff for: src/lightning/data/streaming/data_processor.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
261261
flattened_item, _ = tree_flatten(item)
262262

263263
num_bytes = 0
264-
for index, element in enumerate(flattened_item):
264+
for element in flattened_item:
265265
if isinstance(element, str) and element.startswith(base_path) and os.path.exists(element):
266266
file_bytes = os.path.getsize(element)
267267
if file_bytes == 0:
@@ -358,7 +358,7 @@ def _loop(self) -> None:
358358
for uploader in self.uploaders:
359359
uploader.join()
360360

361-
if self.remove:
361+
if self.remove and self.input_dir.path is not None:
362362
assert self.remover
363363
self.remove_queue.put(None)
364364
self.remover.join()
@@ -380,7 +380,7 @@ def _loop(self) -> None:
380380
self.progress_queue.put((self.worker_index, self._counter))
381381
self._last_time = time()
382382

383-
if self.remove:
383+
if self.remove and self.input_dir.path is not None:
384384
self.remove_queue.put(self.paths[index])
385385

386386
try:
@@ -420,6 +420,13 @@ def _try_upload(self, filepath: Optional[str]) -> None:
420420
self.to_upload_queues[self._counter % self.num_uploaders].put(filepath)
421421

422422
def _collect_paths(self) -> None:
423+
if self.input_dir.path is None:
424+
for index in range(len(self.items)):
425+
self.ready_to_process_queue.put(index)
426+
for _ in range(self.num_downloaders):
427+
self.ready_to_process_queue.put(None)
428+
return
429+
423430
items = []
424431
for item in self.items:
425432
flattened_item, spec = tree_flatten(item)
@@ -456,6 +463,8 @@ def _collect_paths(self) -> None:
456463
self.items = items
457464

458465
def _start_downloaders(self) -> None:
466+
if self.input_dir.path is None:
467+
return
459468
for _ in range(self.num_downloaders):
460469
to_download_queue: Queue = Queue()
461470
p = Process(
@@ -478,8 +487,9 @@ def _start_downloaders(self) -> None:
478487
self.to_download_queues[downloader_index].put(None)
479488

480489
def _start_remover(self) -> None:
481-
if not self.remove:
490+
if not self.remove or self.input_dir.path is None:
482491
return
492+
483493
self.remover = Process(
484494
target=_remove_target,
485495
args=(
@@ -548,9 +558,6 @@ def _handle_data_transform_recipe(self, index: int) -> None:
548558
for filename in filenames:
549559
filepaths.append(os.path.join(directory, filename))
550560

551-
if len(filepaths) == 0:
552-
raise RuntimeError("You haven't saved any files under the `output_dir`.")
553-
554561
for filepath in filepaths:
555562
self._try_upload(filepath)
556563

@@ -804,7 +811,7 @@ def run(self, data_recipe: DataRecipe) -> None:
804811
if not isinstance(user_items, list):
805812
raise ValueError("The `prepare_structure` should return a list of item metadata.")
806813

807-
if self.reorder_files:
814+
if self.reorder_files and self.input_dir.path:
808815
# TODO: Only do this on node 0, and broadcast the item sizes to the other nodes.
809816
item_sizes = _get_item_filesizes(user_items, base_path=self.input_dir.path)
810817
workers_user_items = _map_items_to_workers_weighted(

Diff for: src/lightning/data/streaming/functions.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from datetime import datetime
1717
from pathlib import Path
1818
from types import FunctionType
19-
from typing import Any, Callable, Optional, Sequence, Union
19+
from typing import Any, Callable, Dict, Optional, Sequence, Union
2020

2121
import torch
2222

@@ -30,16 +30,28 @@
3030
from torch.utils._pytree import tree_flatten
3131

3232

33-
def _get_input_dir(inputs: Sequence[Any]) -> str:
34-
flattened_item, _ = tree_flatten(inputs[0])
33+
def _get_indexed_paths(data: Any) -> Dict[int, str]:
34+
flattened_item, _ = tree_flatten(data)
3535

3636
indexed_paths = {
3737
index: element
3838
for index, element in enumerate(flattened_item)
3939
if isinstance(element, str) and os.path.exists(element)
4040
}
4141

42+
return indexed_paths
43+
44+
45+
def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]:
46+
indexed_paths = _get_indexed_paths(inputs[0])
47+
4248
if len(indexed_paths) == 0:
49+
# Check whether the second element has any input_path
50+
indexed_paths = _get_indexed_paths(inputs[1])
51+
if len(indexed_paths) == 0:
52+
return None
53+
54+
# Every element should have filepaths if any contains one.
4355
raise ValueError(f"The provided item {inputs[0]} didn't contain any filepaths.")
4456

4557
absolute_path = str(Path(list(indexed_paths.values())[0]).resolve())
@@ -129,6 +141,7 @@ def map(
129141
machine: Optional[str] = None,
130142
num_downloaders: Optional[int] = None,
131143
reorder_files: bool = True,
144+
error_when_not_empty: bool = False,
132145
) -> None:
133146
"""This function map a callbable over a collection of files possibly in a distributed way.
134147
@@ -144,6 +157,7 @@ def map(
144157
num_downloaders: The number of downloaders per worker.
145158
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
146159
Set this to ``False`` if the order in which samples are processed should be preserved.
160+
error_when_not_empty: Whether we should error if the output folder isn't empty.
147161
148162
"""
149163
if not isinstance(inputs, Sequence):
@@ -161,7 +175,8 @@ def map(
161175
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
162176
)
163177

164-
_assert_dir_is_empty(output_dir)
178+
if error_when_not_empty:
179+
_assert_dir_is_empty(output_dir)
165180

166181
input_dir = _resolve_dir(_get_input_dir(inputs))
167182

Diff for: tests/tests_data/streaming/test_data_processor.py

+45
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, List
55
from unittest import mock
66

7+
import lightning_cloud
78
import numpy as np
89
import pytest
910
import torch
@@ -25,6 +26,7 @@
2526
_wait_for_file_to_exist,
2627
)
2728
from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize
29+
from lightning_cloud import resolver
2830
from lightning_utilities.core.imports import RequirementCache
2931

3032
_PIL_AVAILABLE = RequirementCache("PIL")
@@ -872,3 +874,46 @@ def test_get_item_filesizes(tmp_path):
872874
assert os.path.getsize(tmp_path / "empty_file") == 0
873875
with pytest.raises(RuntimeError, match="has 0 bytes!"):
874876
_get_item_filesizes([str(tmp_path / "empty_file")])
877+
878+
879+
def map_fn_index(output_dir, index):
880+
with open(os.path.join(output_dir, f"{index}.JPEG"), "w") as f:
881+
f.write("Hello")
882+
883+
884+
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
885+
def test_data_processing_map_without_input_dir(monkeypatch, tmpdir):
886+
cache_dir = os.path.join(tmpdir, "cache")
887+
output_dir = os.path.join(tmpdir, "target_dir")
888+
os.makedirs(output_dir, exist_ok=True)
889+
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
890+
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)
891+
892+
map(map_fn_index, list(range(5)), output_dir=output_dir, num_workers=1, reorder_files=True)
893+
894+
assert sorted(os.listdir(output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"]
895+
896+
897+
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
898+
def test_map_error_when_not_empty(monkeypatch, tmpdir):
899+
boto3 = mock.MagicMock()
900+
client_s3_mock = mock.MagicMock()
901+
client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []}
902+
boto3.client.return_value = client_s3_mock
903+
monkeypatch.setattr(resolver, "boto3", boto3)
904+
905+
with pytest.raises(RuntimeError, match="data and datasets are meant to be immutable"):
906+
map(
907+
map_fn,
908+
[0, 1],
909+
output_dir=lightning_cloud.resolver.Dir(path=None, url="s3://bucket"),
910+
error_when_not_empty=True,
911+
)
912+
913+
with pytest.raises(OSError, match="cache"):
914+
map(
915+
map_fn,
916+
[0, 1],
917+
output_dir=lightning_cloud.resolver.Dir(path=None, url="s3://bucket"),
918+
error_when_not_empty=False,
919+
)

0 commit comments

Comments
 (0)