diff --git a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuCheckDeltaInvariant.scala b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuCheckDeltaInvariant.scala index 67164017bbe..bcbbef9b9a4 100644 --- a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuCheckDeltaInvariant.scala +++ b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuCheckDeltaInvariant.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * This file was derived from CheckDeltaInvariant.scala in the * Delta Lake project at https://github.com/delta-io/delta. @@ -132,8 +132,8 @@ object GpuCheckDeltaInvariant extends Logging { ExprChecks.projectOnly( TypeSig.all, TypeSig.all, - paramCheck = Seq(ParamCheck("input", TypeSig.all, TypeSig.all)), - repeatingParamCheck = Some(RepeatingParamCheck("extra", TypeSig.all, TypeSig.all)) + paramCheck = Seq(new ParamCheck("input", TypeSig.all, TypeSig.all)), + repeatingParamCheck = Some(new RepeatingParamCheck("extra", TypeSig.all, TypeSig.all)) ), (c, conf, p, r) => new GpuCheckDeltaInvariantMeta(c, conf, p, r)) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala index f2ce4d8a39f..90f935ac57b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -110,7 +110,7 @@ object AstUtil { val gpuExpr = expr.convertToGpu() // Check if we've already processed this expression (for deduplication) - processed.get(GpuExpressionEquals(gpuExpr)) match { + processed.get(new GpuExpressionEquals(gpuExpr)) match { case Some(replacement) => replacement case None => @@ -135,7 +135,7 @@ object AstUtil { // Create an AttributeReference explicitly to avoid issues with unresolved aliases val attributeRef = AttributeReference(alias.name, gpuExpr.dataType, gpuExpr.nullable, alias.metadata)(alias.exprId, alias.qualifier) - processed.put(GpuExpressionEquals(gpuExpr), attributeRef) + processed.put(new GpuExpressionEquals(gpuExpr), attributeRef) attributeRef } } else { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index f82d903fc8a..086db5183cf 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -30,7 +30,6 @@ import com.nvidia.spark.rapids.shims._ import com.nvidia.spark.rapids.window.{GpuDenseRank, GpuLag, GpuLead, GpuPercentRank, GpuRank, GpuRowNumber, GpuSpecialFrameBoundary, GpuWindowExecMeta, GpuWindowSpecDefinitionMeta} import org.apache.hadoop.fs.Path -import org.apache.spark.internal.Logging import org.apache.spark.rapids.hybrid.HybridExecutionUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ @@ -457,7 +456,157 @@ object WriteFileOp extends FileFormatOp { override def toString = "write" } -object GpuOverrides extends Logging { +object GpuOverrides { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + log.warn(msg) + } + + private def confValueToString(value: Any): String = value.toString + + private def dataTypeExistsRecursively( + dataType: DataType, + f: DataType => Boolean): Boolean = { + f(dataType) || (dataType match { + case ArrayType(elementType, _) => + dataTypeExistsRecursively(elementType, f) + case MapType(keyType, valueType, _) => + dataTypeExistsRecursively(keyType, f) || dataTypeExistsRecursively(valueType, f) + case StructType(fields) => + fields.exists(field => dataTypeExistsRecursively(field.dataType, f)) + case _ => false + }) + } + + // Keep version-specific shim rule registries out of this object's constant pool. + private def shimSingleton(name: String): AnyRef = { + Class.forName("com.nvidia.spark.rapids.shims." + name + "$") + .getField("MODULE" + "$") + .get(null) + .asInstanceOf[AnyRef] + } + + private def invokeShimSingleton(name: String, method: String): Any = { + val module = shimSingleton(name) + module.getClass.getMethod(method).invoke(module) + } + + private def shimExprs( + name: String): Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + invokeShimSingleton(name, "exprs") + .asInstanceOf[Map[Class[_ <: Expression], ExprRule[_ <: Expression]]] + } + + private def shimExprRules( + name: String, + method: String): Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + invokeShimSingleton(name, method) + .asInstanceOf[Map[Class[_ <: Expression], ExprRule[_ <: Expression]]] + } + + private def shimExprRule(name: String, method: String): ExprRule[_ <: Expression] = { + invokeShimSingleton(name, method).asInstanceOf[ExprRule[_ <: Expression]] + } + + private def shimScanRules( + name: String, + method: String): Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = { + invokeShimSingleton(name, method) + .asInstanceOf[Map[Class[_ <: Scan], ScanRule[_ <: Scan]]] + } + + private def shimPartRules( + name: String, + method: String): Map[Class[_ <: Partitioning], PartRule[_ <: Partitioning]] = { + invokeShimSingleton(name, method) + .asInstanceOf[Map[Class[_ <: Partitioning], PartRule[_ <: Partitioning]]] + } + + private def shimDataWriteCmdRules( + name: String, + method: String): Map[Class[_ <: DataWritingCommand], + DataWritingCommandRule[_ <: DataWritingCommand]] = { + invokeShimSingleton(name, method) + .asInstanceOf[Map[Class[_ <: DataWritingCommand], + DataWritingCommandRule[_ <: DataWritingCommand]]] + } + + private def shimRunnableCmdRules( + name: String, + method: String): Map[Class[_ <: RunnableCommand], + RunnableCommandRule[_ <: RunnableCommand]] = { + invokeShimSingleton(name, method) + .asInstanceOf[Map[Class[_ <: RunnableCommand], + RunnableCommandRule[_ <: RunnableCommand]]] + } + + private def shimExecRules( + name: String, + method: String): Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = { + invokeShimSingleton(name, method) + .asInstanceOf[Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]]] + } + + private def shimExecRule(name: String, method: String): ExecRule[_ <: SparkPlan] = { + invokeShimSingleton(name, method).asInstanceOf[ExecRule[_ <: SparkPlan]] + } + + private def optionalShimExecRule(name: String, method: String): ExecRule[_ <: SparkPlan] = { + invokeShimSingleton(name, method) + .asInstanceOf[Option[ExecRule[_ <: SparkPlan]]] + .orNull + } + + @transient private[this] lazy val aggregateInPandasExecShimsModule = { + Class.forName("com.nvidia.spark.rapids.shims.AggregateInPandasExecShims" + "$") + .getField("MODULE" + "$") + .get(null) + } + + @transient private[this] lazy val aggregateInPandasExecRuleMethod = + aggregateInPandasExecShimsModule.getClass.getMethod("execRule") + + private def aggregateInPandasExecRule: ExecRule[_ <: SparkPlan] = { + aggregateInPandasExecRuleMethod.invoke(aggregateInPandasExecShimsModule) + .asInstanceOf[Option[ExecRule[_ <: SparkPlan]]] + .orNull + } + + @transient private[this] lazy val batchScanExecMetaConstructor = + Class.forName("com.nvidia.spark.rapids.shims.BatchScanExecMeta") + .getConstructor(classOf[BatchScanExec], classOf[RapidsConf], + classOf[Option[_]], classOf[DataFromReplacementRule]) + + private def newBatchScanExecMeta( + p: BatchScanExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule): SparkPlanMeta[BatchScanExec] = { + batchScanExecMetaConstructor.newInstance(p, conf, parent, rule) + .asInstanceOf[SparkPlanMeta[BatchScanExec]] + } + + @transient private[this] lazy val gpuSubqueryBroadcastMetaConstructor = + Class.forName("org.apache.spark.sql.rapids.execution.GpuSubqueryBroadcastMeta") + .getConstructor(classOf[SubqueryBroadcastExec], classOf[RapidsConf], + classOf[Option[_]], classOf[DataFromReplacementRule]) + + private def newGpuSubqueryBroadcastMeta( + plan: SubqueryBroadcastExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule): SparkPlanMeta[SubqueryBroadcastExec] = { + gpuSubqueryBroadcastMetaConstructor.newInstance(plan, conf, parent, rule) + .asInstanceOf[SparkPlanMeta[SubqueryBroadcastExec]] + } + val FLOAT_DIFFERS_GROUP_INCOMPAT = "when enabling these, there may be extra groups produced for floating point grouping " + "keys (e.g. -0.0, and 0.0)" @@ -718,13 +867,13 @@ object GpuOverrides extends Logging { expressions.exists(isStringLit) def isOrContainsFloatingPoint(dataType: DataType): Boolean = - TrampolineUtil.dataTypeExistsRecursively(dataType, dt => dt == FloatType || dt == DoubleType) + dataTypeExistsRecursively(dataType, dt => dt == FloatType || dt == DoubleType) def isOrContainsDateOrTimestamp(dataType: DataType): Boolean = - TrampolineUtil.dataTypeExistsRecursively(dataType, dt => dt == TimestampType || dt == DateType) + dataTypeExistsRecursively(dataType, dt => dt == TimestampType || dt == DateType) def isOrContainsTimestamp(dataType: DataType): Boolean = - TrampolineUtil.dataTypeExistsRecursively(dataType, dt => dt == TimestampType) + dataTypeExistsRecursively(dataType, dt => dt == TimestampType) /** Tries to predict whether an adaptive plan will end up with data on the GPU or not. */ def probablyGpuPlan(adaptivePlan: AdaptiveSparkPlanExec, conf: RapidsConf): Boolean = { @@ -849,6 +998,15 @@ object GpuOverrides extends Logging { new ExecRule[INPUT](doWrap, desc, Some(pluginChecks), tag) } + def execFromShim[INPUT <: SparkPlan]( + rule: ShimExecRule[INPUT], + pluginChecks: ExecChecks, + doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) + => SparkPlanMeta[INPUT]): ExecRule[INPUT] = { + assert(rule != null) + exec(rule.desc, pluginChecks, doWrap)(rule.tag) + } + def dataWriteCmd[INPUT <: DataWritingCommand]( desc: String, doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) @@ -859,6 +1017,14 @@ object GpuOverrides extends Logging { new DataWritingCommandRule[INPUT](doWrap, desc, tag) } + def dataWriteCmdFromShim[INPUT <: DataWritingCommand]( + rule: ShimDataWritingCommandRule[INPUT], + doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) + => DataWritingCommandMeta[INPUT]): DataWritingCommandRule[INPUT] = { + assert(rule != null) + dataWriteCmd(rule.desc, doWrap)(rule.tag) + } + def wrapExpr[INPUT <: Expression]( expr: INPUT, conf: RapidsConf, @@ -1025,8 +1191,8 @@ object GpuOverrides extends Logging { ExprChecks.windowOnly( TypeSig.all, TypeSig.all, - Seq(ParamCheck("windowFunction", TypeSig.all, TypeSig.all), - ParamCheck("windowSpec", + Seq(new ParamCheck("windowFunction", TypeSig.all, TypeSig.all), + new ParamCheck("windowSpec", TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_64, TypeSig.numericAndInterval))), (windowExpression, conf, p, r) => new GpuWindowExpressionMeta(windowExpression, conf, p, r)), @@ -1037,11 +1203,11 @@ object GpuOverrides extends Logging { TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral, TypeSig.numericAndInterval, Seq( - ParamCheck("lower", + new ParamCheck("lower", TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128 + TypeSig.FLOAT + TypeSig.DOUBLE, TypeSig.numericAndInterval), - ParamCheck("upper", + new ParamCheck("upper", TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128 + TypeSig.FLOAT + TypeSig.DOUBLE, TypeSig.numericAndInterval))), @@ -1077,7 +1243,7 @@ object GpuOverrides extends Logging { "Window function that returns the index for the row within the aggregation window", ExprChecks.windowOnly(TypeSig.INT, TypeSig.INT, repeatingParamCheck = - Some(RepeatingParamCheck("ordering", + Some(new RepeatingParamCheck("ordering", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.all))), (rowNumber, conf, p, r) => new ExprMeta[RowNumber](rowNumber, conf, p, r) { @@ -1087,7 +1253,7 @@ object GpuOverrides extends Logging { "Window function that returns the rank value within the aggregation window", ExprChecks.windowOnly(TypeSig.INT, TypeSig.INT, repeatingParamCheck = - Some(RepeatingParamCheck("ordering", + Some(new RepeatingParamCheck("ordering", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.all))), (rank, conf, p, r) => new ExprMeta[Rank](rank, conf, p, r) { @@ -1097,7 +1263,7 @@ object GpuOverrides extends Logging { "Window function that returns the dense rank value within the aggregation window", ExprChecks.windowOnly(TypeSig.INT, TypeSig.INT, repeatingParamCheck = - Some(RepeatingParamCheck("ordering", + Some(new RepeatingParamCheck("ordering", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.all))), (denseRank, conf, p, r) => new ExprMeta[DenseRank](denseRank, conf, p, r) { @@ -1108,7 +1274,7 @@ object GpuOverrides extends Logging { "Window function that returns the percent rank value within the aggregation window", ExprChecks.windowOnly(TypeSig.DOUBLE, TypeSig.DOUBLE, repeatingParamCheck = - Some(RepeatingParamCheck("ordering", + Some(new RepeatingParamCheck("ordering", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.all))), (percentRank, conf, p, r) => new ExprMeta[PercentRank](percentRank, conf, p, r) { @@ -1122,12 +1288,12 @@ object GpuOverrides extends Logging { TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all, Seq( - ParamCheck("input", + new ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all), - ParamCheck("offset", TypeSig.INT, TypeSig.INT), - ParamCheck("default", + new ParamCheck("offset", TypeSig.INT, TypeSig.INT), + new ParamCheck("default", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all) @@ -1144,12 +1310,12 @@ object GpuOverrides extends Logging { TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all, Seq( - ParamCheck("input", + new ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all), - ParamCheck("offset", TypeSig.INT, TypeSig.INT), - ParamCheck("default", + new ParamCheck("offset", TypeSig.INT, TypeSig.INT), + new ParamCheck("default", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all) @@ -1409,7 +1575,7 @@ object GpuOverrides extends Logging { expr[AtLeastNNonNulls]( "Checks if number of non null/Nan values is greater than a given value", ExprChecks.projectOnly(TypeSig.BOOLEAN, TypeSig.BOOLEAN, - repeatingParamCheck = Some(RepeatingParamCheck("input", + repeatingParamCheck = Some(new RepeatingParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all))), @@ -1524,7 +1690,7 @@ object GpuOverrides extends Logging { "Returns the bitwise AND of all non-null input values", ExprChecks.reductionAndGroupByAgg( TypeSig.integral, TypeSig.integral, - Seq(ParamCheck("input", TypeSig.integral, TypeSig.integral))), + Seq(new ParamCheck("input", TypeSig.integral, TypeSig.integral))), (a, conf, p, r) => new AggExprMeta[BitAndAgg](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = GpuBitAndAgg(childExprs.head) @@ -1535,7 +1701,7 @@ object GpuOverrides extends Logging { "Returns the bitwise OR of all non-null input values", ExprChecks.reductionAndGroupByAgg( TypeSig.integral, TypeSig.integral, - Seq(ParamCheck("input", TypeSig.integral, TypeSig.integral))), + Seq(new ParamCheck("input", TypeSig.integral, TypeSig.integral))), (a, conf, p, r) => new AggExprMeta[BitOrAgg](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = GpuBitOrAgg(childExprs.head) @@ -1546,7 +1712,7 @@ object GpuOverrides extends Logging { "Returns the bitwise XOR of all non-null input values", ExprChecks.reductionAndGroupByAgg( TypeSig.integral, TypeSig.integral, - Seq(ParamCheck("input", TypeSig.integral, TypeSig.integral))), + Seq(new ParamCheck("input", TypeSig.integral, TypeSig.integral))), (a, conf, p, r) => new AggExprMeta[BitXorAgg](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = GpuBitXorAgg(childExprs.head) @@ -1559,7 +1725,7 @@ object GpuOverrides extends Logging { (gpuCommonTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.BINARY + TypeSig.MAP + GpuTypeShims.additionalArithmeticSupportedTypes).nested(), TypeSig.all, - repeatingParamCheck = Some(RepeatingParamCheck("param", + repeatingParamCheck = Some(new RepeatingParamCheck("param", (gpuCommonTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.BINARY + TypeSig.MAP + GpuTypeShims.additionalArithmeticSupportedTypes).nested(), TypeSig.all))), @@ -1571,7 +1737,7 @@ object GpuOverrides extends Logging { "Returns the least value of all parameters, skipping null values", ExprChecks.projectOnly( TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.orderable, - repeatingParamCheck = Some(RepeatingParamCheck("param", + repeatingParamCheck = Some(new RepeatingParamCheck("param", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.orderable))), (a, conf, p, r) => new ExprMeta[Least](a, conf, p, r) { @@ -1581,7 +1747,7 @@ object GpuOverrides extends Logging { "Returns the greatest value of all parameters, skipping null values", ExprChecks.projectOnly( TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.orderable, - repeatingParamCheck = Some(RepeatingParamCheck("param", + repeatingParamCheck = Some(new RepeatingParamCheck("param", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.orderable))), (a, conf, p, r) => new ExprMeta[Greatest](a, conf, p, r) { @@ -1885,9 +2051,9 @@ object GpuOverrides extends Logging { "are the last day of month, time of day will be ignored. Otherwise, the difference is " + "calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false.", ExprChecks.projectOnly(TypeSig.DOUBLE, TypeSig.DOUBLE, - Seq(ParamCheck("timestamp1", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), - ParamCheck("timestamp2", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), - ParamCheck("round", TypeSig.lit(TypeEnum.BOOLEAN), TypeSig.BOOLEAN))), + Seq(new ParamCheck("timestamp1", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + new ParamCheck("timestamp2", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + new ParamCheck("round", TypeSig.lit(TypeEnum.BOOLEAN), TypeSig.BOOLEAN))), (a, conf, p, r) => new MonthsBetweenExprMeta(a, conf, p, r) ), expr[TruncDate]( @@ -2061,9 +2227,9 @@ object GpuOverrides extends Logging { expr[In]( "IN operator", ExprChecks.projectOnly(TypeSig.BOOLEAN, TypeSig.BOOLEAN, - Seq(ParamCheck("value", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + Seq(new ParamCheck("value", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.comparable)), - Some(RepeatingParamCheck("list", + Some(new RepeatingParamCheck("list", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128).withAllLit(), TypeSig.comparable))), (in, conf, p, r) => new ExprMeta[In](in, conf, p, r) { @@ -2131,12 +2297,12 @@ object GpuOverrides extends Logging { (gpuCommonTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.BINARY + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all, - Seq(ParamCheck("predicate", TypeSig.BOOLEAN, TypeSig.BOOLEAN), - ParamCheck("trueValue", + Seq(new ParamCheck("predicate", TypeSig.BOOLEAN, TypeSig.BOOLEAN), + new ParamCheck("trueValue", (gpuCommonTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.BINARY + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), - ParamCheck("falseValue", + new ParamCheck("falseValue", (gpuCommonTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.BINARY + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all))), @@ -2162,8 +2328,8 @@ object GpuOverrides extends Logging { ExprChecks.fullAgg( TypeSig.all, TypeSig.all, - Seq(ParamCheck("aggFunc", TypeSig.all, TypeSig.all)), - Some(RepeatingParamCheck("filter", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), + Seq(new ParamCheck("aggFunc", TypeSig.all, TypeSig.all)), + Some(new RepeatingParamCheck("filter", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), (a, conf, p, r) => new ExprMeta[AggregateExpression](a, conf, p, r) { private val filter: Option[BaseExprMeta[_]] = a.filter.map(GpuOverrides.wrapExpr(_, this.conf, Some(this))) @@ -2193,23 +2359,23 @@ object GpuOverrides extends Logging { pluginSupportedOrderableSig + TypeSig.ARRAY.nested(gpuCommonTypes) .withPsNote(TypeEnum.ARRAY, "STRUCT is not supported as a child type for ARRAY"), TypeSig.orderable, - Seq(ParamCheck( + Seq(new ParamCheck( "input", pluginSupportedOrderableSig + TypeSig.ARRAY.nested(gpuCommonTypes) .withPsNote(TypeEnum.ARRAY, "STRUCT is not supported as a child type for ARRAY"), TypeSig.orderable))), - GpuSortOrderMeta), + (sortOrder, conf, p, r) => new GpuSortOrderMeta(sortOrder, conf, p, r)), expr[PivotFirst]( "PivotFirst operator", ExprChecks.reductionAndGroupByAgg( TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128), TypeSig.all, - Seq(ParamCheck( + Seq(new ParamCheck( "pivotColumn", (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128), TypeSig.all), - ParamCheck("valueColumn", + new ParamCheck("valueColumn", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.all))), (pivot, conf, p, r) => new ImperativeAggExprMeta[PivotFirst](pivot, conf, p, r) { @@ -2232,7 +2398,7 @@ object GpuOverrides extends Logging { "Count aggregate operator", ExprChecks.fullAgg( TypeSig.LONG, TypeSig.LONG, - repeatingParamCheck = Some(RepeatingParamCheck( + repeatingParamCheck = Some(new RepeatingParamCheck( "input", TypeSig.all, TypeSig.all))), (count, conf, p, r) => new AggExprMeta[Count](count, conf, p, r) { @@ -2249,12 +2415,12 @@ object GpuOverrides extends Logging { }), expr[Max]( "Max aggregate operator", - ExprChecksImpl( + new ExprChecksImpl( ExprChecks.reductionAndGroupByAgg( (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested(), TypeSig.orderable, - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested(), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts @@ -2262,7 +2428,7 @@ object GpuOverrides extends Logging { ExprChecks.windowOnly( (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), TypeSig.orderable, - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), (max, conf, p, r) => new AggExprMeta[Max](max, conf, p, r) { @@ -2274,12 +2440,12 @@ object GpuOverrides extends Logging { }), expr[Min]( "Min aggregate operator", - ExprChecksImpl( + new ExprChecksImpl( ExprChecks.reductionAndGroupByAgg( (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested(), TypeSig.orderable, - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested(), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts @@ -2287,7 +2453,7 @@ object GpuOverrides extends Logging { ExprChecks.windowOnly( (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), TypeSig.orderable, - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), (a, conf, p, r) => new AggExprMeta[Min](a, conf, p, r) { @@ -2302,7 +2468,7 @@ object GpuOverrides extends Logging { ExprChecks.fullAgg( TypeSig.LONG + TypeSig.DOUBLE + TypeSig.DECIMAL_128, TypeSig.LONG + TypeSig.DOUBLE + TypeSig.DECIMAL_128, - Seq(ParamCheck("input", TypeSig.gpuNumeric, TypeSig.cpuNumeric))), + Seq(new ParamCheck("input", TypeSig.gpuNumeric, TypeSig.cpuNumeric))), (a, conf, p, r) => new AggExprMeta[Sum](a, conf, p, r) { override def tagAggForGpu(): Unit = { val inputDataType = a.child.dataType @@ -2325,11 +2491,11 @@ object GpuOverrides extends Logging { (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all, - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all), - ParamCheck("offset", TypeSig.lit(TypeEnum.INT), TypeSig.lit(TypeEnum.INT))) + new ParamCheck("offset", TypeSig.lit(TypeEnum.INT), TypeSig.lit(TypeEnum.INT))) ), (a, conf, p, r) => new AggExprMeta[NthValue](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = @@ -2344,7 +2510,7 @@ object GpuOverrides extends Logging { (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all, - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all)) @@ -2362,7 +2528,7 @@ object GpuOverrides extends Logging { (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all, - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all)) @@ -2383,10 +2549,10 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all, Seq( - ParamCheck("value", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + new ParamCheck("value", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all), - ParamCheck("ordering", (TypeSig.commonCudfTypes - TypeSig.fp + TypeSig.DECIMAL_128 + + new ParamCheck("ordering", (TypeSig.commonCudfTypes - TypeSig.fp + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested( TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY), @@ -2412,10 +2578,10 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all, Seq( - ParamCheck("value", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + new ParamCheck("value", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), TypeSig.all), - ParamCheck("ordering", (TypeSig.commonCudfTypes - TypeSig.fp + TypeSig.DECIMAL_128 + + new ParamCheck("ordering", (TypeSig.commonCudfTypes - TypeSig.fp + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested( TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY), @@ -2483,7 +2649,7 @@ object GpuOverrides extends Logging { // plugin is also an union of all the types of Pandas UDF. (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested() + TypeSig.STRUCT, TypeSig.unionOfPandasUdfOut, - repeatingParamCheck = Some(RepeatingParamCheck( + repeatingParamCheck = Some(new RepeatingParamCheck( "param", (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all))), @@ -2501,7 +2667,7 @@ object GpuOverrides extends Logging { expr[Rand]( "Generate a random column with i.i.d. uniformly distributed values in [0, 1)", ExprChecks.projectOnly(TypeSig.DOUBLE, TypeSig.DOUBLE, - Seq(ParamCheck("seed", + Seq(new ParamCheck("seed", (TypeSig.INT + TypeSig.LONG).withAllLit(), (TypeSig.INT + TypeSig.LONG).withAllLit()))), (a, conf, p, r) => new UnaryExprMeta[Rand](a, conf, p, r) { @@ -2578,9 +2744,9 @@ object GpuOverrides extends Logging { expr[StringLPad]( "Pad a string on the left", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("len", TypeSig.lit(TypeEnum.INT), TypeSig.INT), - ParamCheck("pad", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("len", TypeSig.lit(TypeEnum.INT), TypeSig.INT), + new ParamCheck("pad", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (in, conf, p, r) => new TernaryExprMeta[StringLPad](in, conf, p, r) { override def tagExprForGpu(): Unit = { extractLit(in.pad).foreach { padLit => @@ -2599,9 +2765,9 @@ object GpuOverrides extends Logging { expr[StringRPad]( "Pad a string on the right", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("len", TypeSig.lit(TypeEnum.INT), TypeSig.INT), - ParamCheck("pad", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("len", TypeSig.lit(TypeEnum.INT), TypeSig.INT), + new ParamCheck("pad", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (in, conf, p, r) => new TernaryExprMeta[StringRPad](in, conf, p, r) { override def tagExprForGpu(): Unit = { extractLit(in.pad).foreach { padLit => @@ -2622,11 +2788,11 @@ object GpuOverrides extends Logging { // Java's split API produces different behaviors than cudf when splitting with empty pattern ExprChecks.projectOnly(TypeSig.ARRAY.nested(TypeSig.STRING), TypeSig.ARRAY.nested(TypeSig.STRING), - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING) + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING) .withPsNote(TypeEnum.STRING, "very limited subset of regex supported"), TypeSig.STRING), - ParamCheck("limit", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), + new ParamCheck("limit", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), (in, conf, p, r) => new GpuStringSplitMeta(in, conf, p, r)), expr[GetStructField]( "Gets the named field of the struct", @@ -2721,7 +2887,7 @@ object GpuOverrides extends Logging { (in, conf, p, r) => new UnaryExprMeta[MapFromEntries](in, conf, p, r) { override def tagExprForGpu(): Unit = { // Spark 4.1+ returns an enum value instead of String, so use toString first - SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY).toString.toUpperCase match { + confValueToString(SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY)).toUpperCase match { case "EXCEPTION" | "LAST_WIN" => // Good we can support this case other => willNotWorkOnGpu(s"$other is not supported for config setting" + @@ -2735,9 +2901,15 @@ object GpuOverrides extends Logging { "Creates a map after splitting the input string into pairs of key-value strings", // Java's split API produces different behaviors than cudf when splitting with empty pattern ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING), - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("pairDelim", TypeSig.lit(TypeEnum.STRING), TypeSig.lit(TypeEnum.STRING)), - ParamCheck("keyValueDelim", TypeSig.lit(TypeEnum.STRING), TypeSig.lit(TypeEnum.STRING)))), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck( + "pairDelim", + TypeSig.lit(TypeEnum.STRING), + TypeSig.lit(TypeEnum.STRING)), + new ParamCheck( + "keyValueDelim", + TypeSig.lit(TypeEnum.STRING), + TypeSig.lit(TypeEnum.STRING)))), (in, conf, p, r) => new GpuStringToMapMeta(in, conf, p, r)), expr[ArrayMin]( "Returns the minimum value in the array", @@ -2830,7 +3002,7 @@ object GpuOverrides extends Logging { TypeSig.NULL + TypeSig.STRING + TypeSig.BOOLEAN + TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.BINARY), TypeSig.ARRAY.nested(TypeSig.all), - repeatingParamCheck = Some(RepeatingParamCheck("arg", + repeatingParamCheck = Some(new RepeatingParamCheck("arg", TypeSig.gpuNumeric + TypeSig.NULL + TypeSig.STRING + TypeSig.BOOLEAN + TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.STRUCT + TypeSig.BINARY + TypeSig.ARRAY.nested(TypeSig.gpuNumeric + TypeSig.NULL + TypeSig.STRING + @@ -2858,7 +3030,7 @@ object GpuOverrides extends Logging { TypeSig.ARRAY.nested(TypeSig.orderable), TypeSig.ARRAY.nested(TypeSig.orderable), TypeSig.ARRAY.nested(TypeSig.orderable)), - GpuArrayDistinctMeta), + (expr, conf, p, r) => new GpuArrayDistinctMeta(expr, conf, p, r)), expr[Flatten]( "Creates a single array from an array of arrays", ExprChecks.unaryProject( @@ -2876,11 +3048,11 @@ object GpuOverrides extends Logging { (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all, - Seq(ParamCheck("function", + Seq(new ParamCheck("function", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all)), - Some(RepeatingParamCheck("arguments", + Some(new RepeatingParamCheck("arguments", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all))), @@ -2912,11 +3084,11 @@ object GpuOverrides extends Logging { TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all), Seq( - ParamCheck("argument", + new ParamCheck("argument", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all)), - ParamCheck("function", + new ParamCheck("function", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all))), @@ -2929,11 +3101,11 @@ object GpuOverrides extends Logging { "Return true if any element satisfies the predicate LambdaFunction", ExprChecks.projectOnly(TypeSig.BOOLEAN, TypeSig.BOOLEAN, Seq( - ParamCheck("argument", + new ParamCheck("argument", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all)), - ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), + new ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), (in, conf, p, r) => new ExprMeta[ArrayExists](in, conf, p, r) { override def convertToGpuImpl(): GpuExpression = { GpuArrayExists( @@ -2950,11 +3122,11 @@ object GpuOverrides extends Logging { TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all), Seq( - ParamCheck("argument", + new ParamCheck("argument", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all)), - ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), + new ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), (in, conf, p, r) => new ExprMeta[ArrayFilter](in, conf, p, r) { override def convertToGpuImpl(): GpuExpression = { GpuArrayFilter( @@ -2978,17 +3150,17 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, TypeSig.all, Seq( - ParamCheck("argument", + new ParamCheck("argument", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.STRUCT), TypeSig.ARRAY.nested(TypeSig.all)), - ParamCheck("zero", + new ParamCheck("zero", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, TypeSig.all), - ParamCheck("merge", + new ParamCheck("merge", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, TypeSig.all), - ParamCheck("finish", + new ParamCheck("finish", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, TypeSig.all))), (in, conf, p, r) => new GpuArrayAggregateMeta(in, conf, p, r)), @@ -3000,7 +3172,7 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all), - repeatingParamCheck = Some(RepeatingParamCheck("children", + repeatingParamCheck = Some(new RepeatingParamCheck("children", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all)))), @@ -3131,7 +3303,7 @@ object GpuOverrides extends Logging { TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all))), - GpuMapFromArraysMeta + (expr, conf, p, r) => new GpuMapFromArraysMeta(expr, conf, p, r) ), expr[TransformKeys]( "Transform keys in a map using a transform function", @@ -3139,18 +3311,18 @@ object GpuOverrides extends Logging { TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all), Seq( - ParamCheck("argument", + new ParamCheck("argument", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)), - ParamCheck("function", + new ParamCheck("function", // We need to be able to check for duplicate keys (equality) TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.all - TypeSig.MAP.nested()))), (in, conf, p, r) => new ExprMeta[TransformKeys](in, conf, p, r) { override def tagExprForGpu(): Unit = { // Spark 4.1+ returns an enum value instead of String, so use toString first - SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY).toString.toUpperCase match { + confValueToString(SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY)).toUpperCase match { case "EXCEPTION"| "LAST_WIN" => // Good we can support this case other => willNotWorkOnGpu(s"$other is not supported for config setting" + @@ -3167,11 +3339,11 @@ object GpuOverrides extends Logging { TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all), Seq( - ParamCheck("argument", + new ParamCheck("argument", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)), - ParamCheck("function", + new ParamCheck("function", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all))), @@ -3186,15 +3358,15 @@ object GpuOverrides extends Logging { TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all), Seq( - ParamCheck("argument1", + new ParamCheck("argument1", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)), - ParamCheck("argument2", + new ParamCheck("argument2", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)), - ParamCheck("function", + new ParamCheck("function", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all))), @@ -3210,11 +3382,11 @@ object GpuOverrides extends Logging { TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all), Seq( - ParamCheck("argument", + new ParamCheck("argument", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)), - ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), + new ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), (in, conf, p, r) => new ExprMeta[MapFilter](in, conf, p, r) { override def convertToGpuImpl(): GpuExpression = { GpuMapFilter(childExprs.head.convertToGpu(), childExprs(1).convertToGpu()) @@ -3223,9 +3395,9 @@ object GpuOverrides extends Logging { expr[StringLocate]( "Substring search operator", ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT, - Seq(ParamCheck("substr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), - ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("start", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), + Seq(new ParamCheck("substr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("start", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), (in, conf, p, r) => new TernaryExprMeta[StringLocate](in, conf, p, r) { override def convertToGpu( val0: Expression, @@ -3236,8 +3408,8 @@ object GpuOverrides extends Logging { expr[StringInstr]( "Instr string operator", ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT, - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("substr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("substr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (in, conf, p, r) => new BinaryExprMeta[StringInstr](in, conf, p, r) { override def convertToGpu( str: Expression, @@ -3247,9 +3419,9 @@ object GpuOverrides extends Logging { expr[Substring]( "Substring operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY, - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY), - ParamCheck("pos", TypeSig.INT, TypeSig.INT), - ParamCheck("len", TypeSig.INT, TypeSig.INT))), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY), + new ParamCheck("pos", TypeSig.INT, TypeSig.INT), + new ParamCheck("len", TypeSig.INT, TypeSig.INT))), (in, conf, p, r) => new TernaryExprMeta[Substring](in, conf, p, r) { override def convertToGpu( column: Expression, @@ -3260,16 +3432,16 @@ object GpuOverrides extends Logging { expr[SubstringIndex]( "substring_index operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("delim", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), - ParamCheck("count", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("delim", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + new ParamCheck("count", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), (in, conf, p, r) => new SubstringIndexMeta(in, conf, p, r)), expr[StringRepeat]( "StringRepeat operator that repeats the given strings with numbers of times " + "given by repeatTimes", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("input", TypeSig.STRING, TypeSig.STRING), - ParamCheck("repeatTimes", TypeSig.INT, TypeSig.INT))), + Seq(new ParamCheck("input", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("repeatTimes", TypeSig.INT, TypeSig.INT))), (in, conf, p, r) => new BinaryExprMeta[StringRepeat](in, conf, p, r) { override def convertToGpu( input: Expression, @@ -3278,9 +3450,9 @@ object GpuOverrides extends Logging { expr[StringReplace]( "StringReplace operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("src", TypeSig.STRING, TypeSig.STRING), - ParamCheck("search", TypeSig.STRING, TypeSig.STRING), - ParamCheck("replace", TypeSig.STRING, TypeSig.STRING))), + Seq(new ParamCheck("src", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("search", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("replace", TypeSig.STRING, TypeSig.STRING))), (in, conf, p, r) => new TernaryExprMeta[StringReplace](in, conf, p, r) { override def convertToGpu( column: Expression, @@ -3291,9 +3463,9 @@ object GpuOverrides extends Logging { expr[StringTrim]( "StringTrim operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), + Seq(new ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), // Should really be an OptionalParam - Some(RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + Some(new RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (in, conf, p, r) => new String2TrimExpressionMeta[StringTrim](in, conf, p, r) { override def convertToGpu( column: Expression, @@ -3303,9 +3475,9 @@ object GpuOverrides extends Logging { expr[StringTrimLeft]( "StringTrimLeft operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), + Seq(new ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), // Should really be an OptionalParam - Some(RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + Some(new RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (in, conf, p, r) => new String2TrimExpressionMeta[StringTrimLeft](in, conf, p, r) { override def convertToGpu( @@ -3316,9 +3488,9 @@ object GpuOverrides extends Logging { expr[StringTrimRight]( "StringTrimRight operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), + Seq(new ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), // Should really be an OptionalParam - Some(RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + Some(new RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (in, conf, p, r) => new String2TrimExpressionMeta[StringTrimRight](in, conf, p, r) { override def convertToGpu( @@ -3329,9 +3501,9 @@ object GpuOverrides extends Logging { expr[StringTranslate]( "StringTranslate operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("input", TypeSig.STRING, TypeSig.STRING), - ParamCheck("from", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), - ParamCheck("to", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + Seq(new ParamCheck("input", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("from", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + new ParamCheck("to", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (in, conf, p, r) => new TernaryExprMeta[StringTranslate](in, conf, p, r) { override def convertToGpu( input: Expression, @@ -3365,7 +3537,7 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.BINARY), (TypeSig.STRING + TypeSig.BINARY + TypeSig.ARRAY).nested(TypeSig.all), - repeatingParamCheck = Some(RepeatingParamCheck("input", + repeatingParamCheck = Some(new RepeatingParamCheck("input", (TypeSig.STRING + TypeSig.ARRAY).nested( TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.BINARY), @@ -3378,15 +3550,15 @@ object GpuOverrides extends Logging { pluginChecks = ExprChecks.projectOnly( outputCheck = TypeSig.STRING, paramCheck = Seq( - ParamCheck( + new ParamCheck( name = "num", cudf = TypeSig.STRING, spark = TypeSig.STRING), - ParamCheck( + new ParamCheck( name = "from_base", cudf = TypeSig.INT, spark = TypeSig.INT), - ParamCheck( + new ParamCheck( name = "to_base", cudf = TypeSig.INT, spark = TypeSig.INT)), @@ -3419,7 +3591,7 @@ object GpuOverrides extends Logging { ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all), - repeatingParamCheck = Some(RepeatingParamCheck("input", + repeatingParamCheck = Some(new RepeatingParamCheck("input", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)))), @@ -3433,12 +3605,12 @@ object GpuOverrides extends Logging { TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all), Seq( - ParamCheck("x", + new ParamCheck("x", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all)), - ParamCheck("start", TypeSig.INT, TypeSig.INT), - ParamCheck("length", TypeSig.INT, TypeSig.INT))), + new ParamCheck("start", TypeSig.INT, TypeSig.INT), + new ParamCheck("length", TypeSig.INT, TypeSig.INT))), (in, conf, p, r) => new TernaryExprMeta[Slice](in, conf, p, r) { override def convertToGpu( x: Expression, @@ -3451,13 +3623,13 @@ object GpuOverrides extends Logging { "string to replace nulls. If no value is set for nullReplacement, any null value " + "is filtered.", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("array", + Seq(new ParamCheck("array", TypeSig.ARRAY.nested(TypeSig.STRING), TypeSig.ARRAY.nested(TypeSig.STRING)), - ParamCheck("delimiter", + new ParamCheck("delimiter", TypeSig.STRING, TypeSig.STRING)), - repeatingParamCheck = Some(RepeatingParamCheck("nullReplacement", + repeatingParamCheck = Some(new RepeatingParamCheck("nullReplacement", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (a, conf, p, r) => new ExprMeta[ArrayJoin](a, conf, p, r) { @@ -3475,7 +3647,7 @@ object GpuOverrides extends Logging { "Concatenates multiple input strings or array of strings into a single " + "string using a given separator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - repeatingParamCheck = Some(RepeatingParamCheck("input", + repeatingParamCheck = Some(new RepeatingParamCheck("input", (TypeSig.STRING + TypeSig.ARRAY).nested(TypeSig.STRING), (TypeSig.STRING + TypeSig.ARRAY).nested(TypeSig.STRING)))), (a, conf, p, r) => new ExprMeta[ConcatWs](a, conf, p, r) { @@ -3493,15 +3665,15 @@ object GpuOverrides extends Logging { expr[Murmur3Hash]( "Murmur3 hash operator", HashExprChecks.murmur3ProjectChecks, - Murmur3HashExprMeta.apply), + ((expr, conf, parent, rule) => new Murmur3HashExprMeta(expr, conf, parent, rule))), expr[XxHash64]( "xxhash64 hash operator", HashExprChecks.xxhash64ProjectChecks, - XxHash64ExprMeta.apply), + ((expr, conf, parent, rule) => new XxHash64ExprMeta(expr, conf, parent, rule))), expr[HiveHash]( "hive hash operator", ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT, - repeatingParamCheck = Some(RepeatingParamCheck("input", + repeatingParamCheck = Some(new RepeatingParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested() + TypeSig.psNote(TypeEnum.ARRAY, "The nesting depth has a certain limit") + TypeSig.psNote(TypeEnum.STRUCT, "The nesting depth has a certain limit"), @@ -3582,38 +3754,38 @@ object GpuOverrides extends Logging { expr[RegExpReplace]( "String replace using a regular expression pattern", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), - ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), - ParamCheck("pos", TypeSig.lit(TypeEnum.INT) + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + new ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + new ParamCheck("pos", TypeSig.lit(TypeEnum.INT) .withPsNote(TypeEnum.INT, "only a value of 1 is supported"), TypeSig.lit(TypeEnum.INT)))), (a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)), expr[RegExpExtract]( "Extract a specific group identified by a regular expression", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), - ParamCheck("idx", TypeSig.lit(TypeEnum.INT), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + new ParamCheck("idx", TypeSig.lit(TypeEnum.INT), TypeSig.lit(TypeEnum.INT)))), (a, conf, p, r) => new GpuRegExpExtractMeta(a, conf, p, r)), expr[RegExpExtractAll]( "Extract all strings matching a regular expression corresponding to the regex group index", ExprChecks.projectOnly(TypeSig.ARRAY.nested(TypeSig.STRING), TypeSig.ARRAY.nested(TypeSig.STRING), - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), - ParamCheck("idx", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + new ParamCheck("idx", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), (a, conf, p, r) => new GpuRegExpExtractAllMeta(a, conf, p, r)), expr[ParseUrl]( "Extracts a part from a URL", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("url", TypeSig.STRING, TypeSig.STRING), - ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING).withPsNote( + Seq(new ParamCheck("url", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING).withPsNote( TypeEnum.STRING, "only support partToExtract = PROTOCOL | HOST | QUERY | PATH"), TypeSig.STRING)), // Should really be an OptionalParam - Some(RepeatingParamCheck("key", TypeSig.STRING, TypeSig.STRING))), + Some(new RepeatingParamCheck("key", TypeSig.STRING, TypeSig.STRING))), (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -3726,8 +3898,8 @@ object GpuOverrides extends Logging { TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all), - Seq(ParamCheck("n", TypeSig.lit(TypeEnum.INT), TypeSig.INT)), - Some(RepeatingParamCheck("expr", + Seq(new ParamCheck("n", TypeSig.lit(TypeEnum.INT), TypeSig.INT)), + Some(new RepeatingParamCheck("expr", (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all))), @@ -3739,7 +3911,7 @@ object GpuOverrides extends Logging { TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT), TypeSig.ARRAY.nested(TypeSig.all), - repeatingParamCheck = Some(RepeatingParamCheck("input", + repeatingParamCheck = Some(new RepeatingParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all))), @@ -3755,7 +3927,7 @@ object GpuOverrides extends Logging { .withPsNote(TypeEnum.ARRAY, "window operations are disabled by default due " + "to extreme memory usage"), TypeSig.ARRAY.nested(TypeSig.all), - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.BINARY + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), TypeSig.all))), @@ -3796,7 +3968,7 @@ object GpuOverrides extends Logging { .withPsNote(TypeEnum.ARRAY, "window operations are disabled by default due " + "to extreme memory usage"), TypeSig.ARRAY.nested(TypeSig.all), - Seq(ParamCheck("input", + Seq(new ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + @@ -3835,7 +4007,7 @@ object GpuOverrides extends Logging { "Aggregation computing population standard deviation", ExprChecks.groupByOnly( TypeSig.DOUBLE, TypeSig.DOUBLE, - Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), + Seq(new ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), (a, conf, p, r) => new AggExprMeta[StddevPop](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { val legacyStatisticalAggregate = SQLConf.get.legacyStatisticalAggregate @@ -3846,7 +4018,7 @@ object GpuOverrides extends Logging { "Aggregation computing sample standard deviation", ExprChecks.fullAgg( TypeSig.DOUBLE, TypeSig.DOUBLE, - Seq(ParamCheck("input", TypeSig.DOUBLE, + Seq(new ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), (a, conf, p, r) => new AggExprMeta[StddevSamp](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { @@ -3858,7 +4030,7 @@ object GpuOverrides extends Logging { "Aggregation computing population variance", ExprChecks.groupByOnly( TypeSig.DOUBLE, TypeSig.DOUBLE, - Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), + Seq(new ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), (a, conf, p, r) => new AggExprMeta[VariancePop](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { val legacyStatisticalAggregate = SQLConf.get.legacyStatisticalAggregate @@ -3869,7 +4041,7 @@ object GpuOverrides extends Logging { "Aggregation computing sample variance", ExprChecks.groupByOnly( TypeSig.DOUBLE, TypeSig.DOUBLE, - Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), + Seq(new ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), (a, conf, p, r) => new AggExprMeta[VarianceSamp](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { val legacyStatisticalAggregate = SQLConf.get.legacyStatisticalAggregate @@ -3886,11 +4058,11 @@ object GpuOverrides extends Logging { Seq( // ANSI interval types are new in Spark 3.2.0 and are not yet supported by the // current GPU implementation. - ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.integral + TypeSig.fp), - ParamCheck("percentage", + new ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.integral + TypeSig.fp), + new ParamCheck("percentage", TypeSig.lit(TypeEnum.DOUBLE) + TypeSig.ARRAY.nested(TypeSig.lit(TypeEnum.DOUBLE)), TypeSig.DOUBLE + TypeSig.ARRAY.nested(TypeSig.DOUBLE)), - ParamCheck("frequency", + new ParamCheck("frequency", TypeSig.LONG + TypeSig.ARRAY.nested(TypeSig.LONG), TypeSig.LONG + TypeSig.ARRAY.nested(TypeSig.LONG)))), (c, conf, p, r) => new TypedImperativeAggExprMeta[Percentile](c, conf, p, r) { @@ -3953,13 +4125,13 @@ object GpuOverrides extends Logging { TypeSig.cpuNumeric + TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY.nested( TypeSig.cpuNumeric + TypeSig.DATE + TypeSig.TIMESTAMP), Seq( - ParamCheck("input", + new ParamCheck("input", TypeSig.gpuNumeric, TypeSig.cpuNumeric + TypeSig.DATE + TypeSig.TIMESTAMP), - ParamCheck("percentage", + new ParamCheck("percentage", TypeSig.DOUBLE + TypeSig.ARRAY.nested(TypeSig.DOUBLE), TypeSig.DOUBLE + TypeSig.ARRAY.nested(TypeSig.DOUBLE)), - ParamCheck("accuracy", TypeSig.INT, TypeSig.INT))), + new ParamCheck("accuracy", TypeSig.INT, TypeSig.INT))), (c, conf, p, r) => new TypedImperativeAggExprMeta[ApproximatePercentile](c, conf, p, r) { override def tagAggForGpu(): Unit = { @@ -4000,8 +4172,8 @@ object GpuOverrides extends Logging { expr[GetJsonObject]( "Extracts a json object from path", ExprChecks.projectOnly( - TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("json", TypeSig.STRING, TypeSig.STRING), - ParamCheck("path", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + TypeSig.STRING, TypeSig.STRING, Seq(new ParamCheck("json", TypeSig.STRING, TypeSig.STRING), + new ParamCheck("path", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (a, conf, p, r) => new GpuGetJsonObjectMeta(a, conf, p, r)), expr[JsonToStructs]( "Returns a struct value with the given `jsonStr` and `schema`", @@ -4011,10 +4183,10 @@ object GpuOverrides extends Logging { "MAP only supports keys and values that are of STRING type " + "and is only supported at the top level"), (TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY).nested(TypeSig.all), - Seq(ParamCheck("jsonStr", TypeSig.STRING, TypeSig.STRING))), + Seq(new ParamCheck("jsonStr", TypeSig.STRING, TypeSig.STRING))), (a, conf, p, r) => new UnaryExprMeta[JsonToStructs](a, conf, p, r) { def hasDuplicateFieldNames(dt: DataType): Boolean = - TrampolineUtil.dataTypeExistsRecursively(dt, { + dataTypeExistsRecursively(dt, { case st: StructType => val fn = st.fieldNames fn.length != fn.distinct.length @@ -4022,7 +4194,7 @@ object GpuOverrides extends Logging { }) def hasDateTimeType(dt: DataType): Boolean = - TrampolineUtil.dataTypeExistsRecursively(dt, t => + dataTypeExistsRecursively(dt, t => t.isInstanceOf[DateType] || t.isInstanceOf[TimestampType] ) @@ -4056,7 +4228,7 @@ object GpuOverrides extends Logging { ExprChecks.projectOnly( TypeSig.STRING, TypeSig.STRING, - Seq(ParamCheck("struct", + Seq(new ParamCheck("struct", (TypeSig.BOOLEAN + TypeSig.STRING + TypeSig.integral + TypeSig.FLOAT + TypeSig.DOUBLE + TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.DECIMAL_128 + @@ -4077,8 +4249,8 @@ object GpuOverrides extends Logging { ExprChecks.projectOnly( TypeSig.ARRAY.nested(TypeSig.STRUCT + TypeSig.STRING), TypeSig.ARRAY.nested(TypeSig.STRUCT + TypeSig.STRING), - Seq(ParamCheck("json", TypeSig.STRING, TypeSig.STRING)), - Some(RepeatingParamCheck("field", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + Seq(new ParamCheck("json", TypeSig.STRING, TypeSig.STRING)), + Some(new RepeatingParamCheck("field", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (a, conf, p, r) => new GeneratorExprMeta[JsonTuple](a, conf, p, r) { override def tagExprForGpu(): Unit = { if (childExprs.length >= 50) { @@ -4116,11 +4288,14 @@ object GpuOverrides extends Logging { ExprChecks.projectOnly( TypeSig.ARRAY.nested(TypeSig.integral), TypeSig.ARRAY.nested(TypeSig.integral + TypeSig.TIMESTAMP + TypeSig.DATE), - Seq(ParamCheck("start", TypeSig.integral, TypeSig.integral + TypeSig.TIMESTAMP + + Seq(new ParamCheck("start", TypeSig.integral, TypeSig.integral + TypeSig.TIMESTAMP + TypeSig.DATE), - ParamCheck("stop", TypeSig.integral, TypeSig.integral + TypeSig.TIMESTAMP + + new ParamCheck("stop", TypeSig.integral, TypeSig.integral + TypeSig.TIMESTAMP + TypeSig.DATE)), - Some(RepeatingParamCheck("step", TypeSig.integral, TypeSig.integral + TypeSig.CALENDAR))), + Some(new RepeatingParamCheck( + "step", + TypeSig.integral, + TypeSig.integral + TypeSig.CALENDAR))), (a, conf, p, r) => new GpuSequenceMeta(a, conf, p, r) ), expr[BitLength]( @@ -4174,7 +4349,7 @@ object GpuOverrides extends Logging { ExprChecks.reductionAndGroupByAgg(TypeSig.LONG, TypeSig.LONG, // HyperLogLogPlusPlus depends on Xxhash64 // HyperLogLogPlusPlus supports all the types that Xxhash 64 supports - Seq(ParamCheck("input",XxHash64Shims.supportedTypes, TypeSig.all))), + Seq(new ParamCheck("input",XxHash64Shims.supportedTypes, TypeSig.all))), (a, conf, p, r) => new UnaryExprMeta[HyperLogLogPlusPlus](a, conf, p, r) { // It's the same as Xxhash64 @@ -4218,16 +4393,16 @@ object GpuOverrides extends Logging { StaticInvokeCheck, (a, conf, p, r) => new StaticInvokeMeta(a, conf, p, r) ).note("The supported types are not deterministic since it's a dynamic expression"), - SparkShimImpl.ansiCastRule + shimExprRule("SparkShimImpl", "ansiCastRule") ).collect { case r if r != null => (r.getClassFor.asSubclass(classOf[Expression]), r)}.toMap // Shim expressions should be last to allow overrides with shim-specific versions val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = commonExpressions ++ TimeStamp.getExprs ++ GpuHiveOverrides.exprs ++ ZOrderRules.exprs ++ DecimalArithmeticOverrides.exprs ++ - BloomFilterShims.exprs ++ StringDecodeShims.exprs ++ - InSubqueryShims.exprs ++ RaiseErrorShim.exprs ++ - ExternalSource.exprRules ++ SparkShimImpl.getExprs + BloomFilterShims.exprs ++ shimExprs("StringDecodeShims") ++ + shimExprs("InSubqueryShims") ++ shimExprs("RaiseErrorShim") ++ + ExternalSource.exprRules ++ shimExprRules("SparkShimImpl", "getExprs") def wrapScan[INPUT <: Scan]( scan: INPUT, @@ -4276,7 +4451,7 @@ object GpuOverrides extends Logging { })).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap val scans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = - commonScans ++ SparkShimImpl.getScans ++ ExternalSource.getScans + commonScans ++ shimScanRules("SparkShimImpl", "getScans") ++ ExternalSource.getScans def wrapPart[INPUT <: Partitioning]( part: INPUT, @@ -4290,7 +4465,7 @@ object GpuOverrides extends Logging { part[HashPartitioning]( "Hash based partitioning", // This needs to match what murmur3 supports. - PartChecks(RepeatingParamCheck("hash_key", + PartChecks(new RepeatingParamCheck("hash_key", (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.STRUCT + TypeSig.ARRAY).nested() + TypeSig.psNote(TypeEnum.ARRAY, "Arrays of structs are not supported"), @@ -4314,7 +4489,7 @@ object GpuOverrides extends Logging { } case Murmur3Mode => val arrayWithStructsHashing = hp.expressions.exists(e => - TrampolineUtil.dataTypeExistsRecursively(e.dataType, + dataTypeExistsRecursively(e.dataType, { case ArrayType(_: StructType, _) => true case _ => false @@ -4334,7 +4509,7 @@ object GpuOverrides extends Logging { }), part[RangePartitioning]( "Range partitioning", - PartChecks(RepeatingParamCheck("order_key", + PartChecks(new RepeatingParamCheck("order_key", pluginSupportedOrderableSig + TypeSig.ARRAY.nested(gpuCommonTypes) .withPsNote(TypeEnum.ARRAY, "STRUCT is not supported as a child type for ARRAY"), TypeSig.orderable)), @@ -4368,7 +4543,7 @@ object GpuOverrides extends Logging { ).map(r => (r.getClassFor.asSubclass(classOf[Partitioning]), r)).toMap val parts : Map[Class[_ <: Partitioning], PartRule[_ <: Partitioning]] = - commonParts ++ SparkShimImpl.getPartitionings + commonParts ++ shimPartRules("SparkShimImpl", "getPartitionings") def wrapDataWriteCmds[INPUT <: DataWritingCommand]( writeCmd: INPUT, @@ -4387,7 +4562,8 @@ object GpuOverrides extends Logging { val dataWriteCmds: Map[Class[_ <: DataWritingCommand], DataWritingCommandRule[_ <: DataWritingCommand]] = - commonDataWriteCmds ++ GpuHiveOverrides.dataWriteCmds ++ SparkShimImpl.getDataWriteCmds + commonDataWriteCmds ++ GpuHiveOverrides.dataWriteCmds ++ + shimDataWriteCmdRules("SparkShimImpl", "getDataWriteCmds") def runnableCmd[INPUT <: RunnableCommand]( desc: String, @@ -4399,6 +4575,14 @@ object GpuOverrides extends Logging { new RunnableCommandRule[INPUT](doWrap, desc, tag) } + def runnableCmdFromShim[INPUT <: RunnableCommand]( + rule: ShimRunnableCommandRule[INPUT], + doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) + => RunnableCommandMeta[INPUT]): RunnableCommandRule[INPUT] = { + require(rule != null) + runnableCmd(rule.desc, doWrap)(rule.tag) + } + def wrapRunnableCmd[INPUT <: RunnableCommand]( cmd: INPUT, conf: RapidsConf, @@ -4418,7 +4602,7 @@ object GpuOverrides extends Logging { val runnableCmds = commonRunnableCmds ++ GpuHiveOverrides.runnableCmds ++ ExternalSource.runnableCmds ++ - SparkShimImpl.getRunnableCmds + shimRunnableCmdRules("SparkShimImpl", "getRunnableCmds") def wrapPlan[INPUT <: SparkPlan]( plan: INPUT, @@ -4460,7 +4644,7 @@ object GpuOverrides extends Logging { (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.BINARY).nested(), TypeSig.all), - (p, conf, parent, r) => new BatchScanExecMeta(p, conf, parent, r)), + (p, conf, parent, r) => newBatchScanExecMeta(p, conf, parent, r)), exec[CoalesceExec]( "The backend for the dataframe coalesce method", ExecChecks((gpuCommonTypes + TypeSig.STRUCT + TypeSig.ARRAY + @@ -4495,7 +4679,7 @@ object GpuOverrides extends Logging { // The types below are allowed as inputs and outputs. ExecChecks((pluginSupportedOrderableSig + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all), - GpuTakeOrderedAndProjectExecMeta), + (exec, conf, p, r) => new GpuTakeOrderedAndProjectExecMeta(exec, conf, p, r)), exec[LocalLimitExec]( "Per-partition limiting of results", ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + @@ -4530,7 +4714,7 @@ object GpuOverrides extends Logging { ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.BINARY + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), - GpuFilterExecMeta), + (exec, conf, p, r) => new GpuFilterExecMeta(exec, conf, p, r)), exec[ShuffleExchangeExec]( "The backend for most data being exchanged between processes", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY + @@ -4668,10 +4852,10 @@ object GpuOverrides extends Logging { TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY).nested(), TypeSig.all, Map("partitionSpec" -> - InputCheck( + new InputCheck( TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128), - TypeSig.all))), + TypeSig.all, List.empty))), (windowOp, conf, p, r) => new GpuWindowExecMeta(windowOp, conf, p, r) ), @@ -4685,11 +4869,11 @@ object GpuOverrides extends Logging { exec[SubqueryBroadcastExec]( "Plan to collect and transform the broadcast key values", ExecChecks(TypeSig.all, TypeSig.all), - (s, conf, p, r) => new GpuSubqueryBroadcastMeta(s, conf, p, r) + (s, conf, p, r) => newGpuSubqueryBroadcastMeta(s, conf, p, r) ), - SparkShimImpl.aqeShuffleReaderExec, + shimExecRule("SparkShimImpl", "aqeShuffleReaderExec"), // AggregateInPandasExec renamed to ArrowAggregatePythonExec in Spark 4.1.0 - AggregateInPandasExecShims.execRule.orNull, + aggregateInPandasExecRule, exec[ArrowEvalPythonExec]( "The backend of the Scalar Pandas UDFs. Accelerates the data transfer between the" + " Java process and the Python process. It also supports scheduling GPU resources" + @@ -4745,8 +4929,8 @@ object GpuOverrides extends Logging { neverReplaceExec[DescribeNamespaceExec]("Namespace metadata operation"), neverReplaceExec[DropNamespaceExec]("Namespace metadata operation"), neverReplaceExec[SetCatalogAndNamespaceExec]("Namespace metadata operation"), - SparkShimImpl.neverReplaceShowCurrentNamespaceCommand, - ShowNamespacesExecShims.neverReplaceExec.orNull, + shimExecRule("SparkShimImpl", "neverReplaceShowCurrentNamespaceCommand"), + optionalShimExecRule("ShowNamespacesExecShims", "neverReplaceExec"), neverReplaceExec[AlterTableExec]("Table metadata operation"), neverReplaceExec[CreateTableExec]("Table metadata operation"), neverReplaceExec[DeleteFromTableExec]("Table metadata operation"), @@ -4763,9 +4947,10 @@ object GpuOverrides extends Logging { neverReplaceExec[ShuffleQueryStageExec]("Shuffle query stage") ).collect { case r if r != null => (r.getClassFor.asSubclass(classOf[SparkPlan]), r) }.toMap + // Shim execs at the end; shims get the last word in substitutions. lazy val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = commonExecs ++ GpuHiveOverrides.execs ++ ExternalSource.execRules ++ - SparkShimImpl.getExecs // Shim execs at the end; shims get the last word in substitutions. + shimExecRules("SparkShimImpl", "getExecs") def getTimeParserPolicy: TimeParserPolicy = { val policy = SQLConf.get.getConfString(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "EXCEPTION") @@ -4947,7 +5132,17 @@ protected class ExplainPlanImpl extends ExplainPlanBase { } // work around any GpuOverride failures -object GpuOverrideUtil extends Logging { +object GpuOverrideUtil { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + log.warn(msg, throwable) + } + + private def logError(msg: => String, throwable: Throwable): Unit = { + log.error(msg, throwable) + } + def tryOverride(fn: SparkPlan => SparkPlan): SparkPlan => SparkPlan = { plan => val planOriginal = plan.clone() val failOnError = TEST_CONF.get(plan.conf) || !SUPPRESS_PLANNING_FAILURE.get(plan.conf) @@ -4965,7 +5160,7 @@ object GpuOverrideUtil extends Logging { } /** Tag the initial plan when AQE is enabled */ -case class GpuQueryStagePrepOverrides() extends Rule[SparkPlan] with Logging { +class GpuQueryStagePrepOverrides extends Rule[SparkPlan] with Serializable { override def apply(sparkPlan: SparkPlan): SparkPlan = GpuOverrideUtil.tryOverride { plan => // Exposing a bare exchange at the root is only valid while AQE is preparing a // query stage. Tag the exchanges seen in this rule so transition cleanup can @@ -4979,7 +5174,7 @@ case class GpuQueryStagePrepOverrides() extends Rule[SparkPlan] with Logging { }(sparkPlan) } -case class GpuOverrides() extends Rule[SparkPlan] with Logging { +case class GpuOverrides() extends Rule[SparkPlan] { // Spark calls this method once for the whole plan when AQE is off. When AQE is on, it // gets called once for each query stage (where a query stage is an `Exchange`). diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPostHocResolutionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPostHocResolutionOverrides.scala index eda79ca3e96..458dfa2acf1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPostHocResolutionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPostHocResolutionOverrides.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.rules.Rule * phase by `SparkSessionExtensions.injectPostHocResolutionRule`. As its name suggests, it will * be applied after the logical plan has been resolved. */ -case class GpuPostHocResolutionOverrides(spark: SparkSession) extends Rule[LogicalPlan] { +class GpuPostHocResolutionOverrides(val spark: SparkSession) + extends Rule[LogicalPlan] with Serializable { @transient private val rapidsConf = new RapidsConf(spark.sessionState.conf) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala index eae23b86dd5..b03d082f1c9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2021-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UserDefinedExpression} import org.apache.spark.sql.rapids.execution.TrampolineUtil @@ -91,7 +90,13 @@ object GpuUserDefinedFunction { * and do the processing on CPU. */ trait GpuRowBasedUserDefinedFunction extends GpuExpression - with ShimExpression with UserDefinedExpression with Serializable with Logging { + with ShimExpression with UserDefinedExpression with Serializable { + + @transient private lazy val log = org.slf4j.LoggerFactory.getLogger( + classOf[GpuRowBasedUserDefinedFunction]) + + private def logDebug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg) + /** name of the UDF function */ val name: String diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HashExprMetas.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HashExprMetas.scala index cd8bc87f032..dc59a8dd3c9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HashExprMetas.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HashExprMetas.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -83,14 +83,14 @@ object HashExprChecks { val murmur3ProjectChecks: ExprChecks = ExprChecks.projectOnly( TypeSig.INT, TypeSig.INT, - repeatingParamCheck = Some(RepeatingParamCheck( + repeatingParamCheck = Some(new RepeatingParamCheck( "input", murmur3InputTypes, TypeSig.all))) val xxhash64ProjectChecks: ExprChecks = ExprChecks.projectOnly( TypeSig.LONG, TypeSig.LONG, - repeatingParamCheck = Some(RepeatingParamCheck( + repeatingParamCheck = Some(new RepeatingParamCheck( "input", XxHash64Shims.supportedTypes, TypeSig.all))) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala index f1561e2c251..675f355b92b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import java.nio.{ByteBuffer, ByteOrder} +import java.nio.{Buffer, ByteBuffer, ByteOrder} import scala.collection.mutable.ArrayBuffer @@ -25,7 +25,6 @@ import com.google.flatbuffers.FlatBufferBuilder import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format._ -import org.apache.spark.internal.Logging import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.ShuffleBlockBatchId @@ -117,9 +116,9 @@ object MetaUtils { packedMeta: ByteBuffer, numRows: Long): TableMeta = { val vectorBuffer = fbb.createUnintializedVector(1, packedMeta.remaining(), 1) - packedMeta.mark() + packedMeta.asInstanceOf[Buffer].mark() vectorBuffer.put(packedMeta) - packedMeta.reset() + packedMeta.asInstanceOf[Buffer].reset() val packedMetaOffset = fbb.endVector() TableMeta.startTableMeta(fbb) @@ -262,7 +261,7 @@ class DirectByteBufferFactory extends FlatBufferBuilder.ByteBufferFactory { } } -object ShuffleMetadata extends Logging{ +object ShuffleMetadata { val bbFactory = new DirectByteBufferFactory diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 7aad76138b9..15706589dd7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -33,6 +33,8 @@ import com.nvidia.spark.rapids.RapidsConf.AllowMultipleJars import com.nvidia.spark.rapids.RapidsPluginUtils.buildInfoEvent import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.filecache.{FileCache, FileCacheLocalityManager, FileCacheLocalityMsg} +import com.nvidia.spark.rapids.fileio.RapidsInputFiles +import com.nvidia.spark.rapids.fileio.hadoop.PerfIOS3Reader import com.nvidia.spark.rapids.io.async.TrafficController import com.nvidia.spark.rapids.jni.{GpuTimeZoneDB, Hash, JSONUtils, RmmSpark, TaskPriority} import com.nvidia.spark.rapids.python.PythonWorkerSemaphore @@ -40,7 +42,6 @@ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskContext, TaskFailedReason} import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} -import org.apache.spark.internal.Logging import org.apache.spark.rapids.hybrid.HybridExecutionUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.catalyst.rules.Rule @@ -51,9 +52,10 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil class PluginException(msg: String) extends RuntimeException(msg) -case class CudfVersionMismatchException(errorMsg: String) extends PluginException(errorMsg) +class CudfVersionMismatchException(val errorMsg: String) + extends PluginException(errorMsg) with Serializable -case class ColumnarOverrideRules() extends ColumnarRule with Logging { +class ColumnarOverrideRules extends ColumnarRule { lazy val overrides: Rule[SparkPlan] = GpuOverrides() lazy val overrideTransitions: Rule[SparkPlan] = new GpuTransitionOverrides() @@ -62,7 +64,33 @@ case class ColumnarOverrideRules() extends ColumnarRule with Logging { override def postColumnarTransitions: Rule[SparkPlan] = overrideTransitions } -object RapidsPluginUtils extends Logging { +object RapidsPluginUtils { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logDebug(msg: => String, throwable: Throwable): Unit = { + if (log.isDebugEnabled) { + log.debug(msg, throwable) + } + } + val CUDF_PROPS_FILENAME = "cudf-java-version-info.properties" val JNI_PROPS_FILENAME = "spark-rapids-jni-version-info.properties" val PLUGIN_PROPS_FILENAME = "rapids4spark-version-info.properties" @@ -83,7 +111,7 @@ object RapidsPluginUtils extends Logging { private val SPARK_MASTER = "spark.master" private val SPARK_RAPIDS_REPO_URL = "https://github.com/NVIDIA/spark-rapids" - lazy val buildInfoEvent = SparkRapidsBuildInfoEvent( + lazy val buildInfoEvent = new SparkRapidsBuildInfoEvent( sparkRapidsBuildInfo = loadProps(PLUGIN_PROPS_FILENAME), sparkRapidsJniBuildInfo = loadProps(JNI_PROPS_FILENAME), cudfBuildInfo = loadProps(CUDF_PROPS_FILENAME), @@ -441,12 +469,32 @@ object RapidsPluginUtils extends Logging { /** * The Spark driver plugin provided by the RAPIDS Spark plugin. */ -class RapidsDriverPlugin extends DriverPlugin with Logging { +class RapidsDriverPlugin extends DriverPlugin { var rapidsShuffleHeartbeatManager: RapidsShuffleHeartbeatManager = null var shuffleCleanupListener: ShuffleCleanupListener = null private lazy val extraDriverPlugins = RapidsPluginUtils.extraPlugins.map(_.driverPlugin()).filterNot(_ == null) + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + override def receive(msg: Any): AnyRef = { msg match { case m: FileCacheLocalityMsg => @@ -487,6 +535,7 @@ class RapidsDriverPlugin extends DriverPlugin with Logging { override def init( sc: SparkContext, pluginContext: PluginContext): java.util.Map[String, String] = { val sparkConf = pluginContext.conf + RapidsInputFiles.setS3PerfReader(PerfIOS3Reader.INSTANCE) RapidsPluginUtils.fixupConfigsOnDriver(sparkConf) val conf = new RapidsConf(sparkConf) RapidsPluginUtils.detectMultipleJars(conf) @@ -564,10 +613,10 @@ class RapidsDriverPlugin extends DriverPlugin with Logging { * We store the object in concurrent map where the key is the executor task thread. * It is `AutoCloseable`, so the caller must close it on task success or failure. */ -case class ActiveTaskMetrics( - stageId: Int, - taskAttemptId: Long, - attemptNumber: Int) extends AutoCloseable { +class ActiveTaskMetrics( + val stageId: Int, + val taskAttemptId: Long, + val attemptNumber: Int) extends AutoCloseable with Serializable { private var nvtx = new NvtxRange( s"Stage $stageId Task $taskAttemptId-$attemptNumber", NvtxColor.DARK_GREEN) private var closed = false @@ -586,12 +635,50 @@ case class ActiveTaskMetrics( /** * The Spark executor plugin provided by the RAPIDS Spark plugin. */ -class RapidsExecutorPlugin extends ExecutorPlugin with Logging { +class RapidsExecutorPlugin extends ExecutorPlugin { var rapidsShuffleHeartbeatEndpoint: RapidsShuffleHeartbeatEndpoint = null var shuffleCleanupEndpoint: ShuffleCleanupEndpoint = null private lazy val extraExecutorPlugins = RapidsPluginUtils.extraPlugins.map(_.executorPlugin()).filterNot(_ == null) + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + if (log.isWarnEnabled) { + log.warn(msg, throwable) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logError(msg: => String): Unit = { + if (log.isErrorEnabled) { + log.error(msg) + } + } + + private def logError(msg: => String, throwable: Throwable): Unit = { + if (log.isErrorEnabled) { + log.error(msg, throwable) + } + } + private val activeTaskInfo = new ConcurrentHashMap[Thread, ActiveTaskMetrics]() private var isAsyncProfilerEnabled = false @@ -602,6 +689,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { try { // if configured, re-register checking leaks hook. reRegisterCheckLeakHook() + RapidsInputFiles.setS3PerfReader(PerfIOS3Reader.INSTANCE) val sparkConf = pluginContext.conf() val numCores = RapidsPluginUtils.estimateCoresOnExec(sparkConf) @@ -738,16 +826,17 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { private def checkCudfVersion(conf: RapidsConf): Unit = { try { val expectedCudfVersion = buildInfoEvent.sparkRapidsBuildInfo.getOrElse("cudf_version", - throw CudfVersionMismatchException("Could not find cudf version in " + + throw new CudfVersionMismatchException("Could not find cudf version in " + RapidsPluginUtils.PLUGIN_PROPS_FILENAME)) val cudfVersion = buildInfoEvent.cudfBuildInfo.getOrElse("version", - throw CudfVersionMismatchException("Could not find cudf version in " + + throw new CudfVersionMismatchException("Could not find cudf version in " + RapidsPluginUtils.CUDF_PROPS_FILENAME)) // compare cudf version in the classpath with the cudf version expected by plugin if (!RapidsExecutorPlugin.cudfVersionSatisfied(expectedCudfVersion, cudfVersion)) { - throw CudfVersionMismatchException(s"Found cudf version $cudfVersion, RAPIDS Accelerator " + + throw new CudfVersionMismatchException( + s"Found cudf version $cudfVersion, RAPIDS Accelerator " + s"expects $expectedCudfVersion") } } catch { @@ -877,7 +966,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { val attemptNumber = taskCtx.attemptNumber() activeTaskInfo.put( Thread.currentThread(), - ActiveTaskMetrics(stageId, taskAttemptId, attemptNumber)) + new ActiveTaskMetrics(stageId, taskAttemptId, attemptNumber)) } private def endTaskNvtx(): Unit = { @@ -888,7 +977,27 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { } } -object RapidsExecutorPlugin extends Logging { +object RapidsExecutorPlugin { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + if (log.isWarnEnabled) { + log.warn(msg, throwable) + } + } + /** * Calling System.exit will trigger shutdown hooks to run. * This code is intended to let them run, but then force diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 7021d62b18c..3bd3a920f09 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -27,7 +27,6 @@ import com.nvidia.spark.rapids.jni.kudo.DumpOption import com.nvidia.spark.rapids.lore.{LoreId, OutputLoreId} import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.internal.SQLConf @@ -319,7 +318,19 @@ object RapidsReaderType extends Enumeration { val AUTO, COALESCING, MULTITHREADED, PERFILE = Value } -object RapidsConf extends Logging { +object RapidsConf { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logWarning(msg: => String): Unit = { + log.warn(msg) + } + val MULTITHREAD_READ_NUM_THREADS_DEFAULT = 20 private val registeredConfs = new ListBuffer[ConfEntry[_]]() @@ -3015,7 +3026,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .createWithDefault(-1) // default value for the OOM injection logic (no injection, for regular operation) - private val noInjection = OomInjectionConf( + private val noInjection = new OomInjectionConf( numOoms = 0, skipCount = 0, oomInjectionFilter = OomInjectionType.CPU_OR_GPU, @@ -3064,7 +3075,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. TEST_RETRY_OOM_INJECTION_MODE.get(SQLConf.get).toLowerCase match { case "false" => noInjection case "true" => - OomInjectionConf(numOoms = 1, skipCount = 0, + new OomInjectionConf(numOoms = 1, skipCount = 0, oomInjectionFilter = OomInjectionType.CPU_OR_GPU, withSplit = false) case injectConfStr => val injectConfMap = injectConfStr.split(',').map(_.split('=')).collect { @@ -3077,7 +3088,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .toUpperCase() val oomFilter = OomInjectionType.valueOf(oomFilterStr) val withSplit = injectConfMap.getOrElse("split", false.toString) - val ret = OomInjectionConf( + val ret = new OomInjectionConf( numOoms = numOoms.toInt, skipCount = skipCount.toInt, oomInjectionFilter = oomFilter, @@ -3269,15 +3280,19 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. val buildSideSelection = JoinBuildSideSelection.withName(buildSideStr) val logCardinality = LOG_JOIN_CARDINALITY.get(conf) val sizeEstimateThreshold = JOIN_GATHERER_SIZE_ESTIMATE_THRESHOLD.get(conf) - JoinOptions(strategy, buildSideSelection, targetSize, logCardinality, sizeEstimateThreshold) + new JoinOptions(strategy, buildSideSelection, targetSize, logCardinality, sizeEstimateThreshold) } } -class RapidsConf(conf: Map[String, String]) extends Logging { +class RapidsConf(conf: Map[String, String]) { import ConfHelper._ import RapidsConf._ + private def logWarning(msg: => String): Unit = { + RapidsConf.logWarning(msg) + } + def this(sqlConf: SQLConf) = { this(sqlConf.getAllConfs) } @@ -3370,7 +3385,7 @@ class RapidsConf(conf: Map[String, String]) extends Logging { val buildSideSelection = JoinBuildSideSelection.withName(buildSideStr) val logCardinality = get(LOG_JOIN_CARDINALITY) val sizeEstimateThreshold = get(JOIN_GATHERER_SIZE_ESTIMATE_THRESHOLD) - JoinOptions(strategy, buildSideSelection, targetSize, logCardinality, sizeEstimateThreshold) + new JoinOptions(strategy, buildSideSelection, targetSize, logCardinality, sizeEstimateThreshold) } lazy val sizedJoinPartitionAmplification: Double = get(SIZED_JOIN_PARTITION_AMPLIFICATION) @@ -4130,9 +4145,12 @@ class RapidsConf(conf: Map[String, String]) extends Logging { } } -case class OomInjectionConf( - numOoms: Int, - skipCount: Int, - withSplit: Boolean, - oomInjectionFilter: OomInjectionType -) +class OomInjectionConf( + val numOoms: Int, + val skipCount: Int, + val withSplit: Boolean, + val oomInjectionFilter: OomInjectionType) extends Serializable { + override def toString: String = + "OomInjectionConf(" + numOoms + "," + skipCount + "," + withSplit + "," + + oomInjectionFilter + ")" +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index bc27d2af657..a654ab16658 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import com.nvidia.spark.rapids.GpuTypedImperativeSupportedAggregateExecMeta.{preRowToColProjection, readBufferConverter} import com.nvidia.spark.rapids.RapidsMeta.noNeedToReplaceReason -import com.nvidia.spark.rapids.shims.{AggregateInPandasExecShims, DistributionUtil, SparkShimImpl} +import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, BoundReference, Cast, ComplexTypeMergingExpression, Expression, Literal, QuaternaryExpression, RuntimeReplaceable, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate} @@ -1029,15 +1029,49 @@ object ExpressionContext { case _ => None } + @transient private[this] lazy val sparkShimImplModule = { + Class.forName("com.nvidia.spark.rapids.shims.SparkShimImpl" + "$") + .getField("MODULE" + "$") + .get(null) + } + + @transient private[this] lazy val isWindowFunctionExecMethod = + sparkShimImplModule.getClass.getMethod("isWindowFunctionExec", classOf[SparkPlan]) + + @transient private[this] lazy val aggregateInPandasExecShimsModule = { + Class.forName("com.nvidia.spark.rapids.shims.AggregateInPandasExecShims" + "$") + .getField("MODULE" + "$") + .get(null) + } + + @transient private[this] lazy val isAggregateInPandasExecMethod = + aggregateInPandasExecShimsModule.getClass.getMethod("isAggregateInPandasExec", + classOf[SparkPlan]) + + @transient private[this] lazy val aggregateInPandasGroupingExpressionsMethod = + aggregateInPandasExecShimsModule.getClass.getMethod("getGroupingExpressions", + classOf[SparkPlan]) + + private def isWindowFunctionExec(plan: SparkPlan): Boolean = + isWindowFunctionExecMethod.invoke(sparkShimImplModule, plan).asInstanceOf[Boolean] + + private def isAggregateInPandasExec(plan: SparkPlan): Boolean = + isAggregateInPandasExecMethod.invoke(aggregateInPandasExecShimsModule, plan) + .asInstanceOf[Boolean] + + private def aggregateInPandasGroupingExpressions(plan: SparkPlan): Seq[_] = + aggregateInPandasGroupingExpressionsMethod.invoke(aggregateInPandasExecShimsModule, plan) + .asInstanceOf[Seq[_]] + def getAggregateFunctionContext(meta: BaseExprMeta[_]): ExpressionContext = { val parent = findParentPlanMeta(meta) assert(parent.isDefined, "It is expected that an aggregate function is a child of a SparkPlan") parent.get.wrapped match { - case agg: SparkPlan if SparkShimImpl.isWindowFunctionExec(agg) => + case agg: SparkPlan if isWindowFunctionExec(agg) => WindowAggExprContext // AggregateInPandasExec renamed to ArrowAggregatePythonExec in Spark 4.1.0 - case agg: SparkPlan if AggregateInPandasExecShims.isAggregateInPandasExec(agg) => - if (AggregateInPandasExecShims.getGroupingExpressions(agg).isEmpty) { + case agg: SparkPlan if isAggregateInPandasExec(agg) => + if (aggregateInPandasGroupingExpressions(agg).isEmpty) { ReductionAggExprContext } else { GroupByAggExprContext @@ -1435,7 +1469,7 @@ abstract class BaseExprMeta[INPUT <: Expression]( val inputMapping = scala.collection.mutable.Map[Int, Int]() gpuInputsWithIndex.foreach { case (gpuExpr, originalIndex) => - val exprWrapper = GpuExpressionEquals(gpuExpr) + val exprWrapper = new GpuExpressionEquals(gpuExpr) seenExpressions.get(exprWrapper) match { case Some(existingIndex) => // This expression is a duplicate - map to existing index diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 18055111ee4..13622d07c44 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -746,30 +746,33 @@ abstract class TypeChecks[RET] { /** * Checks a set of named inputs to an SparkPlan node against a TypeSig */ -case class InputCheck(cudf: TypeSig, spark: TypeSig, notes: List[String] = List.empty) +class InputCheck(val cudf: TypeSig, val spark: TypeSig, val notes: List[String]) + extends Serializable /** * Checks a single parameter by position against a TypeSig */ -case class ParamCheck(name: String, cudf: TypeSig, spark: TypeSig) +class ParamCheck(val name: String, val cudf: TypeSig, val spark: TypeSig) + extends Serializable /** * Checks the type signature for a parameter that repeats (Can only be used at the end of a list * of position parameters) */ -case class RepeatingParamCheck(name: String, cudf: TypeSig, spark: TypeSig) +class RepeatingParamCheck(val name: String, val cudf: TypeSig, val spark: TypeSig) + extends Serializable /** * Checks an expression that have input parameters and a single output. This is intended to be * given for a specific ExpressionContext. If your expression does not meet this pattern you may * need to create a custom ExprChecks instance. */ -case class ContextChecks( - outputCheck: TypeSig, - sparkOutputSig: TypeSig, - paramCheck: Seq[ParamCheck] = Seq.empty, - repeatingParamCheck: Option[RepeatingParamCheck] = None) - extends TypeChecks[Map[String, SupportLevel]] { +class ContextChecks( + val outputCheck: TypeSig, + val sparkOutputSig: TypeSig, + val paramCheck: Seq[ParamCheck], + val repeatingParamCheck: Option[RepeatingParamCheck]) + extends TypeChecks[Map[String, SupportLevel]] with Serializable { def tagAst(exprMeta: BaseExprMeta[_]): Unit = { tagBase(exprMeta, exprMeta.willNotWorkInAst) @@ -965,10 +968,10 @@ object ExecChecks { */ abstract class PartChecks extends TypeChecks[Map[String, SupportLevel]] -case class PartChecksImpl( - paramCheck: Seq[ParamCheck] = Seq.empty, - repeatingParamCheck: Option[RepeatingParamCheck] = None) - extends PartChecks { +class PartChecksImpl( + val paramCheck: Seq[ParamCheck], + val repeatingParamCheck: Option[RepeatingParamCheck]) + extends PartChecks with Serializable { override def tag(meta: RapidsMeta[_, _, _]): Unit = { val part = meta.wrapped @@ -1005,9 +1008,9 @@ case class PartChecksImpl( object PartChecks { def apply(repeatingParamCheck: RepeatingParamCheck): PartChecks = - PartChecksImpl(Seq.empty, Some(repeatingParamCheck)) + new PartChecksImpl(Seq.empty, Some(repeatingParamCheck)) - def apply(): PartChecks = PartChecksImpl() + def apply(): PartChecks = new PartChecksImpl(Seq.empty, None) } /** @@ -1020,7 +1023,7 @@ abstract class ExprChecks extends TypeChecks[Map[ExpressionContext, Map[String, def tagAst(meta: BaseExprMeta[_]): Unit } -case class ExprChecksImpl(contexts: Map[ExpressionContext, ContextChecks]) +class ExprChecksImpl(val contexts: Map[ExpressionContext, ContextChecks]) extends ExprChecks { override def tag(meta: RapidsMeta[_, _, _]): Unit = { val exprMeta = meta.asInstanceOf[BaseExprMeta[_]] @@ -1499,9 +1502,9 @@ object ExprChecks { sparkOutputSig: TypeSig, paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (ProjectExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) /** * A check for an expression that supports project and as much of AST as it can. @@ -1514,16 +1517,16 @@ object ExprChecks { repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { val astOutputCheck = outputCheck.intersect(allowedAstTypes) val astParamCheck = paramCheck.map { pc => - ParamCheck(pc.name, pc.cudf.intersect(allowedAstTypes), pc.spark) + new ParamCheck(pc.name, pc.cudf.intersect(allowedAstTypes), pc.spark) } val astRepeatingParamCheck = repeatingParamCheck.map { rpc => - RepeatingParamCheck(rpc.name, rpc.cudf.intersect(allowedAstTypes), rpc.spark) + new RepeatingParamCheck(rpc.name, rpc.cudf.intersect(allowedAstTypes), rpc.spark) } - ExprChecksImpl(Map( + new ExprChecksImpl(Map( ProjectExprContext -> - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck), AstExprContext -> - ContextChecks(astOutputCheck, sparkOutputSig, astParamCheck, astRepeatingParamCheck) + new ContextChecks(astOutputCheck, sparkOutputSig, astParamCheck, astRepeatingParamCheck) )) } @@ -1536,7 +1539,7 @@ object ExprChecks { inputCheck: TypeSig, sparkInputSig: TypeSig): ExprChecks = projectOnly(outputCheck, sparkOutputSig, - Seq(ParamCheck("input", inputCheck, sparkInputSig))) + Seq(new ParamCheck("input", inputCheck, sparkInputSig))) /** * A check for a unary expression that supports project and as much AST as it can. @@ -1548,7 +1551,7 @@ object ExprChecks { inputCheck: TypeSig, sparkInputSig: TypeSig): ExprChecks = projectAndAst(allowedAstTypes, outputCheck, sparkOutputSig, - Seq(ParamCheck("input", inputCheck, sparkInputSig))) + Seq(new ParamCheck("input", inputCheck, sparkInputSig))) /** * Unary expression checks for project where the input matches the output. @@ -1587,8 +1590,8 @@ object ExprChecks { param1: (String, TypeSig, TypeSig), param2: (String, TypeSig, TypeSig)): ExprChecks = projectOnly(outputCheck, sparkOutputSig, - Seq(ParamCheck(param1._1, param1._2, param1._3), - ParamCheck(param2._1, param2._2, param2._3))) + Seq(new ParamCheck(param1._1, param1._2, param1._3), + new ParamCheck(param2._1, param2._2, param2._3))) /** * Helper function for a binary expression where the plugin supports project and AST. @@ -1600,8 +1603,8 @@ object ExprChecks { param1: (String, TypeSig, TypeSig), param2: (String, TypeSig, TypeSig)): ExprChecks = projectAndAst(allowedAstTypes, outputCheck, sparkOutputSig, - Seq(ParamCheck(param1._1, param1._2, param1._3), - ParamCheck(param2._1, param2._2, param2._3))) + Seq(new ParamCheck(param1._1, param1._2, param1._3), + new ParamCheck(param2._1, param2._2, param2._3))) /** * Aggregate operation where only group by agg and reduction is supported in the plugin and in @@ -1612,11 +1615,11 @@ object ExprChecks { sparkOutputSig: TypeSig, paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (GroupByAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (ReductionAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) /** * Aggregate operation where window, reduction, and group by agg are all supported the same. @@ -1626,13 +1629,13 @@ object ExprChecks { sparkOutputSig: TypeSig, paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (GroupByAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (ReductionAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (WindowAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) /** * For a generic expression that can work as both an aggregation and in the project context. @@ -1643,15 +1646,15 @@ object ExprChecks { sparkOutputSig: TypeSig, paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (GroupByAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (ReductionAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (WindowAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (ProjectExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) /** * An aggregation check where group by and reduction are supported by the plugin, but Spark @@ -1663,18 +1666,18 @@ object ExprChecks { paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { val windowParamCheck = paramCheck.map { pc => - ParamCheck(pc.name, TypeSig.none, pc.spark) + new ParamCheck(pc.name, TypeSig.none, pc.spark) } val windowRepeat = repeatingParamCheck.map { pc => - RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) + new RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) } - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (GroupByAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (ReductionAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (WindowAggExprContext, - ContextChecks(TypeSig.none, sparkOutputSig, windowParamCheck, windowRepeat)))) + new ContextChecks(TypeSig.none, sparkOutputSig, windowParamCheck, windowRepeat)))) } /** @@ -1686,9 +1689,9 @@ object ExprChecks { sparkOutputSig: TypeSig, paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (WindowAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) /** @@ -1701,18 +1704,18 @@ object ExprChecks { paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { val noneParamCheck = paramCheck.map { pc => - ParamCheck(pc.name, TypeSig.none, pc.spark) + new ParamCheck(pc.name, TypeSig.none, pc.spark) } val noneRepeatCheck = repeatingParamCheck.map { pc => - RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) + new RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) } - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (ReductionAggExprContext, - ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)), + new ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)), (GroupByAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (WindowAggExprContext, - ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)))) + new ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)))) } /** @@ -1725,18 +1728,18 @@ object ExprChecks { paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { val noneParamCheck = paramCheck.map { pc => - ParamCheck(pc.name, TypeSig.none, pc.spark) + new ParamCheck(pc.name, TypeSig.none, pc.spark) } val noneRepeatCheck = repeatingParamCheck.map { pc => - RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) + new RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) } - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (ReductionAggExprContext, - ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)), + new ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)), (GroupByAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (WindowAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + new ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala index 929abfed832..3f36ca48fc8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2025, NVIDIA CORPORATION. + * Copyright (c) 2021-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,14 @@ package com.nvidia.spark.rapids import com.nvidia.spark.rapids.jni.{SparkPlatformType => PlatformForJni, Version => VersionForJni} -import org.apache.spark.internal.Logging -object VersionUtils extends Logging { +object VersionUtils { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logWarning(msg: => String): Unit = { + log.warn(msg) + } + lazy val isSpark320OrLater: Boolean = cmpSparkVersion(3, 2, 0) >= 0 diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala index fe6306ec47e..e1e6da471dd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala @@ -41,14 +41,15 @@ object BloomFilterShims { }), GpuOverrides.expr[BloomFilterAggregate]( "Bloom filter build", - ExprChecksImpl(Map( + new ExprChecksImpl(Map( (ReductionAggExprContext, - ContextChecks(TypeSig.BINARY, TypeSig.BINARY, - Seq(ParamCheck("child", TypeSig.LONG, TypeSig.LONG), - ParamCheck("estimatedItems", + new ContextChecks(TypeSig.BINARY, TypeSig.BINARY, + Seq(new ParamCheck("child", TypeSig.LONG, TypeSig.LONG), + new ParamCheck("estimatedItems", TypeSig.lit(TypeEnum.LONG), TypeSig.lit(TypeEnum.LONG)), - ParamCheck("numBits", - TypeSig.lit(TypeEnum.LONG), TypeSig.lit(TypeEnum.LONG))))))), + new ParamCheck("numBits", + TypeSig.lit(TypeEnum.LONG), TypeSig.lit(TypeEnum.LONG))), + None)))), (a, conf, p, r) => new TypedImperativeAggExprMeta[BloomFilterAggregate](a, conf, p, r) { private lazy val estimatedNumItems = GpuBloomFilterAggregate.clampEstimatedNumItems( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index d3d0a8fd88f..1dcee177cb8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -136,7 +136,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims TypeSig.DOUBLE + TypeSig.DECIMAL_128, // NullType is not technically allowed by Spark, but in practice in 3.2.0 // it can show up - Seq(ParamCheck("input", + Seq(new ParamCheck("input", TypeSig.integral + TypeSig.fp + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.numericAndInterval + TypeSig.NULL))), (a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) { @@ -184,11 +184,11 @@ trait Spark320PlusShims extends SparkShims with RebaseShims TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DAYTIME, TypeSig.numericAndInterval, Seq( - ParamCheck("lower", + new ParamCheck("lower", TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DAYTIME + TypeSig.DECIMAL_128 + TypeSig.FLOAT + TypeSig.DOUBLE, TypeSig.numericAndInterval), - ParamCheck("upper", + new ParamCheck("upper", TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DAYTIME + TypeSig.DECIMAL_128 + TypeSig.FLOAT + TypeSig.DOUBLE, TypeSig.numericAndInterval))), @@ -199,8 +199,8 @@ trait Spark320PlusShims extends SparkShims with RebaseShims ExprChecks.windowOnly( TypeSig.all, TypeSig.all, - Seq(ParamCheck("windowFunction", TypeSig.all, TypeSig.all), - ParamCheck("windowSpec", + Seq(new ParamCheck("windowFunction", TypeSig.all, TypeSig.all), + new ParamCheck("windowSpec", TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_64 + TypeSig.DAYTIME, TypeSig.numericAndInterval))), (windowExpression, conf, p, r) => new GpuWindowExpressionMeta(windowExpression, conf, p, r)) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/HiveProviderImpl.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/HiveProviderImpl.scala index 50f731c5fa0..f468fb7dcc7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/HiveProviderImpl.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/HiveProviderImpl.scala @@ -48,7 +48,7 @@ class HiveProviderImpl extends HiveProviderCmdShims { ExprChecks.projectOnly( udfTypeSig, TypeSig.all, - repeatingParamCheck = Some(RepeatingParamCheck("param", udfTypeSig, TypeSig.all))), + repeatingParamCheck = Some(new RepeatingParamCheck("param", udfTypeSig, TypeSig.all))), (a, conf, p, r) => new ExprMeta[HiveSimpleUDF](a, conf, p, r) { val function = createFunction(a) @@ -91,7 +91,7 @@ class HiveProviderImpl extends HiveProviderCmdShims { ExprChecks.projectOnly( udfTypeSig, TypeSig.all, - repeatingParamCheck = Some(RepeatingParamCheck("param", udfTypeSig, TypeSig.all))), + repeatingParamCheck = Some(new RepeatingParamCheck("param", udfTypeSig, TypeSig.all))), (a, conf, p, r) => new ExprMeta[HiveGenericUDF](a, conf, p, r) { val function = createFunction(a) private val opRapidsFunc = function match { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala index 11a8d2e9409..62f46684980 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2025, NVIDIA CORPORATION. + * Copyright (c) 2021-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ object GpuScalaUDFMeta { GpuUserDefinedFunction.udfTypeSig, TypeSig.all, repeatingParamCheck = - Some(RepeatingParamCheck("param", GpuUserDefinedFunction.udfTypeSig, TypeSig.all))), + Some(new RepeatingParamCheck("param", GpuUserDefinedFunction.udfTypeSig, TypeSig.all))), (expr, conf, p, r) => new ExprMeta(expr, conf, p, r) { lazy val opRapidsFunc = GpuScalaUDF.getRapidsUDFInstance(expr.function) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index 78b6154f52b..b578197cf64 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -78,9 +78,9 @@ object JoinTypeChecks { joinRideAlongTypes, TypeSig.all, Map( - LEFT_KEYS -> InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes), - RIGHT_KEYS -> InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes), - CONDITION -> InputCheck(TypeSig.BOOLEAN, TypeSig.BOOLEAN))) + LEFT_KEYS -> new InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes, Nil), + RIGHT_KEYS -> new InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes, Nil), + CONDITION -> new InputCheck(TypeSig.BOOLEAN, TypeSig.BOOLEAN, Nil))) def equiJoinMeta(leftKeys: Seq[BaseExprMeta[_]], rightKeys: Seq[BaseExprMeta[_]], @@ -94,7 +94,7 @@ object JoinTypeChecks { val nonEquiJoinChecks: ExecChecks = ExecChecks( joinRideAlongTypes, TypeSig.all, - Map(CONDITION -> InputCheck(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + Map(CONDITION -> new InputCheck(TypeSig.BOOLEAN, TypeSig.BOOLEAN, notes = List("A non-inner join only is supported if the condition expression can be " + "converted to a GPU AST expression")))) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/zorder/ZOrderRules.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/zorder/ZOrderRules.scala index bb56e8a7602..d4e342188f7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/zorder/ZOrderRules.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/zorder/ZOrderRules.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -97,7 +97,7 @@ object ZOrderRules { TypeSig.BINARY, TypeSig.BINARY, repeatingParamCheck = - Some(RepeatingParamCheck("input", + Some(new RepeatingParamCheck("input", TypeSig.INT, TypeSig.INT))), (a, conf, p, r) => new ExprMeta[Expression](a, conf, p, r) { @@ -129,7 +129,7 @@ object ZOrderRules { TypeSig.LONG, TypeSig.LONG, repeatingParamCheck = - Some(RepeatingParamCheck("input", + Some(new RepeatingParamCheck("input", TypeSig.INT, TypeSig.INT))), (a, conf, p, r) => new ExprMeta[Expression](a, conf, p, r) { diff --git a/sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/shims/Spark321PlusDBShims.scala b/sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/shims/Spark321PlusDBShims.scala index f86b131dcbd..14e7e8ff610 100644 --- a/sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/shims/Spark321PlusDBShims.scala +++ b/sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/shims/Spark321PlusDBShims.scala @@ -123,8 +123,8 @@ trait Spark321PlusDBShims extends SparkShims TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), TypeSig.all, Map("partitionSpec" -> - InputCheck(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, - TypeSig.all))), + new InputCheck(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.all, Nil))), (runningWindowFunctionExec, conf, p, r) => new GpuRunningWindowExecMeta(runningWindowFunctionExec, conf, p, r) ) @@ -139,9 +139,9 @@ trait Spark321PlusDBShims extends SparkShims GpuOverrides.expr[EphemeralSubstring]( "Ephemeral version of substring operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY, - Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY), - ParamCheck("pos", TypeSig.INT, TypeSig.INT), - ParamCheck("len", TypeSig.INT, TypeSig.INT))), + Seq(new ParamCheck("str", TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY), + new ParamCheck("pos", TypeSig.INT, TypeSig.INT), + new ParamCheck("len", TypeSig.INT, TypeSig.INT))), (in, conf, p, r) => new TernaryExprMeta[EphemeralSubstring](in, conf, p, r) { override def convertToGpu( column: Expression, diff --git a/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala index 501e4f3c6c7..224d7c6b0a2 100644 --- a/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala +++ b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala @@ -79,7 +79,7 @@ trait Spark341PlusDBShims extends Spark332PlusDBShims { // plugin is also an union of all the types of Pandas UDF. (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested() + TypeSig.STRUCT, TypeSig.unionOfPandasUdfOut, - repeatingParamCheck = Some(RepeatingParamCheck( + repeatingParamCheck = Some(new RepeatingParamCheck( "param", (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all))), diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/Spark350PlusNonDBShims.scala b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/Spark350PlusNonDBShims.scala index 9af31e67b91..bb6beb5ef1c 100644 --- a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/Spark350PlusNonDBShims.scala +++ b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/Spark350PlusNonDBShims.scala @@ -130,7 +130,7 @@ trait Spark350PlusNonDBShims extends Spark340PlusNonDBShims { // plugin is also an union of all the types of Pandas UDF. (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested() + TypeSig.STRUCT, TypeSig.unionOfPandasUdfOut, - repeatingParamCheck = Some(RepeatingParamCheck( + repeatingParamCheck = Some(new RepeatingParamCheck( "param", (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all))), @@ -150,8 +150,8 @@ trait Spark350PlusNonDBShims extends Spark340PlusNonDBShims { ExprChecks.projectOnly( TypeSig.all, TypeSig.all, - Seq(ParamCheck("condition", TypeSig.all, TypeSig.all)), - Some(RepeatingParamCheck("outputs", TypeSig.all, TypeSig.all)) + Seq(new ParamCheck("condition", TypeSig.all, TypeSig.all)), + Some(new RepeatingParamCheck("outputs", TypeSig.all, TypeSig.all)) ), (keep, conf, p, r) => new GpuKeepInstructionMeta(keep, conf, p, r)), GpuOverrides.expr[Discard]( @@ -159,15 +159,15 @@ trait Spark350PlusNonDBShims extends Spark340PlusNonDBShims { ExprChecks.projectOnly( TypeSig.all, TypeSig.all, - Seq(ParamCheck("condition", TypeSig.all, TypeSig.all))), + Seq(new ParamCheck("condition", TypeSig.all, TypeSig.all))), (discard, conf, p, r) => new GpuDiscardInstructionMeta(discard, conf, p, r)), GpuOverrides.expr[Split]( "Split instruction for MERGE operations - splits rows into multiple outputs", ExprChecks.projectOnly( TypeSig.all, TypeSig.all, - Seq(ParamCheck("condition", TypeSig.all, TypeSig.all)), - Some(RepeatingParamCheck("outputs", TypeSig.all, TypeSig.all))), + Seq(new ParamCheck("condition", TypeSig.all, TypeSig.all)), + Some(new RepeatingParamCheck("outputs", TypeSig.all, TypeSig.all))), (split, conf, p, r) => new GpuSplitInstructionMeta(split, conf, p, r)) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap super.getExprs ++ shimExprs