diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 06d08d2e9f96..1dc4814c01ae 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -179,3 +179,37 @@ async fn in_subquery_with_same_table() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn invalid_scalar_subquery() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t1.t1_int) FROM t1"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + assert_eq!( + "Plan(\"Scalar subquery should only return one column\")", + &format!("{err:?}") + ); + + Ok(()) +} + +#[tokio::test] +async fn subquery_not_allowed() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + // In/Exist Subquery is not allowed in ORDER BY clause. + let sql = "SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + + assert_eq!( + "Plan(\"In/Exist subquery can not be used in Sort plan nodes\")", + &format!("{err:?}") + ); + + Ok(()) +} diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer.rs new file mode 100644 index 000000000000..f2a1ba9d64bb --- /dev/null +++ b/datafusion/optimizer/src/analyzer.rs @@ -0,0 +1,202 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::rewrite::TreeNodeRewritable; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::expr_visitor::inspect_expr_pre; +use datafusion_expr::{Expr, LogicalPlan}; +use log::{debug, trace}; +use std::sync::Arc; +use std::time::Instant; + +/// `AnalyzerRule` transforms the unresolved ['LogicalPlan']s and unresolved ['Expr']s into +/// the resolved form. +pub trait AnalyzerRule { + /// Rewrite `plan` + fn analyze(&self, plan: &LogicalPlan, config: &ConfigOptions) -> Result; + + /// A human readable name for this analyzer rule + fn name(&self) -> &str; +} +/// A rule-based Analyzer. +#[derive(Clone)] +pub struct Analyzer { + /// All rules to apply + pub rules: Vec>, +} + +impl Default for Analyzer { + fn default() -> Self { + Self::new() + } +} + +impl Analyzer { + /// Create a new analyzer using the recommended list of rules + pub fn new() -> Self { + let rules = vec![]; + Self::with_rules(rules) + } + + /// Create a new analyzer with the given rules + pub fn with_rules(rules: Vec>) -> Self { + Self { rules } + } + + /// Analyze the logical plan by applying analyzer rules, and + /// do necessary check and fail the invalid plans + pub fn execute_and_check( + &self, + plan: &LogicalPlan, + config: &ConfigOptions, + ) -> Result { + let start_time = Instant::now(); + let mut new_plan = plan.clone(); + + // TODO add common rule executor for Analyzer and Optimizer + for rule in &self.rules { + new_plan = rule.analyze(&new_plan, config)?; + } + check_plan(&new_plan)?; + log_plan("Final analyzed plan", &new_plan); + debug!("Analyzer took {} ms", start_time.elapsed().as_millis()); + Ok(new_plan) + } +} + +/// Log the plan in debug/tracing mode after some part of the optimizer runs +fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} + +/// Do necessary check and fail the invalid plan +fn check_plan(plan: &LogicalPlan) -> Result<()> { + plan.for_each_up(&|plan: &LogicalPlan| { + plan.expressions().into_iter().try_for_each(|expr| { + // recursively look for subqueries + inspect_expr_pre(&expr, |expr| match expr { + Expr::Exists { subquery, .. } + | Expr::InSubquery { subquery, .. } + | Expr::ScalarSubquery(subquery) => { + check_subquery_expr(plan, &subquery.subquery, expr) + } + _ => Ok(()), + }) + }) + }) +} + +/// Do necessary check on subquery expressions and fail the invalid plan +/// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, +/// the allowed while list: [Projection, Filter, Window, Aggregate, Sort, Join]. +/// 2) Check whether the inner plan is in the allowed inner plans list to use correlated(outer) expressions. +/// 3) Check and validate unsupported cases to use the correlated(outer) expressions inside the subquery(inner) plans/inner expressions. +/// For example, we do not want to support to use correlated expressions as the Join conditions in the subquery plan when the Join +/// is a Full Out Join +fn check_subquery_expr( + outer_plan: &LogicalPlan, + inner_plan: &LogicalPlan, + expr: &Expr, +) -> Result<()> { + check_plan(inner_plan)?; + + // Scalar subquery should only return one column + if matches!(expr, Expr::ScalarSubquery(subquery) if subquery.subquery.schema().fields().len() > 1) + { + return Err(DataFusionError::Plan( + "Scalar subquery should only return one column".to_string(), + )); + } + + match outer_plan { + LogicalPlan::Projection(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Window(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Join(_) => Ok(()), + LogicalPlan::Sort(_) => match expr { + Expr::InSubquery { .. } | Expr::Exists { .. } => Err(DataFusionError::Plan( + "In/Exist subquery can not be used in Sort plan nodes".to_string(), + )), + Expr::ScalarSubquery(_) => Ok(()), + _ => Ok(()), + }, + _ => Err(DataFusionError::Plan( + "Subquery can only be used in Projection, Filter, \ + Window functions, Aggregate, Sort and Join plan nodes" + .to_string(), + )), + }?; + check_correlations_in_subquery(outer_plan, inner_plan, expr, true) +} + +// Recursively check the unsupported outer references in the sub query plan. +fn check_correlations_in_subquery( + outer_plan: &LogicalPlan, + inner_plan: &LogicalPlan, + expr: &Expr, + can_contain_outer_ref: bool, +) -> Result<()> { + // We want to support as many operators as possible inside the correlated subquery + if !can_contain_outer_ref && contains_outer_reference(outer_plan, inner_plan, expr) { + return Err(DataFusionError::Plan( + "Accessing outer reference column is not allowed in the plan".to_string(), + )); + } + match inner_plan { + LogicalPlan::Projection(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Window(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Sort(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Union(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::EmptyRelation(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) => inner_plan.apply_children(|plan| { + check_correlations_in_subquery(outer_plan, plan, expr, can_contain_outer_ref) + }), + LogicalPlan::Join(_) => { + // TODO support correlation columns in the subquery join + inner_plan.apply_children(|plan| { + check_correlations_in_subquery( + outer_plan, + plan, + expr, + can_contain_outer_ref, + ) + }) + } + _ => Err(DataFusionError::Plan( + "Unsupported operator in the subquery plan.".to_string(), + )), + } +} + +fn contains_outer_reference( + _outer_plan: &LogicalPlan, + _inner_plan: &LogicalPlan, + _expr: &Expr, +) -> bool { + // TODO check outer references + false +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 4bbbb4645af3..7f930ae3a8d0 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -16,6 +16,7 @@ // under the License. pub mod alias; +pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate_where_exists; pub mod decorrelate_where_in; @@ -35,6 +36,7 @@ pub mod push_down_filter; pub mod push_down_limit; pub mod push_down_projection; pub mod replace_distinct_aggregate; +pub mod rewrite; pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index c1baa25d4363..64445948e116 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,7 @@ //! Query optimizer traits +use crate::analyzer::Analyzer; use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_where_exists::DecorrelateWhereExists; use crate::decorrelate_where_in::DecorrelateWhereIn; @@ -266,9 +267,10 @@ impl Optimizer { F: FnMut(&LogicalPlan, &dyn OptimizerRule), { let options = config.options(); + let analyzed_plan = Analyzer::default().execute_and_check(plan, options)?; let start_time = Instant::now(); - let mut old_plan = Cow::Borrowed(plan); - let mut new_plan = plan.clone(); + let mut old_plan = Cow::Borrowed(&analyzed_plan); + let mut new_plan = analyzed_plan.clone(); let mut i = 0; while i < options.optimizer.max_passes { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); diff --git a/datafusion/optimizer/src/rewrite.rs b/datafusion/optimizer/src/rewrite.rs new file mode 100644 index 000000000000..4a2d35de0086 --- /dev/null +++ b/datafusion/optimizer/src/rewrite.rs @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Trait to make LogicalPlan rewritable + +use datafusion_common::Result; + +use datafusion_expr::LogicalPlan; + +/// a Trait for marking tree node types that are rewritable +pub trait TreeNodeRewritable: Clone { + /// Transform the tree node using the given [TreeNodeRewriter] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// mutate(ChildNode1) + /// pre_visit(ChildNode2) + /// mutate(ChildNode2) + /// mutate(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that node are visited, nor is mutate + /// called on that node + /// + fn transform_using>( + self, + rewriter: &mut R, + ) -> Result { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + let after_op_children = + self.map_children(|node| node.transform_using(rewriter))?; + + // now rewrite this node itself + if need_mutate { + rewriter.mutate(after_op_children) + } else { + Ok(after_op_children) + } + } + + /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. + /// When `op` does not apply to a given node, it is left unchanged. + /// The default tree traversal direction is transform_up(Postorder Traversal). + fn transform(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + self.transform_up(op) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + let node_cloned = self.clone(); + let after_op = match op(node_cloned)? { + Some(value) => value, + None => self, + }; + after_op.map_children(|node| node.transform_down(op)) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up(op))?; + + let after_op_children_clone = after_op_children.clone(); + let new_node = match op(after_op_children)? { + Some(value) => value, + None => after_op_children_clone, + }; + Ok(new_node) + } + + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result; + + /// Apply the given function `func` to this node and recursively apply to the node's children + fn for_each(&self, func: &F) -> Result<()> + where + F: Fn(&Self) -> Result<()>, + { + func(self)?; + self.apply_children(|node| node.for_each(func)) + } + + /// Recursively apply the given function `func` to the node's children and to this node + fn for_each_up(&self, func: &F) -> Result<()> + where + F: Fn(&Self) -> Result<()>, + { + self.apply_children(|node| node.for_each_up(func))?; + func(self) + } + + /// Apply the given function `func` to the node's children + fn apply_children(&self, func: F) -> Result<()> + where + F: Fn(&Self) -> Result<()>; +} + +/// Trait for potentially recursively transform an [`TreeNodeRewritable`] node +/// tree. When passed to `TreeNodeRewritable::transform_using`, `TreeNodeRewriter::mutate` is +/// invoked recursively on all nodes of a tree. +pub trait TreeNodeRewriter: Sized { + /// Invoked before (Preorder) any children of `node` are rewritten / + /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` + fn pre_visit(&mut self, _node: &N) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after (Postorder) all children of `node` have been mutated and + /// returns a potentially modified node. + fn mutate(&mut self, node: N) -> Result; +} + +/// Controls how the [TreeNodeRewriter] recursion should proceed. +#[allow(dead_code)] +pub enum RewriteRecursion { + /// Continue rewrite / visit this node tree. + Continue, + /// Call 'op' immediately and return. + Mutate, + /// Do not rewrite / visit the children of this node. + Stop, + /// Keep recursive but skip apply op on this node + Skip, +} + +impl TreeNodeRewritable for LogicalPlan { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.inputs().into_iter().cloned().collect::>(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + self.with_new_inputs(new_children?.as_slice()) + } else { + Ok(self) + } + } + + fn apply_children(&self, func: F) -> Result<()> + where + F: Fn(&Self) -> Result<()>, + { + let children = self.inputs(); + if !children.is_empty() { + children.into_iter().try_for_each(func) + } else { + Ok(()) + } + } +}