Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,7 +51,7 @@ class BatchedCopyCompressor(maxBatchMemory: Long, stream: Cuda.Stream)
ct,
CodecType.COPY,
outBuffer.getLength)
CompressedTable(outBuffer.getLength, meta, outBuffer)
new CompressedTable(outBuffer.getLength, meta, outBuffer)
}
}
closeOnExcept(result) { _ => stream.sync() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids

import java.io.{InputStream, IOException}
import java.lang.{Boolean => JBoolean}
import java.nio.ByteBuffer
import java.nio.{Buffer, ByteBuffer}
import java.nio.channels.WritableByteChannel
import java.util.HashSet
import java.util.concurrent.ConcurrentHashMap
Expand All @@ -28,7 +28,6 @@ import scala.collection.mutable.ArrayBuffer
import _root_.io.netty.handler.stream.ChunkedStream
import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.util.AbstractFileRegion
import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId}
Expand All @@ -40,10 +39,10 @@ import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId}
* @param offset starting offset within the handle
* @param length number of bytes in this segment
*/
case class PartitionSegment(
handle: SpillablePartialFileHandle,
offset: Long,
length: Long)
class PartitionSegment(
val handle: SpillablePartialFileHandle,
val offset: Long,
val length: Long)

/**
* Catalog for managing shuffle data in MULTITHREADED mode without merging.
Expand All @@ -57,7 +56,19 @@ case class PartitionSegment(
* (MEMORY_WITH_SPILL mode) or stored directly on disk (ONLY_FILE mode) depending
* on memory pressure - both modes work with this skip-merge design.
*/
class MultithreadedShuffleBufferCatalog extends Logging {
class MultithreadedShuffleBufferCatalog {
private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$"))

private def logDebug(msg: => String): Unit = {
if (log.isDebugEnabled) {
log.debug(msg)
}
}

private def logError(msg: => String, throwable: Throwable): Unit = {
log.error(msg, throwable)
}


/**
* Map from ShuffleBlockId to list of segments.
Expand Down Expand Up @@ -99,7 +110,7 @@ class MultithreadedShuffleBufferCatalog extends Logging {
}

val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
val segment = PartitionSegment(handle, offset, length)
val segment = new PartitionSegment(handle, offset, length)

partitionSegments.compute(blockId, (_, existing) => {
val segments = if (existing == null) new ArrayBuffer[PartitionSegment]() else existing
Expand Down Expand Up @@ -280,7 +291,7 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu
}
}

buffer.flip()
buffer.asInstanceOf[Buffer].flip()
buffer
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,7 +68,7 @@ class BatchedNvcompLZ4Compressor(maxBatchMemorySize: Long,
table,
CodecType.NVCOMP_LZ4,
compressedSize)
CompressedTable(compressedSize, meta, buffer)
new CompressedTable(compressedSize, meta, buffer)
}.toArray
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
* Copyright (c) 2024-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,8 +18,7 @@ package com.nvidia.spark.rapids

import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ContiguousTable, Cuda, DeviceMemoryBuffer}
import ai.rapids.cudf.nvcomp.{BatchedZstdCompressor, BatchedZstdDecompressor}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray
import com.nvidia.spark.rapids.Arm.closeOnExcept
import com.nvidia.spark.rapids.format.{BufferMeta, CodecType}

/** A table compression codec that uses nvcomp's ZSTD-GPU codec */
Expand Down Expand Up @@ -59,7 +58,7 @@ class BatchedNvcompZSTDCompressor(maxBatchMemorySize: Long,
compressedBufs.zip(tables).map { case (buffer, table) =>
val compressedLen = buffer.getLength
val meta = MetaUtils.buildTableMeta(None, table, CodecType.NVCOMP_ZSTD, compressedLen)
CompressedTable(compressedLen, meta, buffer)
new CompressedTable(compressedLen, meta, buffer)
}.toArray
}
}
Expand Down Expand Up @@ -90,23 +89,3 @@ class BatchedNvcompZSTDDecompressor(maxBatchMemory: Long,
outputBufs
}
}

object DeviceBuffersUtils {
def incRefCount(bufs: Array[BaseDeviceMemoryBuffer]): Array[BaseDeviceMemoryBuffer] = {
bufs.safeMap { b =>
b.incRefCount()
b
}
}

def allocateBuffers(bufSizes: Array[Long]): Array[DeviceMemoryBuffer] = {
var curPos = 0L
withResource(DeviceMemoryBuffer.allocate(bufSizes.sum)) { singleBuf =>
bufSizes.safeMap { len =>
val ret = singleBuf.slice(curPos, len)
curPos += len
ret
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,19 @@ import org.apache.commons.lang3.mutable.MutableLong

import org.apache.spark.SparkEnv
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase}
import org.apache.spark.storage.BlockManagerId

class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long,
heartbeatTimeoutMillis: Long) extends Logging {
heartbeatTimeoutMillis: Long) {
private val log = org.slf4j.LoggerFactory.getLogger(classOf[RapidsShuffleHeartbeatManager])

private def logDebug(msg: => String): Unit = {
if (log.isDebugEnabled) {
log.debug(msg)
}
}

require(heartbeatIntervalMillis > 0,
s"The interval value: $heartbeatIntervalMillis ms is not > 0")

Expand All @@ -45,14 +52,18 @@ class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long,
// exposed so that it can be mocked in the tests
def getCurrentTimeMillis: Long = System.currentTimeMillis()

private case class ExecutorRegistration(
id: BlockManagerId,
private class ExecutorRegistration(
val id: BlockManagerId,
// this is this executor's registration order, as given by this manager
registrationOrder: Long,
val registrationOrder: Long,
// this is the last registration order this executor is aware of overall
lastRegistrationOrderSeen: MutableLong,
val lastRegistrationOrderSeen: MutableLong,
// last heartbeat received from this executor in millis
lastHeartbeatMillis: MutableLong)
val lastHeartbeatMillis: MutableLong) {
override def toString: String =
s"ExecutorRegistration($id,$registrationOrder,$lastRegistrationOrderSeen," +
s"$lastHeartbeatMillis)"
}

// a counter used to mark each new executor registration with an order
var registrationOrder = 0L
Expand Down Expand Up @@ -82,7 +93,7 @@ class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long,
require(!executorRegistrations.containsKey(id), s"Executor $id already registered")
removeDeadExecutors(getCurrentTimeMillis)
val allExecutors = executors.map(e => e.id).toArray
val newReg = ExecutorRegistration(id,
val newReg = new ExecutorRegistration(id,
registrationOrder,
new MutableLong(registrationOrder),
new MutableLong(getCurrentTimeMillis))
Expand Down Expand Up @@ -167,7 +178,40 @@ class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long,
}

class RapidsShuffleHeartbeatEndpoint(pluginContext: PluginContext, conf: RapidsConf)
extends Logging with AutoCloseable {
extends AutoCloseable {

private val log = org.slf4j.LoggerFactory.getLogger(classOf[RapidsShuffleHeartbeatEndpoint])

private def logInfo(msg: => String): Unit = {
if (log.isInfoEnabled) {
log.info(msg)
}
}

private def logWarning(msg: => String): Unit = {
if (log.isWarnEnabled) {
log.warn(msg)
}
}

private def logDebug(msg: => String): Unit = {
if (log.isDebugEnabled) {
log.debug(msg)
}
}

private def logTrace(msg: => String): Unit = {
if (log.isTraceEnabled) {
log.trace(msg)
}
}

private def logError(msg: => String, throwable: Throwable): Unit = {
if (log.isErrorEnabled) {
log.error(msg, throwable)
}
}

// Number of milliseconds between heartbeats to driver
private[this] val heartbeatIntervalMillis =
conf.shuffleTransportEarlyStartHeartbeatInterval
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,22 +28,21 @@ import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.spill.{SpillableDeviceBufferHandle, SpillableHandle}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.ShuffleBlockId

/** Identifier for a shuffle buffer that holds the data for a table */
case class ShuffleBufferId(
blockId: ShuffleBlockId,
tableId: Int) {
val shuffleId: Int = blockId.shuffleId
val mapId: Long = blockId.mapId
}

/** Catalog for lookup of shuffle buffers by block ID */
class ShuffleBufferCatalog extends Logging {
class ShuffleBufferCatalog {
private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$"))

private def logWarning(msg: => String): Unit = {
if (log.isWarnEnabled) {
log.warn(msg)
}
}

/**
* Information stored for each active shuffle.
* A shuffle block can be comprised of multiple batches. Each batch
Expand Down Expand Up @@ -259,7 +258,7 @@ class ShuffleBufferCatalog extends Logging {
}

val tableId = tableIdCounter.getAndUpdate(ShuffleBufferCatalog.TABLE_ID_UPDATER)
val id = ShuffleBufferId(blockId, tableId)
val id = new ShuffleBufferId(blockId, tableId)
val prev = tableMap.put(tableId, id)
if (prev != null) {
throw new IllegalStateException(s"table ID $tableId is already in use")
Expand All @@ -283,7 +282,7 @@ class ShuffleBufferCatalog extends Logging {
val (maybeHandle, meta) = bufferIdToHandle.get(shuffleBufferId)
maybeHandle match {
case Some(spillable) =>
RapidsShuffleHandle(spillable, meta)
new RapidsShuffleHandle(spillable, meta)
case None =>
throw new IllegalStateException(
"a buffer handle could not be obtained for a degenerate buffer")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer
import com.nvidia.spark.rapids.jni.RmmSpark

import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.GpuShuffleEnv

/**
Expand All @@ -38,7 +37,40 @@ import org.apache.spark.sql.rapids.GpuShuffleEnv
*/
class ShuffleCleanupEndpoint(
pluginContext: PluginContext,
pollIntervalMs: Long = 1000) extends Logging with AutoCloseable {
pollIntervalMs: Long = 1000) extends AutoCloseable {

private val log = org.slf4j.LoggerFactory.getLogger(classOf[ShuffleCleanupEndpoint])

private def logInfo(msg: => String): Unit = {
if (log.isInfoEnabled) {
log.info(msg)
}
}

private def logWarning(msg: => String): Unit = {
if (log.isWarnEnabled) {
log.warn(msg)
}
}

private def logWarning(msg: => String, throwable: Throwable): Unit = {
if (log.isWarnEnabled) {
log.warn(msg, throwable)
}
}

private def logDebug(msg: => String): Unit = {
if (log.isDebugEnabled) {
log.debug(msg)
}
}

private def logTrace(msg: => String): Unit = {
if (log.isTraceEnabled) {
log.trace(msg)
}
}


private val executorId: String = pluginContext.executorID()

Expand Down
Loading
Loading