Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions core/src/main/java/io/substrait/relation/Aggregate.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,42 @@ public abstract class Aggregate extends SingleInputRel implements HasExtension {

@Override
protected Type.Struct deriveRecordType() {
return TypeCreator.REQUIRED.struct(
Stream.concat(
// unique grouping expressions
getGroupings().stream()
.flatMap(g -> g.getExpressions().stream())
.collect(Collectors.toCollection(LinkedHashSet::new))
.stream()
.map(Expression::getType),

// measures
getMeasures().stream().map(t -> t.getFunction().getType())));
// If there's only one grouping set (or none), the nullability rule doesn't apply.
if (getGroupings().size() <= 1) {
final Stream<Type> groupingTypes =
getGroupings().stream()
.flatMap(g -> g.getExpressions().stream())
.map(Expression::getType);
final Stream<Type> measureTypes = getMeasures().stream().map(t -> t.getFunction().getType());
return TypeCreator.REQUIRED.struct(Stream.concat(groupingTypes, measureTypes));
}

final LinkedHashSet<Expression> uniqueExpressions =
getGroupings().stream()
.flatMap(g -> g.getExpressions().stream())
.collect(Collectors.toCollection(LinkedHashSet::new));

// For each unique expression, determine its final nullability based on the spec.
final Stream<Type> groupingTypes =
uniqueExpressions.stream()
.map(
expr -> {
// the code below implements the following statement from the spec
// (https://substrait.io/relations/logical_relations/#aggregate-operation):
// "The values for the grouping expression columns that are not
// part of the grouping set for a particular record will be set to null."
final boolean appearsInAllSets =
getGroupings().stream().allMatch(g -> g.getExpressions().contains(expr));
if (appearsInAllSets) {
return expr.getType();
} else {
return TypeCreator.asNullable(expr.getType());
}
});

final Stream<Type> measureTypes = getMeasures().stream().map(t -> t.getFunction().getType());

return TypeCreator.REQUIRED.struct(Stream.concat(groupingTypes, measureTypes));
}

@Override
Expand Down
16 changes: 16 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

import io.substrait.isthmus.utils.SetUtils;
import io.substrait.plan.Plan;
import io.substrait.relation.Set;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
Expand All @@ -14,6 +16,14 @@
import org.junit.jupiter.params.provider.MethodSource;

public class Substrait2SqlTest extends PlanTestBase {
private void assertSqlRoundTripViaPojoAndProto(String inputSql) {
Plan plan =
assertDoesNotThrow(() -> toSubstraitPlan(inputSql, TPCH_CATALOG), "SQL to Substrait POJO");
assertDoesNotThrow(() -> toSql(plan), "Substrait POJO to SQL");
io.substrait.proto.Plan proto =
assertDoesNotThrow(() -> toProto(plan), "Substrait POJO to Substrait PROTO");
assertDoesNotThrow(() -> toSql(proto), "Substrait PROTO to SQL");
}

@Test
public void simpleTest() throws Exception {
Expand Down Expand Up @@ -84,6 +94,12 @@ public void simpleTestGroupingSets() throws Exception {
"select sum(l_discount) from lineitem group by grouping sets ((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())");
}

@Test
void testRollup() {
assertSqlRoundTripViaPojoAndProto(
"select charcol from (select charcol, count(*) from (values('a')) as t(charcol) group by rollup(charcol))");
}

@Test
public void simpleTestAggFilter() throws Exception {
assertSqlSubstraitRelRoundTrip(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class TpcdsQueryTest extends PlanTestBase {
private static final Set<Integer> alternateForms = Set.of(27, 36, 70, 86);
private static final Set<Integer> toSubstraitExclusions = Set.of(9);
private static final Set<Integer> fromSubstraitPojoExclusions = Set.of(1, 30, 81);
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 67, 81);
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 81);

static IntStream testCases() {
return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n));
Expand Down