Skip to content

Commit e322aec

Browse files
committed
Extend filtered aggregation optimizer to support only masked partial aggregation cases
1 parent f419d2f commit e322aec

File tree

3 files changed

+161
-9
lines changed

3 files changed

+161
-9
lines changed

presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@
3636

3737
import java.util.HashMap;
3838
import java.util.HashSet;
39+
import java.util.LinkedList;
3940
import java.util.List;
4041
import java.util.Map;
4142
import java.util.Set;
4243
import java.util.stream.Collectors;
44+
import java.util.stream.IntStream;
4345

4446
import static com.facebook.presto.SystemSessionProperties.isMergeAggregationsWithAndWithoutFilter;
4547
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
@@ -123,11 +125,13 @@ private static class Context
123125
{
124126
private final Map<VariableReferenceExpression, VariableReferenceExpression> partialResultToMask;
125127
private final Map<VariableReferenceExpression, VariableReferenceExpression> partialOutputMapping;
128+
private final List<VariableReferenceExpression> newAggregationOutput;
126129

127130
public Context()
128131
{
129132
partialResultToMask = new HashMap<>();
130133
partialOutputMapping = new HashMap<>();
134+
newAggregationOutput = new LinkedList<>();
131135
}
132136

133137
public boolean isEmpty()
@@ -139,6 +143,7 @@ public void clear()
139143
{
140144
partialResultToMask.clear();
141145
partialOutputMapping.clear();
146+
newAggregationOutput.clear();
142147
}
143148

144149
public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialOutputMapping()
@@ -150,6 +155,11 @@ public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialR
150155
{
151156
return partialResultToMask;
152157
}
158+
159+
public List<VariableReferenceExpression> getNewAggregationOutput()
160+
{
161+
return newAggregationOutput;
162+
}
153163
}
154164

155165
private static class Rewriter
@@ -218,17 +228,32 @@ else if (node.getStep().equals(FINAL)) {
218228
private AggregationNode createPartialAggregationNode(AggregationNode node, PlanNode rewrittenSource, RewriteContext<Context> context)
219229
{
220230
checkState(context.get().isEmpty(), "There should be no partial aggregation left unmerged for a partial aggregation node");
231+
221232
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsWithoutMaskToOutput = node.getAggregations().entrySet().stream()
222233
.filter(x -> !x.getValue().getMask().isPresent())
223-
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey(), (a, b) -> a));
234+
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey, (a, b) -> a));
224235
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsToMergeOutput = node.getAggregations().entrySet().stream()
225236
.filter(x -> x.getValue().getMask().isPresent() && aggregationsWithoutMaskToOutput.containsKey(removeFilterAndMask(x.getValue())))
226-
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey()));
227-
228-
context.get().getPartialResultToMask().putAll(aggregationsToMergeOutput.entrySet().stream()
229-
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey().getMask().get())));
230-
context.get().getPartialOutputMapping().putAll(aggregationsToMergeOutput.entrySet().stream()
231-
.collect(toImmutableMap(x -> x.getValue(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x.getKey())))));
237+
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));
238+
239+
List<List<AggregationNode.Aggregation>> aggregationWithMaskNotMatched = node.getAggregations().entrySet().stream().map(Map.Entry::getValue).filter(x -> x.getMask().isPresent() && !aggregationsToMergeOutput.containsKey(x))
240+
.collect(Collectors.groupingBy(AggregationNodeUtils::removeFilterAndMask)).values()
241+
.stream().filter(x -> x.size() > 1).collect(Collectors.toList());
242+
List<AggregationNode.Aggregation> newAggregationNoMask = aggregationWithMaskNotMatched.stream().map(x -> removeFilterAndMask(x.get(0))).collect(Collectors.toList());
243+
List<VariableReferenceExpression> variableForNewAggregationNoMask = newAggregationNoMask.stream().map(x -> variableAllocator.newVariable(x.getCall())).collect(Collectors.toList());
244+
Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsAdded = IntStream.range(0, newAggregationNoMask.size()).boxed().collect(toImmutableMap(variableForNewAggregationNoMask::get, newAggregationNoMask::get));
245+
246+
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsWithoutMaskToOutputCombined = ImmutableMap.<AggregationNode.Aggregation, VariableReferenceExpression>builder().putAll(aggregationsWithoutMaskToOutput)
247+
.putAll(aggregationsAdded.entrySet().stream().collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey))).build();
248+
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsToMergeOutputCombined = node.getAggregations().entrySet().stream()
249+
.filter(x -> x.getValue().getMask().isPresent() && aggregationsWithoutMaskToOutputCombined.containsKey(removeFilterAndMask(x.getValue())))
250+
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));
251+
252+
context.get().getNewAggregationOutput().addAll(variableForNewAggregationNoMask);
253+
context.get().getPartialResultToMask().putAll(aggregationsToMergeOutputCombined.entrySet().stream()
254+
.collect(toImmutableMap(Map.Entry::getValue, x -> x.getKey().getMask().get())));
255+
context.get().getPartialOutputMapping().putAll(aggregationsToMergeOutputCombined.entrySet().stream()
256+
.collect(toImmutableMap(Map.Entry::getValue, x -> aggregationsWithoutMaskToOutputCombined.get(removeFilterAndMask(x.getKey())))));
232257

233258
Set<VariableReferenceExpression> maskVariables = new HashSet<>(context.get().getPartialResultToMask().values());
234259
if (maskVariables.isEmpty()) {
@@ -242,9 +267,11 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN
242267
AggregationNode.GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(
243268
groupingVariables.build(), groupingSetDescriptor.getGroupingSetCount(), groupingSetDescriptor.getGlobalGroupingSets());
244269

245-
Set<VariableReferenceExpression> partialResultToMerge = new HashSet<>(aggregationsToMergeOutput.values());
246-
Map<VariableReferenceExpression, AggregationNode.Aggregation> newAggregations = node.getAggregations().entrySet().stream()
270+
Set<VariableReferenceExpression> partialResultToMerge = new HashSet<>(aggregationsToMergeOutputCombined.values());
271+
Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsRemained = node.getAggregations().entrySet().stream()
247272
.filter(x -> !partialResultToMerge.contains(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
273+
Map<VariableReferenceExpression, AggregationNode.Aggregation> newAggregations = ImmutableMap.<VariableReferenceExpression, AggregationNode.Aggregation>builder()
274+
.putAll(aggregationsRemained).putAll(aggregationsAdded).build();
248275

249276
return new AggregationNode(
250277
node.getSourceLocation(),
@@ -331,6 +358,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context)
331358
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
332359
assignments.putAll(excludeMergedAssignments);
333360
assignments.putAll(identityAssignments(context.get().getPartialResultToMask().values()));
361+
assignments.putAll(identityAssignments(context.get().getNewAggregationOutput()));
334362
return new ProjectNode(
335363
node.getSourceLocation(),
336364
node.getId(),

presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,36 @@ public void testOptimizationApplied()
8787
false);
8888
}
8989

90+
@Test
91+
public void testOptimizationAppliedAllHasMask()
92+
{
93+
assertPlan("SELECT partkey, sum(quantity) filter (where orderkey > 10), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey",
94+
enableOptimization(),
95+
anyTree(
96+
aggregation(
97+
singleGroupingSet("partkey"),
98+
ImmutableMap.of(Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum")),
99+
Optional.of("maskFinalSum2"), functionCall("sum", ImmutableList.of("maskPartialSum2"))),
100+
ImmutableMap.of(),
101+
Optional.empty(),
102+
AggregationNode.Step.FINAL,
103+
project(
104+
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"),
105+
"maskPartialSum2", expression("IF(expr2, partialSum, null)")),
106+
anyTree(
107+
aggregation(
108+
singleGroupingSet("partkey", "expr", "expr2"),
109+
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
110+
ImmutableMap.of(),
111+
Optional.empty(),
112+
AggregationNode.Step.PARTIAL,
113+
anyTree(
114+
project(
115+
ImmutableMap.of("expr", expression("orderkey > 0"), "expr2", expression("orderkey >10")),
116+
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))),
117+
false);
118+
}
119+
90120
@Test
91121
public void testOptimizationDisabled()
92122
{
@@ -188,6 +218,55 @@ public void testAggregationsMultipleLevel()
188218
false);
189219
}
190220

221+
@Test
222+
public void testAggregationsMultipleLevelAllAggWithMask()
223+
{
224+
assertPlan("select partkey, avg(sum) filter (where suppkey > 10), avg(sum) filter (where suppkey > 0), avg(filtersum) from (select partkey, suppkey, sum(quantity) filter (where orderkey > 10) sum, sum(quantity) filter (where orderkey > 0) filtersum from lineitem group by partkey, suppkey) t group by partkey",
225+
enableOptimization(),
226+
anyTree(
227+
aggregation(
228+
singleGroupingSet("partkey"),
229+
ImmutableMap.of(Optional.of("finalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg_g10")), Optional.of("maskFinalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg")),
230+
Optional.of("finalFilterAvg"), functionCall("avg", ImmutableList.of("partialFilterAvg"))),
231+
ImmutableMap.of(),
232+
Optional.empty(),
233+
AggregationNode.Step.FINAL,
234+
project(
235+
ImmutableMap.of("maskPartialAvg", expression("IF(expr_2, partialAvg, null)"),
236+
"maskPartialAvg_g10", expression("IF(expr_2_g10, partialAvg, null)")),
237+
anyTree(
238+
aggregation(
239+
singleGroupingSet("partkey", "expr_2", "expr_2_g10"),
240+
ImmutableMap.of(Optional.of("partialAvg"), functionCall("avg", ImmutableList.of("finalSum_g10")), Optional.of("partialFilterAvg"), functionCall("avg", ImmutableList.of("maskFinalSum"))),
241+
ImmutableMap.of(),
242+
Optional.empty(),
243+
AggregationNode.Step.PARTIAL,
244+
anyTree(
245+
project(
246+
ImmutableMap.of("expr_2", expression("suppkey > 0"), "expr_2_g10", expression("suppkey > 10")),
247+
aggregation(
248+
singleGroupingSet("partkey", "suppkey"),
249+
ImmutableMap.of(Optional.of("finalSum_g10"), functionCall("sum", ImmutableList.of("maskPartialSum_g10")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
250+
ImmutableMap.of(),
251+
Optional.empty(),
252+
AggregationNode.Step.FINAL,
253+
project(
254+
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"),
255+
"maskPartialSum_g10", expression("IF(expr_g10, partialSum, null)")),
256+
anyTree(
257+
aggregation(
258+
singleGroupingSet("partkey", "suppkey", "expr", "expr_g10"),
259+
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
260+
ImmutableMap.of(),
261+
Optional.empty(),
262+
AggregationNode.Step.PARTIAL,
263+
anyTree(
264+
project(
265+
ImmutableMap.of("expr", expression("orderkey > 0"), "expr_g10", expression("orderkey > 10")),
266+
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity", "suppkey", "suppkey"))))))))))))))),
267+
false);
268+
}
269+
191270
@Test
192271
public void testGlobalOptimization()
193272
{

0 commit comments

Comments
 (0)