diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index db9a22885b7..ecd891607b7 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -721,6 +721,23 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) +@disable_ansi_mode +def test_array_heterogeneous_elementwise_hof_mixed_project(): + data_gen = ArrayGen(IntegerGen(min_val=-10, max_val=10), max_length=8) + def do_it(spark): + outer_gen = IntegerGen(min_val=-5, max_val=5) + return three_col_df(spark, data_gen, outer_gen, outer_gen).selectExpr( + 'a', + 'b', + 'c', + 'transform(a, item -> item + b) as plus_b', + 'transform(a, item -> item + c) as plus_c', + 'filter(a, item -> item is not null and item + b >= c) as filtered_b_ge_c', + 'exists(a, item -> item is not null and item + c < b) as has_c_less_b') + + assert_gpu_and_cpu_are_equal_collect(do_it) + + array_zips_gen = array_gens_sample + [ArrayGen(map_string_string_gen[0], max_length=5), ArrayGen(BinaryGen(max_length=5), max_length=5)] diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 55151a3a562..c5a3a290c8d 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -126,6 +126,34 @@ def test_array_aggregate_count_if_int(): 'aggregate(a, 0L, (acc, x) -> acc + CAST(CASE WHEN x IS NULL THEN 1 ELSE 0 END as BIGINT)) as null_cnt')) +@disable_ansi_mode +def test_array_hof_mixed_project_with_aggregate(): + data_gen = ArrayGen(IntegerGen(min_val=-10, max_val=10), max_length=8) + def do_it(spark): + return unary_op_df(spark, data_gen).selectExpr( + 'transform(a, x -> x + 1) as plus_one', + 'filter(a, x -> x is not null and x >= 0) as non_negative', + 'exists(a, x -> x is not null and x < 0) as has_negative', + '''aggregate(a, 0L, + (acc, x) -> acc + CAST(CASE WHEN x IS NULL THEN 0 ELSE x END AS BIGINT)) + as sum_or_zero''') + + assert_gpu_and_cpu_are_equal_collect(do_it) + + +@disable_ansi_mode +def test_array_hof_mixed_project_with_indexed_lambdas(): + data_gen = ArrayGen(IntegerGen(min_val=-10, max_val=10), max_length=8) + outer_gen = IntegerGen(min_val=-3, max_val=3, nullable=False) + def do_it(spark): + return two_col_df(spark, data_gen, outer_gen).selectExpr( + 'transform(a, (x, i) -> coalesce(x, 0) + i + b) as indexed_add', + 'filter(a, (x, i) -> x is not null and x + i + b >= 0) as indexed_filter', + 'transform(a, (x, i) -> i - coalesce(x, 0)) as index_minus_value') + + assert_gpu_and_cpu_are_equal_collect(do_it) + + # `if(cond, acc + t, acc)` shape — branches lifted via op identity. Same count-if # pattern as above but written naturally instead of using `CASE WHEN ... THEN 1 ELSE 0`. @disable_ansi_mode diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index afab2a5ae09..d75e00a4ffe 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -131,8 +131,10 @@ object GpuProjectExec { // different vector length, thus not able to reuse cached vectors. GpuExpressionsUtils.cachedNullVectors.get.clear() - val newColumns = boundExprs.safeMap(_.columnarEval(cb)).toArray[ColumnVector] - new ColumnarBatch(newColumns, cb.numRows()) + GpuArrayHofFusion.project(cb, boundExprs).getOrElse { + val newColumns = boundExprs.safeMap(_.columnarEval(cb)).toArray[ColumnVector] + new ColumnarBatch(newColumns, cb.numRows()) + } } finally { GpuExpressionsUtils.cachedNullVectors.get.clear() } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index ae23430a90b..dd36f0a8034 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import ai.rapids.cudf import ai.rapids.cudf.{DType, Table} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression +import com.nvidia.spark.rapids.RapidsPluginImplicits.{AutoCloseableProducingSeq, + ReallyAGpuExpression} import com.nvidia.spark.rapids.jni.GpuMapZipWithUtils import com.nvidia.spark.rapids.shims.ShimExpression @@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.{Add, And, ArrayAggregate, Attribute, AttributeReference, AttributeSeq, CaseWhen, Cast, Expression, ExprId, Greatest, If, LambdaFunction, Least, Literal, Multiply, NamedExpression, NamedLambdaVariable, Or} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, NumericType, ShortType, StructField, StructType} -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** * A named lambda variable. In Spark on the CPU this includes an AtomicReference to the value that @@ -228,59 +229,71 @@ trait GpuArrayTransformBase extends GpuSimpleHigherOrderFunction { boundIntermediate.map(_.dataType) ++ lambdaFunction.arguments.map(_.dataType) } + private[rapids] def lambdaArgumentCount: Int = lambdaFunction.arguments.length + + private[rapids] def lambdaInputTypes: Seq[DataType] = inputToLambda + protected def makeElementProjectBatch( inputBatch: ColumnarBatch, argColumn: GpuColumnVector): ColumnarBatch = { - assert(argColumn.getBase.getType.equals(DType.LIST)) assert(isBound, "Trying to execute an un-bound transform expression") + GpuArrayTransformBase.makeExplodedElementBatch( + inputBatch, argColumn, boundIntermediate, inputToLambda, lambdaArgumentCount) + } + +} + +private[rapids] object GpuArrayTransformBase { + private[rapids] def makeExplodedElementBatch( + inputBatch: ColumnarBatch, + argColumn: GpuColumnVector, + intermediate: Seq[GpuExpression], + elementTypes: Seq[DataType], + lambdaArgumentCount: Int): ColumnarBatch = { + assert(argColumn.getBase.getType.equals(DType.LIST)) def projectAndExplode(explodeOp: Table => Table): Table = { - withResource(GpuProjectExec.project(inputBatch, boundIntermediate)) { - intermediateBatch => - withResource(GpuColumnVector.appendColumns(intermediateBatch, argColumn)) { - projectedBatch => - withResource(GpuColumnVector.from(projectedBatch)) { projectedTable => - explodeOp(projectedTable) - } - } + withResource(GpuProjectExec.project(inputBatch, intermediate)) { intermediateBatch => + withResource(GpuColumnVector.appendColumns(intermediateBatch, argColumn)) { + projectedBatch => + withResource(GpuColumnVector.from(projectedBatch)) { projectedTable => + explodeOp(projectedTable) + } + } } } - if (function.asInstanceOf[GpuLambdaFunction].arguments.length >= 2) { - // Need to do an explodePosition + if (lambdaArgumentCount >= 2) { val explodedTable = projectAndExplode { projectedTable => - projectedTable.explodePosition(boundIntermediate.length) + projectedTable.explodePosition(intermediate.length) } val reorderedTable = withResource(explodedTable) { explodedTable => - // The column order is wrong after an explodePosition. It is - // [other_columns*, position, entry] - // but we want - // [other_columns*, entry, position] - // So we have to remap it - val cols = new Array[cudf.ColumnVector](explodedTable.getNumberOfColumns) - val numOtherColumns = explodedTable.getNumberOfColumns - 2 - (0 until numOtherColumns).foreach { index => - cols(index) = explodedTable.getColumn(index) - } - cols(numOtherColumns) = explodedTable.getColumn(numOtherColumns + 1) - cols(numOtherColumns + 1) = explodedTable.getColumn(numOtherColumns) - - new cudf.Table(cols: _*) + reorderExplodePositionOutput(explodedTable) } withResource(reorderedTable) { reorderedTable => - GpuColumnVector.from(reorderedTable, inputToLambda.toArray) + GpuColumnVector.from(reorderedTable, elementTypes.toArray) } } else { - // Need to do an explode val explodedTable = projectAndExplode { projectedTable => - projectedTable.explode(boundIntermediate.length) + projectedTable.explode(intermediate.length) } withResource(explodedTable) { explodedTable => - GpuColumnVector.from(explodedTable, inputToLambda.toArray) + GpuColumnVector.from(explodedTable, elementTypes.toArray) } } } + private def reorderExplodePositionOutput(explodedTable: Table): Table = { + val cols = new Array[cudf.ColumnVector](explodedTable.getNumberOfColumns) + val numOtherColumns = explodedTable.getNumberOfColumns - 2 + (0 until numOtherColumns).foreach { index => + cols(index) = explodedTable.getColumn(index) + } + cols(numOtherColumns) = explodedTable.getColumn(numOtherColumns + 1) + cols(numOtherColumns + 1) = explodedTable.getColumn(numOtherColumns) + + new cudf.Table(cols: _*) + } } /** @@ -298,18 +311,213 @@ trait GpuArrayElementWiseTransform extends GpuArrayTransformBase { lambdaTransformedCV: cudf.ColumnView, arg: cudf.ColumnView): GpuColumnVector + private[rapids] def transformElementResults( + lambdaTransformed: GpuColumnVector, + arg: GpuColumnVector): GpuColumnVector = { + withResource(GpuListUtils.replaceListDataColumnAsView( + arg.getBase, lambdaTransformed.getBase)) { cv => + transformListColumnView(cv, arg.getBase) + } + } + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { withResource(argument.columnarEval(batch)) { arg => val dataCol = withResource(makeElementProjectBatch(batch, arg)) { cb => function.columnarEval(cb) } withResource(dataCol) { _ => - val cv = GpuListUtils.replaceListDataColumnAsView(arg.getBase, dataCol.getBase) - withResource(cv) { cv => - transformListColumnView(cv, arg.getBase) + transformElementResults(dataCol, arg) + } + } + } +} + +private[rapids] object GpuArrayHofFusion { + private case class HofInProject(index: Int, hof: GpuArrayTransformBase) + private case class HofGroup(hofs: Seq[HofInProject]) { + val startIndex: Int = hofs.head.index + val outputIndexes: Seq[Int] = hofs.map(_.index) + lazy val transforms: Seq[GpuArrayTransformBase] = hofs.map(_.hof) + lazy val first: GpuArrayTransformBase = transforms.head + lazy val unionIntermediate: Seq[GpuExpression] = collectUnionIntermediate(transforms) + lazy val elementTypes: Seq[DataType] = { + val lambdaArgTypes = first.lambdaInputTypes.drop(first.boundIntermediate.length) + unionIntermediate.map(_.dataType) ++ lambdaArgTypes + } + } + + def project( + batch: ColumnarBatch, + boundExprs: Seq[Expression]): Option[ColumnarBatch] = { + val fusedGroups = findFusedGroups(boundExprs) + if (fusedGroups.isEmpty) { + None + } else { + Some(projectWithFusedGroups(batch, boundExprs, fusedGroups)) + } + } + + private[rapids] def findFusedGroupIndexes( + boundExprs: Seq[Expression]): Seq[Seq[Int]] = + findFusedGroups(boundExprs).map(_.outputIndexes) + + private def findFusedGroups( + boundExprs: Seq[Expression]): Seq[HofGroup] = { + val fusedGroups = mutable.ArrayBuffer[HofGroup]() + val groups = mutable.ArrayBuffer[mutable.ArrayBuffer[HofInProject]]() + + def flushGroups(): Unit = { + fusedGroups ++= groups.filter(_.length > 1).map(g => HofGroup(g.toSeq)) + groups.clear() + } + + boundExprs.zipWithIndex.foreach { + case (expr, index) => + extractHof(expr).filter(canFuse) match { + case Some(hof) => + groups.find(g => canShareExplode(g.head.hof, hof)) match { + case Some(group) => group += HofInProject(index, hof) + case None => groups += mutable.ArrayBuffer(HofInProject(index, hof)) + } + case None if !canReorderExpression(expr) => + flushGroups() + case None => + } + } + flushGroups() + fusedGroups.toSeq + } + + private def extractHof( + expr: Expression): Option[GpuArrayTransformBase] = expr match { + case GpuAlias(transform: GpuArrayTransformBase, _) => Some(transform) + case transform: GpuArrayTransformBase => Some(transform) + case _ => None + } + + private def canFuse(transform: GpuArrayTransformBase): Boolean = { + isSupportedTransform(transform) && + transform.isBound && + transform.deterministic && + !transform.hasSideEffects && + transform.argument.deterministic && + transform.boundIntermediate.forall(_.deterministic) && + (transform.lambdaArgumentCount == 1 || transform.lambdaArgumentCount == 2) + } + + private def isSupportedTransform(transform: GpuArrayTransformBase): Boolean = transform match { + case _: GpuArrayElementWiseTransform | _: GpuArrayAggregate => true + case _ => false + } + + private def canShareExplode( + left: GpuArrayTransformBase, + right: GpuArrayTransformBase): Boolean = { + left.lambdaArgumentCount == right.lambdaArgumentCount && + left.argument.semanticEquals(right.argument) + } + + private def canReorderExpression(expr: Expression): Boolean = + expr.deterministic && (expr match { + case gpuExpr: GpuExpression => !gpuExpr.hasSideEffects + case _ => false + }) + + private def projectWithFusedGroups( + batch: ColumnarBatch, + boundExprs: Seq[Expression], + fusedGroups: Seq[HofGroup]): ColumnarBatch = { + val outputColumns = new Array[ColumnVector](boundExprs.length) + val groupsByStartIndex = fusedGroups.map(group => group.startIndex -> group).toMap + closeOnExcept(outputColumns) { _ => + boundExprs.indices.foreach { index => + if (outputColumns(index) == null) { + groupsByStartIndex.get(index) match { + case Some(group) => + val transformed = evaluateFusedGroup(batch, group).toArray[ColumnVector] + closeOnExcept(transformed) { _ => + group.hofs.zipWithIndex.foreach { + case (HofInProject(outputIndex, _), transformedIndex) => + outputColumns(outputIndex) = transformed(transformedIndex) + transformed(transformedIndex) = null + } + } + case None => + outputColumns(index) = boundExprs(index).columnarEval(batch) + } } } + new ColumnarBatch(outputColumns, batch.numRows()) + } + } + + private def evaluateFusedGroup( + batch: ColumnarBatch, + group: HofGroup): Seq[GpuColumnVector] = { + withResource(group.first.argument.columnarEval(batch)) { arg => + withResource(GpuArrayTransformBase.makeExplodedElementBatch( + batch, arg, group.unionIntermediate, group.elementTypes, + group.first.lambdaArgumentCount)) { sharedBatch => + val output = mutable.ArrayBuffer[GpuColumnVector]() + closeOnExcept(output) { _ => + group.transforms.foreach { transform => + val dataCol = withResource(makeTransformLambdaBatch(sharedBatch, group, transform)) { + lambdaBatch => + transform.function.columnarEval(lambdaBatch) + } + withResource(dataCol) { _ => + output += consumeElementResults(batch, transform, dataCol, arg) + } + } + output.toSeq + } + } + } + } + + private def consumeElementResults( + batch: ColumnarBatch, + transform: GpuArrayTransformBase, + dataCol: GpuColumnVector, + arg: GpuColumnVector): GpuColumnVector = transform match { + case elementWise: GpuArrayElementWiseTransform => + elementWise.transformElementResults(dataCol, arg) + case aggregate: GpuArrayAggregate => + aggregate.aggregateElementResults(batch, dataCol, arg) + case other => + throw new IllegalStateException( + s"Unsupported array transform fusion expression: ${other.getClass.getName}") + } + + private def collectUnionIntermediate( + transforms: Seq[GpuArrayTransformBase]): Seq[GpuExpression] = { + val unionIntermediate = mutable.ArrayBuffer[GpuExpression]() + transforms.foreach { transform => + transform.boundIntermediate.foreach { expr => + if (!unionIntermediate.exists(_.semanticEquals(expr))) { + unionIntermediate += expr + } + } + } + unionIntermediate.toSeq + } + + private def makeTransformLambdaBatch( + sharedBatch: ColumnarBatch, + group: HofGroup, + transform: GpuArrayTransformBase): ColumnarBatch = { + val lambdaArgStart = group.unionIntermediate.length + val intermediateIndexes = transform.boundIntermediate.map { expr => + val index = group.unionIntermediate.indexWhere(_.semanticEquals(expr)) + assert(index >= 0, s"Missing shared transform intermediate: $expr") + index } + val indexes = intermediateIndexes ++ + (lambdaArgStart until sharedBatch.numCols()) + val columns = indexes.safeMap { index => + sharedBatch.column(index).asInstanceOf[GpuColumnVector].incRefCount() + }.toArray[ColumnVector] + new ColumnarBatch(columns, sharedBatch.numRows()) } } @@ -1389,61 +1597,67 @@ case class GpuArrayAggregate( } } - override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + private[rapids] def aggregateElementResults( + batch: ColumnarBatch, + transformedData: GpuColumnVector, + arg: GpuColumnVector): GpuColumnVector = { val outDType = GpuColumnVector.getNonNestedRapidsType(dataType) - withResource(argument.asInstanceOf[GpuExpression].columnarEval(batch)) { arg => - // Step 1: g(x) over children + segmented reduce. - val reduced: cudf.ColumnVector = - withResource(makeElementProjectBatch(batch, arg)) { cb => - withResource(function.asInstanceOf[GpuExpression].columnarEval(cb)) { - transformedData => - withResource(GpuListUtils.replaceListDataColumnAsView( - arg.getBase, transformedData.getBase)) { listOfGView => - listOfGView.listReduce(op.cudfAgg, op.nullPolicy, outDType) - } + // Step 1: g(x) over children + segmented reduce. + val reduced: cudf.ColumnVector = withResource(GpuListUtils.replaceListDataColumnAsView( + arg.getBase, transformedData.getBase)) { listOfGView => + listOfGView.listReduce(op.cudfAgg, op.nullPolicy, outDType) + } + + // Step 2: substitute op's identity where needed. + val adjusted: cudf.ColumnVector = withResource(reduced) { reduced => + if (op.nullPolicy == cudf.NullPolicy.EXCLUDE) { + // MAX/MIN should keep no-contribution rows as null until combineWithZero. + // NULL_MAX/NULL_MIN then return zero when it is non-null, and preserve a null zero. + reduced.incRefCount() + } else { + withResource(emptyIncludeListMask(arg.getBase)) { mask => + withResource(op.identityScalar(dataType)) { idScalar => + mask.ifElse(idScalar, reduced) } } + } + } - // Step 2: substitute op's identity where needed. - val adjusted: cudf.ColumnVector = withResource(reduced) { reduced => - if (op.nullPolicy == cudf.NullPolicy.EXCLUDE) { - // MAX/MIN should keep no-contribution rows as null until combineWithZero. - // NULL_MAX/NULL_MIN then return zero when it is non-null, and preserve a null zero. - reduced.incRefCount() - } else { - withResource(emptyIncludeListMask(arg.getBase)) { mask => - withResource(op.identityScalar(dataType)) { idScalar => - mask.ifElse(idScalar, reduced) - } + // Step 3: combine with zero. When `zero` is a Literal (the common 4-arg + // `aggregate(arr, 0, ...)` shape) skip the per-batch column broadcast and pass a + // cudf.Scalar instead — `add/mul/and/or/binaryOp` all accept BinaryOperable. + val combined: cudf.ColumnVector = withResource(adjusted) { adjusted => + zero match { + case lit: GpuLiteral => + withResource(GpuScalar.from(lit.value, lit.dataType)) { zeroScalar => + op.combineWithZero(adjusted, zeroScalar, outDType) + } + case _ => + withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => + op.combineWithZero(adjusted, zeroCv.getBase, outDType) } - } } + } - // Step 3: combine with zero. When `zero` is a Literal (the common 4-arg - // `aggregate(arr, 0, ...)` shape) skip the per-batch column broadcast and pass a - // cudf.Scalar instead — `add/mul/and/or/binaryOp` all accept BinaryOperable. - val combined: cudf.ColumnVector = withResource(adjusted) { adjusted => - zero match { - case lit: GpuLiteral => - withResource(GpuScalar.from(lit.value, lit.dataType)) { zeroScalar => - op.combineWithZero(adjusted, zeroScalar, outDType) - } - case _ => - withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => - op.combineWithZero(adjusted, zeroCv.getBase, outDType) - } - } + // Step 4: restore null on rows where the input list itself was null. cuDF NULL_MAX / + // NULL_MIN / LOGICAL_AND / LOGICAL_OR don't propagate null the way Spark's 3VL would, + // so the combine step alone can't preserve it. Skip outright when the list has no nulls. + if (arg.getBase.getNullCount > 0) { + withResource(combined) { combined => + GpuColumnVector.from(NullUtilities.mergeNulls(combined, arg.getBase), dataType) } + } else { + GpuColumnVector.from(combined, dataType) + } + } - // Step 4: restore null on rows where the input list itself was null. cuDF NULL_MAX / - // NULL_MIN / LOGICAL_AND / LOGICAL_OR don't propagate null the way Spark's 3VL would, - // so the combine step alone can't preserve it. Skip outright when the list has no nulls. - if (arg.getBase.getNullCount > 0) { - withResource(combined) { combined => - GpuColumnVector.from(NullUtilities.mergeNulls(combined, arg.getBase), dataType) - } - } else { - GpuColumnVector.from(combined, dataType) + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + withResource(argument.asInstanceOf[GpuExpression].columnarEval(batch)) { arg => + val transformedData = withResource(makeElementProjectBatch(batch, arg)) { cb => + function.asInstanceOf[GpuExpression].columnarEval(cb) + } + withResource(transformedData) { _ => + aggregateElementResults(batch, transformedData, arg) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuArrayHofFusionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuArrayHofFusionSuite.scala new file mode 100644 index 00000000000..3fad64c3cc1 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuArrayHofFusionSuite.scala @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId} +import org.apache.spark.sql.types.{ArrayType, BooleanType, DataType, IntegerType, LongType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +class GpuArrayHofFusionSuite extends GpuUnitTests { + private val arrayArg = + GpuBoundReference(0, ArrayType(IntegerType, containsNull = true), + nullable = true)(ExprId(0), "a") + private val otherArrayArg = + GpuBoundReference(1, ArrayType(IntegerType, containsNull = true), + nullable = true)(ExprId(1), "other") + private val outerB = + GpuBoundReference(2, IntegerType, nullable = true)(ExprId(2), "b") + private val outerC = + GpuBoundReference(3, IntegerType, nullable = true)(ExprId(3), "c") + + private case class NonDeterministicExpr() extends GpuLeafExpression { + override lazy val deterministic: Boolean = false + override def dataType: DataType = IntegerType + override def nullable: Boolean = false + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = + throw new UnsupportedOperationException("test-only expression") + } + + private def lambda(resultType: DataType, argCount: Int = 1): GpuLambdaFunction = { + val args = (0 until argCount).map { index => + GpuNamedLambdaVariable(s"x$index", IntegerType, nullable = true, ExprId(100 + index)) + } + val value = resultType match { + case BooleanType => true + case LongType => 1L + case _ => 1 + } + GpuLambdaFunction(GpuLiteral(value, resultType), args) + } + + private def transform( + argument: Expression = arrayArg, + argCount: Int = 1, + boundIntermediate: Seq[GpuExpression] = Seq.empty): GpuArrayTransform = + GpuArrayTransform(argument, lambda(IntegerType, argCount), isBound = true, boundIntermediate) + + private def filter( + argument: Expression = arrayArg, + argCount: Int = 1, + boundIntermediate: Seq[GpuExpression] = Seq.empty): GpuArrayFilter = + GpuArrayFilter(argument, lambda(BooleanType, argCount), isBound = true, boundIntermediate) + + private def aggregate( + argument: Expression = arrayArg, + boundIntermediate: Seq[GpuExpression]): GpuArrayAggregate = + GpuArrayAggregate(argument, GpuLiteral(0L, LongType), lambda(LongType), SumOp, + isBound = true, boundIntermediate) + + private def alias(expr: GpuExpression, name: String): GpuAlias = + GpuAlias(expr, name)() + + private def literalProject(name: String): GpuAlias = + alias(GpuLiteral(1, IntegerType), name) + + private def nonDeterministicProject(name: String): GpuAlias = + alias(NonDeterministicExpr(), name) + + test("finds heterogeneous array HOFs over the same argument") { + val exprs = Seq( + alias(transform(boundIntermediate = Seq(outerB)), "t"), + literalProject("safe"), + alias(filter(boundIntermediate = Seq(outerC)), "f"), + alias(aggregate(boundIntermediate = Seq(outerB, outerC)), "a")) + + assertResult(Seq(Seq(0, 2, 3))) { + GpuArrayHofFusion.findFusedGroupIndexes(exprs) + } + } + + test("does not fuse different arguments or lambda arities") { + val exprs = Seq( + alias(transform(), "one_arg"), + alias(filter(argCount = 2), "two_arg"), + alias(transform(otherArrayArg), "other_arg")) + + assertResult(Seq.empty[Seq[Int]]) { + GpuArrayHofFusion.findFusedGroupIndexes(exprs) + } + } + + test("splits fused groups around a non-deterministic barrier") { + val exprs = Seq( + alias(transform(), "before_left"), + alias(filter(), "before_right"), + nonDeterministicProject("barrier"), + alias(transform(boundIntermediate = Seq(outerB)), "after_left"), + literalProject("safe"), + alias(filter(boundIntermediate = Seq(outerC)), "after_right")) + + assertResult(Seq(Seq(0, 1), Seq(3, 5))) { + GpuArrayHofFusion.findFusedGroupIndexes(exprs) + } + } +}