Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate io_func from ReadParquet._plan for better caching #367

Merged
merged 5 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def _combine_similar(self, root: Expr):
if columns_operand is None:
columns_operand = self.columns
columns = set(columns_operand)
ops = [self] + alike
for op in alike:
op_columns = op.operand("columns")
if op_columns is None:
Expand All @@ -109,14 +108,17 @@ def _combine_similar(self, root: Expr):
return

# Check if we have the operation we want elsewhere in the graph
for op in ops:
if op.columns == columns and not op.operand("_series"):
for op in alike:
if set(op.columns) == set(columns) and not op.operand("_series"):
return (
op[columns_operand[0]]
if self._series
else op[columns_operand]
)

if set(self.columns) == set(columns):
return # Skip unnecessary projection change

# Create the "combined" ReadParquet operation
subs = {"columns": columns}
if self._series:
Expand Down
58 changes: 33 additions & 25 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,25 +542,44 @@ def _dataset_info(self):
meta = self.engine._create_dd_meta(dataset_info)
index = dataset_info["index"]
index = [index] if isinstance(index, str) else index
meta, index, columns = set_index_columns(
meta, index, self.operand("columns"), auto_index_allowed
meta, index, all_columns = set_index_columns(
meta, index, None, auto_index_allowed
)
if meta.index.name == NONE_LABEL:
meta.index.name = None
dataset_info["meta"] = meta
dataset_info["base_meta"] = meta
dataset_info["index"] = index
dataset_info["columns"] = columns
dataset_info["all_columns"] = all_columns

return dataset_info

@property
def _meta(self):
meta = self._dataset_info["meta"]
meta = self._dataset_info["base_meta"]
columns = _convert_to_list(self.operand("columns"))
if self._series:
column = _convert_to_list(self.operand("columns"))[0]
return meta[column]
assert len(columns) > 0
return meta[columns[0]]
elif columns is not None:
return meta[columns]
return meta

@cached_property
def _io_func(self):
if self._plan["empty"]:
return lambda x: x
dataset_info = self._dataset_info
return ParquetFunctionWrapper(
self.engine,
dataset_info["fs"],
dataset_info["base_meta"],
self.columns,
dataset_info["index"],
dataset_info["kwargs"]["dtype_backend"],
{}, # All kwargs should now be in `common_kwargs`
self._plan["common_kwargs"],
)

@cached_property
def _plan(self):
dataset_info = self._dataset_info
Expand All @@ -579,39 +598,28 @@ def _plan(self):
# Use statistics to calculate divisions
divisions = _calculate_divisions(stats, dataset_info, len(parts))

meta = dataset_info["meta"]
empty = False
if len(divisions) < 2:
# empty dataframe - just use meta
divisions = (None, None)
io_func = lambda x: x
parts = [meta]
else:
# Use IO function wrapper
io_func = ParquetFunctionWrapper(
self.engine,
dataset_info["fs"],
meta,
dataset_info["columns"],
dataset_info["index"],
dataset_info["kwargs"]["dtype_backend"],
{}, # All kwargs should now be in `common_kwargs`
common_kwargs,
)
parts = [self._meta]
empty = True

_control_cached_plan(dataset_token)
_cached_plan[dataset_token] = {
"func": io_func,
"empty": empty,
"parts": parts,
"statistics": stats,
"divisions": divisions,
"common_kwargs": common_kwargs,
}
return _cached_plan[dataset_token]

def _divisions(self):
return self._plan["divisions"]

def _filtered_task(self, index: int):
tsk = (self._plan["func"], self._plan["parts"][index])
tsk = (self._io_func, self._plan["parts"][index])
if self._series:
return (operator.getitem, tsk, self.columns[0])
return tsk
Expand Down Expand Up @@ -863,7 +871,7 @@ def _collect_pq_statistics(
raise ValueError(f"columns={columns} must be a subset of {allowed}")

# Collect statistics using layer information
fs = expr._plan["func"].fs
fs = expr._io_func.fs
parts = [
part
for i, part in enumerate(expr._plan["parts"])
Expand Down