Skip to content

Commit

Permalink
Inline the common expression in With if used once
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Dec 27, 2024
1 parent 2372bc0 commit afd9c48
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] {

private def applyInternal(p: LogicalPlan): LogicalPlan = {
val inputPlans = p.children
val commonExprIdSet = p.expressions
.flatMap(_.collect { case r: CommonExpressionRef => r.id })
.groupBy(identity)
.transform((_, v) => v.size)
.filter(_._2 > 1)
.keySet
val commonExprsPerChild = Array.fill(inputPlans.length)(mutable.ListBuffer.empty[(Alias, Long)])
var newPlan: LogicalPlan = p.mapExpressions { expr =>
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild)
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild, commonExprIdSet)
}
val newChildren = inputPlans.zip(commonExprsPerChild).map { case (inputPlan, commonExprs) =>
if (commonExprs.isEmpty) {
Expand All @@ -96,16 +102,17 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
e: Expression,
inputPlans: Seq[LogicalPlan],
commonExprsPerChild: Array[mutable.ListBuffer[(Alias, Long)]],
commonExprIdSet: Set[CommonExpressionId],
isNestedWith: Boolean = false): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
// Do not handle nested With in one pass. Leave it to the next rule executor batch.
case w: With if !isNestedWith =>
// Rewrite nested With expressions first
val child = rewriteWithExprAndInputPlans(
w.child, inputPlans, commonExprsPerChild, isNestedWith = true)
w.child, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true)
val defs = w.defs.map(rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, isNestedWith = true))
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
Expand All @@ -114,7 +121,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
"Cannot rewrite canonicalized Common expression definitions")
}

if (CollapseProject.isCheap(child)) {
if (CollapseProject.isCheap(child) || !commonExprIdSet.contains(id)) {
refToExpr(id) = child
} else {
val childPlanIndex = inputPlans.indexWhere(
Expand Down Expand Up @@ -171,7 +178,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {

case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith))
rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith))
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
Expand All @@ -185,7 +193,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}

case other => other.mapChildren(
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith)
rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
val innerExpr2 = With(ref2 + ref2 + outerRef, Seq(commonExprDef2))
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
comparePlans(
Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
Expand All @@ -152,7 +152,8 @@ class RewriteWithExpressionSuite extends PlanTest {
.select(star(), (a + a).as("_common_expr_2"))
// The final Project contains the final result expression, which references both common
// expressions.
.select(($"_common_expr_0" + ($"_common_expr_2" + $"_common_expr_0")).as("col"))
.select(($"_common_expr_0" +
($"_common_expr_2" + $"_common_expr_2" + $"_common_expr_0")).as("col"))
.analyze
)
}
Expand Down Expand Up @@ -490,4 +491,13 @@ class RewriteWithExpressionSuite extends PlanTest {
val wrongPlan = testRelation.select(expr1.as("c1"), expr3.as("c3")).analyze
intercept[AssertionError](Optimizer.execute(wrongPlan))
}

test("SPARK-50683: inline the common expression in With if used once") {
val a = testRelation.output.head
val exprDef = CommonExpressionDef(a + a)
val exprRef = new CommonExpressionRef(exprDef)
val expr = With(exprRef + 1, Seq(exprDef))
val plan = testRelation.select(expr.as("col"))
comparePlans(Optimizer.execute(plan), testRelation.select((a + a + 1).as("col")))
}
}

0 comments on commit afd9c48

Please sign in to comment.