Skip to content

Commit 37a521c

Browse files
authored
map operator: Add weights to evenly distributed works among workers (#19365)
1 parent 9d35c61 commit 37a521c

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

src/lightning/data/streaming/data_processor.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,10 @@ def _map_items_to_workers_sequentially(num_workers: int, user_items: List[Any])
241241

242242

243243
def _map_items_to_workers_weighted(
244-
num_workers: int, user_items: List[Any], weights: Optional[List[int]] = None
244+
num_workers: int,
245+
user_items: List[Any],
246+
weights: Optional[List[int]] = None,
247+
file_size: bool = True,
245248
) -> List[List[Any]]:
246249
# Associate the items to the workers based on number of nodes and node rank.
247250
weights = [1] * len(user_items) if weights is None else weights
@@ -255,7 +258,11 @@ def _map_items_to_workers_weighted(
255258
for worker_id, size in worker_weights.items():
256259
if worker_id not in worker_ids_this_node:
257260
continue
258-
print(f"Worker {worker_id} gets {size / 1e6:.1f} MB ({len(worker_items[worker_id])} files)")
261+
262+
if file_size:
263+
print(f"Worker {worker_id} gets {size / 1e6:.1f} MB ({len(worker_items[worker_id])} files)")
264+
else:
265+
print(f"Worker {worker_id} gets ({len(worker_items[worker_id])}) items for a total weight of {size}.")
259266

260267
return [worker_items[worker_id] for worker_id in worker_ids_this_node]
261268

@@ -769,6 +776,7 @@ def __init__(
769776
fast_dev_run: Optional[Union[bool, int]] = None,
770777
random_seed: Optional[int] = 42,
771778
reorder_files: bool = True,
779+
weights: Optional[List[int]] = None,
772780
):
773781
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
774782
training faster.
@@ -784,6 +792,8 @@ def __init__(
784792
random_seed: The random seed to be set before shuffling the data.
785793
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
786794
Set this to ``False`` if the order in which samples are processed should be preserved.
795+
weights: Provide a list of weights associated to the inputs.
796+
This is used to evenly split the work among the workers.
787797
788798
"""
789799
self.input_dir = _resolve_dir(input_dir)
@@ -799,6 +809,7 @@ def __init__(
799809
self.error_queue: Queue = Queue()
800810
self.stop_queues: List[Queue] = []
801811
self.reorder_files = reorder_files
812+
self.weights = weights
802813

803814
# Ensure the input dir is the same across all nodes
804815
self.input_dir = broadcast_object("input_dir", self.input_dir)
@@ -827,7 +838,14 @@ def run(self, data_recipe: DataRecipe) -> None:
827838
if not isinstance(user_items, list):
828839
raise ValueError("The `prepare_structure` should return a list of item metadata.")
829840

830-
if self.reorder_files and self.input_dir.path:
841+
if self.weights is not None:
842+
if len(self.weights) != len(user_items):
843+
raise ValueError("The provided weights length should match the inputs' length.")
844+
workers_user_items = _map_items_to_workers_weighted(
845+
num_workers=self.num_workers, user_items=user_items, weights=self.weights, file_size=False
846+
)
847+
848+
elif self.reorder_files and self.input_dir.path:
831849
# TODO: Only do this on node 0, and broadcast the item sizes to the other nodes.
832850
item_sizes = _get_item_filesizes(user_items, base_path=self.input_dir.path)
833851
workers_user_items = _map_items_to_workers_weighted(

src/lightning/data/streaming/functions.py

+2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def map(
149149
fn: Callable[[str, Any], None],
150150
inputs: Sequence[Any],
151151
output_dir: Union[str, Dir],
152+
weights: Optional[List[int]] = None,
152153
num_workers: Optional[int] = None,
153154
fast_dev_run: Union[bool, int] = False,
154155
num_nodes: Optional[int] = None,
@@ -201,6 +202,7 @@ def map(
201202
fast_dev_run=fast_dev_run,
202203
num_downloaders=num_downloaders,
203204
reorder_files=reorder_files,
205+
weights=weights,
204206
)
205207
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
206208
return _execute(

tests/tests_data/streaming/test_data_processor.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -907,11 +907,30 @@ def test_data_processing_map_without_input_dir(monkeypatch, tmpdir):
907907
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
908908
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)
909909

910-
map(map_fn_index, list(range(5)), output_dir=output_dir, num_workers=1, reorder_files=True)
910+
map(
911+
map_fn_index,
912+
list(range(5)),
913+
output_dir=output_dir,
914+
num_workers=1,
915+
reorder_files=True,
916+
weights=[1 for _ in range(5)],
917+
)
911918

912919
assert sorted(os.listdir(output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"]
913920

914921

922+
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
923+
def test_data_processing_map_weights_mismatch(monkeypatch, tmpdir):
924+
cache_dir = os.path.join(tmpdir, "cache")
925+
output_dir = os.path.join(tmpdir, "target_dir")
926+
os.makedirs(output_dir, exist_ok=True)
927+
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
928+
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)
929+
930+
with pytest.raises(ValueError, match="The provided weights length"):
931+
map(map_fn_index, list(range(5)), output_dir=output_dir, num_workers=1, reorder_files=True, weights=[1])
932+
933+
915934
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
916935
def test_map_error_when_not_empty(monkeypatch, tmpdir):
917936
boto3 = mock.MagicMock()

0 commit comments

Comments
 (0)