Skip to content

Commit d48945e

Browse files
committed
[FLINK-37780][3/N] ml builtin sql functions and validator change
1 parent 458f450 commit d48945e

File tree

8 files changed

+389
-3
lines changed

8 files changed

+389
-3
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package org.apache.calcite.sql;
2+
3+
import org.apache.flink.table.planner.catalog.CatalogSchemaModel;
4+
import org.apache.flink.table.planner.plan.FlinkCalciteCatalogReader;
5+
6+
import org.apache.calcite.rel.type.RelDataType;
7+
import org.apache.calcite.sql.validate.SqlValidator;
8+
import org.apache.calcite.sql.validate.SqlValidatorCatalogReader;
9+
import org.apache.calcite.sql.validate.SqlValidatorScope;
10+
import org.checkerframework.checker.nullness.qual.Nullable;
11+
12+
import static org.apache.calcite.util.Static.RESOURCE;
13+
14+
/** SqlModelCall to fetch and reference model based on identifier. */
15+
public class SqlModelCall extends SqlBasicCall {
16+
17+
private @Nullable CatalogSchemaModel model = null;
18+
19+
public SqlModelCall(SqlExplicitModelCall modelCall) {
20+
super(
21+
SqlModelOperator.create(
22+
modelCall.getOperator().getName(),
23+
modelCall.getOperator().getKind(),
24+
modelCall.getOperator().getLeftPrec(),
25+
modelCall.getOperator().getRightPrec(),
26+
(SqlIdentifier) modelCall.getOperandList().get(0)),
27+
modelCall.getOperandList(),
28+
modelCall.getParserPosition(),
29+
modelCall.getFunctionQuantifier());
30+
}
31+
32+
@Override
33+
public void validate(SqlValidator validator, SqlValidatorScope scope) {
34+
if (model != null) {
35+
return;
36+
}
37+
38+
SqlIdentifier modelIdentifier = (SqlIdentifier) getOperandList().get(0);
39+
SqlValidatorCatalogReader catalogReader = validator.getCatalogReader();
40+
assert catalogReader instanceof FlinkCalciteCatalogReader;
41+
42+
model = ((FlinkCalciteCatalogReader) catalogReader).getModel(modelIdentifier.names);
43+
if (model == null) {
44+
throw SqlUtil.newContextException(
45+
modelIdentifier.getParserPosition(),
46+
RESOURCE.objectNotFound(modelIdentifier.toString()));
47+
}
48+
}
49+
50+
public RelDataType getInputType(SqlValidator validator) {
51+
assert model != null;
52+
return model.getOutputRowType(validator.getTypeFactory());
53+
}
54+
55+
public RelDataType getOutputType(SqlValidator validator) {
56+
assert model != null;
57+
return model.getOutputRowType(validator.getTypeFactory());
58+
}
59+
60+
private static class SqlModelOperator extends SqlOperator {
61+
62+
private SqlIdentifier modelIdentifier;
63+
64+
private static SqlModelOperator create(
65+
String name,
66+
SqlKind kind,
67+
int leftPrecedence,
68+
int rightPrecedence,
69+
SqlIdentifier identifier) {
70+
return new SqlModelOperator(name, kind, leftPrecedence, rightPrecedence, identifier);
71+
}
72+
73+
private SqlModelOperator(
74+
String name,
75+
SqlKind kind,
76+
int leftPrecedence,
77+
int rightPrecedence,
78+
SqlIdentifier identifier) {
79+
super(name, kind, leftPrecedence, rightPrecedence, null, null, null);
80+
this.modelIdentifier = identifier;
81+
}
82+
83+
@Override
84+
public SqlSyntax getSyntax() {
85+
return SqlSyntax.PREFIX;
86+
}
87+
88+
@Override
89+
public RelDataType deriveType(
90+
SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
91+
SqlValidatorCatalogReader catalogReader = validator.getCatalogReader();
92+
assert catalogReader instanceof FlinkCalciteCatalogReader;
93+
94+
CatalogSchemaModel model =
95+
((FlinkCalciteCatalogReader) catalogReader).getModel(modelIdentifier.names);
96+
if (model == null) {
97+
throw SqlUtil.newContextException(
98+
modelIdentifier.getParserPosition(),
99+
RESOURCE.objectNotFound(modelIdentifier.toString()));
100+
}
101+
return model.getOutputRowType(validator.getTypeFactory());
102+
}
103+
}
104+
}

flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/ProcedureNamespace.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.apache.flink.annotation.Internal;
2020
import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;
21+
import org.apache.flink.table.planner.functions.sql.ml.SqlMlTableFunction;
2122

2223
import org.apache.calcite.rel.type.RelDataType;
2324
import org.apache.calcite.sql.SqlCall;
@@ -61,7 +62,7 @@ public RelDataType validateImpl(RelDataType targetRowType) {
6162
final SqlOperator operator = call.getOperator();
6263
final SqlCallBinding callBinding = new FlinkSqlCallBinding(validator, scope, call);
6364
final SqlCall permutedCall = callBinding.permutedCall();
64-
if (operator instanceof SqlWindowTableFunction) {
65+
if (operator instanceof SqlWindowTableFunction || operator instanceof SqlMlTableFunction) {
6566
permutedCall.validate(validator, scope);
6667
}
6768

flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.calcite.sql.validate;
1818

1919
import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;
20+
import org.apache.flink.table.planner.functions.sql.ml.SqlMlTableFunction;
2021

2122
import com.google.common.annotations.VisibleForTesting;
2223
import com.google.common.base.Preconditions;
@@ -1332,6 +1333,7 @@ private void handleOffsetFetch(@Nullable SqlNode offset, @Nullable SqlNode fetch
13321333
if (node instanceof SqlMerge) {
13331334
validatingSqlMerge = true;
13341335
}
1336+
13351337
SqlCall call = (SqlCall) node;
13361338
final SqlKind kind = call.getKind();
13371339
final List<SqlNode> operands = call.getOperandList();
@@ -2573,7 +2575,7 @@ private SqlNode registerFrom(
25732575
if (operand instanceof SqlBasicCall) {
25742576
final SqlBasicCall call1 = (SqlBasicCall) operand;
25752577
final SqlOperator op = call1.getOperator();
2576-
if (op instanceof SqlWindowTableFunction
2578+
if ((op instanceof SqlWindowTableFunction || op instanceof SqlMlTableFunction)
25772579
&& call1.operand(0).getKind() == SqlKind.SELECT) {
25782580
scopes.put(node, getSelectScope(call1.operand(0)));
25792581
return newNode;

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/FlinkCalciteSqlValidator.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.flink.table.catalog.ResolvedSchema;
2828
import org.apache.flink.table.data.TimestampData;
2929
import org.apache.flink.table.planner.catalog.CatalogSchemaTable;
30+
import org.apache.flink.table.planner.functions.sql.ml.SqlMlTableFunction;
3031
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil;
3132
import org.apache.flink.table.planner.utils.ShortcutUtils;
3233
import org.apache.flink.table.types.logical.DecimalType;
@@ -43,12 +44,14 @@
4344
import org.apache.calcite.sql.SqlAsOperator;
4445
import org.apache.calcite.sql.SqlBasicCall;
4546
import org.apache.calcite.sql.SqlCall;
47+
import org.apache.calcite.sql.SqlExplicitModelCall;
4648
import org.apache.calcite.sql.SqlFunction;
4749
import org.apache.calcite.sql.SqlFunctionCategory;
4850
import org.apache.calcite.sql.SqlIdentifier;
4951
import org.apache.calcite.sql.SqlJoin;
5052
import org.apache.calcite.sql.SqlKind;
5153
import org.apache.calcite.sql.SqlLiteral;
54+
import org.apache.calcite.sql.SqlModelCall;
5255
import org.apache.calcite.sql.SqlNode;
5356
import org.apache.calcite.sql.SqlNodeList;
5457
import org.apache.calcite.sql.SqlOperator;
@@ -371,7 +374,14 @@ protected void addToSelectList(
371374
final SqlBasicCall call = (SqlBasicCall) node;
372375
final SqlOperator operator = call.getOperator();
373376

374-
if (operator instanceof SqlWindowTableFunction) {
377+
if (node instanceof SqlExplicitModelCall) {
378+
// Convert it so that model can be accessed in planner. SqlExplicitModelCall
379+
// from parser can't access model.
380+
SqlExplicitModelCall modelCall = (SqlExplicitModelCall) node;
381+
return new SqlModelCall(modelCall);
382+
}
383+
384+
if (operator instanceof SqlWindowTableFunction || operator instanceof SqlMlTableFunction) {
375385
if (tableArgs.stream().allMatch(Objects::isNull)) {
376386
return rewritten;
377387
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.flink.table.api.TableException;
2222
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
2323
import org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction;
24+
import org.apache.flink.table.planner.functions.sql.ml.SqlMlPredictTableFunction;
2425
import org.apache.flink.table.planner.plan.type.FlinkReturnTypes;
2526
import org.apache.flink.table.planner.plan.type.NumericExceptFirstOperandChecker;
2627

@@ -1341,6 +1342,9 @@ public List<SqlGroupedWindowFunction> getAuxiliaryFunctions() {
13411342
public static final SqlFunction CUMULATE = new SqlCumulateTableFunction();
13421343
public static final SqlFunction SESSION = new SqlSessionTableFunction();
13431344

1345+
// MODEL TABLE FUNCTIONS
1346+
public static final SqlFunction ML_PREDICT = new SqlMlPredictTableFunction();
1347+
13441348
// Catalog Functions
13451349
public static final SqlFunction CURRENT_DATABASE =
13461350
BuiltInSqlFunction.newBuilder()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to you under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.flink.table.planner.functions.sql.ml;
19+
20+
import org.apache.calcite.rel.type.RelDataType;
21+
import org.apache.calcite.rel.type.RelDataTypeFactory;
22+
import org.apache.calcite.sql.SqlCallBinding;
23+
import org.apache.calcite.sql.SqlOperandCountRange;
24+
import org.apache.calcite.sql.SqlOperator;
25+
import org.apache.calcite.sql.SqlOperatorBinding;
26+
import org.apache.calcite.sql.type.SqlOperandCountRanges;
27+
import org.apache.calcite.sql.type.SqlOperandMetadata;
28+
import org.apache.calcite.sql.type.SqlTypeName;
29+
30+
import java.util.Collections;
31+
import java.util.List;
32+
33+
/**
34+
* SqlMlPredictTableFunction implements an operator for prediction.
35+
*
36+
* <p>It allows four parameters:
37+
*
38+
* <ol>
39+
* <li>a table
40+
* <li>a model name
41+
* <li>a descriptor to provide a column name from the input table
42+
* <li>an optional config map
43+
* </ol>
44+
*/
45+
public class SqlMlPredictTableFunction extends SqlMlTableFunction {
46+
47+
public SqlMlPredictTableFunction() {
48+
super("ML_PREDICT", new PredictOperandMetadata());
49+
}
50+
51+
/**
52+
* {@inheritDoc}
53+
*
54+
* <p>Overrides because the first parameter of table-value function windowing is an explicit
55+
* TABLE parameter, which is not scalar.
56+
*/
57+
@Override
58+
public boolean argumentMustBeScalar(int ordinal) {
59+
return ordinal != 0;
60+
}
61+
62+
@Override
63+
protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
64+
// TODO: output type based on table schema and model output schema
65+
// model output schema to be available after integrated with SqlExplicitModelCall
66+
return opBinding.getOperandType(1);
67+
}
68+
69+
private static class PredictOperandMetadata implements SqlOperandMetadata {
70+
private final List<String> paramNames;
71+
private final int mandatoryParamCount;
72+
73+
PredictOperandMetadata() {
74+
paramNames = List.of(PARAM_DATA, PARAM_MODEL, PARAM_COLUMN, PARAM_CONFIG);
75+
// Config is optional
76+
mandatoryParamCount = 3;
77+
}
78+
79+
@Override
80+
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
81+
return Collections.nCopies(
82+
paramNames.size(), typeFactory.createSqlType(SqlTypeName.ANY));
83+
}
84+
85+
@Override
86+
public List<String> paramNames() {
87+
return paramNames;
88+
}
89+
90+
@Override
91+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
92+
// TODO: Check operand types after integrated with SqlExplicitModelCall in validator
93+
return false;
94+
}
95+
96+
@Override
97+
public SqlOperandCountRange getOperandCountRange() {
98+
return SqlOperandCountRanges.between(mandatoryParamCount, paramNames.size());
99+
}
100+
101+
@Override
102+
public String getAllowedSignatures(SqlOperator op, String opName) {
103+
return opName
104+
+ "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]]";
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)