diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoreDumpHandler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoreDumpHandler.scala index 77c9a9987e8..0ade942599e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoreDumpHandler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoreDumpHandler.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,13 +29,32 @@ import org.apache.hadoop.fs.permission.{FsAction, FsPermission} import org.apache.spark.SparkContext import org.apache.spark.api.plugin.PluginContext -import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.SparkSession import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.util.SerializableConfiguration -object GpuCoreDumpHandler extends Logging { +object GpuCoreDumpHandler { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + log.warn(msg, throwable) + } + + private def logError(msg: => String): Unit = { + log.error(msg) + } + + private def logError(msg: => String, throwable: Throwable): Unit = { + log.error(msg, throwable) + } + private var executor: Option[ExecutorService] = None private var dumpedPath: Option[String] = None private var namedPipeFile: File = _ diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index e2674b1b7f0..822b9c3fee2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -27,7 +27,6 @@ import com.nvidia.spark.rapids.jni.RmmSpark import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.spark.{SparkConf, SparkEnv, TaskContext} -import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.resource.ResourceInformation import org.apache.spark.sql.internal.SQLConf @@ -38,7 +37,29 @@ private case object Initialized extends MemoryState private case object Uninitialized extends MemoryState private case object Errored extends MemoryState -object GpuDeviceManager extends Logging { +object GpuDeviceManager { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + log.warn(msg) + } + + private def logError(msg: => String): Unit = { + log.error(msg) + } + // This config controls whether RMM/Pinned memory are initialized from the task // or from the executor side plugin. The default is to initialize from the // executor plugin. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala index bd6760f6765..8287f37dce8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala @@ -280,7 +280,7 @@ class GetJsonObjectCombiner(private val exp: GpuGetJsonObject) extends GpuExpres override def addExpression(e: Expression): Unit = { val localOutputLocation = outputLocation outputLocation += 1 - val key = GpuExpressionEquals(e) + val key = new GpuExpressionEquals(e) if (!toCombine.contains(key)) { toCombine.put(key, localOutputLocation) } @@ -329,7 +329,7 @@ class GetJsonObjectCombiner(private val exp: GpuGetJsonObject) extends GpuExpres } override def getReplacementExpression(e: Expression): Option[Expression] = { - toCombine.get(GpuExpressionEquals(e)).map { localId => + toCombine.get(new GpuExpressionEquals(e)).map { localId => GpuGetStructField(multiGet, localId, Some(fieldName(localId))) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/InternalExclusiveModeGpuDiscoveryPlugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/InternalExclusiveModeGpuDiscoveryPlugin.scala index 8f5b5ee66b9..b43c6dee487 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/InternalExclusiveModeGpuDiscoveryPlugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/InternalExclusiveModeGpuDiscoveryPlugin.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ import ai.rapids.cudf.Cuda import org.apache.spark.SparkConf import org.apache.spark.api.resource.ResourceDiscoveryPlugin -import org.apache.spark.internal.Logging import org.apache.spark.resource.{ResourceInformation, ResourceRequest} /** @@ -32,7 +31,23 @@ import org.apache.spark.resource.{ResourceInformation, ResourceRequest} * It should be loaded by reflection using ShimLoader.newInstanceOf, see ./docs/dev/shims.md */ protected class InternalExclusiveModeGpuDiscoveryPlugin - extends ResourceDiscoveryPlugin with Logging { + extends ResourceDiscoveryPlugin { + + private val log = org.slf4j.LoggerFactory.getLogger( + classOf[InternalExclusiveModeGpuDiscoveryPlugin]) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + override def discoverResource( request: ResourceRequest, sparkconf: SparkConf diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index fc6566dc222..e2d03816c77 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package com.nvidia.spark.rapids -import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, + Long => JLong, Short => JShort} import java.math.BigInteger import java.time.{LocalDate, OffsetDateTime} import java.util @@ -31,7 +32,7 @@ import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray import com.nvidia.spark.rapids.shims.{GpuTypeShims, SparkShimImpl} import org.apache.commons.codec.binary.{Hex => ApacheHex} -import org.json4s.JsonAST.{JField, JNull, JString} +import org.json4s.JsonAST.{JField, JNull, JString, JValue} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -685,7 +686,8 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression case (l: Long, TimestampType) => JString(DateTimeUtils.toJavaTimestamp(l).toString) case (other, _) => JString(other.toString) } - ("value" -> jsonValue) :: ("dataType" -> TrampolineUtil.jsonValue(dataType)) :: Nil + ("value" -> jsonValue) :: + ("dataType" -> TrampolineUtil.jsonValue(dataType).asInstanceOf[JValue]) :: Nil } override def sql: String = (value, dataType) match { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/GpuPythonArguments.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/GpuPythonArguments.scala index 10ecb3fbece..2f33689d784 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/GpuPythonArguments.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/GpuPythonArguments.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,11 +27,11 @@ import org.apache.spark.sql.types.DataType * @param argOffsets The offsets of the original arguments in "flattenedArgs" * @param argNames The optional argument names */ -case class GpuPythonArguments( - flattenedArgs: Seq[Expression], - flattenedTypes: Seq[DataType], - argOffsets: Array[Array[Int]], - argNames: Option[Array[Array[Option[String]]]]) +class GpuPythonArguments( + val flattenedArgs: Seq[Expression], + val flattenedTypes: Seq[DataType], + val argOffsets: Array[Array[Int]], + val argNames: Option[Array[Array[Option[String]]]]) /** Gpu version of ArgumentMetadata */ -case class GpuArgumentMeta(offset: Int, name: Option[String]) +class GpuArgumentMeta(val offset: Int, val name: Option[String]) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonWorkerSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonWorkerSemaphore.scala index 521b3340154..844f5453d60 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonWorkerSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonWorkerSemaphore.scala @@ -26,7 +26,6 @@ import com.nvidia.spark.rapids.python.PythonConfEntries.CONCURRENT_PYTHON_WORKER import org.apache.commons.lang3.mutable.MutableInt import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.internal.Logging /* * PythonWorkerSemaphore is used to limit the number of Python workers(processes) to be started @@ -41,7 +40,15 @@ import org.apache.spark.internal.Logging * the inner semaphore when no longer needed. * */ -object PythonWorkerSemaphore extends Logging { +object PythonWorkerSemaphore { + private val log = org.slf4j.LoggerFactory.getLogger( + "com.nvidia.spark.rapids.python.PythonWorkerSemaphore") + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } private lazy val rapidsConf = new RapidsConf(SparkEnv.get.conf) private lazy val workersPerGpu = rapidsConf.get(CONCURRENT_PYTHON_WORKERS) @@ -97,7 +104,15 @@ object PythonWorkerSemaphore extends Logging { } } -private final class PythonWorkerSemaphore(tasksPerGpu: Int) extends Logging { +private final class PythonWorkerSemaphore(tasksPerGpu: Int) { + private val log = org.slf4j.LoggerFactory.getLogger(classOf[PythonWorkerSemaphore]) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + private val semaphore = new Semaphore(tasksPerGpu) // Map to track which tasks have acquired the semaphore. private val activeTasks = new ConcurrentHashMap[Long, MutableInt] diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BucketingUtilsShim.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BucketingUtilsShim.scala index d1844867760..7d55a011dcf 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BucketingUtilsShim.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BucketingUtilsShim.scala @@ -43,7 +43,7 @@ object BucketingUtilsShim { // table and a normal one. val bucketIdExpression = GpuHashPartitioning(bucketColumns, spec.numBuckets) .partitionIdExpression - GpuWriterBucketSpec(bucketIdExpression, (_: Int) => "") + new GpuWriterBucketSpec(bucketIdExpression, (_: Int) => "") } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/parquet/ParquetSchemaClipShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/parquet/ParquetSchemaClipShims.scala index d91679a7ef4..87ccb7589b1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/parquet/ParquetSchemaClipShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/parquet/ParquetSchemaClipShims.scala @@ -107,7 +107,7 @@ object ParquetSchemaClipShims { val scale = decimalLogicalTypeAnnotation.getScale if (!(maxPrecision == -1 || 1 <= precision && precision <= maxPrecision)) { - throw new RapidsAnalysisException(s"Invalid decimal precision: $typeName " + + throw RapidsAnalysisException(s"Invalid decimal precision: $typeName " + s"cannot store $precision digits (max $maxPrecision)") } @@ -166,14 +166,14 @@ object ParquetSchemaClipShims { ParquetTimestampAnnotationShims.timestampTypeForMillisOrMicros(timestamp) case timestamp: TimestampLogicalTypeAnnotation if timestamp.getUnit == TimeUnit.NANOS && ParquetLegacyNanoAsLongShims.legacyParquetNanosAsLong => - throw new RapidsAnalysisException( + throw RapidsAnalysisException( "GPU does not support spark.sql.legacy.parquet.nanosAsLong") case _ => illegalType() } case INT96 => if (!SQLConf.get.isParquetINT96AsTimestamp) { - throw new RapidsAnalysisException( + throw RapidsAnalysisException( "INT96 is not supported unless it's interpreted as timestamp. " + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") } diff --git a/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/CoalesceConvertIterator.scala b/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/CoalesceConvertIterator.scala index adfb2f6c58b..7136fa4c368 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/CoalesceConvertIterator.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/CoalesceConvertIterator.scala @@ -22,7 +22,6 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.hybrid.{CoalesceBatchConverter => NativeConverter, HybridHostRetryAllocator, RapidsHostColumn} import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -36,7 +35,13 @@ class CoalesceConvertIterator(cpuScanIter: Iterator[ColumnarBatch], targetBatchSizeInBytes: Long, schema: StructType, metrics: Map[String, GpuMetric]) - extends Iterator[Array[RapidsHostColumn]] with Logging { + extends Iterator[Array[RapidsHostColumn]] { + + @transient private lazy val log = org.slf4j.LoggerFactory.getLogger( + classOf[CoalesceConvertIterator]) + + private def logInfo(msg: => String): Unit = if (log.isInfoEnabled) log.info(msg) + private var converterImpl: NativeConverter = _ @@ -140,7 +145,7 @@ class CoalesceConvertIterator(cpuScanIter: Iterator[ColumnarBatch], } } -object CoalesceConvertIterator extends Logging { +object CoalesceConvertIterator { /** * Consumes the RapidsHostBatchProducer and converts the HostColumnVectors to Device ones. */ diff --git a/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/HybridExecutionUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/HybridExecutionUtils.scala index c4bc1c73ff9..c653cf8e6b1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/HybridExecutionUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/HybridExecutionUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024-2025, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ import java.util.Locale import ai.rapids.cudf.DType import com.nvidia.spark.rapids.{RapidsConf, VersionUtils} -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnresolvedHint} import org.apache.spark.sql.catalyst.trees.TreePattern @@ -33,7 +32,6 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types._ object HybridExecutionUtils extends PredicateHelper { - private val HYBRID_JAR_PLUGIN_CLASS_NAME = "com.nvidia.spark.rapids.hybrid.HybridPluginWrapper" /** @@ -434,7 +432,7 @@ object HybridExecutionUtils extends PredicateHelper { } } -object HybridExecOverrides extends Logging { +object HybridExecOverrides { // The SQL hint enables HybridScan for specific tables even if HYBRID_PARQUET_READER is disabled val HYBRID_SCAN_HINT = "HYBRID_SCAN" diff --git a/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/RapidsHostBatchProducer.scala b/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/RapidsHostBatchProducer.scala index 3212d9bfe3d..0fb80cdb513 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/RapidsHostBatchProducer.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/RapidsHostBatchProducer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,6 @@ import com.nvidia.spark.rapids.hybrid.RapidsHostColumn import com.nvidia.spark.rapids.jni.RmmSpark import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.execution.TrampolineUtil /** @@ -92,7 +91,15 @@ class PrefetchHostBatchProducer( taskAttId: Long, base: Iterator[Array[RapidsHostColumn]], capacity: Int, - waitTimeMetric: GpuMetric) extends RapidsHostBatchProducer with Logging { + waitTimeMetric: GpuMetric) extends RapidsHostBatchProducer { + + @transient private lazy val log = org.slf4j.LoggerFactory.getLogger( + classOf[PrefetchHostBatchProducer]) + + private def logInfo(msg: => String): Unit = if (log.isInfoEnabled) log.info(msg) + + private def logError(msg: => String): Unit = if (log.isErrorEnabled) log.error(msg) + @volatile private var isInit: Boolean = false // Mark if there is in-progress element being produced in producerThread diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala index 9e367679891..25bff7901ab 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala @@ -268,7 +268,7 @@ case class GpuJsonScan( val broadcastedConf = sparkSession.sparkContext.broadcast( new SerializableConfiguration(hadoopConf)) - GpuJsonPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, + new GpuJsonPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, parsedOptions, maxReaderBatchSizeRows, maxReaderBatchSizeBytes, maxGpuColumnSizeBytes, metrics, options.asScala.toMap) } @@ -276,7 +276,7 @@ case class GpuJsonScan( override def withInputFile(): GpuScan = this } -case class GpuJsonPartitionReaderFactory( +class GpuJsonPartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, @@ -288,7 +288,8 @@ case class GpuJsonPartitionReaderFactory( maxReaderBatchSizeBytes: Long, maxGpuColumnSizeBytes: Long, metrics: Map[String, GpuMetric], - @transient params: Map[String, String]) extends ShimFilePartitionReaderFactory(params) { + @transient params: Map[String, String]) + extends ShimFilePartitionReaderFactory(params) with Serializable { override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { throw new IllegalStateException("ROW BASED PARSING IS NOT SUPPORTED ON THE GPU...") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.scala index 3a80180b154..bb5fcee97a3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,7 +56,7 @@ class GpuReadJsonFileFormat extends JsonFileFormat with GpuReadFileFormatWithMet sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) val rapidsConf = new RapidsConf(sqlConf) - val factory = GpuJsonPartitionReaderFactory( + val factory = new GpuJsonPartitionReaderFactory( sqlConf, broadcastedHadoopConf, dataSchema, @@ -81,7 +81,7 @@ class GpuReadJsonFileFormat extends JsonFileFormat with GpuReadFileFormatWithMet } } -object GpuReadJsonFileFormat { +object GpuReadJsonFileFormat extends Serializable { def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { val fsse = meta.wrapped GpuJsonScan.tagSupport( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala index 923d92572b5..e621b72b5e8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,26 +16,65 @@ package org.apache.spark.sql.nvidia -import com.nvidia.spark.rapids.RapidsConf - -import org.apache.spark.internal.Logging +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +object LogicalPlanRules { + private val dfUDFEnabledKey = "spark.rapids.sql.dfudf.enabled" + + private def toBoolean(value: String, key: String): Boolean = { + try { + value.trim.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be boolean, but was $value") + } + } + + private def isDFUDFEnabled(conf: SQLConf): Boolean = { + val value = conf.getConfString(dfUDFEnabledKey, null) + if (value == null) { + true + } else { + toBoolean(value, dfUDFEnabledKey) + } + } + + @transient private[this] lazy val dfUDFShimsModule = { + Class.forName("org.apache.spark.sql.nvidia.DFUDFShims" + "$") + .getField("MODULE" + "$") + .get(null) + } + + @transient private[this] lazy val exprToColumnMethod = + dfUDFShimsModule.getClass.getMethod("exprToColumn", classOf[Expression]) + + @transient private[this] lazy val columnToExprMethod = + dfUDFShimsModule.getClass.getMethod("columnToExpr", classOf[Column]) + + private def exprToColumn(expr: Expression): Column = + exprToColumnMethod.invoke(dfUDFShimsModule, expr).asInstanceOf[Column] + + private def columnToExpr(column: Column): Expression = + columnToExprMethod.invoke(dfUDFShimsModule, column).asInstanceOf[Expression] +} -case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { +case class LogicalPlanRules() extends Rule[LogicalPlan] { val replacePartialFunc: PartialFunction[Expression, Expression] = { case f: ScalaUDF if DFUDF.getDFUDF(f.function).isDefined => DFUDF.getDFUDF(f.function).map { - dfudf => DFUDFShims.columnToExpr( - dfudf(f.children.map(DFUDFShims.exprToColumn(_)).toArray)) + dfudf => LogicalPlanRules.columnToExpr( + dfudf(f.children.map(LogicalPlanRules.exprToColumn(_)).toArray)) }.getOrElse{ throw new IllegalStateException("Inconsistent results when extracting df_udf") } } override def apply(plan: LogicalPlan): LogicalPlan = { - if (RapidsConf.DFUDF_ENABLED.get(plan.conf)) { + if (LogicalPlanRules.isDFUDFEnabled(plan.conf)) { plan.transformExpressions(replacePartialFunc) } else { plan diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala index 79f71ba4ca0..e5187b2a300 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ trait DFUDF { def apply(input: Array[Column]): Column } -case class DFUDF0(f: Function0[Column]) +class DFUDF0(val f: Function0[Column]) extends UDF0[Any] with DFUDF { override def call(): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -38,7 +38,7 @@ case class DFUDF0(f: Function0[Column]) } } -case class DFUDF1(f: Function1[Column, Column]) +class DFUDF1(val f: Function1[Column, Column]) extends UDF1[Any, Any] with DFUDF { override def call(t1: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -50,7 +50,7 @@ case class DFUDF1(f: Function1[Column, Column]) } } -case class DFUDF2(f: Function2[Column, Column, Column]) +class DFUDF2(val f: Function2[Column, Column, Column]) extends UDF2[Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -62,7 +62,7 @@ case class DFUDF2(f: Function2[Column, Column, Column]) } } -case class DFUDF3(f: Function3[Column, Column, Column, Column]) +class DFUDF3(val f: Function3[Column, Column, Column, Column]) extends UDF3[Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -74,7 +74,7 @@ case class DFUDF3(f: Function3[Column, Column, Column, Column]) } } -case class DFUDF4(f: Function4[Column, Column, Column, Column, Column]) +class DFUDF4(val f: Function4[Column, Column, Column, Column, Column]) extends UDF4[Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -86,7 +86,7 @@ case class DFUDF4(f: Function4[Column, Column, Column, Column, Column]) } } -case class DFUDF5(f: Function5[Column, Column, Column, Column, Column, Column]) +class DFUDF5(val f: Function5[Column, Column, Column, Column, Column, Column]) extends UDF5[Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -98,7 +98,7 @@ case class DFUDF5(f: Function5[Column, Column, Column, Column, Column, Column]) } } -case class DFUDF6(f: Function6[Column, Column, Column, Column, Column, Column, Column]) +class DFUDF6(val f: Function6[Column, Column, Column, Column, Column, Column, Column]) extends UDF6[Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -110,7 +110,7 @@ case class DFUDF6(f: Function6[Column, Column, Column, Column, Column, Column, C } } -case class DFUDF7(f: Function7[Column, Column, Column, Column, Column, Column, Column, Column]) +class DFUDF7(val f: Function7[Column, Column, Column, Column, Column, Column, Column, Column]) extends UDF7[Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -122,7 +122,7 @@ case class DFUDF7(f: Function7[Column, Column, Column, Column, Column, Column, C } } -case class DFUDF8(f: Function8[Column, Column, Column, Column, Column, Column, Column, Column, +class DFUDF8(val f: Function8[Column, Column, Column, Column, Column, Column, Column, Column, Column]) extends UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any): Any = { @@ -135,7 +135,7 @@ case class DFUDF8(f: Function8[Column, Column, Column, Column, Column, Column, C } } -case class DFUDF9(f: Function9[Column, Column, Column, Column, Column, Column, Column, Column, +class DFUDF9(val f: Function9[Column, Column, Column, Column, Column, Column, Column, Column, Column, Column]) extends UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any, @@ -149,7 +149,7 @@ case class DFUDF9(f: Function9[Column, Column, Column, Column, Column, Column, C } } -case class DFUDF10(f: Function10[Column, Column, Column, Column, Column, Column, Column, Column, +class DFUDF10(val f: Function10[Column, Column, Column, Column, Column, Column, Column, Column, Column, Column, Column]) extends UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any, @@ -164,7 +164,7 @@ case class DFUDF10(f: Function10[Column, Column, Column, Column, Column, Column, } } -case class JDFUDF0(f: UDF0[Column]) +class JDFUDF0(val f: UDF0[Column]) extends UDF0[Any] with DFUDF { override def call(): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -176,7 +176,7 @@ case class JDFUDF0(f: UDF0[Column]) } } -case class JDFUDF1(f: UDF1[Column, Column]) +class JDFUDF1(val f: UDF1[Column, Column]) extends UDF1[Any, Any] with DFUDF { override def call(t1: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -188,7 +188,7 @@ case class JDFUDF1(f: UDF1[Column, Column]) } } -case class JDFUDF2(f: UDF2[Column, Column, Column]) +class JDFUDF2(val f: UDF2[Column, Column, Column]) extends UDF2[Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -200,7 +200,7 @@ case class JDFUDF2(f: UDF2[Column, Column, Column]) } } -case class JDFUDF3(f: UDF3[Column, Column, Column, Column]) +class JDFUDF3(val f: UDF3[Column, Column, Column, Column]) extends UDF3[Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -212,7 +212,7 @@ case class JDFUDF3(f: UDF3[Column, Column, Column, Column]) } } -case class JDFUDF4(f: UDF4[Column, Column, Column, Column, Column]) +class JDFUDF4(val f: UDF4[Column, Column, Column, Column, Column]) extends UDF4[Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -224,7 +224,7 @@ case class JDFUDF4(f: UDF4[Column, Column, Column, Column, Column]) } } -case class JDFUDF5(f: UDF5[Column, Column, Column, Column, Column, Column]) +class JDFUDF5(val f: UDF5[Column, Column, Column, Column, Column, Column]) extends UDF5[Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -236,7 +236,7 @@ case class JDFUDF5(f: UDF5[Column, Column, Column, Column, Column, Column]) } } -case class JDFUDF6(f: UDF6[Column, Column, Column, Column, Column, Column, Column]) +class JDFUDF6(val f: UDF6[Column, Column, Column, Column, Column, Column, Column]) extends UDF6[Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -248,7 +248,7 @@ case class JDFUDF6(f: UDF6[Column, Column, Column, Column, Column, Column, Colum } } -case class JDFUDF7(f: UDF7[Column, Column, Column, Column, Column, Column, Column, Column]) +class JDFUDF7(val f: UDF7[Column, Column, Column, Column, Column, Column, Column, Column]) extends UDF7[Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any): Any = { throw new IllegalStateException("TODO better error message. This should have been replaced") @@ -260,7 +260,7 @@ case class JDFUDF7(f: UDF7[Column, Column, Column, Column, Column, Column, Colum } } -case class JDFUDF8(f: UDF8[Column, Column, Column, Column, Column, Column, Column, Column, +class JDFUDF8(val f: UDF8[Column, Column, Column, Column, Column, Column, Column, Column, Column]) extends UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any): Any = { @@ -273,7 +273,7 @@ case class JDFUDF8(f: UDF8[Column, Column, Column, Column, Column, Column, Colum } } -case class JDFUDF9(f: UDF9[Column, Column, Column, Column, Column, Column, Column, Column, +class JDFUDF9(val f: UDF9[Column, Column, Column, Column, Column, Column, Column, Column, Column, Column]) extends UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any, @@ -287,7 +287,7 @@ case class JDFUDF9(f: UDF9[Column, Column, Column, Column, Column, Column, Colum } } -case class JDFUDF10(f: UDF10[Column, Column, Column, Column, Column, Column, Column, Column, +class JDFUDF10(val f: UDF10[Column, Column, Column, Column, Column, Column, Column, Column, Column, Column, Column]) extends UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BasicColumnarWriteStatsTracker.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BasicColumnarWriteStatsTracker.scala index 85707aad5fc..142d1df9e7a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BasicColumnarWriteStatsTracker.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BasicColumnarWriteStatsTracker.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{SparkContext, TaskContext} -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.WriteTaskStats @@ -35,19 +34,6 @@ import org.apache.spark.sql.rapids.BasicColumnarWriteJobStatsTracker._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -/** - * Simple metrics collected during an instance of [[GpuFileFormatDataWriter]]. - * These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703). - */ -case class BasicColumnarWriteTaskStats( - partitions: Seq[InternalRow], - numFiles: Int, - numWriters: Int, - numBytes: Long, - numRows: Long) - extends WriteTaskStats - - /** * Simple metrics collected during an instance of [[GpuFileFormatDataWriter]]. * This is the columnar version of @@ -56,7 +42,20 @@ case class BasicColumnarWriteTaskStats( class BasicColumnarWriteTaskStatsTracker( hadoopConf: Configuration, taskCommitTimeMetric: Option[GpuMetric]) - extends ColumnarWriteTaskStatsTracker with Logging { + extends ColumnarWriteTaskStatsTracker { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[BasicColumnarWriteTaskStatsTracker]) + + private def logInfo(msg: => String): Unit = if (log.isInfoEnabled) log.info(msg) + + private def logWarning(msg: => String): Unit = if (log.isWarnEnabled) log.warn(msg) + + private def logDebug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg) + + private def logDebug(msg: => String, throwable: Throwable): Unit = { + if (log.isDebugEnabled) log.debug(msg, throwable) + } + private[this] val partitions: mutable.ArrayBuffer[InternalRow] = mutable.ArrayBuffer.empty private[this] var numFiles: Int = 0 private[this] var numSubmittedFiles: Int = 0 @@ -186,7 +185,7 @@ class BasicColumnarWriteTaskStatsTracker( "or files being not immediately visible in the filesystem.") } taskCommitTimeMetric.foreach(_ += taskCommitTime) - BasicColumnarWriteTaskStats(partitions.toSeq, numFiles, maxNumWriters, numBytes, numRows) + new BasicColumnarWriteTaskStats(partitions.toSeq, numFiles, maxNumWriters, numBytes, numRows) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BridgeGenerateUnsafeProjection.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BridgeGenerateUnsafeProjection.scala index fbdc197ce80..ea64ccb4b71 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BridgeGenerateUnsafeProjection.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BridgeGenerateUnsafeProjection.scala @@ -16,6 +16,8 @@ package org.apache.spark.sql.rapids +import scala.util.control.NonFatal + import com.nvidia.spark.rapids.RapidsHostColumnBuilder import org.apache.spark.sql.catalyst.InternalRow @@ -65,17 +67,48 @@ class InterpretedBridgeUnsafeProjection(expressions: Seq[Expression]) /** * The factory object for `UnsafeProjection`. */ -object BridgeUnsafeProjection - extends CodeGeneratorWithInterpretedFallback[Seq[Expression], BridgeUnsafeProjection] { +object BridgeUnsafeProjection { + + def createOptimizedAppendFunction(dataType: DataType, + nullable: Boolean): (Any, RapidsHostColumnBuilder) => Unit = { + BridgeUnsafeProjectionCodegen.createOptimizedAppendFunction(dataType, nullable) + } + + def create(schema: StructType): BridgeUnsafeProjection = create(schema.fields.map(_.dataType)) + + def create(fields: Array[DataType]): BridgeUnsafeProjection = { + create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) + } + + def create(exprs: Seq[Expression]): BridgeUnsafeProjection = { + BridgeUnsafeProjectionCodegen.create(exprs) + } + + def create(expr: Expression): BridgeUnsafeProjection = create(Seq(expr)) + + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): BridgeUnsafeProjection = { + create(bindReferences(exprs, inputSchema)) + } +} + +private object BridgeUnsafeProjectionCodegen { + private[this] val log = org.slf4j.LoggerFactory.getLogger(getClass) + + private def createObject(in: Seq[Expression]): BridgeUnsafeProjection = { + try { + createCodeGeneratedObject(in) + } catch { + case NonFatal(e) => + log.warn("Expr codegen error and falling back to interpreter mode", e) + createInterpretedObject(in) + } + } - override protected def createCodeGeneratedObject(in: Seq[Expression]): BridgeUnsafeProjection = { - // Just call generate directly - let any exceptions propagate naturally - // The CodeGeneratorWithInterpretedFallback base class will catch exceptions - // and fall back to createInterpretedObject + private def createCodeGeneratedObject(in: Seq[Expression]): BridgeUnsafeProjection = { BridgeGenerateUnsafeProjection.generate(in, SQLConf.get.subexpressionEliminationEnabled) } - override protected def createInterpretedObject(in: Seq[Expression]): BridgeUnsafeProjection = { + private def createInterpretedObject(in: Seq[Expression]): BridgeUnsafeProjection = { new InterpretedBridgeUnsafeProjection(in) } @@ -238,8 +271,8 @@ object BridgeUnsafeProjection * * @note The returned UnsafeRow will be pointed to a scratch buffer inside the projection. */ -object BridgeGenerateUnsafeProjection extends - CodeGenerator[Seq[Expression], BridgeUnsafeProjection] { +object BridgeGenerateUnsafeProjection { + private val codegenLog = org.slf4j.LoggerFactory.getLogger(getClass) case class Schema(dataType: DataType, nullable: Boolean) @@ -547,6 +580,18 @@ object BridgeGenerateUnsafeProjection extends protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = bindReferences(in, inputSchema) + def newCodeGenContext(): CodegenContext = new CodegenContext + + def generate(expressions: Seq[Expression]): BridgeUnsafeProjection = { + create(canonicalize(expressions)) + } + + def generate( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): BridgeUnsafeProjection = { + generate(bind(expressions, inputSchema)) + } + def generate( expressions: Seq[Expression], subexpressionEliminationEnabled: Boolean): BridgeUnsafeProjection = { @@ -640,7 +685,9 @@ object BridgeGenerateUnsafeProjection extends val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) - logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") + if (codegenLog.isDebugEnabled) { + codegenLog.debug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") + } val (clazz, _) = CodeGenerator.compile(code) clazz.generate(ctx.references.toArray).asInstanceOf[BridgeUnsafeProjection] diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index cae457272e8..11d6baa6691 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -23,7 +23,6 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.delta.DeltaProvider import com.nvidia.spark.rapids.iceberg.IcebergProvider -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.connector.catalog.SupportsWrite @@ -40,7 +39,11 @@ import org.apache.spark.util.Utils * spark-avro classes because `class not found` exception may throw if spark-avro does not * exist at runtime. Details see: https://github.com/NVIDIA/spark-rapids/issues/5648 */ -trait ExternalSourceBase extends Logging { +trait ExternalSourceBase { + @transient private lazy val log = org.slf4j.LoggerFactory.getLogger(classOf[ExternalSourceBase]) + + private def logWarning(msg: => String): Unit = if (log.isWarnEnabled) log.warn(msg) + val avroScanClassName = "org.apache.spark.sql.v2.avro.AvroScan" lazy val hasSparkAvroJar = { /** spark-avro is an optional package for Spark, so the RAPIDS Accelerator diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala index 7d21ec91452..f0c0a4de66b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala @@ -27,11 +27,11 @@ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit import com.nvidia.spark.rapids.fileio.hadoop.HadoopFileIO +import com.nvidia.spark.rapids.fileio.hadoop.PerfIOHadoopInputFileFactory import com.nvidia.spark.rapids.shims.GpuFileFormatDataWriterShim import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} @@ -348,12 +348,12 @@ class GpuDynamicPartitionDataSingleWriter( } /** - * A case class to hold the batch, the optional partition values and the optional bucket + * A class to hold the batch, the optional partition values and the optional bucket * ID for a split group. All the rows in the batch belong to the group defined by the * partition values and the bucket ID. */ - private case class SplitPack(split: SpillableColumnarBatch, partValues: Option[InternalRow], - bucketId: Option[Int]) extends AutoCloseable { + private class SplitPack(val split: SpillableColumnarBatch, val partValues: Option[InternalRow], + val bucketId: Option[Int]) extends AutoCloseable { override def close(): Unit = { split.safeClose() } @@ -546,7 +546,7 @@ class GpuDynamicPartitionDataSingleWriter( val split = splits(idx) splits(idx) = null closeOnExcept(split) { _ => - SplitPack( + new SplitPack( SpillableColumnarBatch(split, outDataTypes, SpillPriorities.ACTIVE_BATCHING_PRIORITY), getNextPartValues(idx), getBucketId(idx)) @@ -673,7 +673,10 @@ class GpuDynamicPartitionDataSingleWriter( // The input batch that is entirely sorted, so split it up by partitions and (or) // bucket ids, and write the split batches one by one. withResource(splitBatchByKeyAndClose(batch)) { splitPacks => - splitPacks.zipWithIndex.foreach { case (SplitPack(sp, partVals, bucketId), i) => + splitPacks.zipWithIndex.foreach { case (pack, i) => + val sp = pack.split + val partVals = pack.partValues + val bucketId = pack.bucketId val hasDiffPart = partVals != currentWriterId.partitionValues val hasDiffBucket = bucketId != currentWriterId.bucketId if (hasDiffPart || hasDiffBucket) { @@ -719,7 +722,7 @@ class GpuDynamicPartitionDataConcurrentWriter( debugOutputBasePath: Option[String]) extends GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer, debugOutputBasePath) - with Logging { + with RapidsLocalLog { /** Wrapper class for status and caches of a unique concurrent output writer. */ private class WriterStatusWithBatches extends WriterAndStatus with AutoCloseable { @@ -975,9 +978,9 @@ class GpuDynamicPartitionDataConcurrentWriter( * @param bucketIdExpression Expression to calculate bucket id based on bucket column(s). * @param bucketFileNamePrefix Prefix of output file name based on bucket id. */ -case class GpuWriterBucketSpec( - bucketIdExpression: GpuExpression, - bucketFileNamePrefix: Int => String) +class GpuWriterBucketSpec( + val bucketIdExpression: GpuExpression, + val bucketFileNamePrefix: Int => String) extends Serializable /** * A shared job description for all the GPU write tasks. @@ -999,7 +1002,9 @@ class GpuWriteJobDescription( val concurrentWriterPartitionFlushSize: Long) extends Serializable { - lazy val fileIO: HadoopFileIO = new HadoopFileIO(serializableHadoopConf.value) + lazy val fileIO: HadoopFileIO = new HadoopFileIO( + serializableHadoopConf.value, + PerfIOHadoopInputFileFactory.INSTANCE) assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), s""" @@ -1053,6 +1058,6 @@ object BucketIdMetaUtils { // The bucket file name prefix is following Hive, Presto and Trino conversion, then // Hive bucketed tables written by Plugin can be read by other SQL engines. val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" - GpuWriterBucketSpec(bucketIdExpression, fileNamePrefix) + new GpuWriterBucketSpec(bucketIdExpression, fileNamePrefix) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index ecc5a1bb7f5..f3359696d63 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -22,9 +22,22 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.shims.ShuffleManagerShimUtils import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.internal.Logging -class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging { +class GpuShuffleEnv(rapidsConf: RapidsConf) { + private val log = org.slf4j.LoggerFactory.getLogger(classOf[GpuShuffleEnv]) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + private var shuffleCatalog: ShuffleBufferCatalog = _ private var shuffleReceivedBufferCatalog: ShuffleReceivedBufferCatalog = _ private var multithreadedCatalog: MultithreadedShuffleBufferCatalog = _ @@ -89,7 +102,7 @@ class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging { } } -object GpuShuffleEnv extends Logging { +object GpuShuffleEnv { def isUCXShuffleAndEarlyStart(conf: RapidsConf): Boolean = { conf.isUCXShuffleManagerMode && conf.shuffleTransportEarlyStart diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala index b5c761a01c9..c82329a52b2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala @@ -18,57 +18,18 @@ package org.apache.spark.sql.rapids import java.{lang => jl} import java.io.ObjectInputStream -import java.util.Locale import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{NvtxId, NvtxRegistry, PerfIO} +import com.nvidia.spark.rapids.{NvtxId, NvtxRegistry} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.jni.RmmSpark import org.apache.spark.{SparkContext, TaskContext} -import org.apache.spark.internal.Logging import org.apache.spark.util.{AccumulatorV2, LongAccumulator, Utils} -case class NanoTime(value: java.lang.Long) { - override def toString: String = { - val hours = TimeUnit.NANOSECONDS.toHours(value) - var remaining = value - TimeUnit.HOURS.toNanos(hours) - val minutes = TimeUnit.NANOSECONDS.toMinutes(remaining) - remaining = remaining - TimeUnit.MINUTES.toNanos(minutes) - val seconds = remaining.toDouble / TimeUnit.SECONDS.toNanos(1) - val locale = Locale.US - "%02d:%02d:%06.3f".formatLocal(locale, hours, minutes, seconds) - } -} - -// Format example: -// 10.74GB (11534336000 bytes) -// 1.23MB (1289750 bytes) -// 1020.10KB (1044585 bytes) -case class SizeInBytes(value: jl.Long) { - override def toString: String = { - var unitVal = value - var remainVal = 0L - var unitIndex = 0 - while (unitIndex < SizeInBytes.SizeUnitNames.length && unitVal >= 1024) { - val nextUnitVal = unitVal >> 10 - remainVal = unitVal - (nextUnitVal << 10) - unitVal = nextUnitVal - unitIndex += 1 - } - val finalVal = "%.2f".format(unitVal + (remainVal.toDouble / 1024)) - s"$finalVal${SizeInBytes.SizeUnitNames(unitIndex)} ($value bytes)" - } -} - -private object SizeInBytes { - private val SizeUnitNames: Array[String] = Array("B", "KB", "MB", "GB", "TB", "PB", "EB") -} - class NanoSecondAccumulator extends AccumulatorV2[jl.Long, NanoTime] { private var _sum = 0L override def isZero: Boolean = _sum == 0 @@ -100,7 +61,7 @@ class NanoSecondAccumulator extends AccumulatorV2[jl.Long, NanoTime] { s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def value: NanoTime = NanoTime(_sum) + override def value: NanoTime = new NanoTime(_sum) } /** @@ -133,7 +94,7 @@ class SizeInBytesAccumulator extends AccumulatorV2[jl.Long, SizeInBytes] { s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def value: SizeInBytes = SizeInBytes(_sum) + override def value: SizeInBytes = new SizeInBytes(_sum) private[spark] def setValue(newValue: Long): Unit = _sum = newValue } @@ -164,7 +125,7 @@ class HighWatermarkAccumulator extends AccumulatorV2[jl.Long, SizeInBytes] { s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def value: SizeInBytes = SizeInBytes(_value) + override def value: SizeInBytes = new SizeInBytes(_value) } class MaxLongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { @@ -242,7 +203,7 @@ class AvgLongAccumulator extends AccumulatorV2[jl.Long, jl.Double] { } else 0; } -class GpuTaskMetrics extends Serializable with Logging { +class GpuTaskMetrics extends Serializable { private val semaphoreHoldingTime = new NanoSecondAccumulator private val semWaitTimeNs = new NanoSecondAccumulator private val retryCount = new LongAccumulator @@ -467,7 +428,8 @@ class GpuTaskMetrics extends Serializable with Logging { // allocations lives in the JNI. Therefore, we can stick the convention here of calling the // add method instead of adding a dedicated max method to the accumulator. if (maxDeviceMemoryBytes.value.value > 0) { - logError(s"updateMaxMemory called twice for task $taskAttemptId with maxMem $maxMem") + GpuTaskMetrics.log.error(s"updateMaxMemory called twice for task $taskAttemptId " + + s"with maxMem $maxMem") } maxDeviceMemoryBytes.add(maxMem) } @@ -515,13 +477,13 @@ class GpuTaskMetrics extends Serializable with Logging { * to prevent double-counting — each new stage creates fresh accumulators with new IDs. */ def recordPerfioS3BackendOnce(): Unit = { - val acc = PerfIO.s3BackendName match { + val acc = GpuTaskMetrics.perfIOS3BackendName match { case "netty" => perfioS3NettyExecutors case "crt" => perfioS3CrtExecutors case _ => perfioS3S3aExecutors } try { - if (PerfIO.reportedBackendAccIds.add(acc.id)) { + if (GpuTaskMetrics.perfIOReportedBackendAccIds.add(acc.id)) { acc.add(1L) } } catch { @@ -537,7 +499,26 @@ class GpuTaskMetrics extends Serializable with Logging { /** * Provides task level metrics */ -object GpuTaskMetrics extends Logging { +object GpuTaskMetrics { + @transient private[this] lazy val perfIOModule = { + Class.forName("com.nvidia.spark.rapids.PerfIO" + "$") + .getField("MODULE" + "$") + .get(null) + } + + @transient private[this] lazy val perfIOS3BackendNameMethod = + perfIOModule.getClass.getMethod("s3BackendName") + @transient private[this] lazy val perfIOReportedBackendAccIdsMethod = + perfIOModule.getClass.getMethod("reportedBackendAccIds") + + private def perfIOS3BackendName: String = + perfIOS3BackendNameMethod.invoke(perfIOModule).asInstanceOf[String] + + private def perfIOReportedBackendAccIds: java.util.Set[java.lang.Long] = + perfIOReportedBackendAccIdsMethod.invoke(perfIOModule) + .asInstanceOf[java.util.Set[java.lang.Long]] + + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) private val taskLevelMetrics = new ConcurrentHashMap[Long, GpuTaskMetrics]() private val hostBytesAllocated = new AtomicLong(0) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index 69eff4ca387..96a0beba1a7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -37,7 +37,7 @@ import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle import org.apache.spark.{InterruptibleIterator, MapOutputTracker, ShuffleDependency, SparkConf, SparkEnv, TaskContext} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.SerializerManager @@ -71,7 +71,7 @@ class ShuffleHandleWithMetrics[K, V, C]( abstract class GpuShuffleBlockResolverBase( val wrapped: IndexShuffleBlockResolver, catalog: ShuffleBufferCatalog) - extends ShuffleBlockResolver with Logging { + extends ShuffleBlockResolver with RapidsLocalLog { override def getBlockData(blockId: BlockId, dirs: Option[Array[String]]): ManagedBuffer = { // Get MultithreadedShuffleBufferCatalog dynamically since it may not be // initialized when the resolver is created @@ -147,7 +147,7 @@ class ThreadSafeShuffleWriteMetricsReporter(val wrapped: ShuffleWriteMetricsRepo } } -object RapidsShuffleInternalManagerBase extends Logging { +object RapidsShuffleInternalManagerBase extends RapidsLocalLog { def unwrapHandle(handle: ShuffleHandle): ShuffleHandle = handle match { case gh: GpuShuffleHandle[_, _] => gh.wrapped case other => other @@ -322,11 +322,11 @@ abstract class RapidsShuffleThreadedWriterBase[K, V]( private var shuffleWriteRange: NvtxId = NvtxRegistry.THREADED_WRITER_WRITE.push() - // Case class for tracking partial sorted files in multi-batch scenario - private case class PartialFile( - handle: SpillablePartialFileHandle, - partitionLengths: Array[Long], - mapOutputWriter: ShuffleMapOutputWriter) + // Class for tracking partial sorted files in multi-batch scenario + private class PartialFile( + val handle: SpillablePartialFileHandle, + val partitionLengths: Array[Long], + val mapOutputWriter: ShuffleMapOutputWriter) /** * Represents a single compressed record ready to be written to disk. @@ -337,10 +337,10 @@ abstract class RapidsShuffleThreadedWriterBase[K, V]( * @param compressedSize The actual size of compressed data in buffer * @param remainingQuota The quota to release after writing to disk */ - private case class CompressedRecord( - buffer: OpenByteArrayOutputStream, - compressedSize: Long, - remainingQuota: Long) + private class CompressedRecord( + val buffer: OpenByteArrayOutputStream, + val compressedSize: Long, + val remainingQuota: Long) /** * Encapsulates all state for processing one GPU batch in the multi-batch shuffle write. @@ -370,19 +370,19 @@ abstract class RapidsShuffleThreadedWriterBase[K, V]( * @param mergerSlotNum The merger thread pool slot assigned to this batch. * @param mergerFuture Future representing the merger task, used to wait for completion. */ - private case class BatchState( - batchId: Int, - mapOutputWriter: ShuffleMapOutputWriter, - partitionRecords: ConcurrentHashMap[Int, + private class BatchState( + val batchId: Int, + val mapOutputWriter: ShuffleMapOutputWriter, + val partitionRecords: ConcurrentHashMap[Int, ConcurrentLinkedQueue[Future[CompressedRecord]]], - maxPartitionIdQueued: AtomicInteger, - mergerCondition: Object, + val maxPartitionIdQueued: AtomicInteger, + val mergerCondition: Object, // Flag for classic wait/notify pattern: set to true when new work is available, // reset to false after merger thread wakes up and checks actual data state. // This avoids busy-loop polling and provides clear signal for debugging. - hasNewWork: AtomicBoolean, - mergerSlotNum: Int, - mergerFuture: Future[_]) + val hasNewWork: AtomicBoolean, + val mergerSlotNum: Int, + val mergerFuture: Future[_]) /** * Increment the reference count and get the memory size for a value. @@ -568,7 +568,7 @@ abstract class RapidsShuffleThreadedWriterBase[K, V]( null }) - BatchState( + new BatchState( batchId, writer, partitionRecords, @@ -763,7 +763,7 @@ abstract class RapidsShuffleThreadedWriterBase[K, V]( // Return CompressedRecord with buffer and remaining quota for Merger // Total released = excessQuota + remainingQuota should equal recordSize val remainingQuota = recordSize - excessQuota - CompressedRecord(buffer, compressedSize, remainingQuota) + new CompressedRecord(buffer, compressedSize, remainingQuota) } } catch { case e: Exception => @@ -823,7 +823,7 @@ abstract class RapidsShuffleThreadedWriterBase[K, V]( // For multi-batch or when using catalog mode, extract handle val (handle, partLengths) = extractHandleAndLengthsFromWriter( batch.mapOutputWriter) - partialFiles += PartialFile(handle, partLengths, batch.mapOutputWriter) + partialFiles += new PartialFile(handle, partLengths, batch.mapOutputWriter) } else { // Single batch without catalog: commit normally commitAllPartitions(batch.mapOutputWriter, true) @@ -1121,7 +1121,7 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, canUseBatchFetch: Boolean = false, numReaderThreads: Int = 0) - extends ShuffleReader[K, C] with Logging { + extends ShuffleReader[K, C] with RapidsLocalLog { case class GetMapSizesResult( blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], @@ -1712,8 +1712,8 @@ class RapidsCachingWriter[K, V]( * Apache Spark to use the RAPIDS shuffle manager, */ class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: Boolean) - extends ShuffleManager with RapidsShuffleHeartbeatHandler with Logging - with RapidsShuffleReaderShim with ProxyShuffleReaderDelegate { + extends ShuffleManager with RapidsShuffleHeartbeatHandler + with RapidsLocalLog with RapidsShuffleReaderShim with ProxyShuffleReaderDelegate { def getServerId: BlockManagerId = server.fold(blockManager.blockManagerId)(_.getId) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShuffleCleanupListener.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShuffleCleanupListener.scala index 6fdae81c68a..f91cfbf4829 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShuffleCleanupListener.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShuffleCleanupListener.scala @@ -23,7 +23,6 @@ import scala.collection.mutable import com.nvidia.spark.rapids.ShuffleCleanupManager -import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd @@ -48,7 +47,22 @@ import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd * Note: This file is placed in org.apache.spark.sql.rapids package to access * the private[spark] shuffleDepId field in StageInfo. */ -class ShuffleCleanupListener extends SparkListener with Logging { +class ShuffleCleanupListener extends SparkListener { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[ShuffleCleanupListener]) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + /** * Maps SQL execution ID to the set of shuffle IDs associated with that execution. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/aggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/aggregateFunctions.scala index 2b9c5ea2b5e..dc71c038f95 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/aggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/aggregateFunctions.scala @@ -338,7 +338,7 @@ abstract class GpuMin(child: Expression) extends GpuAggregateFunction override def groupByScanAggregation( isRunningBatched: Boolean): Seq[AggAndReplace[GroupByScanAggregation]] = - Seq(AggAndReplace(GroupByScanAggregation.min(), Some(ReplacePolicy.PRECEDING))) + Seq(new AggAndReplace(GroupByScanAggregation.min(), Some(ReplacePolicy.PRECEDING))) override def isGroupByScanSupported: Boolean = child.dataType match { case StringType | TimestampType | DateType => false @@ -347,7 +347,7 @@ abstract class GpuMin(child: Expression) extends GpuAggregateFunction override def scanInputProjection(isRunningBatched: Boolean): Seq[Expression] = inputProjection override def scanAggregation(isRunningBatched: Boolean): Seq[AggAndReplace[ScanAggregation]] = - Seq(AggAndReplace(ScanAggregation.min(), Some(ReplacePolicy.PRECEDING))) + Seq(new AggAndReplace(ScanAggregation.min(), Some(ReplacePolicy.PRECEDING))) override def isScanSupported: Boolean = child.dataType match { case TimestampType | DateType => false @@ -522,7 +522,7 @@ abstract class GpuMax(child: Expression) extends GpuAggregateFunction override def groupByScanAggregation( isRunningBatched: Boolean): Seq[AggAndReplace[GroupByScanAggregation]] = - Seq(AggAndReplace(GroupByScanAggregation.max(), Some(ReplacePolicy.PRECEDING))) + Seq(new AggAndReplace(GroupByScanAggregation.max(), Some(ReplacePolicy.PRECEDING))) override def isGroupByScanSupported: Boolean = child.dataType match { case StringType | TimestampType | DateType => false @@ -531,7 +531,7 @@ abstract class GpuMax(child: Expression) extends GpuAggregateFunction override def scanInputProjection(isRunningBatched: Boolean): Seq[Expression] = inputProjection override def scanAggregation(isRunningBatched: Boolean): Seq[AggAndReplace[ScanAggregation]] = - Seq(AggAndReplace(ScanAggregation.max(), Some(ReplacePolicy.PRECEDING))) + Seq(new AggAndReplace(ScanAggregation.max(), Some(ReplacePolicy.PRECEDING))) override def isScanSupported: Boolean = child.dataType match { case TimestampType | DateType => false @@ -1044,13 +1044,13 @@ abstract class GpuSum( override def groupByScanAggregation( isRunningBatched: Boolean): Seq[AggAndReplace[GroupByScanAggregation]] = - Seq(AggAndReplace(GroupByScanAggregation.sum(), Some(ReplacePolicy.PRECEDING))) + Seq(new AggAndReplace(GroupByScanAggregation.sum(), Some(ReplacePolicy.PRECEDING))) override def scanInputProjection(isRunningBatched: Boolean): Seq[Expression] = windowInputProjection override def scanAggregation(isRunningBatched: Boolean): Seq[AggAndReplace[ScanAggregation]] = - Seq(AggAndReplace(ScanAggregation.sum(), Some(ReplacePolicy.PRECEDING))) + Seq(new AggAndReplace(ScanAggregation.sum(), Some(ReplacePolicy.PRECEDING))) override def scanCombine(isRunningBatched: Boolean, cols: Seq[ColumnVector]): ColumnVector = { if (internalSumForWindowDataType != resultType) { @@ -1498,13 +1498,13 @@ case class GpuCount(children: Seq[Expression], override def groupByScanAggregation( isRunningBatched: Boolean): Seq[AggAndReplace[GroupByScanAggregation]] = - Seq(AggAndReplace(GroupByScanAggregation.sum(), None)) + Seq(new AggAndReplace(GroupByScanAggregation.sum(), None)) override def scanInputProjection(isRunningBatched: Boolean): Seq[Expression] = groupByScanInputProjection(isRunningBatched) override def scanAggregation(isRunningBatched: Boolean): Seq[AggAndReplace[ScanAggregation]] = - Seq(AggAndReplace(ScanAggregation.sum(), None)) + Seq(new AggAndReplace(ScanAggregation.sum(), None)) override def scanCombine(isRunningBatched: Boolean, cols: Seq[ColumnVector]): ColumnVector = cols.head.castTo(DType.INT64) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuEquivalentExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuEquivalentExpressions.scala index 916bb2335da..314b2747176 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuEquivalentExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuEquivalentExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ class GpuEquivalentExpressions { private def addExprToMap( expr: Expression, map: mutable.HashMap[GpuExpressionEquals, GpuExpressionStats]): Boolean = { if (expr.deterministic) { - val wrapper = GpuExpressionEquals(expr) + val wrapper = new GpuExpressionEquals(expr) map.get(wrapper) match { case Some(stats) => stats.useCount += 1 @@ -242,7 +242,7 @@ class GpuEquivalentExpressions { * Exposed for testing. */ private[sql] def getExprState(e: Expression): Option[GpuExpressionStats] = { - equivalenceMap.get(GpuExpressionEquals(e)) + equivalenceMap.get(new GpuExpressionEquals(e)) } // Exposed for testing. @@ -281,7 +281,7 @@ object GpuEquivalentExpressions { expr match { case e: AttributeReference => e case _ => - substitutionMap.get(GpuExpressionEquals(expr)) match { + substitutionMap.get(new GpuExpressionEquals(expr)) match { case Some(attr) => attr case None => expr.mapChildren(replaceWithSemanticCommonRef(_, substitutionMap)) } @@ -510,7 +510,7 @@ trait GpuCombinable extends GpuExpression { /** * Wrapper around an Expression that provides semantic equality. */ -case class GpuExpressionEquals(e: Expression) { +class GpuExpressionEquals(val e: Expression) { override def equals(o: Any): Boolean = o match { case other: GpuExpressionEquals => e.semanticEquals(other.e) case _ => false diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuRandomExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuRandomExpressions.scala index 3f3ae2cf11e..b51fef4f83a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuRandomExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuRandomExpressions.scala @@ -100,7 +100,7 @@ case class GpuRand(child: Expression, doContextCheck: Boolean) extends ShimUnary private lazy val seed: Long = child match { case GpuLiteral(s, IntegerType) => s.asInstanceOf[Int] case GpuLiteral(s, LongType) => s.asInstanceOf[Long] - case _ => throw new RapidsAnalysisException( + case _ => throw RapidsAnalysisException( s"Input argument to $prettyName must be an integer, long or null literal.") } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 5b86ae9bf96..ac1a582de43 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -29,6 +29,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.{GpuListSliceUtils, MapUtils} import com.nvidia.spark.rapids.shims.{GetSequenceSize, NullIntolerantShim, ShimExpression} +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ElementAt, ExpectsInputTypes, Expression, ImplicitCastInputTypes, NamedExpression, RowOrdering, Sequence, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} @@ -39,6 +40,14 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH +object GpuMapDedupPolicy { + private val confEntry = SQLConf.MAP_KEY_DEDUP_POLICY.asInstanceOf[ConfigEntry[AnyRef]] + + def current: String = SQLConf.get.getConf(confEntry).toString + + def isException: Boolean = current.toUpperCase == "EXCEPTION" +} + case class GpuConcat(children: Seq[Expression]) extends GpuComplexTypeMergingExpression { @transient override lazy val dataType: DataType = { @@ -740,7 +749,7 @@ case class GpuMapEntries(child: Expression) extends GpuUnaryExpression with Expe case class GpuMapFromEntries(child: Expression) extends GpuUnaryExpression with ExpectsInputTypes { - private val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) + private val mapKeyDedupPolicy = GpuMapDedupPolicy.current override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -1534,7 +1543,7 @@ case class GpuArraysOverlap(left: Expression, right: Expression) case class GpuMapFromArrays(left: Expression, right: Expression) extends GpuBinaryExpression { - private val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) + private val mapKeyDedupPolicy = GpuMapDedupPolicy.current override def dataType: MapType = { MapType( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala index 4e21fa07ab1..e7f60b1eb8e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala @@ -151,9 +151,7 @@ object GpuCreateMap { SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) } - // Spark 4.1+ returns an enum value instead of String, so use toString first - def exceptionOnDupKeys: Boolean = - SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY).toString.toUpperCase == "EXCEPTION" + def exceptionOnDupKeys: Boolean = GpuMapDedupPolicy.isException def createMapFromKeysValuesAsStructs(dataType: MapType, listsOfKeyValueStructs : ColumnView): GpuColumnVector = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index c1c4d664fe9..676687d86e8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends ShimUnaryExpression with GpuExpression - with ShimGetStructField with NullIntolerantShim { lazy val childSchema: StructType = child.dataType.asInstanceOf[StructType] diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index ee0988d7f90..9915c49a706 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -37,7 +37,6 @@ import com.nvidia.spark.rapids.shims.{BroadcastExchangeShims, ShimBroadcastExcha import org.apache.spark.SparkException import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -75,7 +74,7 @@ class SerializeConcatHostBuffersDeserializeBatch( output: Seq[Attribute], var numRows: Int, var dataLen: Long) - extends Serializable with Logging { + extends Serializable { @transient private var dataTypes = output.map(_.dataType).toArray // used for memoization of deserialization to GPU on Executor @@ -323,7 +322,7 @@ class GpuBroadcastMeta( conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) extends - SparkPlanMeta[BroadcastExchangeExec](exchange, conf, parent, rule) with Logging { + SparkPlanMeta[BroadcastExchangeExec](exchange, conf, parent, rule) { override def tagPlanForGpu(): Unit = { if (!TrampolineUtil.isSupportedRelation(exchange.mode)) { @@ -643,7 +642,7 @@ case class GpuBroadcastExchangeExec( } /** Caches the mappings from canonical CPU exchanges to the GPU exchanges that replaced them */ -object ExchangeMappingCache extends Logging { +object ExchangeMappingCache { // Cache is a mapping from CPU broadcast plan to GPU broadcast plan. The cache should not // artificially hold onto unused plans, so we make both the keys and values weak. The values // point to their corresponding keys, so the keys will not be collected unless the value diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToRowExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToRowExec.scala index 2b2c11ada2a..2f6d51aba84 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToRowExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToRowExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,6 @@ import com.nvidia.spark.rapids.shims.{ShimBroadcastExchangeLike, ShimUnaryExecNo import org.apache.spark.SparkException import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, UnsafeProjection} @@ -46,7 +45,7 @@ case class GpuBroadcastToRowExec( buildKeys: Seq[Expression], broadcastMode: BroadcastMode, child: SparkPlan) - extends ShimBroadcastExchangeLike with ShimUnaryExecNode with GpuExec with Logging { + extends ShimBroadcastExchangeLike with ShimUnaryExecNode with GpuExec { @transient private val timeout: Long = conf.broadcastTimeout diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index b578197cf64..5d4087852c4 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -78,9 +78,9 @@ object JoinTypeChecks { joinRideAlongTypes, TypeSig.all, Map( - LEFT_KEYS -> new InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes, Nil), - RIGHT_KEYS -> new InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes, Nil), - CONDITION -> new InputCheck(TypeSig.BOOLEAN, TypeSig.BOOLEAN, Nil))) + LEFT_KEYS -> new InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes, List.empty), + RIGHT_KEYS -> new InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes, List.empty), + CONDITION -> new InputCheck(TypeSig.BOOLEAN, TypeSig.BOOLEAN, List.empty))) def equiJoinMeta(leftKeys: Seq[BaseExprMeta[_]], rightKeys: Seq[BaseExprMeta[_]], @@ -1063,44 +1063,6 @@ object JoinBuildSideSelection extends Enumeration { } } -/** - * Options to control join behavior. - * @param strategy the join strategy to use (AUTO, INNER_HASH_WITH_POST, INNER_SORT_WITH_POST, - * or HASH_ONLY) - * @param buildSideSelection the build side selection strategy (AUTO, FIXED, or SMALLEST) - * @param targetSize the target batch size in bytes for the join operation - * @param logCardinalityEnabled whether to log cardinality statistics for debugging - * @param sizeEstimateThreshold the threshold used to decide when to skip the expensive join - * output size estimation (defaults to 0.75) - */ -case class JoinOptions( - strategy: JoinStrategy.JoinStrategy, - buildSideSelection: JoinBuildSideSelection.JoinBuildSideSelection, - targetSize: Long, - logCardinalityEnabled: Boolean, - sizeEstimateThreshold: Double) - -/** - * Statistics for join cardinality logging to help diagnose performance issues. - * @param leftRowCount number of rows on the left side - * @param rightRowCount number of rows on the right side - * @param leftDistinctCount distinct count of left join keys - * @param rightDistinctCount distinct count of right join keys - * @param leftNullCounts null counts for each left key column - * @param rightNullCounts null counts for each right key column - * @param leftKeyTypes data types of the left join keys - * @param rightKeyTypes data types of the right join keys - */ -case class JoinCardinalityStats( - leftRowCount: Long, - rightRowCount: Long, - leftDistinctCount: Long, - rightDistinctCount: Long, - leftNullCounts: Seq[Long], - rightNullCounts: Seq[Long], - leftKeyTypes: Seq[DataType], - rightKeyTypes: Seq[DataType]) - /** * Class to hold statistics on the build-side batch of a hash join. * @param streamMagnificationFactor estimated magnification of a stream batch during join @@ -1188,6 +1150,12 @@ abstract class BaseHashJoinIterator( * Compute cardinality statistics for both sides of the join. * This is used for diagnostic logging when logJoinCardinality is enabled. */ + protected def joinStrategy: JoinStrategy.JoinStrategy = + joinOptions.strategy.asInstanceOf[JoinStrategy.JoinStrategy] + + protected def buildSideSelection: JoinBuildSideSelection.JoinBuildSideSelection = + joinOptions.buildSideSelection.asInstanceOf[JoinBuildSideSelection.JoinBuildSideSelection] + protected def computeCardinalityStats( leftKeys: Table, rightKeys: Table): JoinCardinalityStats = { @@ -1207,7 +1175,7 @@ abstract class BaseHashJoinIterator( val leftKeyTypes = boundBuiltKeys.map(_.dataType) val rightKeyTypes = boundStreamKeys.map(_.dataType) - JoinCardinalityStats( + new JoinCardinalityStats( leftRowCount, rightRowCount, leftDistinctCount, @@ -1550,10 +1518,10 @@ class HashJoinIterator( rightData: LazySpillableColumnarBatch): GatherMapsResult = { // Apply heuristics to select the effective strategy val effectiveStrategy = JoinStrategy.selectStrategy( - joinOptions.strategy, + joinStrategy, joinType, hasCondition = false, // This is called for unconditional joins - joinOptions.buildSideSelection, + buildSideSelection, leftKeys.getRowCount, rightKeys.getRowCount) @@ -1594,7 +1562,7 @@ class HashJoinIterator( logJoinCardinality(leftKeys, rightKeys, implName) val innerMaps = JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, buildSide) + buildSideSelection, buildSide) val leftRowCount = leftKeys.getRowCount val rightRowCount = rightKeys.getRowCount @@ -1610,7 +1578,7 @@ class HashJoinIterator( logJoinCardinality(leftKeys, rightKeys, "INNER_SORT_WITH_POST") val innerMaps = JoinImpl.innerSortJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, buildSide) + buildSideSelection, buildSide) val leftRowCount = leftKeys.getRowCount val rightRowCount = rightKeys.getRowCount @@ -1632,7 +1600,7 @@ class HashJoinIterator( JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual) case _: InnerLike => JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, buildSide) + buildSideSelection, buildSide) case LeftSemi => JoinImpl.leftSemiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) case LeftAnti => @@ -1696,10 +1664,10 @@ class ConditionalHashJoinIterator( withResource(GpuColumnVector.from(rightData.getBatch)) { rightTable => // Apply heuristics to select the effective strategy for conditional joins val effectiveStrategy = JoinStrategy.selectStrategy( - joinOptions.strategy, + joinStrategy, joinType, hasCondition = true, // This is a conditional join - joinOptions.buildSideSelection, + buildSideSelection, leftKeys.getRowCount, rightKeys.getRowCount) @@ -1750,7 +1718,7 @@ class ConditionalHashJoinIterator( val rightRowCount = rightKeys.getRowCount val innerMaps = JoinImpl.innerHashJoin(leftKeys, rightKeys, - nullEquality == NullEquality.EQUAL, joinOptions.buildSideSelection, buildSide) + nullEquality == NullEquality.EQUAL, buildSideSelection, buildSide) val compiledCondition = lazyCompiledCondition.getForBuildSide(buildSide) @@ -1777,7 +1745,7 @@ class ConditionalHashJoinIterator( val rightRowCount = rightKeys.getRowCount val innerMaps = JoinImpl.innerSortJoin(leftKeys, rightKeys, - nullEquality == NullEquality.EQUAL, joinOptions.buildSideSelection, buildSide) + nullEquality == NullEquality.EQUAL, buildSideSelection, buildSide) val compiledCondition = lazyCompiledCondition.getForBuildSide(buildSide) @@ -1804,7 +1772,7 @@ class ConditionalHashJoinIterator( case _: InnerLike => // For inner joins, use dynamic build side selection val selectedBuildSide = JoinBuildSideSelection.selectPhysicalBuildSide( - joinOptions.buildSideSelection, buildSide, + buildSideSelection, buildSide, leftKeys.getRowCount, rightKeys.getRowCount) selectedBuildSide match { case GpuBuildLeft => @@ -1933,10 +1901,10 @@ class HashJoinStreamSideIterator( // Apply heuristics to select the effective strategy for unconditional joins // Note: subJoinType is used for strategy selection since that's what we're actually executing val effectiveStrategy = JoinStrategy.selectStrategy( - joinOptions.strategy, + joinStrategy, subJoinType, hasCondition = false, - joinOptions.buildSideSelection, + buildSideSelection, leftKeys.getRowCount, rightKeys.getRowCount) @@ -1977,7 +1945,7 @@ class HashJoinStreamSideIterator( logJoinCardinality(leftKeys, rightKeys, implName, originalJoinType) val innerMaps = JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + buildSideSelection, cudfBuildSide) val leftRowCount = leftKeys.getRowCount val rightRowCount = rightKeys.getRowCount @@ -1996,7 +1964,7 @@ class HashJoinStreamSideIterator( originalJoinType) val innerMaps = JoinImpl.innerSortJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + buildSideSelection, cudfBuildSide) val leftRowCount = leftKeys.getRowCount val rightRowCount = rightKeys.getRowCount @@ -2020,7 +1988,7 @@ class HashJoinStreamSideIterator( JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual) case Inner => JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + buildSideSelection, cudfBuildSide) case t => throw new IllegalStateException(s"unsupported join type: $t") } @@ -2041,10 +2009,10 @@ class HashJoinStreamSideIterator( withResource(GpuColumnVector.from(rightData.getBatch)) { rightTable => // Apply heuristics to select the effective strategy for conditional joins val effectiveStrategy = JoinStrategy.selectStrategy( - joinOptions.strategy, + joinStrategy, subJoinType, hasCondition = true, - joinOptions.buildSideSelection, + buildSideSelection, leftKeys.getRowCount, rightKeys.getRowCount) @@ -2093,7 +2061,7 @@ class HashJoinStreamSideIterator( logJoinCardinality(leftKeys, rightKeys, implName, originalJoinType) val innerMaps = JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + buildSideSelection, cudfBuildSide) val compiledCondition = lazyCondition.getForBuildSide(cudfBuildSide) @@ -2122,7 +2090,7 @@ class HashJoinStreamSideIterator( originalJoinType) val innerMaps = JoinImpl.innerSortJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + buildSideSelection, cudfBuildSide) val compiledCondition = lazyCondition.getForBuildSide(cudfBuildSide) @@ -2161,7 +2129,7 @@ class HashJoinStreamSideIterator( // For inner sub-joins, use dynamic build side selection // For sub-joins, the plan build side is cudfBuildSide (GpuBuildRight for Inner) val selectedBuildSide = JoinBuildSideSelection.selectPhysicalBuildSide( - joinOptions.buildSideSelection, cudfBuildSide, + buildSideSelection, cudfBuildSide, leftKeys.getRowCount, rightKeys.getRowCount) selectedBuildSide match { case GpuBuildLeft => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala index 12faa151a4b..4eebeb813db 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,13 @@ package org.apache.spark.sql.rapids.execution import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.nvidia.spark.rapids.{GpuBatchUtils, GpuColumnVector, GpuExpression, GpuHashPartitioner, GpuMetric, NvtxRegistry, RmmRapidsRetryIterator, SpillableColumnarBatch, SpillPriorities, TaskAutoCloseableResource} +import com.nvidia.spark.rapids.{GpuBatchUtils, GpuColumnVector, GpuExpression, GpuHashPartitioner, + GpuMetric, NvtxRegistry, RapidsLocalLog, RmmRapidsRetryIterator, SpillableColumnarBatch, + SpillPriorities, TaskAutoCloseableResource} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.plans.InnerLike import org.apache.spark.sql.rapids.{GpuHashExpression, GpuMurmur3Hash} import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim @@ -89,7 +90,7 @@ class GpuBatchSubPartitioner( numPartitions: Int, hashSeed: Int, name: String = "GpuBatchSubPartitioner") extends GpuHashPartitioner - with AutoCloseable with Logging { + with AutoCloseable with RapidsLocalLog { private var isNotInited = true private var numCurBatches = 0 @@ -228,7 +229,7 @@ class GpuBatchSubPartitioner( class GpuBatchSubPartitionIterator( batchSubPartitioner: GpuBatchSubPartitioner, targetBatchSize: Long) - extends Iterator[(Seq[Int], Seq[SpillableColumnarBatch])] with Logging { + extends Iterator[(Seq[Int], Seq[SpillableColumnarBatch])] with RapidsLocalLog { // The partitions to be read. Initially it is all the partitions. private val remainingPartIds: ArrayBuffer[Int] = @@ -558,7 +559,7 @@ abstract class BaseSubHashJoinIterator( protected def setupJoinIterator(pair: PartitionPair): Option[Iterator[ColumnarBatch]] } -trait GpuSubPartitionHashJoin extends Logging { self: GpuHashJoin => +trait GpuSubPartitionHashJoin { self: GpuHashJoin => protected lazy val buildSchema: StructType = DataTypeUtilsShim.fromAttributes(buildPlan.output) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRddConverter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRddConverter.scala index 02eaafb5ed4..3707023e834 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRddConverter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRddConverter.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuColumnVector.GpuColumnarBatchBuilder import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging import org.apache.spark.rdd.{MapPartitionsRDD, RDD} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow @@ -102,22 +101,22 @@ private object GpuExternalRowToColumnConverter { // NOT SUPPORTED YET // case CalendarIntervalType => CalendarConverter case (at: ArrayType, true) => - ArrayConverter(getConverterForType(at.elementType, at.containsNull)) + new ArrayConverter(getConverterForType(at.elementType, at.containsNull)) case (at: ArrayType, false) => - NotNullArrayConverter(getConverterForType(at.elementType, at.containsNull)) + new NotNullArrayConverter(getConverterForType(at.elementType, at.containsNull)) case (st: StructType, true) => - StructConverter(st.fields.map(getConverterFor)) + new StructConverter(st.fields.map(getConverterFor)) case (st: StructType, false) => - NotNullStructConverter(st.fields.map(getConverterFor)) + new NotNullStructConverter(st.fields.map(getConverterFor)) case (dt: DecimalType, true) => new DecimalConverter(dt.precision, dt.scale) case (dt: DecimalType, false) => new NotNullDecimalConverter(dt.precision, dt.scale) case (MapType(k, v, vcn), true) => - MapConverter(getConverterForType(k, nullable = false), + new MapConverter(getConverterForType(k, nullable = false), getConverterForType(v, vcn)) case (MapType(k, v, vcn), false) => - NotNullMapConverter(getConverterForType(k, nullable = false), + new NotNullMapConverter(getConverterForType(k, nullable = false), getConverterForType(v, vcn)) case (NullType, true) => NullConverter @@ -394,7 +393,7 @@ private object GpuExternalRowToColumnConverter { ret + OFFSET } - private case class MapConverter( + private class MapConverter( keyConverter: TypeConverter, valueConverter: TypeConverter) extends TypeConverter { override def append(row: Row, @@ -410,7 +409,7 @@ private object GpuExternalRowToColumnConverter { override def getNullSize: Double = VALIDITY_N_OFFSET } - private case class NotNullMapConverter( + private class NotNullMapConverter( keyConverter: TypeConverter, valueConverter: TypeConverter) extends TypeConverter { override def append(row: Row, @@ -453,7 +452,7 @@ private object GpuExternalRowToColumnConverter { ret + OFFSET } - private case class ArrayConverter(childConverter: TypeConverter) + private class ArrayConverter(childConverter: TypeConverter) extends TypeConverter { override def append(row: Row, column: Int, builder: RapidsHostColumnBuilder): Double = { @@ -468,7 +467,7 @@ private object GpuExternalRowToColumnConverter { override def getNullSize: Double = VALIDITY_N_OFFSET } - private case class NotNullArrayConverter(childConverter: TypeConverter) + private class NotNullArrayConverter(childConverter: TypeConverter) extends TypeConverter { override def append(row: Row, column: Int, builder: RapidsHostColumnBuilder): Double = { @@ -492,7 +491,7 @@ private object GpuExternalRowToColumnConverter { ret } - private case class StructConverter( + private class StructConverter( childConverters: Array[TypeConverter]) extends TypeConverter { override def append(row: Row, column: Int, @@ -509,7 +508,7 @@ private object GpuExternalRowToColumnConverter { override def getNullSize: Double = childConverters.map(_.getNullSize).sum + VALIDITY } - private case class NotNullStructConverter( + private class NotNullStructConverter( childConverters: Array[TypeConverter]) extends TypeConverter { override def append(row: Row, column: Int, @@ -645,7 +644,11 @@ private class ExternalRowToColumnarIterator( * of GPU memory. By convention it is the responsibility of the one consuming the data to close it * when they no longer need it. */ -object InternalColumnarRddConverter extends Logging { +object InternalColumnarRddConverter { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logDebug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg) + def apply(df: DataFrame): RDD[Table] = { convert(df) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala index 34c99f40dd9..c1256631c86 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,8 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsRe import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch -case class ShuffledBatchRDDPartition(index: Int, spec: ShufflePartitionSpec) extends Partition +class ShuffledBatchRDDPartition(override val index: Int, val spec: ShufflePartitionSpec) + extends Partition /** * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for @@ -135,7 +136,7 @@ class ShuffledBatchRDD( override def getPartitions: Array[Partition] = { Array.tabulate[Partition](partitionSpecs.length) { i => - ShuffledBatchRDDPartition(i, partitionSpecs(i)) + new ShuffledBatchRDDPartition(i, partitionSpecs(i)) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala index 3969aa9024a..71943859784 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala @@ -20,7 +20,6 @@ import java.util.concurrent.{ExecutorService, ScheduledExecutorService, ThreadPo import org.apache.avro.Schema import org.apache.hadoop.conf.Configuration -import org.json4s.JsonAST import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkMasterRegex, SparkUpgradeException, TaskContext} import org.apache.spark.broadcast.Broadcast @@ -41,7 +40,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.GpuTaskMetrics import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim import org.apache.spark.sql.rapids.shims.SparkUpgradeExceptionShims -import org.apache.spark.sql.rapids.shims.TrampolineConnectShims import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{ShutdownHookManager, ThreadUtils, Utils} @@ -60,9 +58,23 @@ object TrampolineUtil { def toAttributes(structType: StructType): Seq[Attribute] = DataTypeUtilsShim.toAttributes(structType) - def jsonValue(dataType: DataType): JsonAST.JValue = dataType.jsonValue + private[this] lazy val dataTypeJsonValue = classOf[DataType].getMethod("jsonValue") - def createSchemaParser(): Schema.Parser = TrampolineConnectShims.createSchemaParser() + def jsonValue(dataType: DataType): AnyRef = dataTypeJsonValue.invoke(dataType) + + private[this] lazy val trampolineConnectShims = Class + .forName("org.apache.spark.sql.rapids.shims.TrampolineConnectShims$") + .getField("MODULE$") + .get(null) + + private[this] lazy val cleanupAnyExistingSessionMethod = + trampolineConnectShims.getClass.getMethod("cleanupAnyExistingSession") + + private[this] lazy val createSchemaParserMethod = + trampolineConnectShims.getClass.getMethod("createSchemaParser") + + def createSchemaParser(): Schema.Parser = + createSchemaParserMethod.invoke(trampolineConnectShims).asInstanceOf[Schema.Parser] /** Get a human-readable string, e.g.: "4.0 MiB", for a value in bytes. */ def bytesToString(size: Long): String = Utils.bytesToString(size) @@ -106,7 +118,8 @@ object TrampolineUtil { } /** Shuts down and cleans up any existing Spark session */ - def cleanupAnyExistingSession(): Unit = TrampolineConnectShims.cleanupAnyExistingSession() + def cleanupAnyExistingSession(): Unit = + cleanupAnyExistingSessionMethod.invoke(trampolineConnectShims) def asNullable(dt: DataType): DataType = dt.asNullable @@ -266,9 +279,72 @@ object TrampolineUtil { } /** - * This class is to only be used to throw errors specific to the - * RAPIDS Accelerator or errors mirroring Spark where a raw - * AnalysisException is thrown directly rather than via an error - * utility class (this should be rare). + * Factory for raw-message AnalysisExceptions where Spark has no error utility. */ -class RapidsAnalysisException(msg: String) extends AnalysisException(msg) +object RapidsAnalysisException { + private type CtorAndArgs = (java.lang.reflect.Constructor[AnyRef], Array[AnyRef]) + + private val none = None + private val emptyMessageParameters = Map.empty[String, String] + private val emptyStringArray = Array.empty[String] + + def apply(msg: String): AnalysisException = { + val maybeCtorAndArgs = rawMessageCtor7(msg) + .orElse(rawMessageCtor8(msg)) + .orElse(rawMessageCtorWithStringParameters(msg)) + + val (ctor, args) = maybeCtorAndArgs.getOrElse { + throw new IllegalStateException("Unsupported AnalysisException constructor shape") + } + ctor.newInstance(args: _*).asInstanceOf[AnalysisException] + } + + private def rawMessageCtor7(msg: String): Option[CtorAndArgs] = { + classOf[AnalysisException].getConstructors.find { ctor => + val params = ctor.getParameterTypes + params.length == 7 && + params(0) == classOf[String] && + params(5).getName == "scala.collection.immutable.Map" && + isQueryContextArray(params(6)) + }.map { ctor => + val typedCtor = ctor.asInstanceOf[java.lang.reflect.Constructor[AnyRef]] + typedCtor -> Array[AnyRef](msg, none, none, none, none, + emptyMessageParameters, emptyQueryContextArray) + } + } + + private def rawMessageCtor8(msg: String): Option[CtorAndArgs] = { + classOf[AnalysisException].getConstructors.find { ctor => + val params = ctor.getParameterTypes + params.length == 8 && + params(0) == classOf[String] && + params(6).getName == "scala.collection.immutable.Map" && + isQueryContextArray(params(7)) + }.map { ctor => + val typedCtor = ctor.asInstanceOf[java.lang.reflect.Constructor[AnyRef]] + typedCtor -> Array[AnyRef](msg, none, none, none, none, none, + emptyMessageParameters, emptyQueryContextArray) + } + } + + private def rawMessageCtorWithStringParameters(msg: String): Option[CtorAndArgs] = { + classOf[AnalysisException].getConstructors.find { ctor => + val params = ctor.getParameterTypes + params.length == 7 && + params(0) == classOf[String] && + params(6).isArray && + params(6).getComponentType == classOf[String] + }.map { ctor => + val typedCtor = ctor.asInstanceOf[java.lang.reflect.Constructor[AnyRef]] + typedCtor -> Array[AnyRef](msg, none, none, none, none, none, emptyStringArray) + } + } + + private def isQueryContextArray(cls: Class[_]): Boolean = { + cls.isArray && cls.getComponentType.getName == "org.apache.spark.QueryContext" + } + + private def emptyQueryContextArray: AnyRef = { + java.lang.reflect.Array.newInstance(Class.forName("org.apache.spark.QueryContext"), 0) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/BatchGroupUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/BatchGroupUtils.scala index 5baa2974f53..97ab778ce50 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/BatchGroupUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/BatchGroupUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2025, NVIDIA CORPORATION. + * Copyright (c) 2021-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,10 +41,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * and data columns by the Python workers. * @param groupingOffsets the grouping offsets(aka column indices) in the deduplicated attributes. */ -case class GroupArgs( - dedupAttrs: Seq[Attribute], - argOffsets: Array[Int], - groupingOffsets: Seq[Int]) +class GroupArgs( + val dedupAttrs: Seq[Attribute], + val argOffsets: Array[Int], + val groupingOffsets: Seq[Int]) /** * Basic functionality to deal with groups in a batch. @@ -138,7 +138,7 @@ private[python] object BatchGroupUtils { val argOffsets = Array(argOffsetLen, groupingAttrs.length) ++ groupingArgOffsets ++ dataAttrs.indices - GroupArgs(dedupAttrs, argOffsets, groupingArgOffsets) + new GroupArgs(dedupAttrs, argOffsets, groupingArgOffsets) } /** diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala index c791ce3292a..3cb30cb1ce1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -116,6 +116,9 @@ case class GpuAggregateInPandasExec( // Filter child output attributes down to only those that are UDF inputs. // Also eliminate duplicate UDF inputs. val udfArgs = PythonArgumentUtils.flatten(inputs) + val udfFlattenedArgs = udfArgs.flattenedArgs + val udfArgOffsets = udfArgs.argOffsets + val udfArgNames = udfArgs.argNames // Schema of input rows to the python runner val aggInputSchema = StructType(udfArgs.flattenedTypes.zipWithIndex.map { case (dt, i) => @@ -136,7 +139,7 @@ case class GpuAggregateInPandasExec( // Doing this can reduce the data size to be split, probably getting a better performance. val groupingRefs = GpuBindReferences.bindGpuReferences(gpuGroupingExpressions, childOutput, allMetrics) - val pyInputRefs = GpuBindReferences.bindGpuReferences(udfArgs.flattenedArgs, + val pyInputRefs = GpuBindReferences.bindGpuReferences(udfFlattenedArgs, childOutput, allMetrics) val miniIter = inputIter.map { batch => mNumInputBatches += 1 @@ -148,7 +151,7 @@ case class GpuAggregateInPandasExec( // Second splits into separate group batches. val miniAttrs = - (gpuGroupingExpressions ++ udfArgs.flattenedArgs).asInstanceOf[Seq[Attribute]] + (gpuGroupingExpressions ++ udfFlattenedArgs).asInstanceOf[Seq[Attribute]] val keyConverter = (groupedBatch: ColumnarBatch) => { // No `safeMap` because here does not increase the ref count. // (`Seq.indices.map()` is NOT lazy, so it is safe to be used to slice the columns.) @@ -191,9 +194,9 @@ case class GpuAggregateInPandasExec( } } - val runnerFactory = GpuGroupedPythonRunnerFactory(conf, pyFuncs, udfArgs.argOffsets, + val runnerFactory = new GpuGroupedPythonRunnerFactory(conf, pyFuncs, udfArgOffsets, aggInputSchema, DataTypeUtilsShim.fromAttributes(pyOutAttributes), - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, udfArgs.argNames) + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, udfArgNames) // Third, sends to Python to execute the aggregate and returns the result. if (pyInputIter.hasNext) { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala index e4515b40662..968df9dc747 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -138,10 +138,14 @@ case class GpuFlatMapCoGroupsInPandasExec( StructField("out_struct", DataTypeUtilsShim.fromAttributes(output)) :: Nil) // Resolve the argument offsets and related attributes. - val GroupArgs(leftDedupAttrs, leftArgOffsets, leftGroupingOffsets) = - resolveArgOffsets(left, leftGroup) - val GroupArgs(rightDedupAttrs, rightArgOffsets, rightGroupingOffsets) = - resolveArgOffsets(right, rightGroup) + val leftGroupArgs = resolveArgOffsets(left, leftGroup) + val leftDedupAttrs = leftGroupArgs.dedupAttrs + val leftArgOffsets = leftGroupArgs.argOffsets + val leftGroupingOffsets = leftGroupArgs.groupingOffsets + val rightGroupArgs = resolveArgOffsets(right, rightGroup) + val rightDedupAttrs = rightGroupArgs.dedupAttrs + val rightArgOffsets = rightGroupArgs.argOffsets + val rightGroupingOffsets = rightGroupArgs.groupingOffsets left.executeColumnar().zipPartitions(right.executeColumnar()) { (leftIter, rightIter) => if (isPythonOnGpuEnabled) { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala index 27dd4ae65ba..2d69892d988 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala @@ -119,12 +119,14 @@ case class GpuFlatMapGroupsInPandasExec( StructField("out_struct", DataTypeUtilsShim.fromAttributes(localOutput)) :: Nil) // Resolve the argument offsets and related attributes. - val GroupArgs(dedupAttrs, argOffsets, groupingOffsets) = - resolveArgOffsets(child, groupingAttributes) + val groupArgs = resolveArgOffsets(child, groupingAttributes) + val dedupAttrs = groupArgs.dedupAttrs + val argOffsets = groupArgs.argOffsets + val groupingOffsets = groupArgs.groupingOffsets - val runnerFactory = GpuGroupedPythonRunnerFactory(conf, chainedFunc, Array(argOffsets), + val runnerFactory = new GpuGroupedPythonRunnerFactory(conf, chainedFunc, Array(argOffsets), DataTypeUtilsShim.fromAttributes(dedupAttrs), pythonOutputSchema, - udf.evalType) + udf.evalType, None) // Start processing. Map grouped batches to ArrowPythonRunner results. child.executeColumnar().mapPartitionsInternal { inputIter => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonHelper.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonHelper.scala index 56da226495a..7f6c780d33a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonHelper.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonHelper.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,12 +22,26 @@ import com.nvidia.spark.rapids.python.PythonConfEntries._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.api.python.ChainedPythonFunctions -import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{CPUS_PER_TASK, EXECUTOR_CORES} import org.apache.spark.internal.config.Python._ import org.apache.spark.sql.internal.SQLConf -object GpuPythonHelper extends Logging { +object GpuPythonHelper { + + private val log = org.slf4j.LoggerFactory.getLogger(GpuPythonHelper.getClass) + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + private val sparkConf = SparkEnv.get.conf private lazy val rapidsConf = new RapidsConf(sparkConf) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonUDF.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonUDF.scala index 04367d9f29f..d0cbc26e26d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonUDF.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonUDF.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2021-2026, NVIDIA CORPORATION. * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala index fb956738e18..f6f961f540b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import ai.rapids.cudf._ import ai.rapids.cudf.ast.BinaryOperator import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.shims.NullIntolerantShim +import com.nvidia.spark.rapids.shims.{NullIntolerantShim, ShimPredicate} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +30,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuNot(child: Expression) extends CudfUnaryExpression - with Predicate with ImplicitCastInputTypes with NullIntolerantShim { + with ShimPredicate with ImplicitCastInputTypes with NullIntolerantShim { override def toString: String = s"NOT $child" override def inputTypes: Seq[DataType] = Seq(BooleanType) @@ -51,7 +51,7 @@ case class GpuNot(child: Expression) extends CudfUnaryExpression } } -abstract class CudfBinaryPredicateWithSideEffect extends CudfBinaryOperator with Predicate { +abstract class CudfBinaryPredicateWithSideEffect extends CudfBinaryOperator with ShimPredicate { override def inputType: AbstractDataType = BooleanType @@ -152,7 +152,7 @@ case class GpuOr(left: Expression, right: Expression) extends CudfBinaryPredicat GpuExpressionWithSideEffectUtils.boolInverted(col) } -abstract class CudfBinaryComparison extends CudfBinaryOperator with Predicate { +abstract class CudfBinaryComparison extends CudfBinaryOperator with ShimPredicate { // Note that we need to give a superset of allowable input types since orderable types are not // finitely enumerable. The allowable types are checked below by checkInputDataTypes. override def inputType: AbstractDataType = AnyDataType diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/shims/RapidsQueryErrorUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/shims/RapidsQueryErrorUtils.scala index 5ac79327380..cc5121655ca 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/shims/RapidsQueryErrorUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/shims/RapidsQueryErrorUtils.scala @@ -96,6 +96,6 @@ trait RapidsQueryErrorUtils { } def dynamicPartitionParentError: Throwable = { - throw new RapidsAnalysisException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) + throw RapidsAnalysisException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 2b146a07cee..09d2f5db33b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -34,7 +34,7 @@ import com.nvidia.spark.rapids.jni.CharsetDecode import com.nvidia.spark.rapids.jni.GpuSubstringIndexUtils import com.nvidia.spark.rapids.jni.NumberConverter import com.nvidia.spark.rapids.jni.RegexRewriteUtils -import com.nvidia.spark.rapids.shims.{NullIntolerantShim, ShimExpression, SparkShimImpl} +import com.nvidia.spark.rapids.shims.{NullIntolerantShim, ShimExpression, ShimPredicate, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.ConvUtils @@ -163,7 +163,7 @@ case class GpuStringLocate(substr: Expression, col: Expression, start: Expressio case class GpuStartsWith(left: Expression, right: Expression) extends GpuBinaryExpressionArgsAnyScalar - with Predicate + with ShimPredicate with ImplicitCastInputTypes with NullIntolerantShim { @@ -189,7 +189,7 @@ case class GpuStartsWith(left: Expression, right: Expression) case class GpuEndsWith(left: Expression, right: Expression) extends GpuBinaryExpressionArgsAnyScalar - with Predicate + with ShimPredicate with ImplicitCastInputTypes with NullIntolerantShim { @@ -396,7 +396,7 @@ case class GpuConcatWs(children: Seq[Expression]) case class GpuContains(left: Expression, right: Expression) extends GpuBinaryExpression - with Predicate + with ShimPredicate with ImplicitCastInputTypes with NullIntolerantShim with GpuCombinable { @@ -496,7 +496,7 @@ class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombin override def addExpression(e: Expression): Unit = { val localOutputLocation = outputLocation outputLocation += 1 - val key = GpuExpressionEquals(e) + val key = new GpuExpressionEquals(e) if (!toCombine.contains(key)) { toCombine.put(key, localOutputLocation) } @@ -530,7 +530,7 @@ class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombin } override def getReplacementExpression(e: Expression): Option[Expression] = { - toCombine.get(GpuExpressionEquals(e)).map { localId => + toCombine.get(new GpuExpressionEquals(e)).map { localId => GpuGetStructField(multiContains, localId, Some(fieldName(localId))) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/test/cpuJsonExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/test/cpuJsonExpressions.scala index 0dd048967a8..60e738e4c6f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/test/cpuJsonExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/test/cpuJsonExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,8 @@ import org.apache.spark.sql.catalyst.expressions.{GetJsonObject, Literal} import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String -case class CsvWriterWrapper(filePath: String, conf: Configuration) extends AutoCloseable { +class CsvWriterWrapper(val filePath: String, val conf: Configuration) extends AutoCloseable + with Serializable { // This is implemented as a method to make it easier to subclass // ColumnarOutputWriter in the tests, and override this behavior. @@ -262,7 +263,7 @@ object CpuGetJsonObject { val date = DateTimeFormatter.ofPattern("yyyyMMdd").format(LocalDate.now()) val uuid = UUID.randomUUID() val savePath = s"$savePathForVerify/${date}_${tcId}_${uuid}.csv" - withResource(CsvWriterWrapper(savePath, conf)) { csvWriter => + withResource(new CsvWriterWrapper(savePath, conf)) { csvWriter => val pathStr = if (path == null) "null" else path.toString var currRow = 0 var diffRowsNum = 0 diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala new file mode 100644 index 00000000000..013aac1d0cd --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "334"} +{"spark": "340"} +{"spark": "341"} +{"spark": "341db"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "350db143"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "358"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.expressions.Predicate + +trait ShimPredicate extends Predicate { + def contextIndependentFoldable: Boolean = children.forall(_.foldable) +} + diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentUtils.scala index ce90605c035..3eca24135de 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentUtils.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentUtils.scala @@ -64,6 +64,6 @@ object PythonArgumentUtils { } }.toArray }.toArray - GpuPythonArguments(allInputs.toSeq, dataTypes.toSeq, argOffsets, None) + new GpuPythonArguments(allInputs.toSeq, dataTypes.toSeq, argOffsets, None) } } diff --git a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala index 51d0067bf1c..13dcd2c69fb 100644 --- a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala +++ b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala @@ -533,7 +533,7 @@ trait GpuFileFormatWriterBase extends Serializable with Logging { private def verifySchema(format: ColumnarFileFormat, schema: StructType): Unit = { schema.foreach { field => if (!format.supportDataType(field.dataType)) { - throw new RapidsAnalysisException( + throw RapidsAnalysisException( s"$format data source does not support ${field.dataType.catalogString} data type.") } } diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentsUtils.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentsUtils.scala index f67afdd0015..d735fbf14f7 100644 --- a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentsUtils.scala +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentsUtils.scala @@ -48,15 +48,15 @@ object PythonArgumentUtils { (None, e) } if (allInputs.exists(_.semanticEquals(value))) { - GpuArgumentMeta(allInputs.indexWhere(_.semanticEquals(value)), key) + new GpuArgumentMeta(allInputs.indexWhere(_.semanticEquals(value)), key) } else { allInputs += value dataTypes += value.dataType - GpuArgumentMeta(allInputs.length - 1, key) + new GpuArgumentMeta(allInputs.length - 1, key) } }.toArray }.toArray - GpuPythonArguments(allInputs.toSeq, dataTypes.toSeq, + new GpuPythonArguments(allInputs.toSeq, dataTypes.toSeq, argMetas.map(_.map(_.offset)), Some(argMetas.map(_.map(_.name)))) } } diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala new file mode 100644 index 00000000000..085c88655b4 --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +{"spark": "400db173"} +{"spark": "401"} +{"spark": "402"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.expressions.Predicate + +trait ShimPredicate extends Predicate { + def contextIndependentFoldable: Boolean = children.forall(_.foldable) +} + diff --git a/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala b/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala new file mode 100644 index 00000000000..6fa0dffab88 --- /dev/null +++ b/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/ShimPredicate.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "411"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.expressions.Predicate + +trait ShimPredicate extends Predicate { + override def contextIndependentFoldable: Boolean = + children.forall(_.contextIndependentFoldable) +} + diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala index 4b39c53697f..a2dceda1420 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala @@ -156,7 +156,7 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { dataSpec = allCols } if (numBuckets != 0) { - bucketSpec = Some(GpuWriterBucketSpec( + bucketSpec = Some(new GpuWriterBucketSpec( GpuPmod(GpuMurmur3Hash(Seq(allCols.last), 42), GpuLiteral(Math.abs(numBuckets))), _ => "")) }