forked from substrait-io/substrait-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: enable conversion of SQL expressions to Substrait ExtendedExpre…
…ssions (substrait-io#191) Introduces the SqlExpressionToSubstrait class for converting SQL expression to Substrait --------- Co-authored-by: Dane Pitkin <[email protected]> Co-authored-by: Victor Barua <[email protected]> Co-authored-by: Vibhatha Lakmal Abeykoon <[email protected]>
- Loading branch information
1 parent
5180103
commit 80f648a
Showing
5 changed files
with
328 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
repos: | ||
- repo: https://github.com/adrienverge/yamllint.git | ||
rev: v1.26.0 | ||
rev: v1.33.0 | ||
hooks: | ||
- id: yamllint | ||
args: [-c=.yamllint.yaml] | ||
- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook | ||
rev: v8.0.0 | ||
rev: v9.9.0 | ||
hooks: | ||
- id: commitlint | ||
stages: [commit-msg] |
172 changes: 172 additions & 0 deletions
172
isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
package io.substrait.isthmus; | ||
|
||
import com.github.bsideup.jabel.Desugar; | ||
import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; | ||
import io.substrait.extendedexpression.ImmutableExpressionReference; | ||
import io.substrait.extendedexpression.ImmutableExtendedExpression; | ||
import io.substrait.extension.SimpleExtension; | ||
import io.substrait.isthmus.expression.RexExpressionConverter; | ||
import io.substrait.isthmus.expression.ScalarFunctionConverter; | ||
import io.substrait.proto.ExtendedExpression; | ||
import io.substrait.type.NamedStruct; | ||
import io.substrait.type.Type; | ||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.LinkedHashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import org.apache.calcite.jdbc.CalciteSchema; | ||
import org.apache.calcite.prepare.CalciteCatalogReader; | ||
import org.apache.calcite.rel.type.RelDataType; | ||
import org.apache.calcite.rel.type.RelDataTypeField; | ||
import org.apache.calcite.rex.RexInputRef; | ||
import org.apache.calcite.rex.RexNode; | ||
import org.apache.calcite.sql.SqlNode; | ||
import org.apache.calcite.sql.parser.SqlParseException; | ||
import org.apache.calcite.sql.parser.SqlParser; | ||
import org.apache.calcite.sql.validate.SqlValidator; | ||
import org.apache.calcite.sql2rel.SqlToRelConverter; | ||
import org.apache.calcite.sql2rel.StandardConvertletTable; | ||
|
||
public class SqlExpressionToSubstrait extends SqlConverterBase { | ||
|
||
protected final RexExpressionConverter rexConverter; | ||
|
||
public SqlExpressionToSubstrait() { | ||
this(FEATURES_DEFAULT, EXTENSION_COLLECTION); | ||
} | ||
|
||
public SqlExpressionToSubstrait( | ||
FeatureBoard features, SimpleExtension.ExtensionCollection extensions) { | ||
super(features); | ||
ScalarFunctionConverter scalarFunctionConverter = | ||
new ScalarFunctionConverter(extensions.scalarFunctions(), factory); | ||
this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); | ||
} | ||
|
||
@Desugar | ||
private record Result( | ||
SqlValidator validator, | ||
CalciteCatalogReader catalogReader, | ||
Map<String, RelDataType> nameToTypeMap, | ||
Map<String, RexNode> nameToNodeMap) {} | ||
|
||
/** | ||
* Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression } | ||
* | ||
* @param sqlExpression a SQL expression | ||
* @param createStatements table creation statements defining fields referenced by the expression | ||
* @return a {@link io.substrait.proto.ExtendedExpression } | ||
* @throws SqlParseException | ||
*/ | ||
public ExtendedExpression convert(String sqlExpression, List<String> createStatements) | ||
throws SqlParseException { | ||
var result = registerCreateTablesForExtendedExpression(createStatements); | ||
return executeInnerSQLExpression( | ||
sqlExpression, | ||
result.validator(), | ||
result.catalogReader(), | ||
result.nameToTypeMap(), | ||
result.nameToNodeMap()); | ||
} | ||
|
||
private ExtendedExpression executeInnerSQLExpression( | ||
String sqlExpression, | ||
SqlValidator validator, | ||
CalciteCatalogReader catalogReader, | ||
Map<String, RelDataType> nameToTypeMap, | ||
Map<String, RexNode> nameToNodeMap) | ||
throws SqlParseException { | ||
RexNode rexNode = | ||
sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); | ||
NamedStruct namedStruct = toNamedStruct(nameToTypeMap); | ||
|
||
ImmutableExpressionReference expressionReference = | ||
ImmutableExpressionReference.builder() | ||
.expression(rexNode.accept(this.rexConverter)) | ||
.addOutputNames("new-column") | ||
.build(); | ||
|
||
List<io.substrait.extendedexpression.ExtendedExpression.ExpressionReference> | ||
expressionReferences = new ArrayList<>(); | ||
expressionReferences.add(expressionReference); | ||
|
||
ImmutableExtendedExpression.Builder extendedExpression = | ||
ImmutableExtendedExpression.builder() | ||
.referredExpressions(expressionReferences) | ||
.baseSchema(namedStruct); | ||
|
||
return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); | ||
} | ||
|
||
private RexNode sqlToRexNode( | ||
String sql, | ||
SqlValidator validator, | ||
CalciteCatalogReader catalogReader, | ||
Map<String, RelDataType> nameToTypeMap, | ||
Map<String, RexNode> nameToNodeMap) | ||
throws SqlParseException { | ||
SqlParser parser = SqlParser.create(sql, parserConfig); | ||
SqlNode sqlNode = parser.parseExpression(); | ||
SqlNode validSqlNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); | ||
SqlToRelConverter converter = | ||
new SqlToRelConverter( | ||
null, | ||
validator, | ||
catalogReader, | ||
relOptCluster, | ||
StandardConvertletTable.INSTANCE, | ||
converterConfig); | ||
return converter.convertExpression(validSqlNode, nameToNodeMap); | ||
} | ||
|
||
private Result registerCreateTablesForExtendedExpression(List<String> tables) | ||
throws SqlParseException { | ||
Map<String, RelDataType> nameToTypeMap = new LinkedHashMap<>(); | ||
Map<String, RexNode> nameToNodeMap = new HashMap<>(); | ||
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); | ||
CalciteCatalogReader catalogReader = | ||
new CalciteCatalogReader(rootSchema, List.of(), factory, config); | ||
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); | ||
if (tables != null) { | ||
for (String tableDef : tables) { | ||
List<DefinedTable> tList = parseCreateTable(factory, validator, tableDef); | ||
for (DefinedTable t : tList) { | ||
rootSchema.add(t.getName(), t); | ||
for (RelDataTypeField field : t.getRowType(factory).getFieldList()) { | ||
nameToTypeMap.merge( // to validate the sql expression tree | ||
field.getName(), | ||
field.getType(), | ||
(v1, v2) -> { | ||
throw new IllegalArgumentException( | ||
"There is no support for duplicate column names: " + field.getName()); | ||
}); | ||
nameToNodeMap.merge( // to convert sql expression into RexNode | ||
field.getName(), | ||
new RexInputRef(field.getIndex(), field.getType()), | ||
(v1, v2) -> { | ||
throw new IllegalArgumentException( | ||
"There is no support for duplicate column names: " + field.getName()); | ||
}); | ||
} | ||
} | ||
} | ||
} else { | ||
throw new IllegalArgumentException( | ||
"Information regarding the data and types must be passed."); | ||
} | ||
return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); | ||
} | ||
|
||
private NamedStruct toNamedStruct(Map<String, RelDataType> nameToTypeMap) { | ||
var names = new ArrayList<String>(); | ||
var types = new ArrayList<Type>(); | ||
for (Map.Entry<String, RelDataType> entry : nameToTypeMap.entrySet()) { | ||
String k = entry.getKey(); | ||
RelDataType v = entry.getValue(); | ||
names.add(k); | ||
types.add(TypeConverter.DEFAULT.toSubstrait(v)); | ||
} | ||
return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build()); | ||
} | ||
} |
72 changes: 72 additions & 0 deletions
72
isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package io.substrait.isthmus; | ||
|
||
import com.google.common.base.Charsets; | ||
import com.google.common.io.Resources; | ||
import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; | ||
import io.substrait.extendedexpression.ProtoExtendedExpressionConverter; | ||
import io.substrait.proto.ExtendedExpression; | ||
import java.io.IOException; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import org.apache.calcite.sql.parser.SqlParseException; | ||
import org.junit.jupiter.api.Assertions; | ||
|
||
public class ExtendedExpressionTestBase { | ||
public static String asString(String resource) throws IOException { | ||
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); | ||
} | ||
|
||
public static List<String> tpchSchemaCreateStatements(String schemaToLoad) throws IOException { | ||
String[] values = asString(schemaToLoad).split(";"); | ||
return Arrays.stream(values) | ||
.filter(t -> !t.trim().isBlank()) | ||
.collect(java.util.stream.Collectors.toList()); | ||
} | ||
|
||
public static List<String> tpchSchemaCreateStatements() throws IOException { | ||
return tpchSchemaCreateStatements("tpch/schema.sql"); | ||
} | ||
|
||
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(String query) | ||
throws IOException, SqlParseException { | ||
return assertProtoExtendedExpressionRoundtrip(query, new SqlExpressionToSubstrait()); | ||
} | ||
|
||
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( | ||
String query, String schemaToLoad) throws IOException, SqlParseException { | ||
return assertProtoExtendedExpressionRoundtrip( | ||
query, new SqlExpressionToSubstrait(), schemaToLoad); | ||
} | ||
|
||
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( | ||
String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException { | ||
return assertProtoExtendedExpressionRoundtrip(query, s, tpchSchemaCreateStatements()); | ||
} | ||
|
||
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( | ||
String query, SqlExpressionToSubstrait s, String schemaToLoad) | ||
throws IOException, SqlParseException { | ||
return assertProtoExtendedExpressionRoundtrip( | ||
query, s, tpchSchemaCreateStatements(schemaToLoad)); | ||
} | ||
|
||
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( | ||
String query, SqlExpressionToSubstrait s, List<String> creates) | ||
throws SqlParseException, IOException { | ||
// proto initial extended expression | ||
ExtendedExpression extendedExpressionProtoInitial = s.convert(query, creates); | ||
|
||
// pojo final extended expression | ||
io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = | ||
new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial); | ||
|
||
// proto final extended expression | ||
ExtendedExpression extendedExpressionProtoFinal = | ||
new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoFinal); | ||
|
||
// round-trip to validate extended expression proto initial equals to final | ||
Assertions.assertEquals(extendedExpressionProtoFinal, extendedExpressionProtoInitial); | ||
|
||
return extendedExpressionProtoInitial; | ||
} | ||
} |
46 changes: 46 additions & 0 deletions
46
isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package io.substrait.isthmus; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertThrows; | ||
import static org.junit.jupiter.api.Assertions.assertTrue; | ||
|
||
import java.io.IOException; | ||
import java.util.stream.Stream; | ||
import org.apache.calcite.sql.parser.SqlParseException; | ||
import org.junit.jupiter.params.ParameterizedTest; | ||
import org.junit.jupiter.params.provider.Arguments; | ||
import org.junit.jupiter.params.provider.MethodSource; | ||
|
||
public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { | ||
|
||
private static Stream<Arguments> expressionTypeProvider() { | ||
return Stream.of( | ||
Arguments.of("2"), // I32LiteralExpression | ||
Arguments.of("L_ORDERKEY"), // FieldReferenceExpression | ||
Arguments.of("L_ORDERKEY > 10"), // ScalarFunctionExpressionFilter | ||
Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection | ||
Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn | ||
Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull | ||
Arguments.of("L_ORDERKEY is null") // ScalarFunctionExpressionIsNull | ||
); | ||
} | ||
|
||
@ParameterizedTest | ||
@MethodSource("expressionTypeProvider") | ||
public void testExtendedExpressionsRoundTrip(String sqlExpression) | ||
throws SqlParseException, IOException { | ||
assertProtoExtendedExpressionRoundtrip(sqlExpression); | ||
} | ||
|
||
@ParameterizedTest | ||
@MethodSource("expressionTypeProvider") | ||
public void testExtendedExpressionsRoundTripDuplicateColumnIdentifier(String sqlExpression) { | ||
IllegalArgumentException illegalArgumentException = | ||
assertThrows( | ||
IllegalArgumentException.class, | ||
() -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql")); | ||
assertTrue( | ||
illegalArgumentException | ||
.getMessage() | ||
.startsWith("There is no support for duplicate column names")); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
CREATE TABLE LINEITEM ( | ||
L_ORDERKEY BIGINT NOT NULL, | ||
L_PARTKEY BIGINT NOT NULL, | ||
L_SUPPKEY BIGINT NOT NULL, | ||
L_LINENUMBER INTEGER, | ||
L_QUANTITY DECIMAL, | ||
L_EXTENDEDPRICE DECIMAL, | ||
L_DISCOUNT DECIMAL, | ||
L_TAX DECIMAL, | ||
L_RETURNFLAG CHAR(1), | ||
L_LINESTATUS CHAR(1), | ||
L_SHIPDATE DATE, | ||
L_COMMITDATE DATE, | ||
L_RECEIPTDATE DATE, | ||
L_SHIPINSTRUCT CHAR(25), | ||
L_SHIPMODE CHAR(10), | ||
L_COMMENT VARCHAR(44) | ||
); | ||
CREATE TABLE LINEITEM_DUPLICATED ( | ||
L_ORDERKEY BIGINT NOT NULL, | ||
L_PARTKEY BIGINT NOT NULL, | ||
L_SUPPKEY BIGINT NOT NULL, | ||
L_LINENUMBER INTEGER, | ||
L_QUANTITY DECIMAL, | ||
L_EXTENDEDPRICE DECIMAL, | ||
L_DISCOUNT DECIMAL, | ||
L_TAX DECIMAL, | ||
L_RETURNFLAG CHAR(1), | ||
L_LINESTATUS CHAR(1), | ||
L_SHIPDATE DATE, | ||
L_COMMITDATE DATE, | ||
L_RECEIPTDATE DATE, | ||
L_SHIPINSTRUCT CHAR(25), | ||
L_SHIPMODE CHAR(10), | ||
L_COMMENT VARCHAR(44) | ||
); |