-
Couldn't load subscription status.
- Fork 3.6k
Add support for parallelizing processing parquet files across workers and nodes. #19400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 13 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
7d5c13c
update
tchaton 77ed46a
update
tchaton c7517cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 90ff250
update
ff6cf65
update
895c6c3
update
f7ef40e
update
7c0829c
update
92eec05
Merge branch 'master' into add_parquet_reader
tchaton 25ad0c9
update
a8b216e
Merge branch 'add_parquet_reader' of https://github.com/Lightning-AI/…
9851e3d
update
9b0eaed
update
39dae84
update
ae015d2
update
tchaton 8fc93da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0b3c09c
update
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,3 +5,5 @@ pytest-timeout ==2.1.0 | |
| pytest-rerunfailures ==12.0 | ||
| pytest-random-order ==1.1.0 | ||
| viztracer | ||
| pyarrow | ||
| polars | ||
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import io | ||
| from typing import Optional, Tuple | ||
|
|
||
| from lightning_utilities.core.imports import RequirementCache | ||
|
|
||
| _HTTPX_AVAILABLE = RequirementCache("httpx") | ||
|
|
||
| # Credit to the https://github.com/rom1504/pytorch Github repo | ||
| # The code was taken from there. | ||
|
|
||
| def _download_image( | ||
| url: str, | ||
| timeout: int = 10, | ||
| user_agent_token: str = "pytorch-lightning", | ||
| ) -> Tuple[Optional[io.BytesIO], Optional[Exception]]: | ||
| """Download an image with urllib.""" | ||
| url | ||
| img_stream = None | ||
| user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" | ||
| if user_agent_token: | ||
| user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)" | ||
| import httpx | ||
|
|
||
| try: | ||
| with httpx.Client(http2=True) as client: | ||
| r = client.get(url, headers={"User-Agent": user_agent_string}, timeout=timeout) | ||
| img_stream = io.BytesIO(r.read()) | ||
| return img_stream, None | ||
| except Exception as err: # pylint: disable=broad-except | ||
| if img_stream is not None: | ||
| img_stream.close() | ||
| return None, err | ||
|
|
||
|
|
||
| def download_image( | ||
| url: str, | ||
| retries: int = 0, | ||
| timeout: int = 10, | ||
| user_agent_token: str = "pytorch-lightning", | ||
| ) -> Tuple[Optional[io.BytesIO], Optional[Exception]]: | ||
| if not _HTTPX_AVAILABLE: | ||
| raise ModuleNotFoundError("Please, run: `pip install httpx`.") | ||
| for _ in range(retries + 1): | ||
| img_stream, err = _download_image(url, timeout, user_agent_token) | ||
| if img_stream is not None: | ||
| return img_stream, err | ||
| return None, err | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| import os | ||
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass | ||
| from typing import Any, List, Optional | ||
|
|
||
| from lightning_utilities.core.imports import RequirementCache | ||
|
|
||
| from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks | ||
| from lightning.data.utilities.env import _DistributedEnv | ||
|
|
||
| _POLARS_AVAILABLE = RequirementCache("polars") | ||
| _PYARROW_AVAILABLE = RequirementCache("pyarrow") | ||
|
|
||
|
|
||
| class BaseReader(ABC): | ||
|
|
||
| def get_num_nodes(self) -> int: | ||
| return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1)) | ||
|
|
||
| @abstractmethod | ||
| def to_workers_user_items(self, items: List[Any], num_workers: int) -> List[List[Any]]: | ||
tchaton marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """This method is meant to convert the items provided by the users into items to be processed by the | ||
| workers.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def read(self, item: Any) -> Any: | ||
| """Read the data associated to an item.""" | ||
| pass | ||
|
|
||
|
|
||
| @dataclass | ||
| class ParquetSlice: | ||
| """Keep track of a parquet file slice with its filepath, start and end.""" | ||
| filepath: str | ||
| start: int | ||
| end: int | ||
|
|
||
|
|
||
| class ParquetReader(BaseReader): | ||
|
|
||
| def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> None: | ||
| self.num_rows = num_rows | ||
| self.to_pandas = to_pandas | ||
|
|
||
| if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE: | ||
| raise ModuleNotFoundError("Please, run: `pip install pyarrow polars`") | ||
|
|
||
| def _get_num_rows(self, path: str) -> int: | ||
| if _PYARROW_AVAILABLE: | ||
| import pyarrow.dataset as ds | ||
| df = ds.dataset(path).scanner() | ||
| return df.count_rows() | ||
|
|
||
| # FIXED: There is a bug in polars. This leads to read_parquet to hang. | ||
| if _POLARS_AVAILABLE: | ||
| import polars as pol | ||
| df = pol.scan_parquet(path) | ||
| num_rows = df.select(pol.len()).collect().item() | ||
| return num_rows | ||
|
|
||
| raise RuntimeError("Please, install either pyarrow or polars.") | ||
|
|
||
| def read(self, item: ParquetSlice) -> Any: | ||
| if _POLARS_AVAILABLE: | ||
| import polars as pol | ||
| df = pol.scan_parquet(item.filepath).slice(item.start, item.end).collect() | ||
|
|
||
| if self.to_pandas: | ||
| df = df.to_pandas() | ||
|
|
||
| return df | ||
|
|
||
| if _PYARROW_AVAILABLE: | ||
| import pyarrow.dataset as ds | ||
|
|
||
| df = ds.dataset(item.filepath).scanner() | ||
|
|
||
| df = df.take([item.start, item.end]) | ||
|
|
||
| if self.to_pandas: | ||
| df.to_pandas() | ||
|
|
||
| return df | ||
|
|
||
| raise RuntimeError("Please, install either pyarrow or polars.") | ||
|
|
||
|
|
||
| def to_workers_user_items(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]: | ||
| intervals = [(0, self._get_num_rows(item)) for item in items] | ||
|
|
||
| world_size = self.get_num_nodes() * num_workers | ||
|
|
||
| fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes()) | ||
| parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks( | ||
| fake_distributed_env, list(range(len(items))), intervals, False) | ||
|
|
||
| workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(world_size)] | ||
|
|
||
| iterator = enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker)) | ||
|
|
||
| for worker_idx, (parquet_indexes, parquet_slices) in iterator: | ||
| if self.num_rows: | ||
| workers_user_items[worker_idx].extend([ | ||
| ParquetSlice( | ||
| items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows | ||
| if parquet_slice[1] > (parquet_slice_start + self.num_rows) else | ||
| parquet_slice[1] | ||
| ) | ||
| for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) | ||
| for parquet_slice_start in range(parquet_slice[0], parquet_slice[1] + self.num_rows, self.num_rows) | ||
| if parquet_slice_start < parquet_slice[1] | ||
| ]) | ||
| else: | ||
| workers_user_items[worker_idx].extend([ | ||
| ParquetSlice(items[parquet_index], *parquet_slice) | ||
| for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) | ||
| ]) | ||
|
|
||
| return workers_user_items | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.