Skip to content

Commit 40b134f

Browse files
committed
Extend filtered aggregation optimizer to support only masked partial aggregation cases
1 parent f419d2f commit 40b134f

File tree

3 files changed

+200
-10
lines changed

3 files changed

+200
-10
lines changed

presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,16 @@
3636

3737
import java.util.HashMap;
3838
import java.util.HashSet;
39+
import java.util.LinkedList;
3940
import java.util.List;
4041
import java.util.Map;
42+
import java.util.Optional;
4143
import java.util.Set;
44+
import java.util.function.Function;
4245
import java.util.stream.Collectors;
4346

4447
import static com.facebook.presto.SystemSessionProperties.isMergeAggregationsWithAndWithoutFilter;
48+
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
4549
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
4650
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
4751
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
@@ -123,11 +127,13 @@ private static class Context
123127
{
124128
private final Map<VariableReferenceExpression, VariableReferenceExpression> partialResultToMask;
125129
private final Map<VariableReferenceExpression, VariableReferenceExpression> partialOutputMapping;
130+
private final List<VariableReferenceExpression> newAggregationOutput;
126131

127132
public Context()
128133
{
129134
partialResultToMask = new HashMap<>();
130135
partialOutputMapping = new HashMap<>();
136+
newAggregationOutput = new LinkedList<>();
131137
}
132138

133139
public boolean isEmpty()
@@ -139,6 +145,7 @@ public void clear()
139145
{
140146
partialResultToMask.clear();
141147
partialOutputMapping.clear();
148+
newAggregationOutput.clear();
142149
}
143150

144151
public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialOutputMapping()
@@ -150,6 +157,11 @@ public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialR
150157
{
151158
return partialResultToMask;
152159
}
160+
161+
public List<VariableReferenceExpression> getNewAggregationOutput()
162+
{
163+
return newAggregationOutput;
164+
}
153165
}
154166

155167
private static class Rewriter
@@ -218,17 +230,60 @@ else if (node.getStep().equals(FINAL)) {
218230
private AggregationNode createPartialAggregationNode(AggregationNode node, PlanNode rewrittenSource, RewriteContext<Context> context)
219231
{
220232
checkState(context.get().isEmpty(), "There should be no partial aggregation left unmerged for a partial aggregation node");
233+
221234
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsWithoutMaskToOutput = node.getAggregations().entrySet().stream()
222235
.filter(x -> !x.getValue().getMask().isPresent())
223-
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey(), (a, b) -> a));
236+
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey, (a, b) -> a));
224237
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsToMergeOutput = node.getAggregations().entrySet().stream()
225238
.filter(x -> x.getValue().getMask().isPresent() && aggregationsWithoutMaskToOutput.containsKey(removeFilterAndMask(x.getValue())))
226-
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey()));
239+
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));
240+
241+
ImmutableMap.Builder<AggregationNode.Aggregation, VariableReferenceExpression> partialAggregationToOutputBuilder = ImmutableMap.builder();
242+
partialAggregationToOutputBuilder.putAll(aggregationsToMergeOutput.keySet().stream().collect(toImmutableMap(Function.identity(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x)))));
243+
244+
List<List<AggregationNode.Aggregation>> candidateAggregationsWithMaskNotMatched = node.getAggregations().entrySet().stream().map(Map.Entry::getValue)
245+
.filter(x -> x.getMask().isPresent() && !aggregationsToMergeOutput.containsKey(x))
246+
.collect(Collectors.groupingBy(AggregationNodeUtils::removeFilterAndMask)).values()
247+
.stream().filter(x -> x.size() > 1).collect(toImmutableList());
248+
249+
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsWithMaskToMerge = node.getAggregations().entrySet().stream()
250+
.filter(x -> aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue())))
251+
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));
252+
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newMaskAssignmentsBuilder = ImmutableMap.builder();
253+
ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsAddedBuilder = ImmutableMap.builder();
254+
List<AggregationNode.Aggregation> newAggregationAdded = candidateAggregationsWithMaskNotMatched.stream()
255+
.map(aggregations ->
256+
{
257+
List<VariableReferenceExpression> maskVariables = aggregations.stream().map(x -> x.getMask().get()).collect(toImmutableList());
258+
RowExpression orMaskVariables = or(maskVariables);
259+
VariableReferenceExpression newMaskVariable = variableAllocator.newVariable(orMaskVariables);
260+
newMaskAssignmentsBuilder.put(newMaskVariable, orMaskVariables);
261+
AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(
262+
aggregations.get(0).getCall(),
263+
Optional.empty(),
264+
aggregations.get(0).getOrderBy(),
265+
aggregations.get(0).isDistinct(),
266+
Optional.of(newMaskVariable));
267+
VariableReferenceExpression newAggregationVariable = variableAllocator.newVariable(newAggregation.getCall());
268+
aggregationsAddedBuilder.put(newAggregationVariable, newAggregation);
269+
aggregations.forEach(x -> partialAggregationToOutputBuilder.put(x, newAggregationVariable));
270+
return newAggregation;
271+
})
272+
.collect(toImmutableList());
273+
Map<VariableReferenceExpression, RowExpression> newMaskAssignments = newMaskAssignmentsBuilder.build();
274+
Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsAdded = aggregationsAddedBuilder.build();
275+
Map<AggregationNode.Aggregation, VariableReferenceExpression> partialAggregationToOutput = partialAggregationToOutputBuilder.build();
276+
277+
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsToMergeOutputCombined =
278+
node.getAggregations().entrySet().stream()
279+
.filter(x -> x.getValue().getMask().isPresent() && aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue())))
280+
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));
227281

228-
context.get().getPartialResultToMask().putAll(aggregationsToMergeOutput.entrySet().stream()
229-
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey().getMask().get())));
230-
context.get().getPartialOutputMapping().putAll(aggregationsToMergeOutput.entrySet().stream()
231-
.collect(toImmutableMap(x -> x.getValue(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x.getKey())))));
282+
context.get().getNewAggregationOutput().addAll(aggregationsAdded.keySet());
283+
context.get().getPartialResultToMask().putAll(aggregationsWithMaskToMerge.entrySet().stream()
284+
.collect(toImmutableMap(Map.Entry::getValue, x -> x.getKey().getMask().get())));
285+
context.get().getPartialOutputMapping().putAll(aggregationsWithMaskToMerge.entrySet().stream()
286+
.collect(toImmutableMap(Map.Entry::getValue, x -> partialAggregationToOutput.get(x.getKey()))));
232287

233288
Set<VariableReferenceExpression> maskVariables = new HashSet<>(context.get().getPartialResultToMask().values());
234289
if (maskVariables.isEmpty()) {
@@ -242,14 +297,21 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN
242297
AggregationNode.GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(
243298
groupingVariables.build(), groupingSetDescriptor.getGroupingSetCount(), groupingSetDescriptor.getGlobalGroupingSets());
244299

245-
Set<VariableReferenceExpression> partialResultToMerge = new HashSet<>(aggregationsToMergeOutput.values());
246-
Map<VariableReferenceExpression, AggregationNode.Aggregation> newAggregations = node.getAggregations().entrySet().stream()
300+
Set<VariableReferenceExpression> partialResultToMerge = new HashSet<>(aggregationsToMergeOutputCombined.values());
301+
Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsRemained = node.getAggregations().entrySet().stream()
247302
.filter(x -> !partialResultToMerge.contains(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
303+
Map<VariableReferenceExpression, AggregationNode.Aggregation> newAggregations = ImmutableMap.<VariableReferenceExpression, AggregationNode.Aggregation>builder()
304+
.putAll(aggregationsRemained).putAll(aggregationsAdded).build();
305+
306+
PlanNode newChild = rewrittenSource;
307+
if (!newMaskAssignments.isEmpty()) {
308+
newChild = addProjections(newChild, planNodeIdAllocator, newMaskAssignments);
309+
}
248310

249311
return new AggregationNode(
250312
node.getSourceLocation(),
251313
node.getId(),
252-
rewrittenSource,
314+
newChild,
253315
newAggregations,
254316
partialGroupingSetDescriptor,
255317
node.getPreGroupedVariables(),
@@ -265,7 +327,7 @@ private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNod
265327
return (AggregationNode) node.replaceChildren(ImmutableList.of(rewrittenSource));
266328
}
267329
List<VariableReferenceExpression> intermediateVariables = node.getAggregations().values().stream()
268-
.map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(Collectors.toList());
330+
.map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(toImmutableList());
269331
checkState(intermediateVariables.containsAll(context.get().partialResultToMask.keySet()));
270332

271333
ImmutableList.Builder<RowExpression> projectionsFromPartialAgg = ImmutableList.builder();
@@ -331,6 +393,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context)
331393
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
332394
assignments.putAll(excludeMergedAssignments);
333395
assignments.putAll(identityAssignments(context.get().getPartialResultToMask().values()));
396+
assignments.putAll(identityAssignments(context.get().getNewAggregationOutput()));
334397
return new ProjectNode(
335398
node.getSourceLocation(),
336399
node.getId(),

presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,37 @@ public void testOptimizationApplied()
8787
false);
8888
}
8989

90+
@Test
91+
public void testOptimizationAppliedAllHasMask()
92+
{
93+
assertPlan("SELECT partkey, sum(quantity) filter (where orderkey > 10), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey",
94+
enableOptimization(),
95+
anyTree(
96+
aggregation(
97+
singleGroupingSet("partkey"),
98+
ImmutableMap.of(Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum")),
99+
Optional.of("maskFinalSum2"), functionCall("sum", ImmutableList.of("maskPartialSum2"))),
100+
ImmutableMap.of(),
101+
Optional.empty(),
102+
AggregationNode.Step.FINAL,
103+
project(
104+
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"),
105+
"maskPartialSum2", expression("IF(expr2, partialSum, null)")),
106+
anyTree(
107+
aggregation(
108+
singleGroupingSet("partkey", "expr", "expr2"),
109+
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
110+
ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")),
111+
Optional.empty(),
112+
AggregationNode.Step.PARTIAL,
113+
project(
114+
ImmutableMap.of("expr_or", expression("expr or expr2")),
115+
project(
116+
ImmutableMap.of("expr", expression("orderkey > 0"), "expr2", expression("orderkey >10")),
117+
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))),
118+
false);
119+
}
120+
90121
@Test
91122
public void testOptimizationDisabled()
92123
{
@@ -188,6 +219,57 @@ public void testAggregationsMultipleLevel()
188219
false);
189220
}
190221

222+
@Test
223+
public void testAggregationsMultipleLevelAllAggWithMask()
224+
{
225+
assertPlan("select partkey, avg(sum) filter (where suppkey > 10), avg(sum) filter (where suppkey > 0), avg(filtersum) from (select partkey, suppkey, sum(quantity) filter (where orderkey > 10) sum, sum(quantity) filter (where orderkey > 0) filtersum from lineitem group by partkey, suppkey) t group by partkey",
226+
enableOptimization(),
227+
anyTree(
228+
aggregation(
229+
singleGroupingSet("partkey"),
230+
ImmutableMap.of(Optional.of("finalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg_g10")), Optional.of("maskFinalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg")),
231+
Optional.of("finalFilterAvg"), functionCall("avg", ImmutableList.of("partialFilterAvg"))),
232+
ImmutableMap.of(),
233+
Optional.empty(),
234+
AggregationNode.Step.FINAL,
235+
project(
236+
ImmutableMap.of("maskPartialAvg", expression("IF(expr_2, partialAvg, null)"),
237+
"maskPartialAvg_g10", expression("IF(expr_2_g10, partialAvg, null)")),
238+
anyTree(
239+
aggregation(
240+
singleGroupingSet("partkey", "expr_2", "expr_2_g10"),
241+
ImmutableMap.of(Optional.of("partialAvg"), functionCall("avg", ImmutableList.of("finalSum_g10")), Optional.of("partialFilterAvg"), functionCall("avg", ImmutableList.of("maskFinalSum"))),
242+
ImmutableMap.of(new Symbol("partialAvg"), new Symbol("expr_2_or")),
243+
Optional.empty(),
244+
AggregationNode.Step.PARTIAL,
245+
project(
246+
ImmutableMap.of("expr_2_or", expression("expr_2 or expr_2_g10")),
247+
project(
248+
ImmutableMap.of("expr_2", expression("suppkey > 0"), "expr_2_g10", expression("suppkey > 10")),
249+
aggregation(
250+
singleGroupingSet("partkey", "suppkey"),
251+
ImmutableMap.of(Optional.of("finalSum_g10"), functionCall("sum", ImmutableList.of("maskPartialSum_g10")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
252+
ImmutableMap.of(),
253+
Optional.empty(),
254+
AggregationNode.Step.FINAL,
255+
project(
256+
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"),
257+
"maskPartialSum_g10", expression("IF(expr_g10, partialSum, null)")),
258+
anyTree(
259+
aggregation(
260+
singleGroupingSet("partkey", "suppkey", "expr", "expr_g10"),
261+
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
262+
ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")),
263+
Optional.empty(),
264+
AggregationNode.Step.PARTIAL,
265+
project(
266+
ImmutableMap.of("expr_or", expression("expr or expr_g10")),
267+
project(
268+
ImmutableMap.of("expr", expression("orderkey > 0"), "expr_g10", expression("orderkey > 10")),
269+
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity", "suppkey", "suppkey"))))))))))))))),
270+
false);
271+
}
272+
191273
@Test
192274
public void testGlobalOptimization()
193275
{

0 commit comments

Comments
 (0)