@@ -241,7 +241,10 @@ def _map_items_to_workers_sequentially(num_workers: int, user_items: List[Any])
241
241
242
242
243
243
def _map_items_to_workers_weighted (
244
- num_workers : int , user_items : List [Any ], weights : Optional [List [int ]] = None
244
+ num_workers : int ,
245
+ user_items : List [Any ],
246
+ weights : Optional [List [int ]] = None ,
247
+ file_size : bool = True ,
245
248
) -> List [List [Any ]]:
246
249
# Associate the items to the workers based on number of nodes and node rank.
247
250
weights = [1 ] * len (user_items ) if weights is None else weights
@@ -255,7 +258,11 @@ def _map_items_to_workers_weighted(
255
258
for worker_id , size in worker_weights .items ():
256
259
if worker_id not in worker_ids_this_node :
257
260
continue
258
- print (f"Worker { worker_id } gets { size / 1e6 :.1f} MB ({ len (worker_items [worker_id ])} files)" )
261
+
262
+ if file_size :
263
+ print (f"Worker { worker_id } gets { size / 1e6 :.1f} MB ({ len (worker_items [worker_id ])} files)" )
264
+ else :
265
+ print (f"Worker { worker_id } gets ({ len (worker_items [worker_id ])} ) items for a total weight of { size } ." )
259
266
260
267
return [worker_items [worker_id ] for worker_id in worker_ids_this_node ]
261
268
@@ -769,6 +776,7 @@ def __init__(
769
776
fast_dev_run : Optional [Union [bool , int ]] = None ,
770
777
random_seed : Optional [int ] = 42 ,
771
778
reorder_files : bool = True ,
779
+ weights : Optional [List [int ]] = None ,
772
780
):
773
781
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
774
782
training faster.
@@ -784,6 +792,8 @@ def __init__(
784
792
random_seed: The random seed to be set before shuffling the data.
785
793
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
786
794
Set this to ``False`` if the order in which samples are processed should be preserved.
795
+ weights: Provide a list of weights associated to the inputs.
796
+ This is used to evenly split the work among the workers.
787
797
788
798
"""
789
799
self .input_dir = _resolve_dir (input_dir )
@@ -799,6 +809,7 @@ def __init__(
799
809
self .error_queue : Queue = Queue ()
800
810
self .stop_queues : List [Queue ] = []
801
811
self .reorder_files = reorder_files
812
+ self .weights = weights
802
813
803
814
# Ensure the input dir is the same across all nodes
804
815
self .input_dir = broadcast_object ("input_dir" , self .input_dir )
@@ -827,7 +838,14 @@ def run(self, data_recipe: DataRecipe) -> None:
827
838
if not isinstance (user_items , list ):
828
839
raise ValueError ("The `prepare_structure` should return a list of item metadata." )
829
840
830
- if self .reorder_files and self .input_dir .path :
841
+ if self .weights is not None :
842
+ if len (self .weights ) != len (user_items ):
843
+ raise ValueError ("The provided weights length should match the inputs' length." )
844
+ workers_user_items = _map_items_to_workers_weighted (
845
+ num_workers = self .num_workers , user_items = user_items , weights = self .weights , file_size = False
846
+ )
847
+
848
+ elif self .reorder_files and self .input_dir .path :
831
849
# TODO: Only do this on node 0, and broadcast the item sizes to the other nodes.
832
850
item_sizes = _get_item_filesizes (user_items , base_path = self .input_dir .path )
833
851
workers_user_items = _map_items_to_workers_weighted (
0 commit comments