diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CopyCompressionCodec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CopyCompressionCodec.scala index e2e86495b16..e2d0409c68b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CopyCompressionCodec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CopyCompressionCodec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ class BatchedCopyCompressor(maxBatchMemory: Long, stream: Cuda.Stream) ct, CodecType.COPY, outBuffer.getLength) - CompressedTable(outBuffer.getLength, meta, outBuffer) + new CompressedTable(outBuffer.getLength, meta, outBuffer) } } closeOnExcept(result) { _ => stream.sync() } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala index 35fd66d93ac..4fa7122f9e4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import java.io.{InputStream, IOException} import java.lang.{Boolean => JBoolean} -import java.nio.ByteBuffer +import java.nio.{Buffer, ByteBuffer} import java.nio.channels.WritableByteChannel import java.util.HashSet import java.util.concurrent.ConcurrentHashMap @@ -28,7 +28,6 @@ import scala.collection.mutable.ArrayBuffer import _root_.io.netty.handler.stream.ChunkedStream import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle -import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.AbstractFileRegion import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId} @@ -40,10 +39,10 @@ import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId} * @param offset starting offset within the handle * @param length number of bytes in this segment */ -case class PartitionSegment( - handle: SpillablePartialFileHandle, - offset: Long, - length: Long) +class PartitionSegment( + val handle: SpillablePartialFileHandle, + val offset: Long, + val length: Long) /** * Catalog for managing shuffle data in MULTITHREADED mode without merging. @@ -57,7 +56,19 @@ case class PartitionSegment( * (MEMORY_WITH_SPILL mode) or stored directly on disk (ONLY_FILE mode) depending * on memory pressure - both modes work with this skip-merge design. */ -class MultithreadedShuffleBufferCatalog extends Logging { +class MultithreadedShuffleBufferCatalog { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logError(msg: => String, throwable: Throwable): Unit = { + log.error(msg, throwable) + } + /** * Map from ShuffleBlockId to list of segments. @@ -99,7 +110,7 @@ class MultithreadedShuffleBufferCatalog extends Logging { } val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) - val segment = PartitionSegment(handle, offset, length) + val segment = new PartitionSegment(handle, offset, length) partitionSegments.compute(blockId, (_, existing) => { val segments = if (existing == null) new ArrayBuffer[PartitionSegment]() else existing @@ -280,7 +291,7 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu } } - buffer.flip() + buffer.asInstanceOf[Buffer].flip() buffer } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvcompLZ4CompressionCodec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvcompLZ4CompressionCodec.scala index 1b7a2e2f285..136375996c2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvcompLZ4CompressionCodec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvcompLZ4CompressionCodec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ class BatchedNvcompLZ4Compressor(maxBatchMemorySize: Long, table, CodecType.NVCOMP_LZ4, compressedSize) - CompressedTable(compressedSize, meta, buffer) + new CompressedTable(compressedSize, meta, buffer) }.toArray } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvcompZSTDCompressionCodec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvcompZSTDCompressionCodec.scala index 647a318d076..88f0be60e0f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvcompZSTDCompressionCodec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvcompZSTDCompressionCodec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024-2025, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ContiguousTable, Cuda, DeviceMemoryBuffer} import ai.rapids.cudf.nvcomp.{BatchedZstdCompressor, BatchedZstdDecompressor} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray +import com.nvidia.spark.rapids.Arm.closeOnExcept import com.nvidia.spark.rapids.format.{BufferMeta, CodecType} /** A table compression codec that uses nvcomp's ZSTD-GPU codec */ @@ -59,7 +58,7 @@ class BatchedNvcompZSTDCompressor(maxBatchMemorySize: Long, compressedBufs.zip(tables).map { case (buffer, table) => val compressedLen = buffer.getLength val meta = MetaUtils.buildTableMeta(None, table, CodecType.NVCOMP_ZSTD, compressedLen) - CompressedTable(compressedLen, meta, buffer) + new CompressedTable(compressedLen, meta, buffer) }.toArray } } @@ -90,23 +89,3 @@ class BatchedNvcompZSTDDecompressor(maxBatchMemory: Long, outputBufs } } - -object DeviceBuffersUtils { - def incRefCount(bufs: Array[BaseDeviceMemoryBuffer]): Array[BaseDeviceMemoryBuffer] = { - bufs.safeMap { b => - b.incRefCount() - b - } - } - - def allocateBuffers(bufSizes: Array[Long]): Array[DeviceMemoryBuffer] = { - var curPos = 0L - withResource(DeviceMemoryBuffer.allocate(bufSizes.sum)) { singleBuf => - bufSizes.safeMap { len => - val ret = singleBuf.slice(curPos, len) - curPos += len - ret - } - } - } -} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsShuffleHeartbeatManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsShuffleHeartbeatManager.scala index 598d9bd5447..d9098766085 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsShuffleHeartbeatManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsShuffleHeartbeatManager.scala @@ -26,12 +26,19 @@ import org.apache.commons.lang3.mutable.MutableLong import org.apache.spark.SparkEnv import org.apache.spark.api.plugin.PluginContext -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} import org.apache.spark.storage.BlockManagerId class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long, - heartbeatTimeoutMillis: Long) extends Logging { + heartbeatTimeoutMillis: Long) { + private val log = org.slf4j.LoggerFactory.getLogger(classOf[RapidsShuffleHeartbeatManager]) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + require(heartbeatIntervalMillis > 0, s"The interval value: $heartbeatIntervalMillis ms is not > 0") @@ -45,14 +52,18 @@ class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long, // exposed so that it can be mocked in the tests def getCurrentTimeMillis: Long = System.currentTimeMillis() - private case class ExecutorRegistration( - id: BlockManagerId, + private class ExecutorRegistration( + val id: BlockManagerId, // this is this executor's registration order, as given by this manager - registrationOrder: Long, + val registrationOrder: Long, // this is the last registration order this executor is aware of overall - lastRegistrationOrderSeen: MutableLong, + val lastRegistrationOrderSeen: MutableLong, // last heartbeat received from this executor in millis - lastHeartbeatMillis: MutableLong) + val lastHeartbeatMillis: MutableLong) { + override def toString: String = + s"ExecutorRegistration($id,$registrationOrder,$lastRegistrationOrderSeen," + + s"$lastHeartbeatMillis)" + } // a counter used to mark each new executor registration with an order var registrationOrder = 0L @@ -82,7 +93,7 @@ class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long, require(!executorRegistrations.containsKey(id), s"Executor $id already registered") removeDeadExecutors(getCurrentTimeMillis) val allExecutors = executors.map(e => e.id).toArray - val newReg = ExecutorRegistration(id, + val newReg = new ExecutorRegistration(id, registrationOrder, new MutableLong(registrationOrder), new MutableLong(getCurrentTimeMillis)) @@ -167,7 +178,40 @@ class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long, } class RapidsShuffleHeartbeatEndpoint(pluginContext: PluginContext, conf: RapidsConf) - extends Logging with AutoCloseable { + extends AutoCloseable { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[RapidsShuffleHeartbeatEndpoint]) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logTrace(msg: => String): Unit = { + if (log.isTraceEnabled) { + log.trace(msg) + } + } + + private def logError(msg: => String, throwable: Throwable): Unit = { + if (log.isErrorEnabled) { + log.error(msg, throwable) + } + } + // Number of milliseconds between heartbeats to driver private[this] val heartbeatIntervalMillis = conf.shuffleTransportEarlyStartHeartbeatInterval diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index 2dee040dad7..d7dab6f58c5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,22 +28,21 @@ import com.nvidia.spark.rapids.format.TableMeta import com.nvidia.spark.rapids.spill.{SpillableDeviceBufferHandle, SpillableHandle} import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.ShuffleBlockId -/** Identifier for a shuffle buffer that holds the data for a table */ -case class ShuffleBufferId( - blockId: ShuffleBlockId, - tableId: Int) { - val shuffleId: Int = blockId.shuffleId - val mapId: Long = blockId.mapId -} - /** Catalog for lookup of shuffle buffers by block ID */ -class ShuffleBufferCatalog extends Logging { +class ShuffleBufferCatalog { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + /** * Information stored for each active shuffle. * A shuffle block can be comprised of multiple batches. Each batch @@ -259,7 +258,7 @@ class ShuffleBufferCatalog extends Logging { } val tableId = tableIdCounter.getAndUpdate(ShuffleBufferCatalog.TABLE_ID_UPDATER) - val id = ShuffleBufferId(blockId, tableId) + val id = new ShuffleBufferId(blockId, tableId) val prev = tableMap.put(tableId, id) if (prev != null) { throw new IllegalStateException(s"table ID $tableId is already in use") @@ -283,7 +282,7 @@ class ShuffleBufferCatalog extends Logging { val (maybeHandle, meta) = bufferIdToHandle.get(shuffleBufferId) maybeHandle match { case Some(spillable) => - RapidsShuffleHandle(spillable, meta) + new RapidsShuffleHandle(spillable, meta) case None => throw new IllegalStateException( "a buffer handle could not be obtained for a degenerate buffer") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleCleanupEndpoint.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleCleanupEndpoint.scala index a7b6e065354..cb65d3764fa 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleCleanupEndpoint.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleCleanupEndpoint.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import com.nvidia.spark.rapids.jni.RmmSpark import org.apache.spark.api.plugin.PluginContext -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.GpuShuffleEnv /** @@ -38,7 +37,40 @@ import org.apache.spark.sql.rapids.GpuShuffleEnv */ class ShuffleCleanupEndpoint( pluginContext: PluginContext, - pollIntervalMs: Long = 1000) extends Logging with AutoCloseable { + pollIntervalMs: Long = 1000) extends AutoCloseable { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[ShuffleCleanupEndpoint]) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + if (log.isWarnEnabled) { + log.warn(msg, throwable) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logTrace(msg: => String): Unit = { + if (log.isTraceEnabled) { + log.trace(msg) + } + } + private val executorId: String = pluginContext.executorID() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleCleanupManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleCleanupManager.scala index ef90caee02d..839851a4cf5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleCleanupManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleCleanupManager.scala @@ -22,7 +22,6 @@ import java.util.concurrent.{ConcurrentHashMap, Executors, ScheduledExecutorServ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.execution.TrampolineUtil /** @@ -89,7 +88,33 @@ class ShuffleCleanupManager( sc: SparkContext, staleEntryMaxAgeMs: Long = 300000, // 5 minutes cleanupIntervalMs: Long = 60000 // 1 minute -) extends Logging { +) { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[ShuffleCleanupManager]) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + if (log.isWarnEnabled) { + log.warn(msg, throwable) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } /** * Shuffles pending cleanup. Maps shuffleId -> timestamp when unregister was called. @@ -192,8 +217,13 @@ class ShuffleCleanupManager( try { TrampolineUtil.postEvent(sc, - SparkRapidsShuffleDiskSavingsEvent(shuffleId, stat.bytesFromMemory, stat.bytesFromDisk, - stat.numExpansions, stat.numSpills, stat.numForcedFileOnly)) + new SparkRapidsShuffleDiskSavingsEvent( + shuffleId, + stat.bytesFromMemory, + stat.bytesFromDisk, + stat.numExpansions, + stat.numSpills, + stat.numForcedFileOnly)) } catch { case e: Exception => logWarning(s"Failed to post shuffle disk savings event for shuffle $shuffleId", e) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala index 450622ef3ba..513e779098f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,19 +22,19 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableColumn import com.nvidia.spark.rapids.format.TableMeta import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle -import org.apache.spark.internal.Logging import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch -case class RapidsShuffleHandle( - spillable: SpillableDeviceBufferHandle, tableMeta: TableMeta) extends AutoCloseable { +class RapidsShuffleHandle( + val spillable: SpillableDeviceBufferHandle, + val tableMeta: TableMeta) extends AutoCloseable with Serializable { override def close(): Unit = { spillable.safeClose() } } /** Catalog for lookup of shuffle buffers by block ID */ -class ShuffleReceivedBufferCatalog() extends Logging { +class ShuffleReceivedBufferCatalog() { /** * Adds a buffer to the device storage, taking ownership of the buffer. @@ -52,7 +52,7 @@ class ShuffleReceivedBufferCatalog() extends Logging { buffer: DeviceMemoryBuffer, tableMeta: TableMeta, initialSpillPriority: Long): RapidsShuffleHandle = { - RapidsShuffleHandle(SpillableDeviceBufferHandle(buffer), tableMeta) + new RapidsShuffleHandle(SpillableDeviceBufferHandle(buffer), tableMeta) } /** @@ -62,7 +62,7 @@ class ShuffleReceivedBufferCatalog() extends Logging { * @return RapidsShuffleHandle associated with this buffer */ def addDegenerateBatch(meta: TableMeta): RapidsShuffleHandle = { - RapidsShuffleHandle(null, meta) + new RapidsShuffleHandle(null, meta) } def getColumnarBatchAndRemove(handle: RapidsShuffleHandle, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TableCompressionCodec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TableCompressionCodec.scala index 1c45b31b986..09b04d0e8f3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TableCompressionCodec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TableCompressionCodec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,22 +21,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ContiguousTable, Cuda, DeviceMemoryBuffer} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.format.{BufferMeta, CodecType, TableMeta} - -import org.apache.spark.internal.Logging - -/** - * Compressed table descriptor - * @param compressedSize size of the compressed data in bytes - * @param meta metadata describing the table layout when uncompressed - * @param buffer buffer containing the compressed data - */ -case class CompressedTable( - compressedSize: Long, - meta: TableMeta, - buffer: DeviceMemoryBuffer) extends AutoCloseable { - override def close(): Unit = buffer.close() -} +import com.nvidia.spark.rapids.format.{BufferMeta, CodecType} /** An interface to a compression codec that can compress a contiguous Table on the GPU */ trait TableCompressionCodec { @@ -71,12 +56,15 @@ trait TableCompressionCodec { stream: Cuda.Stream): BatchedBufferDecompressor } -/** - * A small case class used to carry codec-specific settings. - */ -case class TableCompressionCodecConfig(lz4ChunkSize: Long, zstdChunkSize: Long) +object TableCompressionCodec { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } -object TableCompressionCodec extends Logging { private val codecNameToId = Map( "copy" -> CodecType.COPY, "zstd" -> CodecType.NVCOMP_ZSTD, @@ -84,7 +72,7 @@ object TableCompressionCodec extends Logging { /** Make a codec configuration object which can be serialized (can be used in tasks) */ def makeCodecConfig(rapidsConf: RapidsConf): TableCompressionCodecConfig = - TableCompressionCodecConfig( + new TableCompressionCodecConfig( rapidsConf.shuffleCompressionLz4ChunkSize, rapidsConf.shuffleCompressionZstdChunkSize) @@ -117,7 +105,15 @@ object TableCompressionCodec extends Logging { * @param stream CUDA stream to use */ abstract class BatchedTableCompressor(maxBatchMemorySize: Long, stream: Cuda.Stream) - extends AutoCloseable with Logging { + extends AutoCloseable { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + // The tables that need to be compressed in the next batch private[this] val tables = new ArrayBuffer[ContiguousTable] @@ -237,7 +233,7 @@ abstract class BatchedTableCompressor(maxBatchMemorySize: Long, stream: Cuda.Str ct.buffer.incRefCount() ct.buffer } - CompressedTable(ct.compressedSize, ct.meta, newBuffer) + new CompressedTable(ct.compressedSize, ct.meta, newBuffer) } } } @@ -262,7 +258,15 @@ abstract class BatchedTableCompressor(maxBatchMemorySize: Long, stream: Cuda.Str * @param stream CUDA stream to use */ abstract class BatchedBufferDecompressor(maxBatchMemorySize: Long, stream: Cuda.Stream) - extends AutoCloseable with Logging { + extends AutoCloseable { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + // The buffers of compressed data that will be decompressed in the next batch private[this] val inputBuffers = new ArrayBuffer[BaseDeviceMemoryBuffer] diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncRunners.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncRunners.scala index 29f8a4debcf..377ee3997de 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncRunners.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncRunners.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ sealed trait AsyncRunResource /** * HostResource represents host memory resource requirement for CPU-bound tasks. */ -case class HostResource(hostMemoryBytes: Long) extends AsyncRunResource +class HostResource(val hostMemoryBytes: Long) extends AsyncRunResource with Serializable /** * DeviceResource is a marker object for GPU resources, no additional fields needed. @@ -44,7 +44,7 @@ object DeviceResource extends AsyncRunResource object AsyncRunResource { def newCpuResource(hostMemoryBytes: Long): AsyncRunResource = { - HostResource(hostMemoryBytes) + new HostResource(hostMemoryBytes) } def newGpuResource(): AsyncRunResource = DeviceResource @@ -76,8 +76,6 @@ trait AsyncResult[T] extends AutoCloseable { } } -case class AsyncMetrics(scheduleTimeMs: Long, executionTimeMs: Long) - class AsyncMetricsBuilder { private var scheduleTimeMs: Long = 0L private var executionTimeMs: Long = 0L @@ -93,7 +91,7 @@ class AsyncMetricsBuilder { } def build(): AsyncMetrics = { - AsyncMetrics(scheduleTimeMs, executionTimeMs) + new AsyncMetrics(scheduleTimeMs, executionTimeMs) } } @@ -136,19 +134,19 @@ class DecayReleaseResult[T](override val data: T, */ sealed trait AsyncRunnerState -case class Init(firstTime: Boolean) extends AsyncRunnerState +class Init(val firstTime: Boolean) extends AsyncRunnerState with Serializable case object Pending extends AsyncRunnerState -case class ScheduleFailed(exception: Throwable) extends AsyncRunnerState +class ScheduleFailed(val exception: Throwable) extends AsyncRunnerState with Serializable case object Running extends AsyncRunnerState case object Completed extends AsyncRunnerState -case class ExecFailed(exception: Throwable) extends AsyncRunnerState +class ExecFailed(val exception: Throwable) extends AsyncRunnerState with Serializable -case class Closed(exception: Option[Throwable]) extends AsyncRunnerState +class Closed(val exception: Option[Throwable]) extends AsyncRunnerState with Serializable case object Cancelled extends AsyncRunnerState @@ -271,7 +269,7 @@ trait AsyncRunner[T] extends Callable[AsyncResult[T]] { } } - @volatile private var state: AsyncRunnerState = Init(firstTime = true) + @volatile private var state: AsyncRunnerState = new Init(firstTime = true) def isHoldingStateLock: Boolean = stateLock.isHeldByCurrentThread diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ResourceBoundedThreadExecutor.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ResourceBoundedThreadExecutor.scala index eee1eaceb4c..48b0994b9b7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ResourceBoundedThreadExecutor.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ResourceBoundedThreadExecutor.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ import java.util.concurrent.{BlockingQueue, Callable, Future, FutureTask, Priori import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.util.TaskCompletionListener @@ -37,7 +36,11 @@ import org.apache.spark.util.TaskCompletionListener * @tparam T the result type returned by the AsyncRunner */ class RapidsFutureTask[T](val runner: AsyncRunner[T]) - extends FutureTask[AsyncResult[T]](runner) with Logging { + extends FutureTask[AsyncResult[T]](runner) { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[RapidsFutureTask[_]]) + + private def logWarning(msg: => String): Unit = if (log.isWarnEnabled) log.warn(msg) override def run(): Unit = runner.withStateLock { rr => rr.getState match { @@ -61,14 +64,14 @@ class RapidsFutureTask[T](val runner: AsyncRunner[T]) } else { // Failed due to unexpected exceptions val ex = new IllegalStateException("runner failed unexpectedly") - rr.setState(ExecFailed(ex)) + rr.setState(new ExecFailed(ex)) } // Throw the ScheduleFailed exception within the scope of `FutureTask.run`, so that // the exception can be properly recorded and propagated to the caller of `get()`. - case ScheduleFailed(ex: Throwable) => + case failed: ScheduleFailed => // Trick: register a pre-hook to let `AsyncRunner.call` throw the exception - rr.addPreHook(() => throw ex) + rr.addPreHook(() => throw failed.exception) super.run() // Handle the cancelled case as a special kind of ScheduleFailed @@ -86,7 +89,7 @@ class RapidsFutureTask[T](val runner: AsyncRunner[T]) } override def setException(e: Throwable): Unit = { - runner.setState(ExecFailed(e)) + runner.setState(new ExecFailed(e)) super.setException(e) } @@ -144,7 +147,21 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, workQueue: BlockingQueue[Runnable], threadFactory: ThreadFactory, keepAliveTime: Long = 100L) extends ThreadPoolExecutor(corePoolSize, - maximumPoolSize, keepAliveTime, TimeUnit.SECONDS, workQueue, threadFactory) with Logging { + maximumPoolSize, keepAliveTime, TimeUnit.SECONDS, workQueue, threadFactory) { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[ResourceBoundedThreadExecutor]) + + private def logInfo(msg: => String): Unit = if (log.isInfoEnabled) log.info(msg) + + private def logWarning(msg: => String): Unit = if (log.isWarnEnabled) log.warn(msg) + + private def logDebug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg) + + private def logError(msg: => String): Unit = if (log.isErrorEnabled) log.error(msg) + + private def logError(msg: => String, throwable: Throwable): Unit = { + if (log.isErrorEnabled) log.error(msg, throwable) + } logInfo(s"Creating ResourceBoundedThreadExecutor with resourcePool: ${mgr.toString}, " + s"corePoolSize: $corePoolSize, maximumPoolSize: $maximumPoolSize, " + @@ -209,7 +226,7 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, rr.getState match { // Cancelled case: Cancelled -> ScheduleFailed case Cancelled => - rr.setState(ScheduleFailed(new IllegalStateException("cancelled"))) + rr.setState(new ScheduleFailed(new IllegalStateException("cancelled"))) logWarning(s"Runner being cancelled ahead of execution: $rr") // The main path: Init -> Pending -> Running | Pending | ScheduleFailed @@ -229,9 +246,9 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, rr.setState(Running) futTask.scheduleTime += s.elapsedTime // Fail the scheduling: Pending -> ScheduleFailed - case AcquireExcepted(ex) => - rr.setState(ScheduleFailed(ex)) - logError(s"$ex [$rr]") + case excepted: AcquireExcepted => + rr.setState(new ScheduleFailed(excepted.exception)) + logError(s"${excepted.exception} [$rr]") // Bypass the execution: Pending -> Pending case AcquireFailed => } @@ -241,7 +258,7 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, // If we throw an exception here, it will crash the ThreadWorker without signaling // the caller. So we just mark the state as ScheduleFailed to pass the exception to // the caller via FutureTask.get(). - rr.setState(ScheduleFailed(new IllegalStateException("Unexpected state"))) + rr.setState(new ScheduleFailed(new IllegalStateException("Unexpected state"))) logError(s"Unexpected state before schedule: $rr") } } @@ -255,7 +272,7 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, // and recorded the exception internally. if (t != null) { if (!rr.getState.isInstanceOf[ExecFailed]) { - rr.setState(ExecFailed(t)) + rr.setState(new ExecFailed(t)) } // Also try to fail the Spark task which launched this runner. rr.sparkTaskContext.foreach { ctx => @@ -268,7 +285,7 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, // Post execution state handling rr.getState match { case Cancelled | // very rare case: cancelled between execution and afterExecute - ExecFailed(_) => // failed execution (ScheduleFailed should be cast to ExecFailed) + _: ExecFailed => // failed execution (ScheduleFailed should be cast to ExecFailed) // release holding resource immediately on exception if (rr.isHoldingResource) { rr.releaseResourceCallback() @@ -290,7 +307,7 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, require(!rr.isHoldingResource, s"Pending state should NOT hold Resource: $rr") // Requeue runners which failed to acquire resource and bypassed the execution. futTask.scheduleTime += timeoutMs * 1000000L - rr.setState(Init(firstTime = false)) // reset to Init state for re-scheduling + rr.setState(new Init(firstTime = false)) // reset to Init state for re-scheduling // Re-add the task to the work queue for re-execution if (!workQueue.add(futTask)) { // Fatal error @@ -334,11 +351,11 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, // Finalize the runner state state match { case Completed => // Completed -> Closed - rr.setState(Closed(None)) - case ExecFailed(ex) => // ExecFailed -> Closed - rr.setState(Closed(Some(ex))) + rr.setState(new Closed(None)) + case failed: ExecFailed => // ExecFailed -> Closed + rr.setState(new Closed(Some(failed.exception))) case Cancelled => // Cancelled -> Closed - rr.setState(Closed(Some(new IllegalStateException("cancelled")))) + rr.setState(new Closed(Some(new IllegalStateException("cancelled")))) case _ => throw new IllegalStateException(s"Should NOT reach here: $rr") } @@ -364,12 +381,12 @@ class ResourceBoundedThreadExecutor(mgr: ResourcePool, // 2. Mark the runner as Cancelled fut.runner.withStateLock { rr => rr.getState match { - case Init(_) => rr.setState(Cancelled) // Init -> Cancelled + case _: Init => rr.setState(Cancelled) // Init -> Cancelled case Pending => rr.setState(Cancelled) // Pending -> Cancelled case Running => rr.setState(Cancelled) // Running -> Cancelled - case ScheduleFailed(_) => rr.setState(Cancelled) // ScheduleFailed -> Cancelled + case _: ScheduleFailed => rr.setState(Cancelled) // ScheduleFailed -> Cancelled case Completed => rr.setState(Cancelled) // Completed -> Cancelled - case Cancelled | ExecFailed(_) | Closed(_) => // do nothing + case Cancelled | _: ExecFailed | _: Closed => // do nothing } } // 3. If the runner is still holding resource, we release it diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ResourcePools.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ResourcePools.scala index 5023d53f551..efa195183a6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ResourcePools.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ResourcePools.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ import java.util.concurrent.locks.ReentrantLock import scala.collection.mutable -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.execution.TrampolineUtil.bytesToString // Being thrown when a task requests resources that are not valid or exceed the limits @@ -32,13 +31,13 @@ class InvalidResourceRequest(msg: String) extends RuntimeException( // Represents the status of acquiring resources for a task sealed trait AcquireStatus -case class AcquireSuccessful(elapsedTime: Long) extends AcquireStatus +class AcquireSuccessful(val elapsedTime: Long) extends AcquireStatus with Serializable // AcquireFailed indicates that the task could not be scheduled due to resource constraints case object AcquireFailed extends AcquireStatus // AcquireExcepted indicates that an exception occurred while trying to acquire resources -case class AcquireExcepted(exception: Throwable) extends AcquireStatus +class AcquireExcepted(val exception: Throwable) extends AcquireStatus with Serializable /** * ResourceManager interface to be implemented for AsyncRunners requiring different kinds of @@ -68,7 +67,13 @@ trait ResourcePool { * The implementation uses condition variables to efficiently block and wake up waiting * tasks when resources become available through task completion and resource release. */ -class HostMemoryPool(val maxHostMemoryBytes: Long) extends ResourcePool with Logging { +class HostMemoryPool(val maxHostMemoryBytes: Long) extends ResourcePool { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[HostMemoryPool]) + + private def logWarning(msg: => String): Unit = if (log.isWarnEnabled) log.warn(msg) + + private def logDebug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg) private val lock = new ReentrantLock() @@ -94,7 +99,7 @@ class HostMemoryPool(val maxHostMemoryBytes: Long) extends ResourcePool with Log // step 2: try to acquire the resource with blocking and timeout // 2.1 If no resource needed, acquire immediately if (memoryRequire == 0L) { - AcquireSuccessful(elapsedTime = 0L) + new AcquireSuccessful(elapsedTime = 0L) } // 2.2 The main path for acquiring resource with blocking and timeout else { @@ -160,10 +165,10 @@ class HostMemoryPool(val maxHostMemoryBytes: Long) extends ResourcePool with Log s"Over-committed HostMemoryPool: exceeded_amount=${bytesToString(-remaining)}, " + s"AsyncRunners=$numRunnerInPool, SparkTasks=${tasksInPool.size}") } - AcquireSuccessful(elapsedTime = timeoutNs - waitTimeNs) + new AcquireSuccessful(elapsedTime = timeoutNs - waitTimeNs) } } catch { - case ex: Throwable => AcquireExcepted(ex) + case ex: Throwable => new AcquireExcepted(ex) } finally { lock.unlock() } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala index 2c799f1cd39..54460afd9bb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024-2025, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,15 +21,6 @@ import java.util.concurrent.{Callable, ExecutorService, Future, TimeUnit} import org.apache.spark.sql.rapids.{ColumnarWriteTaskStatsTracker, GpuWriteTaskStatsTracker} -/** - * Stats related classes used by ThrottlingExecutor - */ -case class ThrottlingExecutorStats ( - var numTasksScheduled: Int, - var accumulatedThrottleTimeNs: Long, - var minThrottleTimeNs: Long, - var maxThrottleTimeNs: Long) - /** * Only for GpuWriteTaskStatsTracker cases */ @@ -53,7 +44,7 @@ class StatsUpdaterForWriteFunc(val statsTrackers: Seq[ColumnarWriteTaskStatsTrac class ThrottlingExecutor(executor: ExecutorService, throttler: TrafficController, updateStats : ThrottlingExecutorStats => Unit) { - val stats: ThrottlingExecutorStats = ThrottlingExecutorStats(0, 0L, Long.MaxValue, 0L) + val stats: ThrottlingExecutorStats = new ThrottlingExecutorStats(0, 0L, Long.MaxValue, 0L) private def blockUntilTaskRunnable(task: Task[_]): Unit = { val blockStart = System.nanoTime() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala index 0475aa6b6db..98c2f206de1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import java.util import ai.rapids.cudf.MemoryBuffer -import org.apache.spark.internal.Logging /** * Class to hold a bounce buffer reference in `buffer`. @@ -53,9 +52,9 @@ abstract class BounceBuffer(val buffer: MemoryBuffer) extends AutoCloseable { * @param deviceBounceBuffer - device buffer to use for sends * @param hostBounceBuffer - optional host buffer to use for sends */ -case class SendBounceBuffers( - deviceBounceBuffer: BounceBuffer, - hostBounceBuffer: Option[BounceBuffer]) extends AutoCloseable { +class SendBounceBuffers( + val deviceBounceBuffer: BounceBuffer, + val hostBounceBuffer: Option[BounceBuffer]) extends AutoCloseable with Serializable { def bounceBufferSize: Long = { deviceBounceBuffer.buffer.getLength @@ -82,8 +81,22 @@ class BounceBufferManager[T <: MemoryBuffer]( val bufferSize: Long, val numBuffers: Int, allocator: Long => T) - extends AutoCloseable - with Logging { + extends AutoCloseable { + + private val log = org.slf4j.LoggerFactory.getLogger( + "com.nvidia.spark.rapids.shuffle.BounceBufferManager") + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logTrace(msg: => String): Unit = { + if (log.isTraceEnabled) { + log.trace(msg) + } + } class BounceBufferImpl(buff: MemoryBuffer) extends BounceBuffer(buff) { override def free(bb: BounceBuffer): Unit = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferReceiveState.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferReceiveState.scala index 60d3b13de37..d05da04a589 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferReceiveState.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferReceiveState.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,12 +26,11 @@ import com.nvidia.spark.rapids.NvtxRegistry import com.nvidia.spark.rapids.format.TableMeta import com.nvidia.spark.rapids.jni.RmmSpark -import org.apache.spark.internal.Logging -case class ConsumedBatchFromBounceBuffer( - contigBuffer: DeviceMemoryBuffer, - meta: TableMeta, - handler: RapidsShuffleFetchHandler) +class ConsumedBatchFromBounceBuffer( + val contigBuffer: DeviceMemoryBuffer, + val meta: TableMeta, + val handler: RapidsShuffleFetchHandler) extends Serializable /** * A helper case class to maintain the state associated with a transfer request to a peer. @@ -59,7 +58,21 @@ class BufferReceiveState( requests: Seq[PendingTransferRequest], transportOnClose: () => Unit, stream: Cuda.Stream = Cuda.DEFAULT_STREAM) - extends AutoCloseable with Logging { + extends AutoCloseable { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[BufferReceiveState]) + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } val transportBuffer = new CudfTransportBuffer(bounceBuffer.buffer) // we use this to keep a list (should be depth 1) of "requests for receives" @@ -223,7 +236,7 @@ class BufferReceiveState( } if (contigBuffer != null) { - Some(ConsumedBatchFromBounceBuffer( + Some(new ConsumedBatchFromBounceBuffer( contigBuffer, pendingTransferRequest.tableMeta, pendingTransferRequest.handler)) } else { None diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala index 0a7942bd581..22f5655680e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.format.{BufferMeta, BufferTransferRequest} -import org.apache.spark.internal.Logging import org.apache.spark.shuffle.rapids.RapidsShuffleSendPrepareException /** @@ -56,7 +55,21 @@ class BufferSendState( sendBounceBuffers: SendBounceBuffers, requestHandler: RapidsShuffleRequestHandler, serverStream: Cuda.Stream = Cuda.DEFAULT_STREAM) - extends AutoCloseable with Logging { + extends AutoCloseable { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[BufferSendState]) + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } class SendBlock(val bufferHandle: RapidsShuffleHandle) extends BlockWithSize { // we assume that the size of the buffer won't change as it goes to host/disk @@ -148,8 +161,8 @@ class BufferSendState( } } - case class RangeBuffer( - range: BlockRange[SendBlock], rapidsBuffer: MemoryBuffer) + private class RangeBuffer( + val range: BlockRange[SendBlock], val rapidsBuffer: MemoryBuffer) extends AutoCloseable { override def close(): Unit = { rapidsBuffer.close() @@ -189,7 +202,7 @@ class BufferSendState( case _ => hostBuffs += blockRange.rangeSize() } - RangeBuffer(blockRange, buff) + new RangeBuffer(blockRange, buff) } logDebug(s"Occupancy for bounce buffer is " + @@ -201,7 +214,9 @@ class BufferSendState( hostBounceBuffer.buffer } - acquiredBuffs.foreach { case RangeBuffer(blockRange, memoryBuffer) => + acquiredBuffs.foreach { rangeBuffer => + val blockRange = rangeBuffer.range + val memoryBuffer = rangeBuffer.rapidsBuffer needsCleanup = true require(blockRange.rangeSize() <= bounceBuffToUse.getLength - buffOffset) bounceBuffToUse.copyFromMemoryBufferAsync( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 5498dd890c4..7bc24b6c343 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,6 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.{MetadataResponse, TableMeta, TransferState} -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.storage.ShuffleBlockBatchId @@ -70,9 +69,9 @@ trait RapidsShuffleFetchHandler { * @param tableMeta shuffle metadata describing the table * @param handler a specific handler that is waiting for this batch */ -case class PendingTransferRequest(client: RapidsShuffleClient, - tableMeta: TableMeta, - handler: RapidsShuffleFetchHandler) { +class PendingTransferRequest(val client: RapidsShuffleClient, + val tableMeta: TableMeta, + val handler: RapidsShuffleFetchHandler) extends Serializable { val getLength: Long = tableMeta.bufferMeta.size() } @@ -98,7 +97,28 @@ class RapidsShuffleClient( exec: Executor, clientCopyExecutor: Executor, catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog) - extends Logging with AutoCloseable { + extends AutoCloseable { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[RapidsShuffleClient]) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + // these are handlers that are interested (live spark tasks) in peer failure handling private val liveHandlers = @@ -338,7 +358,7 @@ class RapidsShuffleClient( // We check the uncompressedSize to make sure we don't request a 0-sized buffer // from a peer. We treat such a corner case as a degenerate batch if (tableMeta.bufferMeta() != null && tableMeta.bufferMeta().uncompressedSize() > 0) { - ptrs += PendingTransferRequest( + ptrs += new PendingTransferRequest( this, tableMeta, handler) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala index 8d7817da595..430a59b6888 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,6 @@ import com.nvidia.spark.rapids.{NvtxRegistry, RapidsConf, RapidsShuffleHandle, S import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.format.TableMeta -import org.apache.spark.internal.Logging import org.apache.spark.shuffle.rapids.RapidsShuffleSendPrepareException import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.storage.{BlockManagerId, ShuffleBlockBatchId} @@ -74,7 +73,40 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, requestHandler: RapidsShuffleRequestHandler, exec: Executor, bssExec: Executor, - rapidsConf: RapidsConf) extends AutoCloseable with Logging { + rapidsConf: RapidsConf) extends AutoCloseable { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[RapidsShuffleServer]) + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logTrace(msg: => String): Unit = { + if (log.isTraceEnabled) { + log.trace(msg) + } + } + + private def logError(msg: => String): Unit = { + if (log.isErrorEnabled) { + log.error(msg) + } + } + + private def logError(msg: => String, throwable: Throwable): Unit = { + if (log.isErrorEnabled) { + log.error(msg, throwable) + } + } + def getId: BlockManagerId = { // upon seeing this port, the other side will try to connect to the port diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala index ef45bf5059d..5c468f75d88 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2025, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,13 @@ package com.nvidia.spark.rapids.shuffle -import java.nio.{ByteBuffer, ByteOrder} +import java.nio.{Buffer, ByteBuffer, ByteOrder} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import ai.rapids.cudf.MemoryBuffer import com.nvidia.spark.rapids.{NvtxRegistry, RapidsConf, ShimReflectionUtils} -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.storage.RapidsStorageUtils import org.apache.spark.storage.BlockManagerId @@ -156,20 +155,6 @@ object TransactionStatus extends Enumeration { val NotStarted, InProgress, Complete, Success, Error, Cancelled = Value } -/** - * Case class representing stats for the a transaction - * @param txTimeMs amount of time this [[Transaction]] took - * @param sendSize amount of bytes sent - * @param receiveSize amount of bytes received - * @param sendThroughput send throughput in GB/sec - * @param recvThroughput receive throughput in GB/sec - */ -case class TransactionStats(txTimeMs: Double, - sendSize: Long, - receiveSize: Long, - sendThroughput: Double, - recvThroughput: Double) - /** * TransportBuffer represents a buffer with an address and length. * @@ -194,7 +179,7 @@ class MetadataTransportBuffer(val dbb: RefCountedDirectByteBuffer) extends Trans def copy(in: ByteBuffer): Unit = { val bb = dbb.getBuffer() bb.put(in) - bb.rewind() + bb.asInstanceOf[Buffer].rewind() } override def getAddress(): Long = @@ -400,7 +385,15 @@ trait RapidsShuffleTransport extends AutoCloseable { * * @param bufferSize the size of direct `ByteBuffer` to allocate. */ -class DirectByteBufferPool(bufferSize: Long) extends Logging { +class DirectByteBufferPool(bufferSize: Long) { + private val log = org.slf4j.LoggerFactory.getLogger(classOf[DirectByteBufferPool]) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + val buffers = new ConcurrentLinkedQueue[ByteBuffer]() val high = new AtomicInteger(0) @@ -415,7 +408,7 @@ class DirectByteBufferPool(bufferSize: Long) extends Logging { logDebug(s"Allocating new direct buffer, high watermark = $high") new RefCountedDirectByteBuffer(ByteBuffer.allocateDirect(bufferSize.toInt), Option(this)) } else { - buff.clear() + buff.asInstanceOf[Buffer].clear() // Reset endianness to BIG_ENDIAN, as it could have changed depending on the consumer // (i.e. flat buffers force byte order to be LITTLE_ENDIAN, but pool consumers could be // things like handshake messages that don't use flat buffers). @@ -430,7 +423,7 @@ class DirectByteBufferPool(bufferSize: Long) extends Logging { def releaseBuffer(buff: RefCountedDirectByteBuffer): Boolean = { logDebug(s"Free direct buffers ${buffers.size()}") - buff.getBuffer().clear() + buff.getBuffer().asInstanceOf[Buffer].clear() buffers.offer(buff.getBuffer()) } } @@ -531,7 +524,7 @@ object TransportUtils { NvtxRegistry.TRANSPORT_COPY_BUFFER.push() try { val ro = src.asReadOnlyBuffer() - ro.limit(ro.position() + size) // make sure we only copy size bytes + ro.asInstanceOf[Buffer].limit(ro.position() + size) // make sure we only copy size bytes // copy from position to remaining = (limit - position) dst.put(ro) // bulk put } finally { @@ -550,7 +543,16 @@ object TransportUtils { } } -object RapidsShuffleTransport extends Logging { +object RapidsShuffleTransport { + private val log = org.slf4j.LoggerFactory.getLogger( + "com.nvidia.spark.rapids.shuffle.RapidsShuffleTransport") + + private def logError(msg: => String, throwable: Throwable): Unit = { + if (log.isErrorEnabled) { + log.error(msg, throwable) + } + } + /** * Used in `BlockManagerId`s when returning a map status after a shuffle write to * let the readers know what TCP port to use to establish a transport connection. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIterator.scala index 2fdce9862ad..9ba96c2bf5c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIterator.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,37 +18,6 @@ package com.nvidia.spark.rapids.shuffle import scala.collection.mutable.ArrayBuffer -// Helper trait that callers can use to add blocks to the iterator -// as long as they can provide a size -trait BlockWithSize { - /** - * Abstract method to return the size in bytes of this block - * @return Long - size in bytes - */ - def size: Long -} - -/** - * Specifies a start and end range of bytes for a block. - * @param block - a BlockWithSize instance - * @param rangeStart - byte offset for the start of the range (inclusive) - * @param rangeEnd - byte offset for the end of the range (exclusive) - * @tparam T - the specific type of `BlockWithSize` - */ -case class BlockRange[T <: BlockWithSize]( - block: T, rangeStart: Long, rangeEnd: Long) { - require(rangeStart < rangeEnd, - s"Instantiated a BlockRange with invalid boundaries: $rangeStart to $rangeEnd") - - /** - * Returns the size of this range in bytes - * @return - Long - size in bytes - */ - def rangeSize(): Long = rangeEnd - rangeStart - - def isComplete(): Boolean = rangeEnd == block.size -} - /** * Given a set of blocks, this iterator returns BlockRanges * of such blocks that fit `windowSize`. The ranges are just logical @@ -90,21 +59,21 @@ class WindowedBlockIterator[T <: BlockWithSize](blocks: Seq[T], windowSize: Long require(windowSize > 0, s"Invalid window size specified $windowSize") - private case class BlockWindow(start: Long, size: Long) { - val end = start + size // exclusive end offset + private class BlockWindow(val start: Long, val size: Long) { + val end: Long = start + size // exclusive end offset def move(): BlockWindow = { - BlockWindow(start + size, size) + new BlockWindow(start + size, size) } } // start the window at byte 0 - private[this] var window = BlockWindow(0, windowSize) + private[this] var window = new BlockWindow(0, windowSize) private[this] var done = false // helper class that captures the start/end byte offset // for `block` on creation - private case class BlockWithOffset[T <: BlockWithSize]( - block: T, startOffset: Long, endOffset: Long) + private class BlockWithOffset[T <: BlockWithSize]( + val block: T, val startOffset: Long, val endOffset: Long) private[this] val blocksWithOffsets = { var lastOffset = 0L @@ -113,7 +82,7 @@ class WindowedBlockIterator[T <: BlockWithSize](blocks: Seq[T], windowSize: Long val startOffset = lastOffset val endOffset = startOffset + block.size lastOffset = endOffset // for next block - BlockWithOffset(block, startOffset, endOffset) + new BlockWithOffset(block, startOffset, endOffset) } } @@ -121,9 +90,10 @@ class WindowedBlockIterator[T <: BlockWithSize](blocks: Seq[T], windowSize: Long // is an index into the `blocksWithOffsets` sequence private[this] var lastSeenBlock = 0 - case class BlocksForWindow(lastBlockIndex: Option[Int], - blockRanges: Seq[BlockRange[T]], - hasMoreBlocks: Boolean) + private class BlocksForWindow( + val lastBlockIndex: Option[Int], + val blockRanges: Seq[BlockRange[T]], + val hasMoreBlocks: Boolean) private def getBlocksForWindow( window: BlockWindow, @@ -144,7 +114,7 @@ class WindowedBlockIterator[T <: BlockWithSize](blocks: Seq[T], windowSize: Long if (window.end >= b.endOffset) { rangeEnd = b.endOffset - b.startOffset } - blockRangesInWindow.append(BlockRange[T](b.block, rangeStart, rangeEnd)) + blockRangesInWindow.append(new BlockRange[T](b.block, rangeStart, rangeEnd)) lastBlockIndex = Some(thisBlock) } else { // skip this block, unless it's before our window starts @@ -153,7 +123,7 @@ class WindowedBlockIterator[T <: BlockWithSize](blocks: Seq[T], windowSize: Long thisBlock = thisBlock + 1 } val lastBlock = blockRangesInWindow.last - BlocksForWindow(lastBlockIndex, + new BlocksForWindow(lastBlockIndex, blockRangesInWindow.toSeq, !continue || !lastBlock.isComplete()) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 2a05e486e9a..dbafe1c4f65 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids.spill import java.io._ -import java.nio.ByteBuffer +import java.nio.{Buffer, ByteBuffer} import java.nio.channels.{Channels, FileChannel, WritableByteChannel} import java.nio.file.StandardOpenOption import java.util @@ -27,16 +27,14 @@ import java.util.concurrent.ArrayBlockingQueue import scala.collection.mutable import ai.rapids.cudf._ -import com.nvidia.spark.rapids.{GpuColumnVector, GpuColumnVectorFromBuffer, GpuCompressedColumnVector, GpuDeviceManager, HashedPriorityQueue, HostAlloc, HostMemoryOutputStream, MemoryBufferToHostByteBufferIterator, NvtxId, NvtxRegistry, RapidsConf, RapidsHostColumnVector} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuColumnVectorFromBuffer, GpuCompressedColumnVector, GpuDeviceManager, HashedPriorityQueue, HostAlloc, HostByteBufferIterator, HostMemoryOutputStream, MemoryBufferToHostByteBufferIterator, NvtxId, NvtxRegistry, RapidsConf, RapidsHostColumnVector} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq import com.nvidia.spark.rapids.format.TableMeta -import com.nvidia.spark.rapids.internal.HostByteBufferIterator import com.nvidia.spark.rapids.jni.TaskPriority import org.apache.commons.io.IOUtils import org.apache.spark.{SparkConf, SparkEnv, TaskContext} -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.{GpuTaskMetrics, RapidsDiskBlockManager} import org.apache.spark.sql.rapids.execution.{SerializedHostTableUtils, TrampolineUtil} import org.apache.spark.sql.rapids.storage.RapidsStorageUtils @@ -172,7 +170,13 @@ trait StoreHandle extends AutoCloseable { var taskPriority: Long = taskId.map(TaskPriority.getTaskPriority).getOrElse(Long.MaxValue) } -trait SpillableHandle extends StoreHandle with Logging { +trait SpillableHandle extends StoreHandle { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + log.warn(msg, throwable) + } + /** * used to gate when a spill is actively being done so that a second thread won't * also begin spilling, and a handle won't release the underlying buffer if it's @@ -349,7 +353,7 @@ trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { } } -object SpillableHostBufferHandle extends Logging { +object SpillableHostBufferHandle { def apply(hmb: HostMemoryBuffer): SpillableHostBufferHandle = { val handle = new SpillableHostBufferHandle(hmb.getLength, host = Some(hmb)) SpillFramework.stores.hostStore.trackNoSpill(handle) @@ -675,7 +679,7 @@ class SpillableColumnarBatchHandle private ( override val approxSizeInBytes: Long, private[spill] override var dev: Option[ColumnarBatch], private[spill] var host: Option[SpillableHostBufferHandle] = None) - extends DeviceSpillableHandle[ColumnarBatch] with Logging { + extends DeviceSpillableHandle[ColumnarBatch] { override def spillable: Boolean = synchronized { if (super.spillable) { @@ -1305,7 +1309,7 @@ object HandleComparator extends util.Comparator[StoreHandle] { } } -trait HandleStore[T <: StoreHandle] extends AutoCloseable with Logging { +trait HandleStore[T <: StoreHandle] extends AutoCloseable { protected lazy val handles = new HashedPriorityQueue[T](HandleComparator) def numHandles: Int = synchronized { @@ -1358,7 +1362,7 @@ trait HandleStore[T <: StoreHandle] extends AutoCloseable with Logging { } trait SpillableStore[T <: SpillableHandle] - extends HandleStore[T] with Logging { + extends HandleStore[T] { protected def spillNvtxRange: NvtxId /** @@ -1480,8 +1484,16 @@ trait SpillableStore[T <: SpillableHandle] } class SpillableHostStore(val maxSize: Option[Long] = None) - extends SpillableStore[HostSpillableHandle[_]] - with Logging { + extends SpillableStore[HostSpillableHandle[_]] { + + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + private[spill] var totalSize: Long = 0L @@ -1648,7 +1660,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) private class SpillableHostBufferHandleBuilderForHost( var handle: SpillableHostBufferHandle, var singleShotBuffer: HostMemoryBuffer) - extends SpillableHostBufferHandleBuilder with Logging { + extends SpillableHostBufferHandleBuilder { private var copied = 0L override def copyNext(mb: DeviceMemoryBuffer, len: Long, stream: Cuda.Stream): Unit = { @@ -1752,7 +1764,13 @@ class SpillableDeviceStore extends SpillableStore[DeviceSpillableHandle[_]] { } class DiskHandleStore(conf: SparkConf) - extends HandleStore[DiskHandle] with Logging { + extends HandleStore[DiskHandle] { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logWarning(msg: => String): Unit = { + log.warn(msg) + } + val diskBlockManager: RapidsDiskBlockManager = new RapidsDiskBlockManager(conf) def getFile(blockId: BlockId): File = { @@ -1915,7 +1933,7 @@ class SpillableTableHandle private ( override val approxSizeInBytes: Long, private[spill] override var dev: Option[Table], private[spill] var host: Option[SpillableHostBufferHandle] = None) - extends DeviceSpillableHandle[Table] with Logging { + extends DeviceSpillableHandle[Table] { override def spillable: Boolean = synchronized { if (super.spillable) { @@ -2047,7 +2065,19 @@ object SpillableTableHandle { } } -object SpillFramework extends Logging { +object SpillFramework { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + log.warn(msg) + } + // public for tests. Some tests not in the `spill` package require setting this // because they need fine control over allocations. var storesInternal: SpillableStores = _ @@ -2211,7 +2241,13 @@ private[spill] class BounceBuffer[T <: AutoCloseable]( class BounceBufferPool[T <: AutoCloseable](private val bufSize: Long, private val bbCount: Int, private val allocator: Long => T) - extends AutoCloseable with Logging { + extends AutoCloseable { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logError(msg: => String): Unit = { + log.error(msg) + } + private val pool = new ArrayBlockingQueue[BounceBuffer[T]](bbCount) for (_ <- 1 to bbCount) { @@ -2277,7 +2313,13 @@ class BounceBufferPool[T <: AutoCloseable](private val bufSize: Long, */ class ChunkedPacker(table: Table, bounceBufferPool: BounceBufferPool[DeviceMemoryBuffer]) - extends Iterator[(BounceBuffer[DeviceMemoryBuffer], Long)] with Logging with AutoCloseable { + extends Iterator[(BounceBuffer[DeviceMemoryBuffer], Long)] with AutoCloseable { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logWarning(msg: => String): Unit = { + log.warn(msg) + } + private var closed: Boolean = false @@ -2310,7 +2352,7 @@ class ChunkedPacker(table: Table, val tmpBB = packedMeta.getMetadataDirectBuffer val metaCopy = ByteBuffer.allocateDirect(tmpBB.capacity()) metaCopy.put(tmpBB) - metaCopy.flip() + metaCopy.asInstanceOf[Buffer].flip() metaCopy } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala index 52dfd286773..9f47cd41f0f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala @@ -17,14 +17,13 @@ package com.nvidia.spark.rapids.spill import java.io.{BufferedInputStream, BufferedOutputStream, File, FileInputStream, FileOutputStream, IOException, RandomAccessFile} -import java.nio.ByteBuffer +import java.nio.{Buffer, ByteBuffer} import java.nio.channels.FileChannel import ai.rapids.cudf.HostMemoryBuffer import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.HostAlloc -import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.GpuTaskMetrics import org.apache.spark.sql.rapids.execution.TrampolineUtil @@ -74,7 +73,25 @@ class SpillablePartialFileHandle private ( priority: Long, syncWrites: Boolean, capacityHintProvider: Option[(Long, Long) => Long]) - extends HostSpillableHandle[ai.rapids.cudf.HostMemoryBuffer] with Logging { + extends HostSpillableHandle[ai.rapids.cudf.HostMemoryBuffer] { + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logDebug(msg: => String, throwable: Throwable): Unit = { + if (log.isDebugEnabled) { + log.debug(msg, throwable) + } + } + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + log.warn(msg, throwable) + } + // State management @volatile private var spilledToDisk: Boolean = false @@ -257,7 +274,7 @@ class SpillablePartialFileHandle private ( withResource(new FileOutputStream(file)) { fos => val channel = fos.getChannel val bb = buffer.asByteBuffer() - bb.limit(writePosition.toInt) + bb.asInstanceOf[Buffer].limit(writePosition.toInt) while (bb.hasRemaining) { channel.write(bb) } @@ -669,7 +686,7 @@ class SpillablePartialFileHandle private ( try { val channel = fos.getChannel val bb = bufferToSpill.asByteBuffer() - bb.limit(totalBytesWritten.toInt) + bb.asInstanceOf[Buffer].limit(totalBytesWritten.toInt) while (bb.hasRemaining) { channel.write(bb) } @@ -798,7 +815,7 @@ class SpillablePartialFileHandle private ( } } -object SpillablePartialFileHandle extends Logging { +object SpillablePartialFileHandle { /** * Create a file-only handle. diff --git a/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleExecutorComponents.scala b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleExecutorComponents.scala index e923d57ab83..af7f212bbd6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleExecutorComponents.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleExecutorComponents.scala @@ -21,7 +21,6 @@ import java.util.{Map => JMap, Optional} import com.google.common.annotations.VisibleForTesting import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.internal.Logging import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.shuffle.api.{ShuffleExecutorComponents, ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} import org.apache.spark.storage.BlockManager @@ -31,7 +30,7 @@ import org.apache.spark.storage.BlockManager * instances with host memory buffer support. */ class RapidsLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) - extends ShuffleExecutorComponents with Logging { + extends ShuffleExecutorComponents { private var blockManager: BlockManager = null private var blockResolver: IndexShuffleBlockResolver = null diff --git a/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleMapOutputWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleMapOutputWriter.scala index 3a0360d143c..1578d39b43d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleMapOutputWriter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleMapOutputWriter.scala @@ -25,7 +25,6 @@ import com.nvidia.spark.rapids.{HostAlloc, RapidsConf} import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter, WritableByteChannelWrapper} import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage @@ -42,7 +41,17 @@ class RapidsLocalDiskShuffleMapOutputWriter( numPartitions: Int, blockResolver: IndexShuffleBlockResolver, sparkConf: SparkConf) - extends ShuffleMapOutputWriter with Logging { + extends ShuffleMapOutputWriter { + @transient private lazy val log = org.slf4j.LoggerFactory.getLogger( + classOf[RapidsLocalDiskShuffleMapOutputWriter]) + + private def logDebug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg) + + private def logWarning(msg: => String): Unit = if (log.isWarnEnabled) log.warn(msg) + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + if (log.isWarnEnabled) log.warn(msg, throwable) + } private val partitionLengths = new Array[Long](numPartitions) private var lastPartitionId = -1 diff --git a/sql-plugin/src/main/scala/org/apache/spark/storage/RapidsPushBasedFetchHelper.scala b/sql-plugin/src/main/scala/org/apache/spark/storage/RapidsPushBasedFetchHelper.scala index 18f53e74060..7945d28dbc2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/storage/RapidsPushBasedFetchHelper.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/storage/RapidsPushBasedFetchHelper.scala @@ -27,7 +27,6 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.MapOutputTracker import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID -import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER import org.apache.spark.storage.RapidsShuffleBlockFetcherIterator._ @@ -52,7 +51,39 @@ private class RapidsPushBasedFetchHelper( private val iterator: RapidsShuffleBlockFetcherIterator, private val shuffleClient: BlockStoreClient, private val blockManager: BlockManager, - private val mapOutputTracker: MapOutputTracker) extends Logging { + private val mapOutputTracker: MapOutputTracker) { + + private val log = org.slf4j.LoggerFactory.getLogger(classOf[RapidsPushBasedFetchHelper]) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + if (log.isWarnEnabled) { + log.warn(msg, throwable) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logError(msg: => String, throwable: Throwable): Unit = { + if (log.isErrorEnabled) { + log.error(msg, throwable) + } + } private[this] val startTimeNs = System.nanoTime() diff --git a/sql-plugin/src/main/scala/org/apache/spark/storage/RapidsShuffleBlockFetcherIterator.scala b/sql-plugin/src/main/scala/org/apache/spark/storage/RapidsShuffleBlockFetcherIterator.scala index eb48fc3afdb..a33094fdb63 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/storage/RapidsShuffleBlockFetcherIterator.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/storage/RapidsShuffleBlockFetcherIterator.scala @@ -33,7 +33,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.SparkException -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.config import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} @@ -115,7 +115,51 @@ final class RapidsShuffleBlockFetcherIterator( checksumAlgorithm: String, shuffleMetrics: ShuffleReadMetricsReporter, doBatchFetch: Boolean) - extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { + extends Iterator[(BlockId, InputStream)] with DownloadFileManager { + + private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$")) + + private def logInfo(msg: => String): Unit = { + if (log.isInfoEnabled) { + log.info(msg) + } + } + + private def logWarning(msg: => String): Unit = { + if (log.isWarnEnabled) { + log.warn(msg) + } + } + + private def logWarning(msg: => String, throwable: Throwable): Unit = { + if (log.isWarnEnabled) { + log.warn(msg, throwable) + } + } + + private def logDebug(msg: => String): Unit = { + if (log.isDebugEnabled) { + log.debug(msg) + } + } + + private def logTrace(msg: => String): Unit = { + if (log.isTraceEnabled) { + log.trace(msg) + } + } + + private def logError(msg: => String): Unit = { + if (log.isErrorEnabled) { + log.error(msg) + } + } + + private def logError(msg: => String, throwable: Throwable): Unit = { + if (log.isErrorEnabled) { + log.error(msg, throwable) + } + } import RapidsShuffleBlockFetcherIterator._ diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala index 515ff75ffaa..2eb34347edf 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { def prepareBufferReceiveState( tableMeta: TableMeta, bounceBuffer: BounceBuffer): BufferReceiveState = { - val ptr = PendingTransferRequest(client, tableMeta, mockHandler) + val ptr = new PendingTransferRequest(client, tableMeta, mockHandler) spy(new BufferReceiveState(123L, bounceBuffer, Seq(ptr), () => {})) } @@ -42,7 +42,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { bounceBuffer: BounceBuffer): BufferReceiveState = { val ptrs = tableMetas.map { tm => - PendingTransferRequest(client, tm, mockHandler) + new PendingTransferRequest(client, tm, mockHandler) } spy(new BufferReceiveState(123L, bounceBuffer, ptrs, () => {})) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index af24b332c83..c9c67fb5185 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -172,7 +172,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) doNothing().when(client).doFetch(any(), ac.capture()) - val mockBuffer = RapidsShuffleHandle(mock[SpillableDeviceBufferHandle], null) + val mockBuffer = new RapidsShuffleHandle(mock[SpillableDeviceBufferHandle], null) when(mockBuffer.spillable.sizeInBytes).thenReturn(123L) val cb = new ColumnarBatch(Array.empty, 10) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala index 8d7415fba04..9b0ca83b692 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { fillBuffer(hostBuff) deviceBuffer.copyFromHostBuffer(hostBuff) val mockMeta = RapidsShuffleTestHelper.mockTableMeta(100000) - RapidsShuffleHandle(SpillableDeviceBufferHandle(deviceBuffer), mockMeta) + new RapidsShuffleHandle(SpillableDeviceBufferHandle(deviceBuffer), mockMeta) } } new MockRapidsShuffleRequestHandler(mockBuffers) @@ -208,7 +208,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { withResource(new RefCountedDirectByteBuffer(bb)) { _ => val tableMeta = MetaUtils.buildTableMeta(1, 456, bb, 100) val testHandle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(456)) - val rapidsBuffer = RapidsShuffleHandle(testHandle, tableMeta) + val rapidsBuffer = new RapidsShuffleHandle(testHandle, tableMeta) when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(1))) .thenReturn(rapidsBuffer) @@ -277,8 +277,8 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val ex = new IllegalStateException("something happened") when(mockHandleThatThrows.materialize()).thenThrow(ex) - val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta) - val rapidsBufferThatThrows = RapidsShuffleHandle(mockHandleThatThrows, tableMeta) + val rapidsBuffer = new RapidsShuffleHandle(mockHandle, tableMeta) + val rapidsBufferThatThrows = new RapidsShuffleHandle(mockHandleThatThrows, tableMeta) when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(1))) .thenReturn(rapidsBuffer) @@ -359,7 +359,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val tableMeta = MetaUtils.buildTableMeta(tableId, 456, bb, 100) val rapidsBuffer = if (error) { val mockHandle = mock[SpillableDeviceBufferHandle] - val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta) + val rapidsBuffer = new RapidsShuffleHandle(mockHandle, tableMeta) when(mockHandle.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) // mock an error with the copy when(rapidsBuffer.spillable.materialize()) @@ -369,7 +369,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { rapidsBuffer } else { val testHandle = spy(SpillableDeviceBufferHandle(spy(DeviceMemoryBuffer.allocate(456)))) - RapidsShuffleHandle(testHandle, tableMeta) + new RapidsShuffleHandle(testHandle, tableMeta) } when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(tableId))) .thenAnswer(_ => rapidsBuffer) diff --git a/tests/src/test/spark330/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark330/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index 0e2f4aa1f74..14ef01fe928 100644 --- a/tests/src/test/spark330/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark330/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -117,7 +117,7 @@ abstract class RapidsShuffleTestHelper def getSendBounceBuffer(size: Long): SendBounceBuffers = { val db = DeviceMemoryBuffer.allocate(size) - SendBounceBuffers(new BounceBuffer(db) { + new SendBounceBuffers(new BounceBuffer(db) { override def free(bb: BounceBuffer): Unit = { db.close() } diff --git a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index 66c0f4925de..73ad14721c1 100644 --- a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -131,7 +131,7 @@ abstract class RapidsShuffleTestHelper def getSendBounceBuffer(size: Long): SendBounceBuffers = { val db = DeviceMemoryBuffer.allocate(size) - SendBounceBuffers(new BounceBuffer(db) { + new SendBounceBuffers(new BounceBuffer(db) { override def free(bb: BounceBuffer): Unit = { db.close() }