Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50683][SQL] Inline the common expression in With if used once #49310

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] {

private def applyInternal(p: LogicalPlan): LogicalPlan = {
val inputPlans = p.children
val commonExprIdSet = p.expressions
.flatMap(_.collect { case r: CommonExpressionRef => r.id })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid collecting references from nested With, as this rule skips nested With and leaves it to the next iteration.

Copy link
Contributor Author

@zml1206 zml1206 Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think collecting references from nested With is necessary, e.g. With(ref0 + With(ref1 + ref1, Seq(def(ref0+b,1))), Seq(def(a+a,0))), first iteration will replace all ref0.
This will collect additional ref1, but it will not affect.

.groupBy(identity)
.transform((_, v) => v.size)
.filter(_._2 > 1)
.keySet
val commonExprsPerChild = Array.fill(inputPlans.length)(mutable.ListBuffer.empty[(Alias, Long)])
var newPlan: LogicalPlan = p.mapExpressions { expr =>
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild)
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild, commonExprIdSet)
}
val newChildren = inputPlans.zip(commonExprsPerChild).map { case (inputPlan, commonExprs) =>
if (commonExprs.isEmpty) {
Expand All @@ -96,16 +102,17 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
e: Expression,
inputPlans: Seq[LogicalPlan],
commonExprsPerChild: Array[mutable.ListBuffer[(Alias, Long)]],
commonExprIdSet: Set[CommonExpressionId],
isNestedWith: Boolean = false): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
// Do not handle nested With in one pass. Leave it to the next rule executor batch.
case w: With if !isNestedWith =>
// Rewrite nested With expressions first
val child = rewriteWithExprAndInputPlans(
w.child, inputPlans, commonExprsPerChild, isNestedWith = true)
w.child, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true)
val defs = w.defs.map(rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, isNestedWith = true))
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
Expand All @@ -114,7 +121,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
"Cannot rewrite canonicalized Common expression definitions")
}

if (CollapseProject.isCheap(child)) {
if (CollapseProject.isCheap(child) || !commonExprIdSet.contains(id)) {
refToExpr(id) = child
} else {
val childPlanIndex = inputPlans.indexWhere(
Expand Down Expand Up @@ -171,7 +178,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {

case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith))
rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith))
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
Expand All @@ -185,7 +193,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}

case other => other.mapChildren(
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith)
rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
val innerExpr2 = With(ref2 + ref2 + outerRef, Seq(commonExprDef2))
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
comparePlans(
Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
Expand All @@ -152,7 +152,8 @@ class RewriteWithExpressionSuite extends PlanTest {
.select(star(), (a + a).as("_common_expr_2"))
// The final Project contains the final result expression, which references both common
// expressions.
.select(($"_common_expr_0" + ($"_common_expr_2" + $"_common_expr_0")).as("col"))
.select(($"_common_expr_0" +
($"_common_expr_2" + $"_common_expr_2" + $"_common_expr_0")).as("col"))
.analyze
)
}
Expand Down Expand Up @@ -490,4 +491,13 @@ class RewriteWithExpressionSuite extends PlanTest {
val wrongPlan = testRelation.select(expr1.as("c1"), expr3.as("c3")).analyze
intercept[AssertionError](Optimizer.execute(wrongPlan))
}

test("SPARK-50683: inline the common expression in With if used once") {
val a = testRelation.output.head
val exprDef = CommonExpressionDef(a + a)
val exprRef = new CommonExpressionRef(exprDef)
val expr = With(exprRef + 1, Seq(exprDef))
val plan = testRelation.select(expr.as("col"))
comparePlans(Optimizer.execute(plan), testRelation.select((a + a + 1).as("col")))
}
}
Loading