From afd9c481c11a27fd9fab244e95a5511ebe66afd9 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Fri, 27 Dec 2024 13:37:08 +0800 Subject: [PATCH] Inline the common expression in With if used once --- .../optimizer/RewriteWithExpression.scala | 21 +++++++++++++------ .../RewriteWithExpressionSuite.scala | 14 +++++++++++-- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index 40189a9f61021..5d85e89e1eabe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -68,9 +68,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] { private def applyInternal(p: LogicalPlan): LogicalPlan = { val inputPlans = p.children + val commonExprIdSet = p.expressions + .flatMap(_.collect { case r: CommonExpressionRef => r.id }) + .groupBy(identity) + .transform((_, v) => v.size) + .filter(_._2 > 1) + .keySet val commonExprsPerChild = Array.fill(inputPlans.length)(mutable.ListBuffer.empty[(Alias, Long)]) var newPlan: LogicalPlan = p.mapExpressions { expr => - rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild) + rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild, commonExprIdSet) } val newChildren = inputPlans.zip(commonExprsPerChild).map { case (inputPlan, commonExprs) => if (commonExprs.isEmpty) { @@ -96,6 +102,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] { e: Expression, inputPlans: Seq[LogicalPlan], commonExprsPerChild: Array[mutable.ListBuffer[(Alias, Long)]], + commonExprIdSet: Set[CommonExpressionId], isNestedWith: Boolean = false): Expression = { if (!e.containsPattern(WITH_EXPRESSION)) return e e match { @@ -103,9 +110,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] { case w: With if !isNestedWith => // Rewrite nested With expressions first val child = rewriteWithExprAndInputPlans( - w.child, inputPlans, commonExprsPerChild, isNestedWith = true) + w.child, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true) val defs = w.defs.map(rewriteWithExprAndInputPlans( - _, inputPlans, commonExprsPerChild, isNestedWith = true)) + _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true)) val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression] defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => @@ -114,7 +121,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] { "Cannot rewrite canonicalized Common expression definitions") } - if (CollapseProject.isCheap(child)) { + if (CollapseProject.isCheap(child) || !commonExprIdSet.contains(id)) { refToExpr(id) = child } else { val childPlanIndex = inputPlans.indexWhere( @@ -171,7 +178,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] { case c: ConditionalExpression => val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map( - rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith)) + rewriteWithExprAndInputPlans( + _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith)) val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs) // Use transformUp to handle nested With. newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) { @@ -185,7 +193,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] { } case other => other.mapChildren( - rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith) + rewriteWithExprAndInputPlans( + _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith) ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index 9f0a7fdaf3152..8918b58ca1b56 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -140,7 +140,7 @@ class RewriteWithExpressionSuite extends PlanTest { val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2)) val ref2 = new CommonExpressionRef(commonExprDef2) // The inner main expression references the outer expression - val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2)) + val innerExpr2 = With(ref2 + ref2 + outerRef, Seq(commonExprDef2)) val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef)) comparePlans( Optimizer.execute(testRelation.select(outerExpr2.as("col"))), @@ -152,7 +152,8 @@ class RewriteWithExpressionSuite extends PlanTest { .select(star(), (a + a).as("_common_expr_2")) // The final Project contains the final result expression, which references both common // expressions. - .select(($"_common_expr_0" + ($"_common_expr_2" + $"_common_expr_0")).as("col")) + .select(($"_common_expr_0" + + ($"_common_expr_2" + $"_common_expr_2" + $"_common_expr_0")).as("col")) .analyze ) } @@ -490,4 +491,13 @@ class RewriteWithExpressionSuite extends PlanTest { val wrongPlan = testRelation.select(expr1.as("c1"), expr3.as("c3")).analyze intercept[AssertionError](Optimizer.execute(wrongPlan)) } + + test("SPARK-50683: inline the common expression in With if used once") { + val a = testRelation.output.head + val exprDef = CommonExpressionDef(a + a) + val exprRef = new CommonExpressionRef(exprDef) + val expr = With(exprRef + 1, Seq(exprDef)) + val plan = testRelation.select(expr.as("col")) + comparePlans(Optimizer.execute(plan), testRelation.select((a + a + 1).as("col"))) + } }