@@ -244,9 +244,65 @@ public Rel visit(org.apache.calcite.rel.core.Minus minus) {
244244 return Set .builder ().inputs (inputs ).setOp (setOp ).build ();
245245 }
246246
247+ /**
248+ * Pre-processes the input to an Aggregate relation to handle nullability changes introduced by
249+ * ROLLUP/CUBE/GROUPING SETS.
250+ *
251+ * @param aggregate The original Calcite aggregate node.
252+ * @return A Substrait Rel node that is correctly typed to be the input to the Substrait
253+ * Aggregate.
254+ */
255+ private Rel handleRollupCorrection (org .apache .calcite .rel .core .Aggregate aggregate ) {
256+ Rel originalInput = apply (aggregate .getInput ());
257+
258+ // Determine the correct final output type for the aggregate, which accounts for nullability.
259+ NamedStruct aggregateOutputType = typeConverter .toNamedStruct (aggregate .getRowType ());
260+ List <Integer > groupKeyIndices = aggregate .getGroupSet ().asList ();
261+
262+ // Create a list of expressions to cast the original input to the correct final type if needed.
263+ List <Expression > castExpressions = new ArrayList <>();
264+
265+ boolean needsCasting = false ;
266+ for (int i = 0 ; i < originalInput .getRecordType ().fields ().size (); i ++) {
267+ Expression fieldReference = FieldReference .newInputRelReference (i , originalInput );
268+
269+ if (groupKeyIndices .contains (i )) {
270+ int groupKeyOutputIndex = groupKeyIndices .indexOf (i );
271+ Type finalType = aggregateOutputType .struct ().fields ().get (groupKeyOutputIndex );
272+
273+ if (finalType .nullable () && !fieldReference .getType ().nullable ()) {
274+ needsCasting = true ; // Mark that a cast is necessary.
275+ castExpressions .add (
276+ Expression .Cast .builder ()
277+ .type (finalType )
278+ .input (fieldReference )
279+ .failureBehavior (Expression .FailureBehavior .RETURN_NULL )
280+ .build ());
281+ } else {
282+ castExpressions .add (fieldReference );
283+ }
284+ } else {
285+ castExpressions .add (fieldReference );
286+ }
287+ }
288+
289+ // Only add the extra Project node if a cast was actually needed.
290+ if (needsCasting ) {
291+ int originalFieldCount = originalInput .getRecordType ().fields ().size ();
292+ return Project .builder ()
293+ .input (originalInput )
294+ .expressions (castExpressions )
295+ .remap (Rel .Remap .offset (originalFieldCount , castExpressions .size ()))
296+ .build ();
297+ }
298+
299+ // If no casting was needed, just return the original converted input.
300+ return originalInput ;
301+ }
302+
247303 @ Override
248304 public Rel visit (org .apache .calcite .rel .core .Aggregate aggregate ) {
249- Rel input = apply (aggregate . getInput () );
305+ Rel input = handleRollupCorrection (aggregate );
250306 Stream <ImmutableBitSet > sets ;
251307 if (aggregate .groupSets != null ) {
252308 sets = aggregate .groupSets .stream ();
0 commit comments