Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: convert sql expression into proto extended expressions #191

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
333989f
feat: convert sql expression into proto extended expressions
davisusanibar Oct 24, 2023
f4b6581
fix: implement nameToNodeMap and nameToTypeMap dyamically instead of …
davisusanibar Oct 26, 2023
a79f57d
fix: cover support also for project extended expression
davisusanibar Oct 26, 2023
a37be92
fix: cover support also for project extended expression
davisusanibar Oct 27, 2023
9f6aaf3
fix: create schema dynamically
davisusanibar Nov 15, 2023
52b41e3
fix: set function reference and extensions dinamically
davisusanibar Nov 16, 2023
74d13d3
Merge branch 'main' into feature/sql_to_extended_expression
davisusanibar Nov 16, 2023
3d80d1f
fix: clean code
davisusanibar Nov 16, 2023
ae84176
Merge branch 'main' into feature/sql_to_extended_expression
davisusanibar Nov 16, 2023
5954a62
fix: clean code
davisusanibar Nov 16, 2023
fc33a32
fix: rename variables to clean code
davisusanibar Nov 17, 2023
217f2a0
fix: from/to pojo/protobuf
davisusanibar Nov 23, 2023
75e4f48
feat: enable support from/to pojo/protobuf for extended expressions
davisusanibar Nov 24, 2023
1d23187
Merge branch 'main' into feature/from_to_protobuf_pojo
davisusanibar Nov 24, 2023
5adc79f
fix: consume core module for proto/pojo conversions
davisusanibar Nov 24, 2023
940f703
fix: clean code redundant method
davisusanibar Nov 25, 2023
e281f2f
Merge branch 'main' into feature/sql_to_extended_expression
davisusanibar Nov 25, 2023
f817eb0
fix: apply suggestions from code review
davisusanibar Nov 29, 2023
b1c96bd
fix: code review core module
davisusanibar Nov 29, 2023
3d9b927
fix: code review core module testing side
davisusanibar Nov 29, 2023
e790492
feat: support aggregation function in extended expression from/to poj…
davisusanibar Dec 6, 2023
ef7c076
fix: merge from/to proto/pojo
davisusanibar Dec 6, 2023
d1b4efb
fix: merge from/to proto/pojo
davisusanibar Dec 6, 2023
c26fecd
fix: merge from/to proto/pojo + solve comments on the PR
davisusanibar Dec 6, 2023
bdde874
Merge branch 'main' into feature/sql_to_extended_expression
davisusanibar Dec 6, 2023
0fa69c8
fix: code review suggestion
davisusanibar Dec 6, 2023
92d2cc5
refactor: bind instanceof checked variables
vbarua Dec 7, 2023
379b83f
fix: adding Aggregate.Measure POJO instead of Proto
davisusanibar Dec 8, 2023
e415785
fix: simplify extended expression immutable class
davisusanibar Dec 8, 2023
c27dd37
fix: clean code
davisusanibar Dec 8, 2023
a5d8126
fix: support any kind of expression type on extended expression conve…
davisusanibar Dec 9, 2023
50602f2
fix: error scalar function test case
davisusanibar Dec 9, 2023
f57322a
Merge branch 'feature/from_to_protobuf_pojo' into feature/sql_to_exte…
davisusanibar Dec 12, 2023
1c8b8b5
fix: support any kind of expression type on extended expression conve…
davisusanibar Dec 12, 2023
c71bff1
fix: consolidate PR and resolve conflicting files
davisusanibar Dec 14, 2023
183dcb6
fix: addressing PR comments
davisusanibar Dec 19, 2023
3835a4f
docs: update SqlExpressionToSubstrait#convert docs
vbarua Dec 20, 2023
6327a17
fix: commit suggestion code
davisusanibar Jan 10, 2024
8658558
fix: addressing PR comments
davisusanibar Jan 11, 2024
ecf3133
Merge branch 'main' into feature/sql_to_extended_expression
davisusanibar Jan 11, 2024
3ff586a
fix: delete integration with arrow project
davisusanibar Jan 11, 2024
6151bca
fix: apply suggestions from code review
davisusanibar Jan 12, 2024
96a2f25
fix: addressing PR comments
davisusanibar Jan 15, 2024
579d2af
Merge branch 'main' into feature/sql_to_extended_expression
davisusanibar Jan 17, 2024
3b110f3
refactor: remove unused nation.parquet data
vbarua Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions core/src/main/java/io/substrait/extension/ExtensionCollector.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.extension;

import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.Plan;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
Expand Down Expand Up @@ -98,6 +99,54 @@ public void addExtensionsToPlan(Plan.Builder builder) {
builder.addAllExtensions(extensionList);
}

public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) {
var uriPos = new AtomicInteger(1);
var uris = new HashMap<String, SimpleExtensionURI>();

var extensionList = new ArrayList<SimpleExtensionDeclaration>();
for (var e : funcMap.forwardMap.entrySet()) {
SimpleExtensionURI uri =
uris.computeIfAbsent(
e.getValue().namespace(),
k ->
SimpleExtensionURI.newBuilder()
.setExtensionUriAnchor(uriPos.getAndIncrement())
.setUri(k)
.build());
var decl =
SimpleExtensionDeclaration.newBuilder()
.setExtensionFunction(
SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
.setFunctionAnchor(e.getKey())
.setName(e.getValue().key())
.setExtensionUriReference(uri.getExtensionUriAnchor()))
.build();
extensionList.add(decl);
}
for (var e : typeMap.forwardMap.entrySet()) {
SimpleExtensionURI uri =
uris.computeIfAbsent(
e.getValue().namespace(),
k ->
SimpleExtensionURI.newBuilder()
.setExtensionUriAnchor(uriPos.getAndIncrement())
.setUri(k)
.build());
var decl =
SimpleExtensionDeclaration.newBuilder()
.setExtensionType(
SimpleExtensionDeclaration.ExtensionType.newBuilder()
.setTypeAnchor(e.getKey())
.setName(e.getValue().key())
.setExtensionUriReference(uri.getExtensionUriAnchor()))
.build();
extensionList.add(decl);
}

builder.addAllExtensionUris(uris.values());
builder.addAllExtensions(extensionList);
}

/** We don't depend on guava... */
private static class BidiMap<T1, T2> {
private final Map<T1, T2> forwardMap;
Expand Down
196 changes: 196 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
Original file line number Diff line number Diff line change
@@ -1,22 +1,44 @@
package io.substrait.isthmus;

import com.github.bsideup.jabel.Desugar;
import com.google.common.annotations.VisibleForTesting;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.Expression;
import io.substrait.proto.Expression.ScalarFunction;
import io.substrait.proto.ExpressionReference;
import io.substrait.proto.ExtendedExpression;
vbarua marked this conversation as resolved.
Show resolved Hide resolved
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.Plan;
import io.substrait.proto.PlanRel;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
import io.substrait.relation.RelProtoConverter;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.sql2rel.StandardConvertletTable;
Expand Down Expand Up @@ -48,6 +70,12 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep
return executeInner(sql, factory, pair.left, pair.right);
}

public ExtendedExpression executeExpression(String expr, List<String> tables)
throws SqlParseException {
var pair = registerCreateTables(tables);
return executeInnerExpression(expr, pair.left, pair.right);
}
vbarua marked this conversation as resolved.
Show resolved Hide resolved

// Package protected for testing
List<RelRoot> sqlToRelNode(String sql, List<String> tables) throws SqlParseException {
var pair = registerCreateTables(tables);
Expand Down Expand Up @@ -91,6 +119,138 @@ private Plan executeInner(
return plan.build();
}

private ExtendedExpression executeInnerExpression(
String sql, SqlValidator validator, CalciteCatalogReader catalogReader)
throws SqlParseException {
ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder();
ExtensionCollector functionCollector = new ExtensionCollector();
sqlToRexNode(sql, validator, catalogReader)
.forEach(
rexNode -> {
// FIXME! Implement it dynamically for more expression types
ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode);

// FIXME! Get output type dynamically:
// final static Map<String, Type> getTypeCreator = new HashMap<>(){{put("BOOLEAN",
// TypeCreator.of(true).BOOLEAN);}};
// getTypeCreator.get(rexNode.getType()).accept(...)
io.substrait.proto.Type output =
TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector));

// FIXME! setFunctionReference, addArguments(index: 0, 1)
Expression.Builder expressionBuilder =
Expression.newBuilder()
.setScalarFunction(
ScalarFunction.newBuilder()
.setFunctionReference(1)
.setOutputType(output)
.addArguments(
0,
FunctionArgument.newBuilder().setValue(result.referenceBuilder()))
.addArguments(
1,
FunctionArgument.newBuilder()
.setValue(result.expressionBuilderLiteral())));
ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setExpression(expressionBuilder)
.addOutputNames(result.ref().getName());

// FIXME! Get schema dynamically
// (as the same for Plan with:
// TypeConverter.DEFAULT.toNamedStruct(rexNode.getType());)
List<String> columnNames =
Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT");
List<Type> dataTypes =
Arrays.asList(
TypeCreator.NULLABLE.I32,
TypeCreator.NULLABLE.STRING,
TypeCreator.NULLABLE.I32,
TypeCreator.NULLABLE.STRING);
NamedStruct namedStruct =
NamedStruct.of(
columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build());

extendedExpressionBuilder
.addReferredExpr(0, expressionReferenceBuilder)
.setBaseSchema(namedStruct.toProto(new TypeProtoConverter(functionCollector)));

// Extensions URI FIXME! Populate/create this dynamically
HashMap<String, SimpleExtensionURI> extensionUris = new HashMap<>();
extensionUris.put(
"key-001",
SimpleExtensionURI.newBuilder()
.setExtensionUriAnchor(1)
.setUri("/functions_comparison.yaml")
.build());

// Extensions FIXME! Populate/create this dynamically, maybe use rexNode.getKind()
ArrayList<SimpleExtensionDeclaration> extensions = new ArrayList<>();
SimpleExtensionDeclaration extensionFunctionLowerThan =
SimpleExtensionDeclaration.newBuilder()
.setExtensionFunction(
SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
.setFunctionAnchor(1)
.setName("gt:any_any")
.setExtensionUriReference(1))
.build();
extensions.add(extensionFunctionLowerThan);

extendedExpressionBuilder.addAllExtensionUris(extensionUris.values());
extendedExpressionBuilder.addAllExtensions(extensions);
});
return extendedExpressionBuilder.build();
}

static class TraverseRexNode {
static RexInputRef ref = null;
static Expression.Builder referenceBuilder = null;
static Expression.Builder expressionBuilderLiteral = null;

static ResulTraverseRowExpression getRowExpression(RexNode rexNode) {

switch (rexNode.getClass().getSimpleName().toUpperCase()) {
case "REXCALL":
for (RexNode rexInternal : ((RexCall) rexNode).operands) {
getRowExpression(rexInternal);
}
;
break;
case "REXINPUTREF":
ref = (RexInputRef) rexNode;
referenceBuilder =
Expression.newBuilder()
.setSelection(
Expression.FieldReference.newBuilder()
.setDirectReference(
Expression.ReferenceSegment.newBuilder()
.setStructField(
Expression.ReferenceSegment.StructField.newBuilder()
.setField(ref.getIndex()))));
break;
case "REXLITERAL":
RexLiteral literal = (RexLiteral) rexNode;
expressionBuilderLiteral =
Expression.newBuilder()
.setLiteral(
Expression.Literal.newBuilder().setI32(literal.getValueAs(Integer.class)));
break;
default:
throw new AssertionError(
"Unsupported type for: " + rexNode.getClass().getSimpleName().toUpperCase());
}
ResulTraverseRowExpression result =
new ResulTraverseRowExpression(ref, referenceBuilder, expressionBuilderLiteral);
return result;
}
}

@Desugar
private record ResulTraverseRowExpression(
RexInputRef ref,
Expression.Builder referenceBuilder,
Expression.Builder expressionBuilderLiteral) {}

private List<RelRoot> sqlToRelNode(
String sql, SqlValidator validator, CalciteCatalogReader catalogReader)
throws SqlParseException {
Expand All @@ -107,6 +267,42 @@ private List<RelRoot> sqlToRelNode(
return roots;
}

private List<RexNode> sqlToRexNode(
String sql, SqlValidator validator, CalciteCatalogReader catalogReader)
throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
SqlNode sqlNode = parser.parseExpression();
Result result = getResult(validator);
SqlNode validSQLNode =
validator.validateParameterizedExpression(
sqlNode,
result.nameToTypeMap()); // FIXME! It may be optional to include this validation
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
RexNode rexNode = converter.convertExpression(validSQLNode, result.nameToNodeMap());

return Collections.singletonList(rexNode);
}

private static Result getResult(SqlValidator validator) {
// FIXME! Needs to be created dinamycally, this is for PoC purpose
HashMap<String, RexNode> nameToNodeMap = new HashMap<>();
nameToNodeMap.put(
"N_NATIONKEY",
new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT)));
nameToNodeMap.put(
"N_REGIONKEY",
new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT)));
final Map<String, RelDataType> nameToTypeMap = new HashMap<>();
for (Map.Entry<String, RexNode> entry : nameToNodeMap.entrySet()) {
nameToTypeMap.put(entry.getKey(), entry.getValue().getType());
}
Result result = new Result(nameToNodeMap, nameToTypeMap);
return result;
}

private @Desugar record Result(
HashMap<String, RexNode> nameToNodeMap, Map<String, RelDataType> nameToTypeMap) {}

@VisibleForTesting
SqlToRelConverter createSqlToRelConverter(
SqlValidator validator, CalciteCatalogReader catalogReader) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.substrait.isthmus;

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import io.substrait.proto.ExtendedExpression;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.sql.parser.SqlParseException;

public class ExtendedExpressionTestBase {
public static String asString(String resource) throws IOException {
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
}

public static List<String> tpchSchemaCreateStatements() throws IOException {
String[] values = asString("tpch/schema.sql").split(";");
return Arrays.stream(values)
.filter(t -> !t.trim().isBlank())
.collect(java.util.stream.Collectors.toList());
}

protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query)
davisusanibar marked this conversation as resolved.
Show resolved Hide resolved
throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundrip(query, new SqlToSubstrait());
}

protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, SqlToSubstrait s)
throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundrip(query, s, tpchSchemaCreateStatements());
}

protected ExtendedExpression assertProtoExtendedExpressionRoundrip(
davisusanibar marked this conversation as resolved.
Show resolved Hide resolved
String query, SqlToSubstrait s, List<String> creates) throws SqlParseException {
io.substrait.proto.ExtendedExpression protoExtendedExpression =
s.executeExpression(query, creates);

try {
String ee = JsonFormat.printer().print(protoExtendedExpression);
System.out.println("Proto Extended Expression: \n" + ee);
vbarua marked this conversation as resolved.
Show resolved Hide resolved

// FIXME! Implement test validation as the same as proto Plan implementation
vbarua marked this conversation as resolved.
Show resolved Hide resolved
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}

return protoExtendedExpression;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package io.substrait.isthmus;

import java.io.IOException;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;

public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase {

@Test
public void filter() throws IOException, SqlParseException {
assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 18");
vbarua marked this conversation as resolved.
Show resolved Hide resolved
}
}