Skip to content

Commit 966340a

Browse files
committed
feat(isthmus): add support for scalar subqueries
Signed-off-by: Niels Pardon <[email protected]>
1 parent 9e4afb9 commit 966340a

File tree

4 files changed

+103
-18
lines changed

4 files changed

+103
-18
lines changed

Diff for: isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java

+7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io.substrait.expression.EnumArg;
66
import io.substrait.expression.Expression;
77
import io.substrait.expression.Expression.FailureBehavior;
8+
import io.substrait.expression.Expression.ScalarSubquery;
89
import io.substrait.expression.Expression.SingleOrList;
910
import io.substrait.expression.Expression.Switch;
1011
import io.substrait.expression.FieldReference;
@@ -538,4 +539,10 @@ public RexNode visitEnumArg(SimpleExtension.Function fnDef, int argIdx, EnumArg
538539
"EnumArg(value=%s) not handled by visitor type %s.",
539540
e.value(), this.getClass().getCanonicalName())));
540541
}
542+
543+
@Override
544+
public RexNode visit(ScalarSubquery expr) throws RuntimeException {
545+
RelNode inputRelnode = expr.input().accept(relNodeConverter);
546+
return RexSubQuery.scalar(inputRelnode);
547+
}
541548
}

Diff for: isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java

+11-6
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
public class PlanTestBase {
3737
protected final SimpleExtension.ExtensionCollection extensions = SimpleExtension.loadDefaults();
38-
protected final RelCreator creator = new RelCreator();
38+
protected final RelCreator creator = new RelCreator(tpchSchemaCreateStatements());
3939
protected final RelBuilder builder = creator.createRelBuilder();
4040
protected final RexBuilder rex = creator.rex();
4141
protected final RelDataTypeFactory typeFactory = creator.typeFactory();
@@ -47,11 +47,16 @@ public static String asString(String resource) throws IOException {
4747
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
4848
}
4949

50-
public static List<String> tpchSchemaCreateStatements() throws IOException {
51-
String[] values = asString("tpch/schema.sql").split(";");
52-
return Arrays.stream(values)
53-
.filter(t -> !t.trim().isBlank())
54-
.collect(java.util.stream.Collectors.toList());
50+
public static List<String> tpchSchemaCreateStatements() {
51+
String[] values;
52+
try {
53+
values = asString("tpch/schema.sql").split(";");
54+
return Arrays.stream(values)
55+
.filter(t -> !t.trim().isBlank())
56+
.collect(java.util.stream.Collectors.toList());
57+
} catch (IOException e) {
58+
throw new RuntimeException(e);
59+
}
5560
}
5661

5762
protected Plan assertProtoPlanRoundrip(String query) throws IOException, SqlParseException {

Diff for: isthmus/src/test/java/io/substrait/isthmus/RelCreator.java

+21-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.isthmus;
22

33
import java.util.Arrays;
4+
import java.util.List;
45
import org.apache.calcite.config.CalciteConnectionConfig;
56
import org.apache.calcite.config.CalciteConnectionProperty;
67
import org.apache.calcite.jdbc.CalciteSchema;
@@ -25,19 +26,37 @@
2526
import org.apache.calcite.sql2rel.SqlToRelConverter;
2627
import org.apache.calcite.sql2rel.StandardConvertletTable;
2728
import org.apache.calcite.tools.RelBuilder;
29+
import org.apache.calcite.util.Pair;
2830

29-
public class RelCreator {
31+
public class RelCreator extends SqlConverterBase {
3032
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(RelCreator.class);
3133

3234
private RelOptCluster cluster;
3335
private CalciteCatalogReader catalog;
36+
private SqlValidator validator;
3437

3538
public RelCreator() {
39+
super(SqlConverterBase.FEATURES_DEFAULT);
3640
CalciteSchema schema = CalciteSchema.createRootSchema(false);
3741
RelDataTypeFactory factory = new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);
3842
CalciteConnectionConfig config =
3943
CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false");
40-
catalog = new CalciteCatalogReader(schema, Arrays.asList(), factory, config);
44+
this.validator = new Validator(catalog, cluster.getTypeFactory(), SqlValidator.Config.DEFAULT);
45+
this.catalog = new CalciteCatalogReader(schema, Arrays.asList(), factory, config);
46+
VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.EMPTY_CONTEXT);
47+
cluster = RelOptCluster.create(planner, new RexBuilder(factory));
48+
}
49+
50+
public RelCreator(List<String> creates) {
51+
super(SqlConverterBase.FEATURES_DEFAULT);
52+
RelDataTypeFactory factory = new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);
53+
try {
54+
Pair<SqlValidator, CalciteCatalogReader> pair = this.registerCreateTables(creates);
55+
this.validator = pair.left;
56+
this.catalog = pair.right;
57+
} catch (SqlParseException e) {
58+
throw new RuntimeException(e);
59+
}
4160
VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.EMPTY_CONTEXT);
4261
cluster = RelOptCluster.create(planner, new RexBuilder(factory));
4362
}
@@ -51,8 +70,6 @@ public RelRoot parse(String sql) {
5170
() ->
5271
new RelMetadataQuery(
5372
new ProxyingMetadataHandlerProvider(DefaultRelMetadataProvider.INSTANCE)));
54-
SqlValidator validator =
55-
new Validator(catalog, cluster.getTypeFactory(), SqlValidator.Config.DEFAULT);
5673

5774
SqlToRelConverter.Config converterConfig =
5875
SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(false);

Diff for: isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java

+64-8
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
package io.substrait.isthmus;
22

33
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
5+
import static org.junit.jupiter.api.Assertions.assertTrue;
46

57
import io.substrait.dsl.SubstraitBuilder;
8+
import io.substrait.expression.Expression;
69
import io.substrait.isthmus.expression.ExpressionRexConverter;
7-
import io.substrait.isthmus.expression.ScalarFunctionConverter;
8-
import io.substrait.isthmus.expression.WindowFunctionConverter;
910
import io.substrait.relation.Rel;
11+
import io.substrait.relation.Rel.Remap;
1012
import io.substrait.type.Type;
1113
import io.substrait.type.TypeCreator;
1214
import java.util.List;
15+
import org.apache.calcite.rel.RelNode;
1316
import org.apache.calcite.rel.type.RelDataType;
17+
import org.apache.calcite.rex.RexNode;
18+
import org.apache.calcite.rex.RexSubQuery;
19+
import org.apache.calcite.sql.SqlKind;
1420
import org.junit.jupiter.api.Test;
1521

1622
public class SubstraitExpressionConverterTest extends PlanTestBase {
@@ -20,17 +26,19 @@ public class SubstraitExpressionConverterTest extends PlanTestBase {
2026

2127
final SubstraitBuilder b = new SubstraitBuilder(extensions);
2228

23-
final ExpressionRexConverter converter =
24-
new ExpressionRexConverter(
25-
typeFactory,
26-
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
27-
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
28-
TypeConverter.DEFAULT);
29+
final ExpressionRexConverter converter;
2930

3031
final List<Type> commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN);
3132
final Rel commonTable =
3233
b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType);
3334

35+
final SubstraitRelNodeConverter relNodeConverter =
36+
new SubstraitRelNodeConverter(extensions, typeFactory, builder);
37+
38+
public SubstraitExpressionConverterTest() {
39+
converter = relNodeConverter.expressionRexConverter;
40+
}
41+
3442
@Test
3543
public void switchExpression() {
3644
var expr =
@@ -43,6 +51,54 @@ public void switchExpression() {
4351
assertTypeMatch(calciteExpr.getType(), N.BOOLEAN);
4452
}
4553

54+
@Test
55+
public void scalarSubQuery() {
56+
/*
57+
* equivalent to:
58+
*
59+
* select
60+
* r_regionkey
61+
* from
62+
* region
63+
* where
64+
* r_name = 'EUROPE'
65+
*/
66+
Rel subQueryRel =
67+
b.project(
68+
input -> {
69+
return List.of(b.fieldReference(input, 0));
70+
},
71+
// currently, all columns of the input are emitted first and then the expressions of
72+
// this project unless remap is configured
73+
Remap.of(List.of(3)),
74+
b.filter(
75+
input -> {
76+
return b.equal(
77+
b.fieldReference(input, 1),
78+
Expression.StrLiteral.builder().nullable(false).value("EUROPE").build());
79+
},
80+
b.namedScan(
81+
List.of("REGION"),
82+
List.of("r_regionkey", "r_name", "r_comment"),
83+
List.of(
84+
TypeCreator.REQUIRED.I64,
85+
TypeCreator.NULLABLE.fixedChar(25),
86+
TypeCreator.NULLABLE.varChar(152)))));
87+
88+
Expression.ScalarSubquery expr =
89+
Expression.ScalarSubquery.builder()
90+
.type(TypeCreator.REQUIRED.I64)
91+
.input(subQueryRel)
92+
.build();
93+
94+
RexNode calciteExpr = expr.accept(converter);
95+
assertEquals(SqlKind.SCALAR_QUERY, calciteExpr.getKind());
96+
assertInstanceOf(RexSubQuery.class, calciteExpr);
97+
98+
RelNode subQueryRelNode = subQueryRel.accept(relNodeConverter);
99+
assertTrue(subQueryRelNode.deepEquals(((RexSubQuery) calciteExpr).rel));
100+
}
101+
46102
void assertTypeMatch(RelDataType actual, Type expected) {
47103
Type type = TypeConverter.DEFAULT.toSubstrait(actual);
48104
assertEquals(expected, type);

0 commit comments

Comments
 (0)