Skip to content

Commit 42fcaec

Browse files
committed
[FLINK-37780][2/N] ml builtin sql functions
1 parent ef934ca commit 42fcaec

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to you under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.flink.table.planner.functions.sql.ml;
19+
20+
import org.apache.calcite.rel.type.RelDataType;
21+
import org.apache.calcite.rel.type.RelDataTypeFactory;
22+
import org.apache.calcite.sql.SqlCallBinding;
23+
import org.apache.calcite.sql.SqlOperandCountRange;
24+
import org.apache.calcite.sql.SqlOperator;
25+
import org.apache.calcite.sql.SqlOperatorBinding;
26+
import org.apache.calcite.sql.type.SqlOperandCountRanges;
27+
import org.apache.calcite.sql.type.SqlOperandMetadata;
28+
import org.apache.calcite.sql.type.SqlTypeName;
29+
30+
import java.util.Collections;
31+
import java.util.List;
32+
33+
/**
34+
* SqlMlPredictTableFunction implements an operator for prediction.
35+
*
36+
* <p>It allows four parameters:
37+
*
38+
* <ol>
39+
* <li>a table
40+
* <li>a model name
41+
* <li>a descriptor to provide a column name from the input table
42+
* <li>an optional config map
43+
* </ol>
44+
*/
45+
public class SqlMlPredictTableFunction extends SqlMlTableFunction {
46+
47+
public SqlMlPredictTableFunction(String name) {
48+
super(name, new PredictOperandMetadata());
49+
}
50+
51+
@Override
52+
protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
53+
// TODO: output type based on table schema and model output schema
54+
// model output schema to be available after integrated with SqlExplicitModelCall
55+
return null;
56+
}
57+
58+
private static class PredictOperandMetadata implements SqlOperandMetadata {
59+
private final List<String> paramNames;
60+
private final int mandatoryParamCount;
61+
62+
PredictOperandMetadata() {
63+
paramNames = List.of(PARAM_DATA, PARAM_MODEL, PARAM_COLUMN, PARAM_CONFIG);
64+
// Config is optional
65+
mandatoryParamCount = 3;
66+
}
67+
68+
@Override
69+
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
70+
return Collections.nCopies(
71+
paramNames.size(), typeFactory.createSqlType(SqlTypeName.ANY));
72+
}
73+
74+
@Override
75+
public List<String> paramNames() {
76+
return paramNames;
77+
}
78+
79+
@Override
80+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
81+
// TODO: Check operand types after integrated with SqlExplicitModelCall in validator
82+
return false;
83+
}
84+
85+
@Override
86+
public SqlOperandCountRange getOperandCountRange() {
87+
return SqlOperandCountRanges.between(mandatoryParamCount, paramNames.size());
88+
}
89+
90+
@Override
91+
public String getAllowedSignatures(SqlOperator op, String opName) {
92+
return opName
93+
+ "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]]";
94+
}
95+
}
96+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to you under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.flink.table.planner.functions.sql.ml;
19+
20+
import org.apache.calcite.rel.type.RelDataType;
21+
import org.apache.calcite.sql.SqlFunction;
22+
import org.apache.calcite.sql.SqlFunctionCategory;
23+
import org.apache.calcite.sql.SqlKind;
24+
import org.apache.calcite.sql.SqlOperatorBinding;
25+
import org.apache.calcite.sql.SqlTableFunction;
26+
import org.apache.calcite.sql.type.ReturnTypes;
27+
import org.apache.calcite.sql.type.SqlOperandMetadata;
28+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
29+
30+
/**
31+
* Base class for a table-valued function that works with models. Examples include {@code
32+
* ML_PREDICT}.
33+
*/
34+
public abstract class SqlMlTableFunction extends SqlFunction implements SqlTableFunction {
35+
36+
protected static final String PARAM_DATA = "data";
37+
protected static final String PARAM_MODEL = "model";
38+
protected static final String PARAM_COLUMN = "input_column";
39+
protected static final String PARAM_CONFIG = "config";
40+
41+
public SqlMlTableFunction(String name, SqlOperandMetadata operandMetadata) {
42+
super(
43+
name,
44+
SqlKind.OTHER_FUNCTION,
45+
ReturnTypes.CURSOR,
46+
null,
47+
operandMetadata,
48+
SqlFunctionCategory.SYSTEM);
49+
}
50+
51+
@Override
52+
public SqlReturnTypeInference getRowTypeInference() {
53+
return this::inferRowType;
54+
}
55+
56+
protected abstract RelDataType inferRowType(SqlOperatorBinding opBinding);
57+
}

0 commit comments

Comments
 (0)