Skip to content

Commit

Permalink
[Enhancement] Disable extract agg columns on complex and json type co…
Browse files Browse the repository at this point in the history
…lumns for pruning columns (StarRocks#53991)

Signed-off-by: stephen <[email protected]>
  • Loading branch information
stephen-shelby authored Dec 17, 2024
1 parent a85d304 commit e78d2f6
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.starrocks.sql.optimizer.OptExpressionVisitor;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.Projection;
import com.starrocks.sql.optimizer.operator.physical.PhysicalHashAggregateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
Expand Down Expand Up @@ -69,6 +68,18 @@ private boolean hasNonColumnRefParameter(PhysicalHashAggregateOperator aggregate
return false;
}

private boolean hasComplexOrJsonTypeColumn(ScalarOperator operator) {
if (operator.isColumnRef() && (operator.getType().isComplexType() || operator.getType().isJsonType())) {
return true;
}
for (ScalarOperator child : operator.getChildren()) {
if (hasComplexOrJsonTypeColumn(child)) {
return true;
}
}
return false;
}

private void rewriteAggregateOperator(PhysicalHashAggregateOperator aggregateOperator, Projection projection) {
Map<ColumnRefOperator, ScalarOperator> columnRefMap = projection.getColumnRefMap();
Map<ColumnRefOperator, ScalarOperator> rewriteMap = Maps.newHashMap();
Expand All @@ -92,7 +103,7 @@ private void rewriteAggregateOperator(PhysicalHashAggregateOperator aggregateOpe
return;
}
if (!scalarOperator.isColumnRef() && !hasDictMappingOperator(scalarOperator) &&
!(scalarOperator.getOpType() == OperatorType.SUBFIELD)) {
!hasComplexOrJsonTypeColumn(scalarOperator)) {
rewriteMap.put(childRef, scalarOperator);
extractedColumns.add(childRef);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ public void testSelectMap() throws Exception {

sql = "select avg(c1[1]) from test_map where c1[1] is not null";
assertPlanContains(sql, "2:AGGREGATE (update finalize)\n" +
" | output: avg(2: c1[1])");
" | output: avg(5: expr)");

sql = "select c2[2][1] from test_map";
assertPlanContains(sql, "<slot 5> : 3: c2[2][1]");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1658,11 +1658,11 @@ public void testAggregateDuplicatedExprs() throws Exception {
"sum(arrays_overlap(v3, [1])) as q2, " +
"sum(arrays_overlap(v3, [1])) as q3 FROM tarray;");
assertContains(plan, " 2:AGGREGATE (update finalize)\n" +
" | output: sum(arrays_overlap(3: v3, CAST([1] AS ARRAY<BIGINT>)))\n" +
" | output: sum(4: arrays_overlap)\n" +
" | group by: \n" +
" | \n" +
" 1:Project\n" +
" | <slot 3> : 3: v3\n" +
" | <slot 4> : arrays_overlap(3: v3, CAST([1] AS ARRAY<BIGINT>))\n" +
" | \n" +
" 0:OlapScanNode");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,12 @@ public void testCountDistinctLambdaGlobalAgg() throws Exception {
"count(distinct array_length(array_map(x -> x + 1, d_2))) from adec";
String plan = getFragmentPlan(sql);
assertCContains(plan, " 2:AGGREGATE (update finalize)\n" +
" | output: multi_distinct_count(array_length(array_map" +
"(<slot 10> -> CAST(<slot 10> AS DECIMAL64(13,3)) + 1, 5: d_2)))");
" | output: multi_distinct_count(11: array_length)\n" +
" | group by: \n" +
" | \n" +
" 1:Project\n" +
" | <slot 11> : array_length(array_map(<slot 10> -> CAST(<slot 10> AS DECIMAL64(13,3)) + 1, 5: d_2))\n" +
" | \n" +
" 0:OlapScanNode");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,11 @@ public void testLambdaWithAggAndWindowFunctions() throws Exception {
String sql = "select array_agg(array_length(array_map(x->x*2, c2))) from test_array12";
String plan = getFragmentPlan(sql);
Assert.assertTrue(plan.contains(" 2:AGGREGATE (update finalize)\n" +
" | output: array_agg(array_length(array_map(<slot 4> -> CAST(<slot 4> AS BIGINT) * 2, 3: c2)))"));
Assert.assertTrue(plan.contains(" 1:Project\n" +
" | <slot 3> : 3: c2"));
" | output: array_agg(5: array_length)\n" +
" | group by: \n" +
" | \n" +
" 1:Project\n" +
" | <slot 5> : array_length(array_map(<slot 4> -> CAST(<slot 4> AS BIGINT) * 2, 3: c2))"));

sql = "select array_map(x->x > count(c1), c2) from test_array12 group by c2";
plan = getFragmentPlan(sql);
Expand Down

0 comments on commit e78d2f6

Please sign in to comment.