Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: map switch expression to a Calcite CASE statement #189

Merged
merged 5 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.Expression.IfClause;
import io.substrait.expression.Expression.IfThen;
import io.substrait.expression.Expression.SwitchClause;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableExpression.Cast;
import io.substrait.expression.ImmutableExpression.SingleOrList;
import io.substrait.expression.ImmutableExpression.Switch;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ToTypeString;
Expand Down Expand Up @@ -42,6 +46,7 @@ public class SubstraitBuilder {

private static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
private static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
private static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml";
private static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml";

private final SimpleExtension.ExtensionCollection extensions;
Expand Down Expand Up @@ -301,6 +306,14 @@ public Expression.I32Literal i32(int v) {
return Expression.I32Literal.builder().value(v).build();
}

public Expression cast(Expression input, Type type) {
return Cast.builder()
.input(input)
.type(type)
.failureBehavior(FailureBehavior.UNSPECIFIED)
.build();
}

public FieldReference fieldReference(Rel input, int index) {
return ImmutableFieldReference.newInputRelReference(index, input);
}
Expand All @@ -321,12 +334,16 @@ public List<FieldReference> fieldReferences(List<Rel> inputs, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public Expression cast(Expression input, Type type) {
return Cast.builder()
.input(input)
.type(type)
.failureBehavior(FailureBehavior.UNSPECIFIED)
.build();
public IfThen ifThen(Iterable<? extends IfClause> ifClauses, Expression elseClause) {
return IfThen.builder().addAllIfClauses(ifClauses).elseClause(elseClause).build();
}

public IfClause ifClause(Expression condition, Expression then) {
return IfClause.builder().condition(condition).then(then).build();
}

public Expression singleOrList(Expression condition, Expression... options) {
return SingleOrList.builder().condition(condition).addOptions(options).build();
}

public List<Expression.SortField> sortFields(Rel input, int... indexes) {
Expand All @@ -340,8 +357,17 @@ public List<Expression.SortField> sortFields(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public Expression singleOrList(Expression condition, Expression... options) {
return SingleOrList.builder().condition(condition).addOptions(options).build();
public SwitchClause switchClause(Expression.Literal condition, Expression then) {
return SwitchClause.builder().condition(condition).then(then).build();
}

public Switch switchExpression(
Expression match, Iterable<? extends SwitchClause> clauses, Expression defaultClause) {
return Switch.builder()
.match(match)
.addAllSwitchClauses(clauses)
.defaultClause(defaultClause)
.build();
}

// Aggregate Functions
Expand Down Expand Up @@ -436,6 +462,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig
return scalarFn(FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
}

public Expression.ScalarFunctionInvocation or(Expression... args) {
// If any arg is nullable, the output of or is potentially nullable
// For example: false or null = null
var isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
var outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN;
return scalarFn(FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
}

public Expression.ScalarFunctionInvocation scalarFn(
String namespace, String key, Type outputType, Expression... args) {
var declaration =
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws

@Value.Immutable
abstract static class Switch implements Expression {
public abstract Expression match();

public abstract List<SwitchClause> switchClauses();

public abstract Expression defaultClause();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,20 @@ public static Expression.StructLiteral struct(
}

public static Expression.Switch switchStatement(
Expression defaultExpression, Expression.SwitchClause... conditionClauses) {
Expression match, Expression defaultExpression, Expression.SwitchClause... conditionClauses) {
return Expression.Switch.builder()
.match(match)
.defaultClause(defaultExpression)
.addSwitchClauses(conditionClauses)
.build();
}

public static Expression.Switch switchStatement(
Expression defaultExpression, Iterable<? extends Expression.SwitchClause> conditionClauses) {
Expression match,
Expression defaultExpression,
Iterable<? extends Expression.SwitchClause> conditionClauses) {
return Expression.Switch.builder()
.match(match)
.defaultClause(defaultExpression)
.addAllSwitchClauses(conditionClauses)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ public Expression visit(io.substrait.expression.Expression.Switch expr) {
return Expression.newBuilder()
.setSwitchExpression(
Expression.SwitchExpression.newBuilder()
.setMatch(expr.match().accept(this))
.addAllIfs(clauses)
.setElse(expr.defaultClause().accept(this)))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* Converts from {@link io.substrait.proto.Expression} to {@link io.substrait.expression.Expression}
*/
public class ProtoExpressionConverter {

static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(ProtoExpressionConverter.class);

Expand Down Expand Up @@ -168,7 +169,8 @@ public Expression from(io.substrait.proto.Expression expr) {
switchExpr.getIfsList().stream()
.map(t -> ExpressionCreator.switchClause(from(t.getIf()), from(t.getThen())))
.collect(java.util.stream.Collectors.toList());
yield ExpressionCreator.switchStatement(from(switchExpr.getElse()), clauses);
yield ExpressionCreator.switchStatement(
from(switchExpr.getMatch()), from(switchExpr.getElse()), clauses);
}
case SINGULAR_OR_LIST -> {
var orList = expr.getSingularOrList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ public Optional<Expression> visitFallback(Expression expr) {

@Override
public Optional<Expression> visit(Expression.Switch expr) throws RuntimeException {
var matchExpr = expr.match().accept(this);
var defaultClause = expr.defaultClause().accept(this);
var switchClauses =
transformList(
Expand All @@ -234,6 +235,7 @@ public Optional<Expression> visit(Expression.Switch expr) throws RuntimeExceptio
return Optional.of(
Expression.Switch.builder()
.from(expr)
.match(matchExpr.orElse(expr.match()))
.defaultClause(defaultClause.orElse(expr.defaultClause()))
.switchClauses(switchClauses.orElse(expr.switchClauses()))
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.substrait.expression.EnumArg;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.SingleOrList;
import io.substrait.expression.Expression.Switch;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
Expand Down Expand Up @@ -261,6 +262,22 @@ public RexNode visit(Expression.IfThen expr) throws RuntimeException {
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args);
}

@Override
public RexNode visit(Switch expr) throws RuntimeException {
RexNode match = expr.match().accept(this);
Stream<RexNode> caseThenArgs =
expr.switchClauses().stream()
.flatMap(
clause ->
Stream.of(
rexBuilder.makeCall(
SqlStdOperatorTable.EQUALS, match, clause.condition().accept(this)),
clause.then().accept(this)));
Stream<RexNode> defaultArg = Stream.of(expr.defaultClause().accept(this));
List<RexNode> args = Stream.concat(caseThenArgs, defaultArg).collect(Collectors.toList());
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args);
}

@Override
public RexNode visit(Expression.ScalarFunctionInvocation expr) throws RuntimeException {
SqlOperator operator =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,52 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.expression.CallConverters.CASE;
import static io.substrait.isthmus.expression.CallConverters.CREATE_SEARCH_CONV;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.Expression;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.ExpressionRexConverter;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.plan.Plan;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.io.IOException;
import java.util.List;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;

/** Tests which test that an expression can be converted to and from Calcite expressions. */
public class ExpressionConvertabilityTest extends PlanTestBase {

static final TypeCreator R = TypeCreator.of(false);
static final TypeCreator N = TypeCreator.of(true);

final SubstraitBuilder b = new SubstraitBuilder(extensions);

final ExpressionProtoConverter expressionProtoConverter =
new ExpressionProtoConverter(new ExtensionCollector(), null);

final ExpressionRexConverter converter =
new ExpressionRexConverter(
typeFactory,
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
TypeConverter.DEFAULT);

final RexBuilder rexBuilder = new RexBuilder(typeFactory);

// Define a shared table (i.e. a NamedScan) for use in tests.
final List<Type> commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN);
final Rel commonTable =
b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType);

final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory);

@Test
public void listLiteral() throws IOException, SqlParseException {
assertFullRoundTrip("select ARRAY[1,2,3] from ORDERS");
Expand All @@ -44,40 +58,50 @@ public void mapLiteral() throws IOException, SqlParseException {
}

@Test
public void singleOrList() throws IOException {
Plan.Root root =
b.root(
b.filter(
input -> b.singleOrList(b.fieldReference(input, 0), b.i32(5), b.i32(10)),
commonTable));
var relNode = converter.convert(root.getInput());
var expression =
((Filter) relNode)
.getCondition()
.accept(
new RexExpressionConverter(
CREATE_SEARCH_CONV.apply(relNode.getCluster().getRexBuilder()),
new ScalarFunctionConverter(
SimpleExtension.loadDefaults().scalarFunctions(), typeFactory)));
var to = new ExpressionProtoConverter(new ExtensionCollector(), null);
public void singleOrList() {
Expression singleOrList = b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10));
RexNode rexNode = singleOrList.accept(converter);
Expression substraitExpression =
rexNode.accept(
new RexExpressionConverter(
CREATE_SEARCH_CONV.apply(rexBuilder),
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)));

// cannot roundtrip test singleOrList because Calcite simplifies the representation
assertExpressionEquality(
b.or(
b.equal(b.fieldReference(commonTable, 0), b.i32(5)),
b.equal(b.fieldReference(commonTable, 0), b.i32(10))),
substraitExpression);
}

@Test
public void switchExpression() {
Expression switchExpression =
b.switchExpression(
b.fieldReference(commonTable, 0),
List.of(b.switchClause(b.i32(5), b.i32(1)), b.switchClause(b.i32(10), b.i32(2))),
b.i32(3));
RexNode rexNode = switchExpression.accept(converter);
Expression expression =
rexNode.accept(
new RexExpressionConverter(
CASE, new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)));

// cannot roundtrip test switchExpression because Calcite simplifies the representation
assertExpressionEquality(
b.ifThen(
List.of(
b.ifClause(b.equal(b.fieldReference(commonTable, 0), b.i32(5)), b.i32(1)),
b.ifClause(b.equal(b.fieldReference(commonTable, 0), b.i32(10)), b.i32(2))),
b.i32(3)),
expression);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests were creating a full plan and converting it, but we only actually want/need to verify the expression conversion.

Removed the plan related noise by converting the expression directly. Also used a bunch of our nice utils (including one I added) to make the tests more readable.


void assertExpressionEquality(Expression expected, Expression actual) {
// go the extra mile and convert both inputs to protobuf
// helps verify that the protobuf conversion is not broken
assertEquals(
expression.accept(to),
b.scalarFn(
"/functions_boolean.yaml",
"or:bool",
R.BOOLEAN,
b.scalarFn(
"/functions_comparison.yaml",
"equal:any_any",
R.BOOLEAN,
b.fieldReference(commonTable, 0),
b.i32(5)),
b.scalarFn(
"/functions_comparison.yaml",
"equal:any_any",
R.BOOLEAN,
b.fieldReference(commonTable, 0),
b.i32(10)))
.accept(to));
expected.accept(expressionProtoConverter), actual.accept(expressionProtoConverter));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.isthmus.expression.ExpressionRexConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import org.apache.calcite.rel.type.RelDataType;
import org.junit.jupiter.api.Test;

public class SubstraitExpressionConverterTest extends PlanTestBase {

static final TypeCreator R = TypeCreator.of(false);
static final TypeCreator N = TypeCreator.of(true);

final SubstraitBuilder b = new SubstraitBuilder(extensions);

final ExpressionRexConverter converter =
new ExpressionRexConverter(
typeFactory,
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
TypeConverter.DEFAULT);

final List<Type> commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN);
final Rel commonTable =
b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType);

@Test
public void switchExpression() {
var expr =
b.switchExpression(
b.fieldReference(commonTable, 0),
List.of(b.switchClause(b.i32(0), b.fieldReference(commonTable, 3))),
b.bool(false));
var calciteExpr = expr.accept(converter);

assertTypeMatch(calciteExpr.getType(), N.BOOLEAN);
}

void assertTypeMatch(RelDataType actual, Type expected) {
Type type = TypeConverter.DEFAULT.toSubstrait(actual);
assertEquals(expected, type);
}
}