diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 9d06a878458..a1dc754a353 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -208,6 +208,14 @@ def build_base_parser() -> ArgumentParser: "(e.g., block_size for JSON reading). This configuration is " "especially useful when reading large JSON files.", ) + parser.add_argument( + "--override_num_blocks", + type=Optional[int], + default=None, + help="Override the number of output blocks for Ray Data read " + "operations. Useful for controlling parallelism when reading " + "large datasets.", + ) parser.add_argument( "--work_dir", type=str, diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index d07f7184d23..fb755e3238c 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -201,6 +201,8 @@ class RayLocalJsonDataLoadStrategy(RayDataLoadStrategy): def load_data(self, **kwargs): from data_juicer.core.data.ray_dataset import RayDataset + override_num_blocks = kwargs.pop("override_num_blocks", None) + path = self.ds_config["path"] # Convert to absolute path if relative @@ -281,7 +283,7 @@ def load_data(self, **kwargs): else: logger.info(f"Loading {data_format} data.") try: - dataset = RayDataset.read(data_format, path) + dataset = RayDataset.read(data_format, path, override_num_blocks=override_num_blocks) return RayDataset(dataset, dataset_path=path, cfg=self.cfg) except Exception as e: if auto_detect: diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 538a90bbb95..3ff71e16773 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -347,11 +347,14 @@ def count(self) -> int: return self.data.count() @classmethod - def read(cls, data_format: str, paths: Union[str, List[str]]) -> RayDataset: + def read(cls, data_format: str, paths: Union[str, List[str]], **kwargs) -> RayDataset: + # Pop override_num_blocks since native Ray read functions don't support it + override_num_blocks = kwargs.pop("override_num_blocks", None) + if data_format in {"json", "jsonl", "json.gz", "jsonl.gz", "json.zst", "jsonl.zst"}: - return RayDataset.read_json(paths) + return RayDataset.read_json(paths, override_num_blocks=override_num_blocks, **kwargs) elif data_format == "webdataset": - return RayDataset.read_webdataset(paths) + return RayDataset.read_webdataset(paths, override_num_blocks=override_num_blocks, **kwargs) elif data_format in { "parquet", "images", @@ -369,23 +372,35 @@ def read(cls, data_format: str, paths: Union[str, List[str]]) -> RayDataset: from data_juicer.utils.lazy_loader import LazyLoader LazyLoader.check_packages(["pylance"]) - return getattr(ray.data, f"read_{data_format}")(paths) + dataset = getattr(ray.data, f"read_{data_format}")(paths, **kwargs) + if override_num_blocks: + dataset = dataset.repartition(override_num_blocks) + return dataset @classmethod - def read_json(cls, paths: Union[str, List[str]]) -> RayDataset: + def read_json( + cls, paths: Union[str, List[str]], override_num_blocks: Optional[int] = None, **kwargs + ) -> ray.data.Dataset: # Note: a temp solution for reading json stream # TODO: replace with ray.data.read_json_stream once it is available import pyarrow.json as js try: js.open_json - return read_json_stream(paths) + return read_json_stream(paths, override_num_blocks=override_num_blocks, **kwargs) except AttributeError: - return ray.data.read_json(paths) + dataset = ray.data.read_json(paths, **kwargs) + if override_num_blocks: + dataset = dataset.repartition(override_num_blocks) + return dataset @classmethod - def read_webdataset(cls, paths: Union[str, List[str]]) -> RayDataset: - return ray.data.read_webdataset(paths, decoder=partial(_custom_default_decoder, format="PIL")) + def read_webdataset(cls, paths: Union[str, List[str]], **kwargs) -> ray.data.Dataset: + override_num_blocks = kwargs.pop("override_num_blocks", None) + dataset = ray.data.read_webdataset(paths, decoder=partial(_custom_default_decoder, format="PIL"), **kwargs) + if override_num_blocks: + dataset = dataset.repartition(override_num_blocks) + return dataset def to_list(self) -> list: return self.data.to_pandas().to_dict(orient="records") diff --git a/data_juicer/core/executor/ray_executor_partitioned.py b/data_juicer/core/executor/ray_executor_partitioned.py index c9bb55b50d5..6b4a0ceafdd 100644 --- a/data_juicer/core/executor/ray_executor_partitioned.py +++ b/data_juicer/core/executor/ray_executor_partitioned.py @@ -414,7 +414,8 @@ def _run_impl(self, load_data_np: Optional[PositiveInt] = None, skip_return=Fals # Load the full dataset using a single DatasetBuilder logger.info("Loading dataset with single DatasetBuilder...") - dataset = self.datasetbuilder.load_dataset(num_proc=load_data_np) + override_num_blocks = getattr(self.cfg, "override_num_blocks", None) + dataset = self.datasetbuilder.load_dataset(num_proc=load_data_np, override_num_blocks=override_num_blocks) columns = dataset.schema().columns # Prepare operations diff --git a/demos/agent/scripts/run_bad_case_pipeline.sh b/demos/agent/scripts/run_bad_case_pipeline.sh old mode 100755 new mode 100644 diff --git a/demos/agent/scripts/verify_bad_case_export.py b/demos/agent/scripts/verify_bad_case_export.py old mode 100755 new mode 100644 diff --git a/demos/partition_and_checkpoint/robustness_benchmark.py b/demos/partition_and_checkpoint/robustness_benchmark.py old mode 100755 new mode 100644 diff --git a/demos/partition_and_checkpoint/run_demo.py b/demos/partition_and_checkpoint/run_demo.py old mode 100755 new mode 100644 diff --git a/thirdparty/LLM_ecosystems/setup_helm.sh b/thirdparty/LLM_ecosystems/setup_helm.sh old mode 100755 new mode 100644 diff --git a/thirdparty/LLM_ecosystems/setup_megatron.sh b/thirdparty/LLM_ecosystems/setup_megatron.sh old mode 100755 new mode 100644 diff --git a/thirdparty/models/setup_easyanimate.sh b/thirdparty/models/setup_easyanimate.sh old mode 100755 new mode 100644 diff --git a/tools/check_s3_integration.py b/tools/check_s3_integration.py old mode 100755 new mode 100644