@@ -244,9 +244,65 @@ public Rel visit(org.apache.calcite.rel.core.Minus minus) {
244
244
return Set .builder ().inputs (inputs ).setOp (setOp ).build ();
245
245
}
246
246
247
+ /**
248
+ * Pre-processes the input to an Aggregate relation to handle nullability changes introduced by
249
+ * ROLLUP/CUBE/GROUPING SETS.
250
+ *
251
+ * @param aggregate The original Calcite aggregate node.
252
+ * @return A Substrait Rel node that is correctly typed to be the input to the Substrait
253
+ * Aggregate.
254
+ */
255
+ private Rel handleRollupCorrection (org .apache .calcite .rel .core .Aggregate aggregate ) {
256
+ Rel originalInput = apply (aggregate .getInput ());
257
+
258
+ // Determine the correct final output type for the aggregate, which accounts for nullability.
259
+ NamedStruct aggregateOutputType = typeConverter .toNamedStruct (aggregate .getRowType ());
260
+ List <Integer > groupKeyIndices = aggregate .getGroupSet ().asList ();
261
+
262
+ // Create a list of expressions to cast the original input to the correct final type if needed.
263
+ List <Expression > castExpressions = new ArrayList <>();
264
+
265
+ boolean needsCasting = false ;
266
+ for (int i = 0 ; i < originalInput .getRecordType ().fields ().size (); i ++) {
267
+ Expression fieldReference = FieldReference .newInputRelReference (i , originalInput );
268
+
269
+ if (groupKeyIndices .contains (i )) {
270
+ int groupKeyOutputIndex = groupKeyIndices .indexOf (i );
271
+ Type finalType = aggregateOutputType .struct ().fields ().get (groupKeyOutputIndex );
272
+
273
+ if (finalType .nullable () && !fieldReference .getType ().nullable ()) {
274
+ needsCasting = true ; // Mark that a cast is necessary.
275
+ castExpressions .add (
276
+ Expression .Cast .builder ()
277
+ .type (finalType )
278
+ .input (fieldReference )
279
+ .failureBehavior (Expression .FailureBehavior .RETURN_NULL )
280
+ .build ());
281
+ } else {
282
+ castExpressions .add (fieldReference );
283
+ }
284
+ } else {
285
+ castExpressions .add (fieldReference );
286
+ }
287
+ }
288
+
289
+ // Only add the extra Project node if a cast was actually needed.
290
+ if (needsCasting ) {
291
+ int originalFieldCount = originalInput .getRecordType ().fields ().size ();
292
+ return Project .builder ()
293
+ .input (originalInput )
294
+ .expressions (castExpressions )
295
+ .remap (Rel .Remap .offset (originalFieldCount , castExpressions .size ()))
296
+ .build ();
297
+ }
298
+
299
+ // If no casting was needed, just return the original converted input.
300
+ return originalInput ;
301
+ }
302
+
247
303
@ Override
248
304
public Rel visit (org .apache .calcite .rel .core .Aggregate aggregate ) {
249
- Rel input = apply (aggregate . getInput () );
305
+ Rel input = handleRollupCorrection (aggregate );
250
306
Stream <ImmutableBitSet > sets ;
251
307
if (aggregate .groupSets != null ) {
252
308
sets = aggregate .groupSets .stream ();
0 commit comments