Skip to content

Commit 62f552d

Browse files
committed
fix(isthmus): fix nullability for grouping sets
1 parent 9c0248a commit 62f552d

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

core/src/main/java/io/substrait/relation/Aggregate.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,24 @@ public abstract class Aggregate extends SingleInputRel implements HasExtension {
2121

2222
@Override
2323
protected Type.Struct deriveRecordType() {
24-
return TypeCreator.REQUIRED.struct(
25-
Stream.concat(
26-
// unique grouping expressions
27-
getGroupings().stream()
28-
.flatMap(g -> g.getExpressions().stream())
29-
.collect(Collectors.toCollection(LinkedHashSet::new))
30-
.stream()
31-
.map(Expression::getType),
24+
boolean isGroupingSet = getGroupings().size() > 1;
3225

33-
// measures
34-
getMeasures().stream().map(t -> t.getFunction().getType())));
26+
Stream<Type> groupingTypes =
27+
getGroupings().stream()
28+
.flatMap(g -> g.getExpressions().stream())
29+
.collect(Collectors.toCollection(LinkedHashSet::new))
30+
.stream()
31+
.map(
32+
expr -> {
33+
if (isGroupingSet) {
34+
return TypeCreator.asNullable(expr.getType());
35+
}
36+
return expr.getType();
37+
});
38+
39+
Stream<Type> measureTypes = getMeasures().stream().map(t -> t.getFunction().getType());
40+
41+
return TypeCreator.REQUIRED.struct(Stream.concat(groupingTypes, measureTypes));
3542
}
3643

3744
@Override

isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package io.substrait.isthmus;
22

3+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
34
import static org.junit.jupiter.api.Assertions.assertEquals;
45
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
56

67
import io.substrait.isthmus.utils.SetUtils;
8+
import io.substrait.plan.Plan;
79
import io.substrait.relation.Set;
810
import org.apache.calcite.rel.RelNode;
911
import org.apache.calcite.rel.RelRoot;
@@ -14,6 +16,14 @@
1416
import org.junit.jupiter.params.provider.MethodSource;
1517

1618
public class Substrait2SqlTest extends PlanTestBase {
19+
private void assertSqlRoundTripViaPojoAndProto(String inputSql) {
20+
Plan plan =
21+
assertDoesNotThrow(() -> toSubstraitPlan(inputSql, TPCH_CATALOG), "SQL to Substrait POJO");
22+
assertDoesNotThrow(() -> toSql(plan), "Substrait POJO to SQL");
23+
io.substrait.proto.Plan proto =
24+
assertDoesNotThrow(() -> toProto(plan), "Substrait POJO to Substrait PROTO");
25+
assertDoesNotThrow(() -> toSql(proto), "Substrait PROTO to SQL");
26+
}
1727

1828
@Test
1929
public void simpleTest() throws Exception {
@@ -84,6 +94,12 @@ public void simpleTestGroupingSets() throws Exception {
8494
"select sum(l_discount) from lineitem group by grouping sets ((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())");
8595
}
8696

97+
@Test
98+
void testRollup() {
99+
assertSqlRoundTripViaPojoAndProto(
100+
"select charcol from (select charcol, count(*) from (values('a')) as t(charcol) group by rollup(charcol))");
101+
}
102+
87103
@Test
88104
public void simpleTestAggFilter() throws Exception {
89105
assertSqlSubstraitRelRoundTrip(

isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
public class TpcdsQueryTest extends PlanTestBase {
1616
private static final Set<Integer> toSubstraitExclusions = Set.of(9, 27, 36, 70, 86);
1717
private static final Set<Integer> fromSubstraitPojoExclusions = Set.of(1, 30, 81);
18-
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 67, 81);
18+
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 81);
1919

2020
static IntStream testCases() {
2121
return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n));

0 commit comments

Comments
 (0)