36
36
37
37
import java .util .HashMap ;
38
38
import java .util .HashSet ;
39
+ import java .util .LinkedList ;
39
40
import java .util .List ;
40
41
import java .util .Map ;
42
+ import java .util .Optional ;
41
43
import java .util .Set ;
44
+ import java .util .function .Function ;
42
45
import java .util .stream .Collectors ;
43
46
44
47
import static com .facebook .presto .SystemSessionProperties .isMergeAggregationsWithAndWithoutFilter ;
48
+ import static com .facebook .presto .expressions .LogicalRowExpressions .or ;
45
49
import static com .facebook .presto .spi .StandardErrorCode .GENERIC_INTERNAL_ERROR ;
46
50
import static com .facebook .presto .spi .plan .AggregationNode .Step .FINAL ;
47
51
import static com .facebook .presto .spi .plan .AggregationNode .Step .PARTIAL ;
@@ -123,11 +127,13 @@ private static class Context
123
127
{
124
128
private final Map <VariableReferenceExpression , VariableReferenceExpression > partialResultToMask ;
125
129
private final Map <VariableReferenceExpression , VariableReferenceExpression > partialOutputMapping ;
130
+ private final List <VariableReferenceExpression > newAggregationOutput ;
126
131
127
132
public Context ()
128
133
{
129
134
partialResultToMask = new HashMap <>();
130
135
partialOutputMapping = new HashMap <>();
136
+ newAggregationOutput = new LinkedList <>();
131
137
}
132
138
133
139
public boolean isEmpty ()
@@ -139,6 +145,7 @@ public void clear()
139
145
{
140
146
partialResultToMask .clear ();
141
147
partialOutputMapping .clear ();
148
+ newAggregationOutput .clear ();
142
149
}
143
150
144
151
public Map <VariableReferenceExpression , VariableReferenceExpression > getPartialOutputMapping ()
@@ -150,6 +157,11 @@ public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialR
150
157
{
151
158
return partialResultToMask ;
152
159
}
160
+
161
+ public List <VariableReferenceExpression > getNewAggregationOutput ()
162
+ {
163
+ return newAggregationOutput ;
164
+ }
153
165
}
154
166
155
167
private static class Rewriter
@@ -218,17 +230,60 @@ else if (node.getStep().equals(FINAL)) {
218
230
private AggregationNode createPartialAggregationNode (AggregationNode node , PlanNode rewrittenSource , RewriteContext <Context > context )
219
231
{
220
232
checkState (context .get ().isEmpty (), "There should be no partial aggregation left unmerged for a partial aggregation node" );
233
+
221
234
Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsWithoutMaskToOutput = node .getAggregations ().entrySet ().stream ()
222
235
.filter (x -> !x .getValue ().getMask ().isPresent ())
223
- .collect (toImmutableMap (x -> x . getValue (), x -> x . getKey () , (a , b ) -> a ));
236
+ .collect (toImmutableMap (Map . Entry :: getValue , Map . Entry :: getKey , (a , b ) -> a ));
224
237
Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsToMergeOutput = node .getAggregations ().entrySet ().stream ()
225
238
.filter (x -> x .getValue ().getMask ().isPresent () && aggregationsWithoutMaskToOutput .containsKey (removeFilterAndMask (x .getValue ())))
226
- .collect (toImmutableMap (x -> x .getValue (), x -> x .getKey ()));
239
+ .collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ));
240
+
241
+ ImmutableMap .Builder <AggregationNode .Aggregation , VariableReferenceExpression > partialAggregationToOutputBuilder = ImmutableMap .builder ();
242
+ partialAggregationToOutputBuilder .putAll (aggregationsToMergeOutput .keySet ().stream ().collect (toImmutableMap (Function .identity (), x -> aggregationsWithoutMaskToOutput .get (removeFilterAndMask (x )))));
243
+
244
+ List <List <AggregationNode .Aggregation >> candidateAggregationsWithMaskNotMatched = node .getAggregations ().entrySet ().stream ().map (Map .Entry ::getValue )
245
+ .filter (x -> x .getMask ().isPresent () && !aggregationsToMergeOutput .containsKey (x ))
246
+ .collect (Collectors .groupingBy (AggregationNodeUtils ::removeFilterAndMask )).values ()
247
+ .stream ().filter (x -> x .size () > 1 ).collect (toImmutableList ());
248
+
249
+ Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsWithMaskToMerge = node .getAggregations ().entrySet ().stream ()
250
+ .filter (x -> aggregationsToMergeOutput .containsKey (x .getValue ()) || candidateAggregationsWithMaskNotMatched .stream ().anyMatch (aggregations -> aggregations .contains (x .getValue ())))
251
+ .collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ));
252
+ ImmutableMap .Builder <VariableReferenceExpression , RowExpression > newMaskAssignmentsBuilder = ImmutableMap .builder ();
253
+ ImmutableMap .Builder <VariableReferenceExpression , AggregationNode .Aggregation > aggregationsAddedBuilder = ImmutableMap .builder ();
254
+ List <AggregationNode .Aggregation > newAggregationAdded = candidateAggregationsWithMaskNotMatched .stream ()
255
+ .map (aggregations ->
256
+ {
257
+ List <VariableReferenceExpression > maskVariables = aggregations .stream ().map (x -> x .getMask ().get ()).collect (toImmutableList ());
258
+ RowExpression orMaskVariables = or (maskVariables );
259
+ VariableReferenceExpression newMaskVariable = variableAllocator .newVariable (orMaskVariables );
260
+ newMaskAssignmentsBuilder .put (newMaskVariable , orMaskVariables );
261
+ AggregationNode .Aggregation newAggregation = new AggregationNode .Aggregation (
262
+ aggregations .get (0 ).getCall (),
263
+ Optional .empty (),
264
+ aggregations .get (0 ).getOrderBy (),
265
+ aggregations .get (0 ).isDistinct (),
266
+ Optional .of (newMaskVariable ));
267
+ VariableReferenceExpression newAggregationVariable = variableAllocator .newVariable (newAggregation .getCall ());
268
+ aggregationsAddedBuilder .put (newAggregationVariable , newAggregation );
269
+ aggregations .forEach (x -> partialAggregationToOutputBuilder .put (x , newAggregationVariable ));
270
+ return newAggregation ;
271
+ })
272
+ .collect (toImmutableList ());
273
+ Map <VariableReferenceExpression , RowExpression > newMaskAssignments = newMaskAssignmentsBuilder .build ();
274
+ Map <VariableReferenceExpression , AggregationNode .Aggregation > aggregationsAdded = aggregationsAddedBuilder .build ();
275
+ Map <AggregationNode .Aggregation , VariableReferenceExpression > partialAggregationToOutput = partialAggregationToOutputBuilder .build ();
276
+
277
+ Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsToMergeOutputCombined =
278
+ node .getAggregations ().entrySet ().stream ()
279
+ .filter (x -> x .getValue ().getMask ().isPresent () && aggregationsToMergeOutput .containsKey (x .getValue ()) || candidateAggregationsWithMaskNotMatched .stream ().anyMatch (aggregations -> aggregations .contains (x .getValue ())))
280
+ .collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ));
227
281
228
- context .get ().getPartialResultToMask ().putAll (aggregationsToMergeOutput .entrySet ().stream ()
229
- .collect (toImmutableMap (x -> x .getValue (), x -> x .getKey ().getMask ().get ())));
230
- context .get ().getPartialOutputMapping ().putAll (aggregationsToMergeOutput .entrySet ().stream ()
231
- .collect (toImmutableMap (x -> x .getValue (), x -> aggregationsWithoutMaskToOutput .get (removeFilterAndMask (x .getKey ())))));
282
+ context .get ().getNewAggregationOutput ().addAll (aggregationsAdded .keySet ());
283
+ context .get ().getPartialResultToMask ().putAll (aggregationsWithMaskToMerge .entrySet ().stream ()
284
+ .collect (toImmutableMap (Map .Entry ::getValue , x -> x .getKey ().getMask ().get ())));
285
+ context .get ().getPartialOutputMapping ().putAll (aggregationsWithMaskToMerge .entrySet ().stream ()
286
+ .collect (toImmutableMap (Map .Entry ::getValue , x -> partialAggregationToOutput .get (x .getKey ()))));
232
287
233
288
Set <VariableReferenceExpression > maskVariables = new HashSet <>(context .get ().getPartialResultToMask ().values ());
234
289
if (maskVariables .isEmpty ()) {
@@ -242,14 +297,21 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN
242
297
AggregationNode .GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode .GroupingSetDescriptor (
243
298
groupingVariables .build (), groupingSetDescriptor .getGroupingSetCount (), groupingSetDescriptor .getGlobalGroupingSets ());
244
299
245
- Set <VariableReferenceExpression > partialResultToMerge = new HashSet <>(aggregationsToMergeOutput .values ());
246
- Map <VariableReferenceExpression , AggregationNode .Aggregation > newAggregations = node .getAggregations ().entrySet ().stream ()
300
+ Set <VariableReferenceExpression > partialResultToMerge = new HashSet <>(aggregationsToMergeOutputCombined .values ());
301
+ Map <VariableReferenceExpression , AggregationNode .Aggregation > aggregationsRemained = node .getAggregations ().entrySet ().stream ()
247
302
.filter (x -> !partialResultToMerge .contains (x .getKey ())).collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
303
+ Map <VariableReferenceExpression , AggregationNode .Aggregation > newAggregations = ImmutableMap .<VariableReferenceExpression , AggregationNode .Aggregation >builder ()
304
+ .putAll (aggregationsRemained ).putAll (aggregationsAdded ).build ();
305
+
306
+ PlanNode newChild = rewrittenSource ;
307
+ if (!newMaskAssignments .isEmpty ()) {
308
+ newChild = addProjections (newChild , planNodeIdAllocator , newMaskAssignments );
309
+ }
248
310
249
311
return new AggregationNode (
250
312
node .getSourceLocation (),
251
313
node .getId (),
252
- rewrittenSource ,
314
+ newChild ,
253
315
newAggregations ,
254
316
partialGroupingSetDescriptor ,
255
317
node .getPreGroupedVariables (),
@@ -265,7 +327,7 @@ private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNod
265
327
return (AggregationNode ) node .replaceChildren (ImmutableList .of (rewrittenSource ));
266
328
}
267
329
List <VariableReferenceExpression > intermediateVariables = node .getAggregations ().values ().stream ()
268
- .map (x -> (VariableReferenceExpression ) x .getArguments ().get (0 )).collect (Collectors . toList ());
330
+ .map (x -> (VariableReferenceExpression ) x .getArguments ().get (0 )).collect (toImmutableList ());
269
331
checkState (intermediateVariables .containsAll (context .get ().partialResultToMask .keySet ()));
270
332
271
333
ImmutableList .Builder <RowExpression > projectionsFromPartialAgg = ImmutableList .builder ();
@@ -331,6 +393,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context)
331
393
.collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
332
394
assignments .putAll (excludeMergedAssignments );
333
395
assignments .putAll (identityAssignments (context .get ().getPartialResultToMask ().values ()));
396
+ assignments .putAll (identityAssignments (context .get ().getNewAggregationOutput ()));
334
397
return new ProjectNode (
335
398
node .getSourceLocation (),
336
399
node .getId (),
0 commit comments