Skip to content

Commit 1a9fd66

Browse files
committed
fix tpcds query 67
1 parent 9c0248a commit 1a9fd66

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,65 @@ public Rel visit(org.apache.calcite.rel.core.Minus minus) {
244244
return Set.builder().inputs(inputs).setOp(setOp).build();
245245
}
246246

247+
/**
248+
* Pre-processes the input to an Aggregate relation to handle nullability changes introduced by
249+
* ROLLUP/CUBE/GROUPING SETS.
250+
*
251+
* @param aggregate The original Calcite aggregate node.
252+
* @return A Substrait Rel node that is correctly typed to be the input to the Substrait
253+
* Aggregate.
254+
*/
255+
private Rel handleRollupCorrection(org.apache.calcite.rel.core.Aggregate aggregate) {
256+
Rel originalInput = apply(aggregate.getInput());
257+
258+
// Determine the correct final output type for the aggregate, which accounts for nullability.
259+
NamedStruct aggregateOutputType = typeConverter.toNamedStruct(aggregate.getRowType());
260+
List<Integer> groupKeyIndices = aggregate.getGroupSet().asList();
261+
262+
// Create a list of expressions to cast the original input to the correct final type if needed.
263+
List<Expression> castExpressions = new ArrayList<>();
264+
265+
boolean needsCasting = false;
266+
for (int i = 0; i < originalInput.getRecordType().fields().size(); i++) {
267+
Expression fieldReference = FieldReference.newInputRelReference(i, originalInput);
268+
269+
if (groupKeyIndices.contains(i)) {
270+
int groupKeyOutputIndex = groupKeyIndices.indexOf(i);
271+
Type finalType = aggregateOutputType.struct().fields().get(groupKeyOutputIndex);
272+
273+
if (finalType.nullable() && !fieldReference.getType().nullable()) {
274+
needsCasting = true; // Mark that a cast is necessary.
275+
castExpressions.add(
276+
Expression.Cast.builder()
277+
.type(finalType)
278+
.input(fieldReference)
279+
.failureBehavior(Expression.FailureBehavior.RETURN_NULL)
280+
.build());
281+
} else {
282+
castExpressions.add(fieldReference);
283+
}
284+
} else {
285+
castExpressions.add(fieldReference);
286+
}
287+
}
288+
289+
// Only add the extra Project node if a cast was actually needed.
290+
if (needsCasting) {
291+
int originalFieldCount = originalInput.getRecordType().fields().size();
292+
return Project.builder()
293+
.input(originalInput)
294+
.expressions(castExpressions)
295+
.remap(Rel.Remap.offset(originalFieldCount, castExpressions.size()))
296+
.build();
297+
}
298+
299+
// If no casting was needed, just return the original converted input.
300+
return originalInput;
301+
}
302+
247303
@Override
248304
public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
249-
Rel input = apply(aggregate.getInput());
305+
Rel input = handleRollupCorrection(aggregate);
250306
Stream<ImmutableBitSet> sets;
251307
if (aggregate.groupSets != null) {
252308
sets = aggregate.groupSets.stream();

isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
public class TpcdsQueryTest extends PlanTestBase {
1616
private static final Set<Integer> toSubstraitExclusions = Set.of(9, 27, 36, 70, 86);
1717
private static final Set<Integer> fromSubstraitPojoExclusions = Set.of(1, 30, 81);
18-
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 67, 81);
18+
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 81);
1919

2020
static IntStream testCases() {
2121
return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n));

0 commit comments

Comments
 (0)