Skip to content

Commit a776a99

Browse files
committed
fix: UDFIT
1 parent 72abcbb commit a776a99

File tree

3 files changed

+158
-61
lines changed

3 files changed

+158
-61
lines changed

core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/utils/PhysicalExpressionUtils.java

-48
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import cn.edu.tsinghua.iginx.engine.shared.function.manager.FunctionManager;
3434
import java.util.ArrayList;
3535
import java.util.List;
36-
import java.util.concurrent.atomic.AtomicBoolean;
3736
import org.apache.arrow.vector.types.pojo.Schema;
3837

3938
public class PhysicalExpressionUtils {
@@ -360,51 +359,4 @@ public static List<ScalarExpression<?>> getRowMappingFunctionArgumentExpressions
360359
}
361360
return scalarExpressions;
362361
}
363-
364-
public static boolean containSystemFunctionOnly(Expression expression) {
365-
AtomicBoolean containUdf = new AtomicBoolean(false);
366-
expression.accept(
367-
new ExpressionVisitor() {
368-
@Override
369-
public void visit(BaseExpression expression) {}
370-
371-
@Override
372-
public void visit(BinaryExpression expression) {}
373-
374-
@Override
375-
public void visit(BracketExpression expression) {}
376-
377-
@Override
378-
public void visit(ConstantExpression expression) {}
379-
380-
@Override
381-
public void visit(FromValueExpression expression) {}
382-
383-
@Override
384-
public void visit(FuncExpression expression) {
385-
if (FunctionManager.getInstance()
386-
.getFunction(expression.getFuncName())
387-
.getFunctionType()
388-
!= FunctionType.System) {
389-
containUdf.set(true);
390-
}
391-
}
392-
393-
@Override
394-
public void visit(MultipleExpression expression) {}
395-
396-
@Override
397-
public void visit(UnaryExpression expression) {}
398-
399-
@Override
400-
public void visit(CaseWhenExpression expression) {}
401-
402-
@Override
403-
public void visit(KeyExpression expression) {}
404-
405-
@Override
406-
public void visit(SequenceExpression expression) {}
407-
});
408-
return !containUdf.get();
409-
}
410362
}

optimizer/src/main/java/cn/edu/tsinghua/iginx/physical/optimizer/naive/NaivePhysicalPlanner.java

+25-13
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import cn.edu.tsinghua.iginx.engine.physical.task.memory.row.RowToArrowUnaryMemoryPhysicalTask;
3737
import cn.edu.tsinghua.iginx.engine.physical.task.memory.row.UnaryRowMemoryPhysicalTask;
3838
import cn.edu.tsinghua.iginx.engine.physical.task.utils.PhysicalCloseable;
39-
import cn.edu.tsinghua.iginx.engine.physical.utils.PhysicalExpressionUtils;
4039
import cn.edu.tsinghua.iginx.engine.physical.utils.PhysicalJoinUtils;
4140
import cn.edu.tsinghua.iginx.engine.shared.RequestContext;
4241
import cn.edu.tsinghua.iginx.engine.shared.data.read.BatchStream;
@@ -45,13 +44,13 @@
4544
import cn.edu.tsinghua.iginx.engine.shared.expr.Expression;
4645
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall;
4746
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams;
48-
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionType;
4947
import cn.edu.tsinghua.iginx.engine.shared.function.system.ArithmeticExpr;
5048
import cn.edu.tsinghua.iginx.engine.shared.operator.*;
5149
import cn.edu.tsinghua.iginx.engine.shared.operator.type.JoinAlgType;
5250
import cn.edu.tsinghua.iginx.engine.shared.operator.type.OperatorType;
5351
import cn.edu.tsinghua.iginx.engine.shared.source.*;
5452
import cn.edu.tsinghua.iginx.physical.optimizer.naive.initializer.*;
53+
import cn.edu.tsinghua.iginx.physical.optimizer.naive.util.UDFDetector;
5554
import java.util.ArrayList;
5655
import java.util.Collections;
5756
import java.util.List;
@@ -293,9 +292,7 @@ public PhysicalTask<BatchStream> construct(AddSchemaPrefix operator, RequestCont
293292
}
294293

295294
public PhysicalTask<?> construct(RowTransform operator, RequestContext context) {
296-
if (operator.getFunctionCallList().stream()
297-
.map(FunctionCall::getFunction)
298-
.anyMatch(f -> f.getFunctionType() != FunctionType.System)) {
295+
if (operator.getFunctionCallList().stream().anyMatch(UDFDetector::containNonSystemFunction)) {
299296
return constructRow(operator, context);
300297
}
301298

@@ -324,6 +321,10 @@ public PhysicalTask<?> construct(Select operator, RequestContext context) {
324321
return storageTask;
325322
}
326323

324+
if (UDFDetector.containNonSystemFunction(operator.getFilter())) {
325+
return constructRow(operator, context);
326+
}
327+
327328
if (sourceTask.getResultClass() == RowStream.class) {
328329
return new UnaryRowMemoryPhysicalTask(
329330
convert(sourceTask, context, RowStream.class), operator, context);
@@ -400,9 +401,7 @@ public PhysicalTask<?> construct(SetTransform operator, RequestContext context)
400401
return storageTask;
401402
}
402403

403-
if (operator.getFunctionCallList().stream()
404-
.map(FunctionCall::getFunction)
405-
.anyMatch(f -> f.getFunctionType() != FunctionType.System)) {
404+
if (operator.getFunctionCallList().stream().anyMatch(UDFDetector::containNonSystemFunction)) {
406405
return constructRow(operator, context);
407406
}
408407

@@ -421,13 +420,10 @@ public PhysicalTask<?> construct(GroupBy operator, RequestContext context) {
421420
return storageTask;
422421
}
423422

424-
if (operator.getFunctionCallList().stream()
425-
.map(FunctionCall::getFunction)
426-
.anyMatch(f -> f.getFunctionType() != FunctionType.System)) {
423+
if (operator.getFunctionCallList().stream().anyMatch(UDFDetector::containNonSystemFunction)) {
427424
return constructRow(operator, context);
428425
}
429-
if (!operator.getGroupByExpressions().stream()
430-
.allMatch(PhysicalExpressionUtils::containSystemFunctionOnly)) {
426+
if (operator.getGroupByExpressions().stream().anyMatch(UDFDetector::containNonSystemFunction)) {
431427
return constructRow(operator, context);
432428
}
433429

@@ -482,6 +478,10 @@ public PhysicalTask<?> construct(InnerJoin operator, RequestContext context) {
482478
return constructRow(operator, context);
483479
}
484480

481+
if (UDFDetector.containNonSystemFunction(operator.getFilter())) {
482+
return constructRow(operator, context);
483+
}
484+
485485
// NOTE: The order of left and right task is reversed in InnerJoin
486486
// 这里以及后面交换了左右两个表的顺序,原因是在之前基于行的实现中,右表是BuildSide,左表是ProbeSide
487487
// 现在基于列的实现中,左表是BuildSide,右表是ProbeSide
@@ -503,6 +503,10 @@ public PhysicalTask<?> construct(OuterJoin operator, RequestContext context) {
503503
return constructRow(operator, context);
504504
}
505505

506+
if (UDFDetector.containNonSystemFunction(operator.getFilter())) {
507+
return constructRow(operator, context);
508+
}
509+
506510
operator = PhysicalJoinUtils.reverse(operator);
507511

508512
PhysicalTask<BatchStream> leftTask = fetchAsync(operator.getSourceA(), context);
@@ -521,6 +525,10 @@ public PhysicalTask<?> construct(MarkJoin operator, RequestContext context) {
521525
return constructRow(operator, context);
522526
}
523527

528+
if (UDFDetector.containNonSystemFunction(operator.getFilter())) {
529+
return constructRow(operator, context);
530+
}
531+
524532
operator = PhysicalJoinUtils.reverse(operator);
525533

526534
PhysicalTask<BatchStream> leftTask = fetchAsync(operator.getSourceA(), context);
@@ -539,6 +547,10 @@ public PhysicalTask<?> construct(SingleJoin operator, RequestContext context) {
539547
return constructRow(operator, context);
540548
}
541549

550+
if (UDFDetector.containNonSystemFunction(operator.getFilter())) {
551+
return constructRow(operator, context);
552+
}
553+
542554
operator = PhysicalJoinUtils.reverse(operator);
543555

544556
PhysicalTask<BatchStream> leftTask = fetchAsync(operator.getSourceA(), context);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* IGinX - the polystore system with high performance
3+
* Copyright (C) Tsinghua University
4+
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package cn.edu.tsinghua.iginx.physical.optimizer.naive.util;
21+
22+
import cn.edu.tsinghua.iginx.engine.shared.expr.*;
23+
import cn.edu.tsinghua.iginx.engine.shared.function.Function;
24+
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall;
25+
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams;
26+
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionType;
27+
import cn.edu.tsinghua.iginx.engine.shared.function.manager.FunctionManager;
28+
import cn.edu.tsinghua.iginx.engine.shared.operator.filter.*;
29+
import java.util.concurrent.atomic.AtomicBoolean;
30+
31+
public class UDFDetector {
32+
33+
public static boolean containNonSystemFunction(FunctionCall functionCall) {
34+
if (!containNonSystemFunction(functionCall.getFunction())) {
35+
return true;
36+
}
37+
FunctionParams params = functionCall.getParams();
38+
for (Expression expression : params.getExpressions()) {
39+
if (containNonSystemFunction(expression)) {
40+
return true;
41+
}
42+
}
43+
return false;
44+
}
45+
46+
public static boolean containNonSystemFunction(Function function) {
47+
return function.getFunctionType() != FunctionType.System;
48+
}
49+
50+
public static boolean containNonSystemFunction(Filter filter) {
51+
boolean[] result = new boolean[1];
52+
filter.accept(
53+
new FilterVisitor() {
54+
@Override
55+
public void visit(AndFilter filter) {}
56+
57+
@Override
58+
public void visit(OrFilter filter) {}
59+
60+
@Override
61+
public void visit(NotFilter filter) {}
62+
63+
@Override
64+
public void visit(KeyFilter filter) {}
65+
66+
@Override
67+
public void visit(ValueFilter filter) {}
68+
69+
@Override
70+
public void visit(PathFilter filter) {}
71+
72+
@Override
73+
public void visit(BoolFilter filter) {}
74+
75+
@Override
76+
public void visit(ExprFilter filter) {
77+
if (containNonSystemFunction(filter.getExpressionA())
78+
|| containNonSystemFunction(filter.getExpressionB())) {
79+
result[0] = true;
80+
}
81+
}
82+
83+
@Override
84+
public void visit(InFilter filter) {}
85+
});
86+
return result[0];
87+
}
88+
89+
public static boolean containNonSystemFunction(Expression expression) {
90+
AtomicBoolean containUdf = new AtomicBoolean(false);
91+
expression.accept(
92+
new ExpressionVisitor() {
93+
@Override
94+
public void visit(BaseExpression expression) {}
95+
96+
@Override
97+
public void visit(BinaryExpression expression) {}
98+
99+
@Override
100+
public void visit(BracketExpression expression) {}
101+
102+
@Override
103+
public void visit(ConstantExpression expression) {}
104+
105+
@Override
106+
public void visit(FromValueExpression expression) {}
107+
108+
@Override
109+
public void visit(FuncExpression expression) {
110+
if (containNonSystemFunction(
111+
FunctionManager.getInstance().getFunction(expression.getFuncName()))) {
112+
containUdf.set(true);
113+
}
114+
}
115+
116+
@Override
117+
public void visit(MultipleExpression expression) {}
118+
119+
@Override
120+
public void visit(UnaryExpression expression) {}
121+
122+
@Override
123+
public void visit(CaseWhenExpression expression) {}
124+
125+
@Override
126+
public void visit(KeyExpression expression) {}
127+
128+
@Override
129+
public void visit(SequenceExpression expression) {}
130+
});
131+
return containUdf.get();
132+
}
133+
}

0 commit comments

Comments
 (0)