Skip to content

Commit

Permalink
feat: enable conversion of SQL expressions to Substrait ExtendedExpre…
Browse files Browse the repository at this point in the history
…ssions (substrait-io#191)

Introduces the SqlExpressionToSubstrait class for converting SQL expression to Substrait

---------

Co-authored-by: Dane Pitkin <[email protected]>
Co-authored-by: Victor Barua <[email protected]>
Co-authored-by: Vibhatha Lakmal Abeykoon <[email protected]>
  • Loading branch information
4 people authored Jan 18, 2024
1 parent 5180103 commit 80f648a
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.26.0
rev: v1.33.0
hooks:
- id: yamllint
args: [-c=.yamllint.yaml]
- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
rev: v8.0.0
rev: v9.9.0
hooks:
- id: commitlint
stages: [commit-msg]
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package io.substrait.isthmus;

import com.github.bsideup.jabel.Desugar;
import io.substrait.extendedexpression.ExtendedExpressionProtoConverter;
import io.substrait.extendedexpression.ImmutableExpressionReference;
import io.substrait.extendedexpression.ImmutableExtendedExpression;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.proto.ExtendedExpression;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.sql2rel.StandardConvertletTable;

public class SqlExpressionToSubstrait extends SqlConverterBase {

protected final RexExpressionConverter rexConverter;

public SqlExpressionToSubstrait() {
this(FEATURES_DEFAULT, EXTENSION_COLLECTION);
}

public SqlExpressionToSubstrait(
FeatureBoard features, SimpleExtension.ExtensionCollection extensions) {
super(features);
ScalarFunctionConverter scalarFunctionConverter =
new ScalarFunctionConverter(extensions.scalarFunctions(), factory);
this.rexConverter = new RexExpressionConverter(scalarFunctionConverter);
}

@Desugar
private record Result(
SqlValidator validator,
CalciteCatalogReader catalogReader,
Map<String, RelDataType> nameToTypeMap,
Map<String, RexNode> nameToNodeMap) {}

/**
* Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression }
*
* @param sqlExpression a SQL expression
* @param createStatements table creation statements defining fields referenced by the expression
* @return a {@link io.substrait.proto.ExtendedExpression }
* @throws SqlParseException
*/
public ExtendedExpression convert(String sqlExpression, List<String> createStatements)
throws SqlParseException {
var result = registerCreateTablesForExtendedExpression(createStatements);
return executeInnerSQLExpression(
sqlExpression,
result.validator(),
result.catalogReader(),
result.nameToTypeMap(),
result.nameToNodeMap());
}

private ExtendedExpression executeInnerSQLExpression(
String sqlExpression,
SqlValidator validator,
CalciteCatalogReader catalogReader,
Map<String, RelDataType> nameToTypeMap,
Map<String, RexNode> nameToNodeMap)
throws SqlParseException {
RexNode rexNode =
sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap);
NamedStruct namedStruct = toNamedStruct(nameToTypeMap);

ImmutableExpressionReference expressionReference =
ImmutableExpressionReference.builder()
.expression(rexNode.accept(this.rexConverter))
.addOutputNames("new-column")
.build();

List<io.substrait.extendedexpression.ExtendedExpression.ExpressionReference>
expressionReferences = new ArrayList<>();
expressionReferences.add(expressionReference);

ImmutableExtendedExpression.Builder extendedExpression =
ImmutableExtendedExpression.builder()
.referredExpressions(expressionReferences)
.baseSchema(namedStruct);

return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build());
}

private RexNode sqlToRexNode(
String sql,
SqlValidator validator,
CalciteCatalogReader catalogReader,
Map<String, RelDataType> nameToTypeMap,
Map<String, RexNode> nameToNodeMap)
throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
SqlNode sqlNode = parser.parseExpression();
SqlNode validSqlNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap);
SqlToRelConverter converter =
new SqlToRelConverter(
null,
validator,
catalogReader,
relOptCluster,
StandardConvertletTable.INSTANCE,
converterConfig);
return converter.convertExpression(validSqlNode, nameToNodeMap);
}

private Result registerCreateTablesForExtendedExpression(List<String> tables)
throws SqlParseException {
Map<String, RelDataType> nameToTypeMap = new LinkedHashMap<>();
Map<String, RexNode> nameToNodeMap = new HashMap<>();
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false);
CalciteCatalogReader catalogReader =
new CalciteCatalogReader(rootSchema, List.of(), factory, config);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
if (tables != null) {
for (String tableDef : tables) {
List<DefinedTable> tList = parseCreateTable(factory, validator, tableDef);
for (DefinedTable t : tList) {
rootSchema.add(t.getName(), t);
for (RelDataTypeField field : t.getRowType(factory).getFieldList()) {
nameToTypeMap.merge( // to validate the sql expression tree
field.getName(),
field.getType(),
(v1, v2) -> {
throw new IllegalArgumentException(
"There is no support for duplicate column names: " + field.getName());
});
nameToNodeMap.merge( // to convert sql expression into RexNode
field.getName(),
new RexInputRef(field.getIndex(), field.getType()),
(v1, v2) -> {
throw new IllegalArgumentException(
"There is no support for duplicate column names: " + field.getName());
});
}
}
}
} else {
throw new IllegalArgumentException(
"Information regarding the data and types must be passed.");
}
return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap);
}

private NamedStruct toNamedStruct(Map<String, RelDataType> nameToTypeMap) {
var names = new ArrayList<String>();
var types = new ArrayList<Type>();
for (Map.Entry<String, RelDataType> entry : nameToTypeMap.entrySet()) {
String k = entry.getKey();
RelDataType v = entry.getValue();
names.add(k);
types.add(TypeConverter.DEFAULT.toSubstrait(v));
}
return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package io.substrait.isthmus;

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import io.substrait.extendedexpression.ExtendedExpressionProtoConverter;
import io.substrait.extendedexpression.ProtoExtendedExpressionConverter;
import io.substrait.proto.ExtendedExpression;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Assertions;

public class ExtendedExpressionTestBase {
public static String asString(String resource) throws IOException {
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
}

public static List<String> tpchSchemaCreateStatements(String schemaToLoad) throws IOException {
String[] values = asString(schemaToLoad).split(";");
return Arrays.stream(values)
.filter(t -> !t.trim().isBlank())
.collect(java.util.stream.Collectors.toList());
}

public static List<String> tpchSchemaCreateStatements() throws IOException {
return tpchSchemaCreateStatements("tpch/schema.sql");
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(String query)
throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundtrip(query, new SqlExpressionToSubstrait());
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
String query, String schemaToLoad) throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundtrip(
query, new SqlExpressionToSubstrait(), schemaToLoad);
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundtrip(query, s, tpchSchemaCreateStatements());
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
String query, SqlExpressionToSubstrait s, String schemaToLoad)
throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundtrip(
query, s, tpchSchemaCreateStatements(schemaToLoad));
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
String query, SqlExpressionToSubstrait s, List<String> creates)
throws SqlParseException, IOException {
// proto initial extended expression
ExtendedExpression extendedExpressionProtoInitial = s.convert(query, creates);

// pojo final extended expression
io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal =
new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial);

// proto final extended expression
ExtendedExpression extendedExpressionProtoFinal =
new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoFinal);

// round-trip to validate extended expression proto initial equals to final
Assertions.assertEquals(extendedExpressionProtoFinal, extendedExpressionProtoInitial);

return extendedExpressionProtoInitial;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.IOException;
import java.util.stream.Stream;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase {

private static Stream<Arguments> expressionTypeProvider() {
return Stream.of(
Arguments.of("2"), // I32LiteralExpression
Arguments.of("L_ORDERKEY"), // FieldReferenceExpression
Arguments.of("L_ORDERKEY > 10"), // ScalarFunctionExpressionFilter
Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection
Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn
Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull
Arguments.of("L_ORDERKEY is null") // ScalarFunctionExpressionIsNull
);
}

@ParameterizedTest
@MethodSource("expressionTypeProvider")
public void testExtendedExpressionsRoundTrip(String sqlExpression)
throws SqlParseException, IOException {
assertProtoExtendedExpressionRoundtrip(sqlExpression);
}

@ParameterizedTest
@MethodSource("expressionTypeProvider")
public void testExtendedExpressionsRoundTripDuplicateColumnIdentifier(String sqlExpression) {
IllegalArgumentException illegalArgumentException =
assertThrows(
IllegalArgumentException.class,
() -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql"));
assertTrue(
illegalArgumentException
.getMessage()
.startsWith("There is no support for duplicate column names"));
}
}
36 changes: 36 additions & 0 deletions isthmus/src/test/resources/tpch/schema_error.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
CREATE TABLE LINEITEM (
L_ORDERKEY BIGINT NOT NULL,
L_PARTKEY BIGINT NOT NULL,
L_SUPPKEY BIGINT NOT NULL,
L_LINENUMBER INTEGER,
L_QUANTITY DECIMAL,
L_EXTENDEDPRICE DECIMAL,
L_DISCOUNT DECIMAL,
L_TAX DECIMAL,
L_RETURNFLAG CHAR(1),
L_LINESTATUS CHAR(1),
L_SHIPDATE DATE,
L_COMMITDATE DATE,
L_RECEIPTDATE DATE,
L_SHIPINSTRUCT CHAR(25),
L_SHIPMODE CHAR(10),
L_COMMENT VARCHAR(44)
);
CREATE TABLE LINEITEM_DUPLICATED (
L_ORDERKEY BIGINT NOT NULL,
L_PARTKEY BIGINT NOT NULL,
L_SUPPKEY BIGINT NOT NULL,
L_LINENUMBER INTEGER,
L_QUANTITY DECIMAL,
L_EXTENDEDPRICE DECIMAL,
L_DISCOUNT DECIMAL,
L_TAX DECIMAL,
L_RETURNFLAG CHAR(1),
L_LINESTATUS CHAR(1),
L_SHIPDATE DATE,
L_COMMITDATE DATE,
L_RECEIPTDATE DATE,
L_SHIPINSTRUCT CHAR(25),
L_SHIPMODE CHAR(10),
L_COMMENT VARCHAR(44)
);

0 comments on commit 80f648a

Please sign in to comment.