diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 44523f435..266bbdca6 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -93,7 +93,6 @@ def _combine_similar(self, root: Expr): if columns_operand is None: columns_operand = self.columns columns = set(columns_operand) - ops = [self] + alike for op in alike: op_columns = op.operand("columns") if op_columns is None: @@ -109,14 +108,17 @@ def _combine_similar(self, root: Expr): return # Check if we have the operation we want elsewhere in the graph - for op in ops: - if op.columns == columns and not op.operand("_series"): + for op in alike: + if set(op.columns) == set(columns) and not op.operand("_series"): return ( op[columns_operand[0]] if self._series else op[columns_operand] ) + if set(self.columns) == set(columns): + return # Skip unnecessary projection change + # Create the "combined" ReadParquet operation subs = {"columns": columns} if self._series: diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 0ffb0ecf2..24c289268 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -542,25 +542,44 @@ def _dataset_info(self): meta = self.engine._create_dd_meta(dataset_info) index = dataset_info["index"] index = [index] if isinstance(index, str) else index - meta, index, columns = set_index_columns( - meta, index, self.operand("columns"), auto_index_allowed + meta, index, all_columns = set_index_columns( + meta, index, None, auto_index_allowed ) if meta.index.name == NONE_LABEL: meta.index.name = None - dataset_info["meta"] = meta + dataset_info["base_meta"] = meta dataset_info["index"] = index - dataset_info["columns"] = columns + dataset_info["all_columns"] = all_columns return dataset_info @property def _meta(self): - meta = self._dataset_info["meta"] + meta = self._dataset_info["base_meta"] + columns = _convert_to_list(self.operand("columns")) if self._series: - column = _convert_to_list(self.operand("columns"))[0] - return meta[column] + assert len(columns) > 0 + return meta[columns[0]] + elif columns is not None: + return meta[columns] return meta + @cached_property + def _io_func(self): + if self._plan["empty"]: + return lambda x: x + dataset_info = self._dataset_info + return ParquetFunctionWrapper( + self.engine, + dataset_info["fs"], + dataset_info["base_meta"], + self.columns, + dataset_info["index"], + dataset_info["kwargs"]["dtype_backend"], + {}, # All kwargs should now be in `common_kwargs` + self._plan["common_kwargs"], + ) + @cached_property def _plan(self): dataset_info = self._dataset_info @@ -579,31 +598,20 @@ def _plan(self): # Use statistics to calculate divisions divisions = _calculate_divisions(stats, dataset_info, len(parts)) - meta = dataset_info["meta"] + empty = False if len(divisions) < 2: # empty dataframe - just use meta divisions = (None, None) - io_func = lambda x: x - parts = [meta] - else: - # Use IO function wrapper - io_func = ParquetFunctionWrapper( - self.engine, - dataset_info["fs"], - meta, - dataset_info["columns"], - dataset_info["index"], - dataset_info["kwargs"]["dtype_backend"], - {}, # All kwargs should now be in `common_kwargs` - common_kwargs, - ) + parts = [self._meta] + empty = True _control_cached_plan(dataset_token) _cached_plan[dataset_token] = { - "func": io_func, + "empty": empty, "parts": parts, "statistics": stats, "divisions": divisions, + "common_kwargs": common_kwargs, } return _cached_plan[dataset_token] @@ -611,7 +619,7 @@ def _divisions(self): return self._plan["divisions"] def _filtered_task(self, index: int): - tsk = (self._plan["func"], self._plan["parts"][index]) + tsk = (self._io_func, self._plan["parts"][index]) if self._series: return (operator.getitem, tsk, self.columns[0]) return tsk @@ -863,7 +871,7 @@ def _collect_pq_statistics( raise ValueError(f"columns={columns} must be a subset of {allowed}") # Collect statistics using layer information - fs = expr._plan["func"].fs + fs = expr._io_func.fs parts = [ part for i, part in enumerate(expr._plan["parts"])