Skip to content

Commit 6785d1f

Browse files
authored
[feat](nerieds)Push encode slot (#45958)
### What problem does this PR solve? Issue Number: close #xxx Related PR: #xxx Problem Summary: ### Release note None ### Check List (For Author) - Test <!-- At least one of them must be included. --> - [x] Regression test - [ ] Unit Test - [ ] Manual test (add detailed scripts or steps below) - [ ] No need to test or manual test. Explain why: - [ ] This is a refactor/code format and no logic has been changed. - [ ] Previous test can cover this change. - [ ] No code files have been changed. - [ ] Other reason <!-- Add your reason? --> - Behavior changed: - [x] No. - [ ] Yes. <!-- Explain the behavior change --> - Does this need documentation? - [x] No. - [ ] Yes. <!-- Add document PR link here. eg: apache/doris-website#1214 --> ### Check List (For Reviewer who merge this PR) - [ ] Confirm the release note - [ ] Confirm test cases - [ ] Confirm document - [ ] Add branch pick label <!-- Add branch pick label that this PR should merge into -->
1 parent 10eac22 commit 6785d1f

26 files changed

+1320
-40
lines changed

fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVPlanUtil.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ public static ConnectContext createMTMVContext(MTMV mtmv) {
6262
ctx.getSessionVariable().allowModifyMaterializedViewData = true;
6363
// Disable add default limit rule to avoid refresh data wrong
6464
ctx.getSessionVariable().setDisableNereidsRules(
65-
String.join(",", ImmutableSet.of(RuleType.ADD_DEFAULT_LIMIT.name())));
65+
String.join(",", ImmutableSet.of(
66+
"COMPRESSED_MATERIALIZE_AGG", "COMPRESSED_MATERIALIZE_SORT",
67+
RuleType.ADD_DEFAULT_LIMIT.name())));
6668
Optional<String> workloadGroup = mtmv.getWorkloadGroup();
6769
if (workloadGroup.isPresent()) {
6870
ctx.getSessionVariable().setWorkloadGroup(workloadGroup.get());

fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
5656
import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite;
5757
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
58+
import org.apache.doris.nereids.rules.rewrite.DecoupleEncodeDecode;
5859
import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
5960
import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen;
6061
import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
@@ -106,6 +107,7 @@
106107
import org.apache.doris.nereids.rules.rewrite.PruneOlapScanTablet;
107108
import org.apache.doris.nereids.rules.rewrite.PullUpCteAnchor;
108109
import org.apache.doris.nereids.rules.rewrite.PullUpJoinFromUnionAll;
110+
import org.apache.doris.nereids.rules.rewrite.PullUpProjectBetweenTopNAndAgg;
109111
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderApply;
110112
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderLimit;
111113
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN;
@@ -115,6 +117,7 @@
115117
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
116118
import org.apache.doris.nereids.rules.rewrite.PushDownAggWithDistinctThroughJoinOneSide;
117119
import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin;
120+
import org.apache.doris.nereids.rules.rewrite.PushDownEncodeSlot;
118121
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
119122
import org.apache.doris.nereids.rules.rewrite.PushDownLimit;
120123
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
@@ -253,6 +256,13 @@ public class Rewriter extends AbstractBatchJobExecutor {
253256
new CountLiteralRewrite(),
254257
new NormalizeSort()
255258
),
259+
260+
topDown(// must behind NormalizeAggregate/NormalizeSort
261+
new MergeProjects(),
262+
new PushDownEncodeSlot(),
263+
new DecoupleEncodeDecode()
264+
),
265+
256266
topic("Window analysis",
257267
topDown(
258268
new ExtractAndNormalizeWindowExpression(),
@@ -372,9 +382,11 @@ public class Rewriter extends AbstractBatchJobExecutor {
372382
// generate one PhysicalLimit if current distribution is gather or two
373383
// PhysicalLimits with gather exchange
374384
topDown(new LimitSortToTopN()),
375-
topDown(new SimplifyEncodeDecode()),
376-
topDown(new LimitAggToTopNAgg()),
377385
topDown(new MergeTopNs()),
386+
topDown(new SimplifyEncodeDecode(),
387+
new MergeProjects()
388+
),
389+
topDown(new LimitAggToTopNAgg()),
378390
topDown(new SplitLimit()),
379391
topDown(
380392
new PushDownLimit(),
@@ -466,6 +478,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
466478
custom(RuleType.ADD_PROJECT_FOR_JOIN, AddProjectForJoin::new),
467479
topDown(new MergeProjects())
468480
),
481+
topic("Adjust topN project",
482+
topDown(new MergeProjects(),
483+
new PullUpProjectBetweenTopNAndAgg())),
469484
// this rule batch must keep at the end of rewrite to do some plan check
470485
topic("Final rewrite and check",
471486
custom(RuleType.CHECK_DATA_TYPES, CheckDataTypes::new),

fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java

+1-7
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,14 @@ public class LogicalProperties {
4444
protected final Supplier<DataTrait> dataTraitSupplier;
4545
private Integer hashCode = null;
4646

47-
public LogicalProperties(Supplier<List<Slot>> outputSupplier,
48-
Supplier<DataTrait> dataTraitSupplier) {
49-
this(outputSupplier, dataTraitSupplier, ImmutableList::of);
50-
}
51-
5247
/**
5348
* constructor of LogicalProperties.
5449
*
5550
* @param outputSupplier provide the output. Supplier can lazy compute output without
5651
* throw exception for which children have UnboundRelation
5752
*/
5853
public LogicalProperties(Supplier<List<Slot>> outputSupplier,
59-
Supplier<DataTrait> dataTraitSupplier,
60-
Supplier<List<Slot>> nonUserVisibleOutputSupplier) {
54+
Supplier<DataTrait> dataTraitSupplier) {
6155
this.outputSupplier = Suppliers.memoize(
6256
Objects.requireNonNull(outputSupplier, "outputSupplier can not be null")
6357
);

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java

+4
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ public enum RuleType {
114114
// rewrite rules
115115
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
116116
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
117+
COMPRESSED_MATERIALIZE_REPEAT(RuleTypeClass.REWRITE),
118+
PUSH_DOWN_ENCODE_SLOT(RuleTypeClass.REWRITE),
119+
ADJUST_TOPN_PROJECT(RuleTypeClass.REWRITE),
120+
DECOUPLE_DECODE_ENCODE_SLOT(RuleTypeClass.REWRITE),
117121
SIMPLIFY_ENCODE_DECODE(RuleTypeClass.REWRITE),
118122
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
119123
NORMALIZE_SORT(RuleTypeClass.REWRITE),

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CompressedMaterialize.java

+52
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
3232
import org.apache.doris.nereids.trees.plans.Plan;
3333
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
34+
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
3435
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
3536
import org.apache.doris.nereids.types.DataType;
3637
import org.apache.doris.nereids.types.coercion.CharacterType;
@@ -101,6 +102,9 @@ private LogicalSort<Plan> compressMaterializeSort(LogicalSort<Plan> sort) {
101102
}
102103

103104
private Optional<Expression> getEncodeExpression(Expression expression) {
105+
if (expression.isConstant()) {
106+
return Optional.empty();
107+
}
104108
DataType type = expression.getDataType();
105109
Expression encodeExpr = null;
106110
if (type instanceof CharacterType) {
@@ -169,4 +173,52 @@ private LogicalAggregate<Plan> compressedMaterializeAggregate(LogicalAggregate<P
169173
}
170174
return aggregate;
171175
}
176+
177+
private Map<Expression, Expression> getEncodeGroupingSets(LogicalRepeat<Plan> repeat) {
178+
Map<Expression, Expression> encode = Maps.newHashMap();
179+
// the first grouping set contains all group by keys
180+
for (Expression gb : repeat.getGroupingSets().get(0)) {
181+
Optional<Expression> encodeExpr = getEncodeExpression(gb);
182+
encodeExpr.ifPresent(expression -> encode.put(gb, expression));
183+
}
184+
return encode;
185+
}
186+
187+
private LogicalRepeat<Plan> compressMaterializeRepeat(LogicalRepeat<Plan> repeat) {
188+
Map<Expression, Expression> encode = getEncodeGroupingSets(repeat);
189+
if (encode.isEmpty()) {
190+
return repeat;
191+
}
192+
List<List<Expression>> newGroupingSets = Lists.newArrayList();
193+
for (int i = 0; i < repeat.getGroupingSets().size(); i++) {
194+
List<Expression> grouping = Lists.newArrayList();
195+
for (int j = 0; j < repeat.getGroupingSets().get(i).size(); j++) {
196+
Expression groupingExpr = repeat.getGroupingSets().get(i).get(j);
197+
grouping.add(encode.getOrDefault(groupingExpr, groupingExpr));
198+
}
199+
newGroupingSets.add(grouping);
200+
}
201+
List<NamedExpression> newOutputs = Lists.newArrayList();
202+
Map<Expression, Expression> decodeMap = new HashMap<>();
203+
for (Expression gp : encode.keySet()) {
204+
decodeMap.put(gp, new DecodeAsVarchar(encode.get(gp)));
205+
}
206+
for (NamedExpression out : repeat.getOutputExpressions()) {
207+
Expression replaced = ExpressionUtils.replace(out, decodeMap);
208+
if (out != replaced) {
209+
if (out instanceof SlotReference) {
210+
newOutputs.add(new Alias(out.getExprId(), replaced, out.getName()));
211+
} else if (out instanceof Alias) {
212+
newOutputs.add(((Alias) out).withChildren(replaced.children()));
213+
} else {
214+
// should not reach here
215+
Preconditions.checkArgument(false, "output abnormal: " + repeat);
216+
}
217+
} else {
218+
newOutputs.add(out);
219+
}
220+
}
221+
repeat = repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs);
222+
return repeat;
223+
}
172224
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.rules.rewrite;
19+
20+
import org.apache.doris.nereids.rules.Rule;
21+
import org.apache.doris.nereids.rules.RuleType;
22+
import org.apache.doris.nereids.trees.expressions.Alias;
23+
import org.apache.doris.nereids.trees.expressions.Expression;
24+
import org.apache.doris.nereids.trees.expressions.NamedExpression;
25+
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
26+
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString;
27+
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
28+
import org.apache.doris.qe.ConnectContext;
29+
30+
import com.google.common.collect.Lists;
31+
32+
import java.util.List;
33+
34+
/**
35+
* in project:
36+
* decode_as_varchar(encode_as_xxx(v)) => v
37+
*/
38+
public class DecoupleEncodeDecode extends OneRewriteRuleFactory {
39+
@Override
40+
public Rule build() {
41+
return logicalProject()
42+
.when(topN -> ConnectContext.get() != null
43+
&& ConnectContext.get().getSessionVariable().enableCompressMaterialize)
44+
.then(this::rewrite)
45+
.toRule(RuleType.DECOUPLE_DECODE_ENCODE_SLOT);
46+
}
47+
48+
private LogicalProject<?> rewrite(LogicalProject<?> project) {
49+
List<NamedExpression> newProjections = Lists.newArrayList();
50+
boolean hasNewProjections = false;
51+
for (NamedExpression e : project.getProjects()) {
52+
boolean changed = false;
53+
if (e instanceof Alias) {
54+
Alias alias = (Alias) e;
55+
Expression body = alias.child();
56+
if (body instanceof DecodeAsVarchar && body.child(0) instanceof EncodeString) {
57+
Expression encodeBody = body.child(0).child(0);
58+
newProjections.add((NamedExpression) alias.withChildren(encodeBody));
59+
changed = true;
60+
} else if (body instanceof EncodeString && body.child(0) instanceof DecodeAsVarchar) {
61+
Expression decodeBody = body.child(0).child(0);
62+
newProjections.add((NamedExpression) alias.withChildren(decodeBody));
63+
changed = true;
64+
}
65+
}
66+
if (!changed) {
67+
newProjections.add(e);
68+
hasNewProjections = true;
69+
}
70+
}
71+
if (hasNewProjections) {
72+
project = project.withProjects(newProjections);
73+
}
74+
return project;
75+
}
76+
77+
}

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpJoinFromUnionAll.java

+11-4
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,17 @@ boolean comparePlan(Plan plan1, Plan plan2) {
586586
isEqual = false;
587587
}
588588
for (int i = 0; isEqual && i < plan2.getOutput().size(); i++) {
589-
NamedExpression expr = ((LogicalProject<?>) plan1).getProjects().get(i);
590-
NamedExpression replacedExpr = (NamedExpression)
591-
expr.rewriteUp(e -> plan1ToPlan2.getOrDefault(e, e));
592-
if (!replacedExpr.equals(((LogicalProject<?>) plan2).getProjects().get(i))) {
589+
Expression expr1 = ((LogicalProject<?>) plan1).getProjects().get(i);
590+
Expression expr2 = ((LogicalProject<?>) plan2).getProjects().get(i);
591+
if (expr1 instanceof Alias) {
592+
if (!(expr2 instanceof Alias)) {
593+
return false;
594+
}
595+
expr1 = expr1.child(0);
596+
expr2 = expr2.child(0);
597+
}
598+
Expression replacedExpr = expr1.rewriteUp(e -> plan1ToPlan2.getOrDefault(e, e));
599+
if (!replacedExpr.equals(expr2)) {
593600
isEqual = false;
594601
}
595602
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.rules.rewrite;
19+
20+
import org.apache.doris.nereids.properties.OrderKey;
21+
import org.apache.doris.nereids.rules.Rule;
22+
import org.apache.doris.nereids.rules.RuleType;
23+
import org.apache.doris.nereids.trees.expressions.Alias;
24+
import org.apache.doris.nereids.trees.expressions.Expression;
25+
import org.apache.doris.nereids.trees.expressions.NamedExpression;
26+
import org.apache.doris.nereids.trees.expressions.Slot;
27+
import org.apache.doris.nereids.trees.expressions.SlotReference;
28+
import org.apache.doris.nereids.trees.plans.Plan;
29+
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
30+
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
31+
import org.apache.doris.qe.ConnectContext;
32+
33+
import org.apache.logging.log4j.LogManager;
34+
import org.apache.logging.log4j.Logger;
35+
36+
import java.util.ArrayList;
37+
import java.util.HashMap;
38+
import java.util.List;
39+
import java.util.Map;
40+
import java.util.Set;
41+
42+
/**
43+
*
44+
* try to reduce shuffle cost of topN operator, used to optimize plan after applying Compress_materialize
45+
*
46+
* topn(orderKey=[a])
47+
* --> project(a+1 as x, a+2 as y, a)
48+
* --> any(output(a))
49+
* =>
50+
* project(a+1 as x, a+2 as y, a)
51+
* --> topn(orderKey=[a])
52+
* --> any(output(a))
53+
*
54+
*/
55+
public class PullUpProjectBetweenTopNAndAgg extends OneRewriteRuleFactory {
56+
public static final Logger LOG = LogManager.getLogger(PullUpProjectBetweenTopNAndAgg.class);
57+
58+
@Override
59+
public Rule build() {
60+
return logicalTopN(logicalProject(logicalAggregate()))
61+
.when(topN -> ConnectContext.get() != null
62+
&& ConnectContext.get().getSessionVariable().enableCompressMaterialize)
63+
.then(topN -> adjust(topN)).toRule(RuleType.ADJUST_TOPN_PROJECT);
64+
}
65+
66+
private Plan adjust(LogicalTopN<? extends Plan> topN) {
67+
LogicalProject<Plan> project = (LogicalProject<Plan>) topN.child();
68+
Set<Slot> projectInputSlots = project.getInputSlots();
69+
Map<SlotReference, SlotReference> keyAsKey = new HashMap<>();
70+
for (NamedExpression proj : project.getProjects()) {
71+
if (proj instanceof Alias && ((Alias) proj).child(0) instanceof SlotReference) {
72+
keyAsKey.put((SlotReference) ((Alias) proj).toSlot(), (SlotReference) ((Alias) proj).child());
73+
}
74+
}
75+
boolean match = true;
76+
List<OrderKey> newOrderKeys = new ArrayList<>();
77+
for (OrderKey orderKey : topN.getOrderKeys()) {
78+
Expression orderExpr = orderKey.getExpr();
79+
if (orderExpr instanceof SlotReference) {
80+
if (projectInputSlots.contains(orderExpr)) {
81+
newOrderKeys.add(orderKey);
82+
} else if (keyAsKey.containsKey(orderExpr)) {
83+
newOrderKeys.add(orderKey.withExpression(keyAsKey.get(orderExpr)));
84+
} else {
85+
match = false;
86+
break;
87+
}
88+
} else {
89+
match = false;
90+
break;
91+
}
92+
}
93+
if (match) {
94+
if (project.getProjects().size() >= project.getInputSlots().size()) {
95+
LOG.info("$$$$ before: project.getProjects() = " + project.getProjects());
96+
LOG.info("$$$$ before: project.getInputSlots() = " + project.getInputSlots());
97+
LOG.info("$$$$ before: " + topN.treeString());
98+
topN = topN.withChildren(project.children()).withOrderKeys(newOrderKeys);
99+
project = (LogicalProject<Plan>) project.withChildren(topN);
100+
LOG.info("$$$$ after:" + project.treeString());
101+
return project;
102+
}
103+
}
104+
return topN;
105+
}
106+
}

0 commit comments

Comments
 (0)