Skip to content

Commit b938573

Browse files
carlyeksvbarua
andauthored
fix: map switch expression to a Calcite CASE statement (#189)
* feat: added or util to builder --------- Co-authored-by: Victor Barua <[email protected]>
1 parent b66d5b1 commit b938573

9 files changed

Lines changed: 186 additions & 50 deletions

File tree

core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
import io.substrait.expression.AggregateFunctionInvocation;
55
import io.substrait.expression.Expression;
66
import io.substrait.expression.Expression.FailureBehavior;
7+
import io.substrait.expression.Expression.IfClause;
8+
import io.substrait.expression.Expression.IfThen;
9+
import io.substrait.expression.Expression.SwitchClause;
710
import io.substrait.expression.FieldReference;
811
import io.substrait.expression.ImmutableExpression.Cast;
912
import io.substrait.expression.ImmutableExpression.SingleOrList;
13+
import io.substrait.expression.ImmutableExpression.Switch;
1014
import io.substrait.expression.ImmutableFieldReference;
1115
import io.substrait.extension.SimpleExtension;
1216
import io.substrait.function.ToTypeString;
@@ -42,6 +46,7 @@ public class SubstraitBuilder {
4246

4347
private static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
4448
private static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
49+
private static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml";
4550
private static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml";
4651

4752
private final SimpleExtension.ExtensionCollection extensions;
@@ -301,6 +306,14 @@ public Expression.I32Literal i32(int v) {
301306
return Expression.I32Literal.builder().value(v).build();
302307
}
303308

309+
public Expression cast(Expression input, Type type) {
310+
return Cast.builder()
311+
.input(input)
312+
.type(type)
313+
.failureBehavior(FailureBehavior.UNSPECIFIED)
314+
.build();
315+
}
316+
304317
public FieldReference fieldReference(Rel input, int index) {
305318
return ImmutableFieldReference.newInputRelReference(index, input);
306319
}
@@ -321,12 +334,16 @@ public List<FieldReference> fieldReferences(List<Rel> inputs, int... indexes) {
321334
.collect(java.util.stream.Collectors.toList());
322335
}
323336

324-
public Expression cast(Expression input, Type type) {
325-
return Cast.builder()
326-
.input(input)
327-
.type(type)
328-
.failureBehavior(FailureBehavior.UNSPECIFIED)
329-
.build();
337+
public IfThen ifThen(Iterable<? extends IfClause> ifClauses, Expression elseClause) {
338+
return IfThen.builder().addAllIfClauses(ifClauses).elseClause(elseClause).build();
339+
}
340+
341+
public IfClause ifClause(Expression condition, Expression then) {
342+
return IfClause.builder().condition(condition).then(then).build();
343+
}
344+
345+
public Expression singleOrList(Expression condition, Expression... options) {
346+
return SingleOrList.builder().condition(condition).addOptions(options).build();
330347
}
331348

332349
public List<Expression.SortField> sortFields(Rel input, int... indexes) {
@@ -340,8 +357,17 @@ public List<Expression.SortField> sortFields(Rel input, int... indexes) {
340357
.collect(java.util.stream.Collectors.toList());
341358
}
342359

343-
public Expression singleOrList(Expression condition, Expression... options) {
344-
return SingleOrList.builder().condition(condition).addOptions(options).build();
360+
public SwitchClause switchClause(Expression.Literal condition, Expression then) {
361+
return SwitchClause.builder().condition(condition).then(then).build();
362+
}
363+
364+
public Switch switchExpression(
365+
Expression match, Iterable<? extends SwitchClause> clauses, Expression defaultClause) {
366+
return Switch.builder()
367+
.match(match)
368+
.addAllSwitchClauses(clauses)
369+
.defaultClause(defaultClause)
370+
.build();
345371
}
346372

347373
// Aggregate Functions
@@ -436,6 +462,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig
436462
return scalarFn(FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
437463
}
438464

465+
public Expression.ScalarFunctionInvocation or(Expression... args) {
466+
// If any arg is nullable, the output of or is potentially nullable
467+
// For example: false or null = null
468+
var isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
469+
var outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN;
470+
return scalarFn(FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
471+
}
472+
439473
public Expression.ScalarFunctionInvocation scalarFn(
440474
String namespace, String key, Type outputType, Expression... args) {
441475
var declaration =

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
469469

470470
@Value.Immutable
471471
abstract static class Switch implements Expression {
472+
public abstract Expression match();
473+
472474
public abstract List<SwitchClause> switchClauses();
473475

474476
public abstract Expression defaultClause();

core/src/main/java/io/substrait/expression/ExpressionCreator.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,20 @@ public static Expression.StructLiteral struct(
210210
}
211211

212212
public static Expression.Switch switchStatement(
213-
Expression defaultExpression, Expression.SwitchClause... conditionClauses) {
213+
Expression match, Expression defaultExpression, Expression.SwitchClause... conditionClauses) {
214214
return Expression.Switch.builder()
215+
.match(match)
215216
.defaultClause(defaultExpression)
216217
.addSwitchClauses(conditionClauses)
217218
.build();
218219
}
219220

220221
public static Expression.Switch switchStatement(
221-
Expression defaultExpression, Iterable<? extends Expression.SwitchClause> conditionClauses) {
222+
Expression match,
223+
Expression defaultExpression,
224+
Iterable<? extends Expression.SwitchClause> conditionClauses) {
222225
return Expression.Switch.builder()
226+
.match(match)
223227
.defaultClause(defaultExpression)
224228
.addAllSwitchClauses(conditionClauses)
225229
.build();

core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ public Expression visit(io.substrait.expression.Expression.Switch expr) {
237237
return Expression.newBuilder()
238238
.setSwitchExpression(
239239
Expression.SwitchExpression.newBuilder()
240+
.setMatch(expr.match().accept(this))
240241
.addAllIfs(clauses)
241242
.setElse(expr.defaultClause().accept(this)))
242243
.build();

core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* Converts from {@link io.substrait.proto.Expression} to {@link io.substrait.expression.Expression}
2222
*/
2323
public class ProtoExpressionConverter {
24+
2425
static final org.slf4j.Logger logger =
2526
org.slf4j.LoggerFactory.getLogger(ProtoExpressionConverter.class);
2627

@@ -168,7 +169,8 @@ public Expression from(io.substrait.proto.Expression expr) {
168169
switchExpr.getIfsList().stream()
169170
.map(t -> ExpressionCreator.switchClause(from(t.getIf()), from(t.getThen())))
170171
.collect(java.util.stream.Collectors.toList());
171-
yield ExpressionCreator.switchStatement(from(switchExpr.getElse()), clauses);
172+
yield ExpressionCreator.switchStatement(
173+
from(switchExpr.getMatch()), from(switchExpr.getElse()), clauses);
172174
}
173175
case SINGULAR_OR_LIST -> {
174176
var orList = expr.getSingularOrList();

core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ public Optional<Expression> visitFallback(Expression expr) {
220220

221221
@Override
222222
public Optional<Expression> visit(Expression.Switch expr) throws RuntimeException {
223+
var matchExpr = expr.match().accept(this);
223224
var defaultClause = expr.defaultClause().accept(this);
224225
var switchClauses =
225226
transformList(
@@ -234,6 +235,7 @@ public Optional<Expression> visit(Expression.Switch expr) throws RuntimeExceptio
234235
return Optional.of(
235236
Expression.Switch.builder()
236237
.from(expr)
238+
.match(matchExpr.orElse(expr.match()))
237239
.defaultClause(defaultClause.orElse(expr.defaultClause()))
238240
.switchClauses(switchClauses.orElse(expr.switchClauses()))
239241
.build());

isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io.substrait.expression.EnumArg;
66
import io.substrait.expression.Expression;
77
import io.substrait.expression.Expression.SingleOrList;
8+
import io.substrait.expression.Expression.Switch;
89
import io.substrait.expression.FieldReference;
910
import io.substrait.expression.FunctionArg;
1011
import io.substrait.expression.WindowBound;
@@ -261,6 +262,22 @@ public RexNode visit(Expression.IfThen expr) throws RuntimeException {
261262
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args);
262263
}
263264

265+
@Override
266+
public RexNode visit(Switch expr) throws RuntimeException {
267+
RexNode match = expr.match().accept(this);
268+
Stream<RexNode> caseThenArgs =
269+
expr.switchClauses().stream()
270+
.flatMap(
271+
clause ->
272+
Stream.of(
273+
rexBuilder.makeCall(
274+
SqlStdOperatorTable.EQUALS, match, clause.condition().accept(this)),
275+
clause.then().accept(this)));
276+
Stream<RexNode> defaultArg = Stream.of(expr.defaultClause().accept(this));
277+
List<RexNode> args = Stream.concat(caseThenArgs, defaultArg).collect(Collectors.toList());
278+
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args);
279+
}
280+
264281
@Override
265282
public RexNode visit(Expression.ScalarFunctionInvocation expr) throws RuntimeException {
266283
SqlOperator operator =
Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,52 @@
11
package io.substrait.isthmus;
22

3+
import static io.substrait.isthmus.expression.CallConverters.CASE;
34
import static io.substrait.isthmus.expression.CallConverters.CREATE_SEARCH_CONV;
45
import static org.junit.jupiter.api.Assertions.assertEquals;
56

67
import io.substrait.dsl.SubstraitBuilder;
8+
import io.substrait.expression.Expression;
79
import io.substrait.expression.proto.ExpressionProtoConverter;
810
import io.substrait.extension.ExtensionCollector;
9-
import io.substrait.extension.SimpleExtension;
11+
import io.substrait.isthmus.expression.ExpressionRexConverter;
1012
import io.substrait.isthmus.expression.RexExpressionConverter;
1113
import io.substrait.isthmus.expression.ScalarFunctionConverter;
12-
import io.substrait.plan.Plan;
14+
import io.substrait.isthmus.expression.WindowFunctionConverter;
1315
import io.substrait.relation.Rel;
1416
import io.substrait.type.Type;
1517
import io.substrait.type.TypeCreator;
1618
import java.io.IOException;
1719
import java.util.List;
18-
import org.apache.calcite.rel.core.Filter;
20+
import org.apache.calcite.rex.RexBuilder;
21+
import org.apache.calcite.rex.RexNode;
1922
import org.apache.calcite.sql.parser.SqlParseException;
2023
import org.junit.jupiter.api.Test;
2124

2225
/** Tests which test that an expression can be converted to and from Calcite expressions. */
2326
public class ExpressionConvertabilityTest extends PlanTestBase {
27+
2428
static final TypeCreator R = TypeCreator.of(false);
2529
static final TypeCreator N = TypeCreator.of(true);
2630

2731
final SubstraitBuilder b = new SubstraitBuilder(extensions);
2832

33+
final ExpressionProtoConverter expressionProtoConverter =
34+
new ExpressionProtoConverter(new ExtensionCollector(), null);
35+
36+
final ExpressionRexConverter converter =
37+
new ExpressionRexConverter(
38+
typeFactory,
39+
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
40+
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
41+
TypeConverter.DEFAULT);
42+
43+
final RexBuilder rexBuilder = new RexBuilder(typeFactory);
44+
2945
// Define a shared table (i.e. a NamedScan) for use in tests.
3046
final List<Type> commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN);
3147
final Rel commonTable =
3248
b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType);
3349

34-
final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory);
35-
3650
@Test
3751
public void listLiteral() throws IOException, SqlParseException {
3852
assertFullRoundTrip("select ARRAY[1,2,3] from ORDERS");
@@ -44,40 +58,50 @@ public void mapLiteral() throws IOException, SqlParseException {
4458
}
4559

4660
@Test
47-
public void singleOrList() throws IOException {
48-
Plan.Root root =
49-
b.root(
50-
b.filter(
51-
input -> b.singleOrList(b.fieldReference(input, 0), b.i32(5), b.i32(10)),
52-
commonTable));
53-
var relNode = converter.convert(root.getInput());
54-
var expression =
55-
((Filter) relNode)
56-
.getCondition()
57-
.accept(
58-
new RexExpressionConverter(
59-
CREATE_SEARCH_CONV.apply(relNode.getCluster().getRexBuilder()),
60-
new ScalarFunctionConverter(
61-
SimpleExtension.loadDefaults().scalarFunctions(), typeFactory)));
62-
var to = new ExpressionProtoConverter(new ExtensionCollector(), null);
61+
public void singleOrList() {
62+
Expression singleOrList = b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10));
63+
RexNode rexNode = singleOrList.accept(converter);
64+
Expression substraitExpression =
65+
rexNode.accept(
66+
new RexExpressionConverter(
67+
CREATE_SEARCH_CONV.apply(rexBuilder),
68+
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)));
69+
70+
// cannot roundtrip test singleOrList because Calcite simplifies the representation
71+
assertExpressionEquality(
72+
b.or(
73+
b.equal(b.fieldReference(commonTable, 0), b.i32(5)),
74+
b.equal(b.fieldReference(commonTable, 0), b.i32(10))),
75+
substraitExpression);
76+
}
77+
78+
@Test
79+
public void switchExpression() {
80+
Expression switchExpression =
81+
b.switchExpression(
82+
b.fieldReference(commonTable, 0),
83+
List.of(b.switchClause(b.i32(5), b.i32(1)), b.switchClause(b.i32(10), b.i32(2))),
84+
b.i32(3));
85+
RexNode rexNode = switchExpression.accept(converter);
86+
Expression expression =
87+
rexNode.accept(
88+
new RexExpressionConverter(
89+
CASE, new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)));
90+
91+
// cannot roundtrip test switchExpression because Calcite simplifies the representation
92+
assertExpressionEquality(
93+
b.ifThen(
94+
List.of(
95+
b.ifClause(b.equal(b.fieldReference(commonTable, 0), b.i32(5)), b.i32(1)),
96+
b.ifClause(b.equal(b.fieldReference(commonTable, 0), b.i32(10)), b.i32(2))),
97+
b.i32(3)),
98+
expression);
99+
}
100+
101+
void assertExpressionEquality(Expression expected, Expression actual) {
102+
// go the extra mile and convert both inputs to protobuf
103+
// helps verify that the protobuf conversion is not broken
63104
assertEquals(
64-
expression.accept(to),
65-
b.scalarFn(
66-
"/functions_boolean.yaml",
67-
"or:bool",
68-
R.BOOLEAN,
69-
b.scalarFn(
70-
"/functions_comparison.yaml",
71-
"equal:any_any",
72-
R.BOOLEAN,
73-
b.fieldReference(commonTable, 0),
74-
b.i32(5)),
75-
b.scalarFn(
76-
"/functions_comparison.yaml",
77-
"equal:any_any",
78-
R.BOOLEAN,
79-
b.fieldReference(commonTable, 0),
80-
b.i32(10)))
81-
.accept(to));
105+
expected.accept(expressionProtoConverter), actual.accept(expressionProtoConverter));
82106
}
83107
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package io.substrait.isthmus;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import io.substrait.dsl.SubstraitBuilder;
6+
import io.substrait.isthmus.expression.ExpressionRexConverter;
7+
import io.substrait.isthmus.expression.ScalarFunctionConverter;
8+
import io.substrait.isthmus.expression.WindowFunctionConverter;
9+
import io.substrait.relation.Rel;
10+
import io.substrait.type.Type;
11+
import io.substrait.type.TypeCreator;
12+
import java.util.List;
13+
import org.apache.calcite.rel.type.RelDataType;
14+
import org.junit.jupiter.api.Test;
15+
16+
public class SubstraitExpressionConverterTest extends PlanTestBase {
17+
18+
static final TypeCreator R = TypeCreator.of(false);
19+
static final TypeCreator N = TypeCreator.of(true);
20+
21+
final SubstraitBuilder b = new SubstraitBuilder(extensions);
22+
23+
final ExpressionRexConverter converter =
24+
new ExpressionRexConverter(
25+
typeFactory,
26+
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
27+
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
28+
TypeConverter.DEFAULT);
29+
30+
final List<Type> commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN);
31+
final Rel commonTable =
32+
b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType);
33+
34+
@Test
35+
public void switchExpression() {
36+
var expr =
37+
b.switchExpression(
38+
b.fieldReference(commonTable, 0),
39+
List.of(b.switchClause(b.i32(0), b.fieldReference(commonTable, 3))),
40+
b.bool(false));
41+
var calciteExpr = expr.accept(converter);
42+
43+
assertTypeMatch(calciteExpr.getType(), N.BOOLEAN);
44+
}
45+
46+
void assertTypeMatch(RelDataType actual, Type expected) {
47+
Type type = TypeConverter.DEFAULT.toSubstrait(actual);
48+
assertEquals(expected, type);
49+
}
50+
}

0 commit comments

Comments
 (0)