36
36
37
37
import java .util .HashMap ;
38
38
import java .util .HashSet ;
39
+ import java .util .LinkedList ;
39
40
import java .util .List ;
40
41
import java .util .Map ;
41
42
import java .util .Set ;
42
43
import java .util .stream .Collectors ;
44
+ import java .util .stream .IntStream ;
43
45
44
46
import static com .facebook .presto .SystemSessionProperties .isMergeAggregationsWithAndWithoutFilter ;
45
47
import static com .facebook .presto .spi .StandardErrorCode .GENERIC_INTERNAL_ERROR ;
@@ -123,11 +125,13 @@ private static class Context
123
125
{
124
126
private final Map <VariableReferenceExpression , VariableReferenceExpression > partialResultToMask ;
125
127
private final Map <VariableReferenceExpression , VariableReferenceExpression > partialOutputMapping ;
128
+ private final List <VariableReferenceExpression > newAggregationOutput ;
126
129
127
130
public Context ()
128
131
{
129
132
partialResultToMask = new HashMap <>();
130
133
partialOutputMapping = new HashMap <>();
134
+ newAggregationOutput = new LinkedList <>();
131
135
}
132
136
133
137
public boolean isEmpty ()
@@ -139,6 +143,7 @@ public void clear()
139
143
{
140
144
partialResultToMask .clear ();
141
145
partialOutputMapping .clear ();
146
+ newAggregationOutput .clear ();
142
147
}
143
148
144
149
public Map <VariableReferenceExpression , VariableReferenceExpression > getPartialOutputMapping ()
@@ -150,6 +155,11 @@ public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialR
150
155
{
151
156
return partialResultToMask ;
152
157
}
158
+
159
+ public List <VariableReferenceExpression > getNewAggregationOutput ()
160
+ {
161
+ return newAggregationOutput ;
162
+ }
153
163
}
154
164
155
165
private static class Rewriter
@@ -218,17 +228,32 @@ else if (node.getStep().equals(FINAL)) {
218
228
private AggregationNode createPartialAggregationNode (AggregationNode node , PlanNode rewrittenSource , RewriteContext <Context > context )
219
229
{
220
230
checkState (context .get ().isEmpty (), "There should be no partial aggregation left unmerged for a partial aggregation node" );
231
+
221
232
Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsWithoutMaskToOutput = node .getAggregations ().entrySet ().stream ()
222
233
.filter (x -> !x .getValue ().getMask ().isPresent ())
223
- .collect (toImmutableMap (x -> x . getValue (), x -> x . getKey () , (a , b ) -> a ));
234
+ .collect (toImmutableMap (Map . Entry :: getValue , Map . Entry :: getKey , (a , b ) -> a ));
224
235
Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsToMergeOutput = node .getAggregations ().entrySet ().stream ()
225
236
.filter (x -> x .getValue ().getMask ().isPresent () && aggregationsWithoutMaskToOutput .containsKey (removeFilterAndMask (x .getValue ())))
226
- .collect (toImmutableMap (x -> x .getValue (), x -> x .getKey ()));
227
-
228
- context .get ().getPartialResultToMask ().putAll (aggregationsToMergeOutput .entrySet ().stream ()
229
- .collect (toImmutableMap (x -> x .getValue (), x -> x .getKey ().getMask ().get ())));
230
- context .get ().getPartialOutputMapping ().putAll (aggregationsToMergeOutput .entrySet ().stream ()
231
- .collect (toImmutableMap (x -> x .getValue (), x -> aggregationsWithoutMaskToOutput .get (removeFilterAndMask (x .getKey ())))));
237
+ .collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ));
238
+
239
+ List <List <AggregationNode .Aggregation >> aggregationWithMaskNotMatched = node .getAggregations ().entrySet ().stream ().map (Map .Entry ::getValue ).filter (x -> x .getMask ().isPresent () && !aggregationsToMergeOutput .containsKey (x ))
240
+ .collect (Collectors .groupingBy (AggregationNodeUtils ::removeFilterAndMask )).values ()
241
+ .stream ().filter (x -> x .size () > 1 ).collect (Collectors .toList ());
242
+ List <AggregationNode .Aggregation > newAggregationNoMask = aggregationWithMaskNotMatched .stream ().map (x -> removeFilterAndMask (x .get (0 ))).collect (Collectors .toList ());
243
+ List <VariableReferenceExpression > variableForNewAggregationNoMask = newAggregationNoMask .stream ().map (x -> variableAllocator .newVariable (x .getCall ())).collect (Collectors .toList ());
244
+ Map <VariableReferenceExpression , AggregationNode .Aggregation > aggregationsAdded = IntStream .range (0 , newAggregationNoMask .size ()).boxed ().collect (toImmutableMap (variableForNewAggregationNoMask ::get , newAggregationNoMask ::get ));
245
+
246
+ Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsWithoutMaskToOutputCombined = ImmutableMap .<AggregationNode .Aggregation , VariableReferenceExpression >builder ().putAll (aggregationsWithoutMaskToOutput )
247
+ .putAll (aggregationsAdded .entrySet ().stream ().collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ))).build ();
248
+ Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsToMergeOutputCombined = node .getAggregations ().entrySet ().stream ()
249
+ .filter (x -> x .getValue ().getMask ().isPresent () && aggregationsWithoutMaskToOutputCombined .containsKey (removeFilterAndMask (x .getValue ())))
250
+ .collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ));
251
+
252
+ context .get ().getNewAggregationOutput ().addAll (variableForNewAggregationNoMask );
253
+ context .get ().getPartialResultToMask ().putAll (aggregationsToMergeOutputCombined .entrySet ().stream ()
254
+ .collect (toImmutableMap (Map .Entry ::getValue , x -> x .getKey ().getMask ().get ())));
255
+ context .get ().getPartialOutputMapping ().putAll (aggregationsToMergeOutputCombined .entrySet ().stream ()
256
+ .collect (toImmutableMap (Map .Entry ::getValue , x -> aggregationsWithoutMaskToOutputCombined .get (removeFilterAndMask (x .getKey ())))));
232
257
233
258
Set <VariableReferenceExpression > maskVariables = new HashSet <>(context .get ().getPartialResultToMask ().values ());
234
259
if (maskVariables .isEmpty ()) {
@@ -242,9 +267,11 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN
242
267
AggregationNode .GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode .GroupingSetDescriptor (
243
268
groupingVariables .build (), groupingSetDescriptor .getGroupingSetCount (), groupingSetDescriptor .getGlobalGroupingSets ());
244
269
245
- Set <VariableReferenceExpression > partialResultToMerge = new HashSet <>(aggregationsToMergeOutput .values ());
246
- Map <VariableReferenceExpression , AggregationNode .Aggregation > newAggregations = node .getAggregations ().entrySet ().stream ()
270
+ Set <VariableReferenceExpression > partialResultToMerge = new HashSet <>(aggregationsToMergeOutputCombined .values ());
271
+ Map <VariableReferenceExpression , AggregationNode .Aggregation > aggregationsRemained = node .getAggregations ().entrySet ().stream ()
247
272
.filter (x -> !partialResultToMerge .contains (x .getKey ())).collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
273
+ Map <VariableReferenceExpression , AggregationNode .Aggregation > newAggregations = ImmutableMap .<VariableReferenceExpression , AggregationNode .Aggregation >builder ()
274
+ .putAll (aggregationsRemained ).putAll (aggregationsAdded ).build ();
248
275
249
276
return new AggregationNode (
250
277
node .getSourceLocation (),
@@ -331,6 +358,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context)
331
358
.collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
332
359
assignments .putAll (excludeMergedAssignments );
333
360
assignments .putAll (identityAssignments (context .get ().getPartialResultToMask ().values ()));
361
+ assignments .putAll (identityAssignments (context .get ().getNewAggregationOutput ()));
334
362
return new ProjectNode (
335
363
node .getSourceLocation (),
336
364
node .getId (),
0 commit comments