Skip to content

Add ReplaceRedundantJoinWithProject rule to optimizer #25169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes;
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins;
import com.facebook.presto.sql.planner.iterative.rule.ReplaceRedundantJoinWithProject;
import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter;
import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseExpressionPredicate;
import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseToMap;
Expand Down Expand Up @@ -417,6 +418,7 @@ public PlanOptimizers(
new PushLimitThroughSemiJoin(),
new PushLimitThroughUnion(),
new RemoveTrivialFilters(),
new ReplaceRedundantJoinWithProject(),
new ImplementFilteredAggregations(metadata.getFunctionAndTypeManager()),
new SingleDistinctAggregationToGroupBy(),
new MultipleDistinctAggregationToMarkDistinct(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;

import java.util.List;
import java.util.stream.Collectors;

import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isEmpty;
import static com.facebook.presto.sql.planner.plan.Patterns.join;

public class ReplaceRedundantJoinWithProject
implements Rule<JoinNode>
{
private static final Pattern<JoinNode> PATTERN = join();

@Override
public Pattern<JoinNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(JoinNode node, Captures captures, Context context)
{
Lookup lookup = context.getLookup();
PlanNode left = node.getLeft();
PlanNode right = node.getRight();

List<VariableReferenceExpression> leftOutputVariables = node.getOutputVariables().stream()
.filter(variable -> left.getOutputVariables().contains(variable))
.collect(Collectors.toList());

List<VariableReferenceExpression> rightOutputVariables = node.getOutputVariables().stream()
.filter(variable -> right.getOutputVariables().contains(variable))
.collect(Collectors.toList());

switch (node.getType()) {
case INNER:
return Result.empty();
case LEFT:
return !isEmpty(left, lookup) && isEmpty(right, lookup) ?
Result.ofPlanNode(appendNulls(
left,
leftOutputVariables,
rightOutputVariables,
context.getIdAllocator()
)) :
Result.empty();
case RIGHT:
return isEmpty(left, lookup) && !isEmpty(right, lookup) ?
Result.ofPlanNode(appendNulls(
right,
rightOutputVariables,
leftOutputVariables,
context.getIdAllocator()
)) :
Result.empty();
case FULL:
if (isEmpty(left, lookup) && !isEmpty(right, lookup)) {
return Result.ofPlanNode(appendNulls(
right,
rightOutputVariables,
leftOutputVariables,
context.getIdAllocator()));
}
if (!isEmpty(left, lookup) && isEmpty(right, lookup)) {
return Result.ofPlanNode(appendNulls(
left,
leftOutputVariables,
rightOutputVariables,
context.getIdAllocator()));
}
return Result.empty();
default:
throw new IllegalArgumentException();
}
}

private static ProjectNode appendNulls(PlanNode source, List<VariableReferenceExpression> sourceOutputs, List<VariableReferenceExpression> nullVariables, PlanNodeIdAllocator idAllocator)
{
Assignments.Builder assignments = Assignments.builder()
.putIdentities(sourceOutputs);
nullVariables
.forEach(variable -> assignments.put(variable, new ConstantExpression(null, variable.getType())));

return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public static boolean isAtMost(PlanNode node, Lookup lookup, long maxCardinality
return Range.closed(0L, maxCardinality).encloses(extractCardinality(node, lookup));
}

public static boolean isEmpty(PlanNode node, Lookup lookup)
{
return isAtMost(node, lookup, 0);
}

public static Range<Long> extractCardinality(PlanNode node)
{
return extractCardinality(node, noLookup());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.StringLiteral;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -82,6 +83,9 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses
.map(rowExpression -> {
ConstantExpression expression = (ConstantExpression) rowExpression;
if (expression.getType().getJavaType() == boolean.class) {
if (expression.isNull()) {
return new NullLiteral();
}
return new BooleanLiteral(String.valueOf(expression.getValue()));
}
if (expression.getType() instanceof ShortDecimalType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.tree.NullLiteral;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import static com.facebook.presto.spi.plan.JoinType.FULL;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.spi.plan.JoinType.RIGHT;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static java.util.Collections.nCopies;

public class TestReplaceRedundantJoinWithProject
extends BaseRuleTest
{
@Test
public void testDoesNotFireOnInnerJoin()
{
tester().assertThat(new ReplaceRedundantJoinWithProject())
.on(p ->
p.join(
INNER,
p.values(0, p.variable("a")),
p.values(0, p.variable("b"))))
.doesNotFire();
}

@Test
public void testDoesNotFireWhenOuterSourceEmpty()
{
tester().assertThat(new ReplaceRedundantJoinWithProject())
.on(p ->
p.join(
LEFT,
p.values(0, p.variable("a")),
p.values(0, p.variable("b"))))
.doesNotFire();

tester().assertThat(new ReplaceRedundantJoinWithProject())
.on(p ->
p.join(
RIGHT,
p.values(0, p.variable("a")),
p.values(0, p.variable("b"))))
.doesNotFire();
}

@Test
public void testDoesNotFireOnFullJoinWithBothSourcesEmpty()
{
tester().assertThat(new ReplaceRedundantJoinWithProject())
.on(p ->
p.join(
FULL,
p.values(0, p.variable("a")),
p.values(0, p.variable("b"))))
.doesNotFire();
}

@Test
public void testReplaceLeftJoin()
{
tester().assertThat(new ReplaceRedundantJoinWithProject())
.on(p ->
p.join(
LEFT,
p.values(10, p.variable("a")),
p.values(0, p.variable("b"))))
.matches(
project(
ImmutableMap.of(
"a", expression("a"),
"b", expression("null")),
values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new NullLiteral())))));
}

@Test
public void testReplaceRightJoin()
{
tester().assertThat(new ReplaceRedundantJoinWithProject())
.on(p ->
p.join(
RIGHT,
p.values(0, p.variable("a")),
p.values(10, p.variable("b"))))
.matches(
project(
ImmutableMap.of(
"a", expression("null"),
"b", expression("b")),
values(ImmutableList.of("b"), nCopies(10, ImmutableList.of(new NullLiteral())))));
}

@Test
public void testReplaceFULLJoin()
{
tester().assertThat(new ReplaceRedundantJoinWithProject())
.on(p ->
p.join(
FULL,
p.values(10, p.variable("a")),
p.values(0, p.variable("b"))))
.matches(
project(
ImmutableMap.of(
"a", expression("a"),
"b", expression("null")),
values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new NullLiteral())))));

tester().assertThat(new ReplaceRedundantJoinWithProject())
.on(p ->
p.join(
FULL,
p.values(0, p.variable("a")),
p.values(10, p.variable("b"))))
.matches(
project(
ImmutableMap.of(
"a", expression("null"),
"b", expression("b")),
values(ImmutableList.of("b"), nCopies(10, ImmutableList.of(new NullLiteral())))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,20 @@ public Builder put(Entry<VariableReferenceExpression, RowExpression> assignment)
return this;
}

public Builder putIdentities(Iterable<VariableReferenceExpression> variables)
{
for (VariableReferenceExpression variable : variables) {
putIdentity(variable);
}
return this;
}

public Builder putIdentity(VariableReferenceExpression variable)
{
put(variable, variable);
return this;
}

public Assignments build()
{
return new Assignments(assignments);
Expand Down
Loading