feat: Wire override_num_blocks through full call chain for Ray Data read operations#984
feat: Wire override_num_blocks through full call chain for Ray Data read operations#984fengrui-z wants to merge 7 commits into
Conversation
…on check to support local model paths
- config.py: add --override_num_blocks CLI argument - ray_executor_partitioned.py: pass override_num_blocks to load_dataset() - load_strategy.py: extract and forward override_num_blocks in RayLocalJsonDataLoadStrategy - ray_dataset.py: add kwargs to read(), add override_num_blocks to read_json() - Enables users to control Ray Data block parallelism for large datasets (5PB+)
There was a problem hiding this comment.
Code Review
This pull request introduces a new configuration option, --override_num_blocks, to control the number of output blocks during Ray Data read operations, enhancing parallelism control. The changes involve propagating this parameter through the RayDataset loading methods and the partitioned executor. Feedback highlights several potential TypeError risks because standard Ray read functions do not natively support the override_num_blocks parameter; reviewers suggest using the .repartition() method as a workaround. Additionally, the review identifies missing **kwargs in method signatures and incorrect return type hints that need to be addressed for consistency and correctness.
|
|
||
| @classmethod | ||
| def read_json(cls, paths: Union[str, List[str]]) -> RayDataset: | ||
| def read_json(cls, paths: Union[str, List[str]], override_num_blocks: Optional[int] = None) -> RayDataset: |
There was a problem hiding this comment.
The read_json method is missing **kwargs in its signature. Since RayDataset.read (line 350) now passes **kwargs to read_json, any additional arguments (such as read_options from the config) will cause a TypeError. Also, the return type hint should be updated to ray.data.Dataset.
| def read_json(cls, paths: Union[str, List[str]], override_num_blocks: Optional[int] = None) -> RayDataset: | |
| def read_json(cls, paths: Union[str, List[str]], override_num_blocks: Optional[int] = None, **kwargs) -> ray.data.Dataset: |
| def read_webdataset(cls, paths: Union[str, List[str]], **kwargs) -> RayDataset: | ||
| return ray.data.read_webdataset(paths, decoder=partial(_custom_default_decoder, format="PIL"), **kwargs) |
There was a problem hiding this comment.
ray.data.read_webdataset does not support override_num_blocks. If this argument is present in kwargs, the call will fail. Additionally, the return type hint should be corrected.
| def read_webdataset(cls, paths: Union[str, List[str]], **kwargs) -> RayDataset: | |
| return ray.data.read_webdataset(paths, decoder=partial(_custom_default_decoder, format="PIL"), **kwargs) | |
| @classmethod | |
| def read_webdataset(cls, paths: Union[str, List[str]], **kwargs) -> ray.data.Dataset: | |
| override_num_blocks = kwargs.pop("override_num_blocks", None) | |
| dataset = ray.data.read_webdataset(paths, decoder=partial(_custom_default_decoder, format="PIL"), **kwargs) | |
| if override_num_blocks: | |
| dataset = dataset.repartition(override_num_blocks) | |
| return dataset |
|
|
||
| @classmethod | ||
| def read(cls, data_format: str, paths: Union[str, List[str]]) -> RayDataset: | ||
| def read(cls, data_format: str, paths: Union[str, List[str]], **kwargs) -> RayDataset: |
There was a problem hiding this comment.
The return type hint RayDataset is incorrect. This method returns a ray.data.Dataset object, which is then wrapped by the caller (e.g., in load_strategy.py). This inconsistency exists in several methods in this class.
| def read(cls, data_format: str, paths: Union[str, List[str]], **kwargs) -> RayDataset: | |
| def read(cls, data_format: str, paths: Union[str, List[str]], **kwargs) -> ray.data.Dataset: |
…ay read calls - read(): pop override_num_blocks from kwargs, use repartition() for non-JSON formats - read_json(): add **kwargs, use repartition() in fallback path instead of passing to ray.data.read_json - read_webdataset(): pop override_num_blocks, apply repartition() after read - Fix return type hints to ray.data.Dataset
|
From the current config surface, this looks like the first top-level config that maps specifically to a For example: ray_read_options:
override_num_blocks: 1000
concurrency: 64
ray_remote_args:
num_cpus: 0.25 |
Good suggestion — I considered this but decided to keep a top-level argument for the following reasons:
Happy to revisit if we find ourselves adding a second or third Ray read param in the near future. |
Summary
Enable users to control Ray Data's block parallelism via a new
--override_num_blocksCLI argument. This parameter was already implemented at the lowest layer (read_json_stream()) but was never wired through the upstream call chain, making it inaccessible without monkey-patching.Motivation
When processing very large datasets (billions of records / PB-scale), Ray Data's default block size (128MB) creates an excessive number of blocks (~40M blocks for 5PB), leading to:
Previously, the only workaround was to monkey-patch
RayLocalJsonDataLoadStrategy.load_data()(as seen in benchmark scripts). This PR provides a clean, user-facing configuration path.Changes
data_juicer/config/config.py--override_num_blocksCLI argument (Optional[int], defaultNone)data_juicer/core/executor/ray_executor_partitioned.pycfg.override_num_blocksand pass toload_dataset()data_juicer/core/data/load_strategy.pyRayLocalJsonDataLoadStrategy.load_data()extracts and forwardsoverride_num_blocksdata_juicer/core/data/ray_dataset.pyread()accepts**kwargs;read_json()acceptsoverride_num_blocksand passes toread_json_stream()Usage
Backward Compatibility
None(no override) — existing behavior is unchanged**kwargspass-through orOptionalparameters withNonedefaultsTesting
python -c "import data_juicer"passescfg→PartitionedRayExecutor._run_impl()→DatasetBuilder.load_dataset()→RayLocalJsonDataLoadStrategy.load_data()→RayDataset.read()→RayDataset.read_json()→read_json_stream()→ray.data.read_datasource(override_num_blocks=N)Related