Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.HashSet
import java.util.concurrent.ConcurrentHashMap

import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import _root_.io.netty.handler.stream.ChunkedStream
import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle
Expand All @@ -39,11 +40,143 @@ import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId}
* @param handle the partial file handle containing the data
* @param offset starting offset within the handle
* @param length number of bytes in this segment
* @param handleRef lifecycle reference for the partial file handle
*/
case class PartitionSegment(
handle: SpillablePartialFileHandle,
offset: Long,
length: Long)
length: Long,
handleRef: ShuffleHandleReference)

/**
* Reference-counted lifecycle wrapper for a partial shuffle file handle.
*
* Shuffle cleanup can request a close while retained buffers, streams, or file regions are still
* reading the handle. The physical close is deferred until all active leases release, allowing
* catalog metadata to be removed without closing data under active consumers.
*/
private[rapids] final class ShuffleHandleReference(handle: SpillablePartialFileHandle)
extends Logging {

private var refCount: Int = 0
private var closeRequested: Boolean = false
private var closed: Boolean = false
private var closeShuffleId: Int = -1

def retain(): Unit = synchronized {
if (closed) {
throw new IllegalStateException("Cannot retain a closed shuffle handle")
}
refCount += 1
Comment thread
greptile-apps[bot] marked this conversation as resolved.
Outdated
}

def release(): Unit = {
val shuffleIdToClose = synchronized {
if (refCount <= 0) {
throw new IllegalStateException("release() without matching retain()")
}
refCount -= 1
closeShuffleIdIfReady()
}
shuffleIdToClose.foreach(closeHandle)
}

def requestClose(shuffleId: Int): Unit = {
val shuffleIdToClose = synchronized {
if (!closeRequested) {
closeRequested = true
closeShuffleId = shuffleId
}
closeShuffleIdIfReady()
}
shuffleIdToClose.foreach(closeHandle)
}

private def closeShuffleIdIfReady(): Option[Int] = {
if (closeRequested && refCount == 0 && !closed) {
closed = true
Some(closeShuffleId)
} else {
None
}
}

private def closeHandle(shuffleId: Int): Unit = {
try {
handle.close()
} catch {
case e: InterruptedException =>
Thread.currentThread().interrupt()
logError(s"Interrupted while closing handle for shuffle $shuffleId", e)
case NonFatal(e) =>
logError(s"Failed to close handle for shuffle $shuffleId", e)
}
}
}

/**
* Owns a temporary retain on one or more partial shuffle file handles.
*
* A lease is used while a buffer, stream, or file region may still read the handles. Closing the
* lease releases all retained handles exactly once.
*/
private[rapids] final class ShuffleHandleLease(handleRefs: Seq[ShuffleHandleReference])
extends AutoCloseable {
private var released: Boolean = false

override def close(): Unit = {
val refsToRelease = synchronized {
if (released) {
Seq.empty
} else {
released = true
handleRefs
}
}
ShuffleHandleLease.releaseAll(refsToRelease.reverseIterator)
}
}

private[rapids] object ShuffleHandleLease {
def acquire(handleRefs: Seq[ShuffleHandleReference]): ShuffleHandleLease = {
val retained = new ArrayBuffer[ShuffleHandleReference](handleRefs.size)
try {
handleRefs.foreach { ref =>
ref.retain()
retained += ref
}
new ShuffleHandleLease(retained.toSeq)
} catch {
case t: Throwable =>
try {
releaseAll(retained.reverseIterator)
} catch {
case releaseFailure: Throwable =>
t.addSuppressed(releaseFailure)
}
throw t
}
}

private def releaseAll(handleRefs: Iterator[ShuffleHandleReference]): Unit = {
var firstFailure: Throwable = null
handleRefs.foreach { ref =>
try {
ref.release()
} catch {
case t: Throwable =>
if (firstFailure == null) {
firstFailure = t
} else {
firstFailure.addSuppressed(t)
}
}
}
if (firstFailure != null) {
throw firstFailure
}
}
}

/**
* Catalog for managing shuffle data in MULTITHREADED mode without merging.
Expand All @@ -69,6 +202,10 @@ class MultithreadedShuffleBufferCatalog extends Logging {
/** Track active shuffles for cleanup */
private val activeShuffles = new ConcurrentHashMap[Int, JBoolean]()

/** Track each unique partial file handle so cleanup can be deferred while readers are active. */
private val handleRefs =
new ConcurrentHashMap[SpillablePartialFileHandle, ShuffleHandleReference]()

/**
* Register a shuffle as active.
* Must be called before adding any partitions for this shuffle.
Expand Down Expand Up @@ -99,7 +236,8 @@ class MultithreadedShuffleBufferCatalog extends Logging {
}

val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
val segment = PartitionSegment(handle, offset, length)
val handleRef = handleRefs.computeIfAbsent(handle, h => new ShuffleHandleReference(h))
val segment = PartitionSegment(handle, offset, length, handleRef)

partitionSegments.compute(blockId, (_, existing) => {
val segments = if (existing == null) new ArrayBuffer[PartitionSegment]() else existing
Expand Down Expand Up @@ -217,19 +355,16 @@ class MultithreadedShuffleBufferCatalog extends Logging {
numForcedFileOnly += 1
}

try {
handle.close()
} catch {
case e: Exception =>
logError(s"Failed to close handle for shuffle $shuffleId", e)
}
// Drop catalog ownership; retained buffers/streams keep the reference alive.
handleRefs.remove(handle, segment.handleRef)
segment.handleRef.requestClose(shuffleId)
}
}
}
}

logDebug(s"Unregistered shuffle $shuffleId: closed ${closedHandles.size()} handles, " +
s"bytesFromMemory=$bytesFromMemory, bytesFromDisk=$bytesFromDisk, " +
logDebug(s"Unregistered shuffle $shuffleId: cleanup requested for ${closedHandles.size()} " +
s"handles, bytesFromMemory=$bytesFromMemory, bytesFromDisk=$bytesFromDisk, " +
s"numExpansions=$numExpansions, numSpills=$numSpills, numForcedFileOnly=$numForcedFileOnly")

// Return statistics if we had any data
Expand All @@ -252,50 +387,79 @@ class MultithreadedShuffleBufferCatalog extends Logging {
*/
class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBuffer {

private val handleRefs: Seq[ShuffleHandleReference] = segments.map(_.handleRef).distinct

/** Guards bufferLeases while retain()/release() can be called from different threads. */
private val retainLock = new Object

/** Leases that keep this buffer's partial shuffle file handles open after retain(). */
private val bufferLeases = new ArrayBuffer[ShuffleHandleLease]()

override def size(): Long = segments.map(_.length).sum

override def nioByteBuffer(): ByteBuffer = {
// This method loads all data into memory. It's required by the ManagedBuffer interface
// but is NOT used in the network transfer path - Spark's network layer uses
// convertToNetty() which returns our streaming MultiSegmentFileRegion.
// This method may be called by other code paths (e.g., local block reading).
val totalSize = size().toInt
val buffer = ByteBuffer.allocate(totalSize)
val bytes = new Array[Byte](8192) // Read buffer

segments.foreach { segment =>
var remaining = segment.length
var position = segment.offset
while (remaining > 0) {
val toRead = math.min(remaining, bytes.length).toInt
val bytesRead = segment.handle.readAt(position, bytes, 0, toRead)
if (bytesRead <= 0) {
throw new IOException(
s"Unexpected EOF reading segment at position $position, " +
s"expected ${segment.length} bytes")
val lease = ShuffleHandleLease.acquire(handleRefs)
try {
// This method loads all data into memory. It's required by the ManagedBuffer interface
// but is NOT used in the network transfer path - Spark's network layer uses
// convertToNetty() which returns our streaming MultiSegmentFileRegion.
// This method may be called by other code paths (e.g., local block reading).
val totalSize = size().toInt
val buffer = ByteBuffer.allocate(totalSize)
val bytes = new Array[Byte](8192) // Read buffer

segments.foreach { segment =>
var remaining = segment.length
var position = segment.offset
while (remaining > 0) {
val toRead = math.min(remaining, bytes.length).toInt
val bytesRead = segment.handle.readAt(position, bytes, 0, toRead)
if (bytesRead <= 0) {
throw new IOException(
s"Unexpected EOF reading segment at position $position, " +
s"expected ${segment.length} bytes")
}
buffer.put(bytes, 0, bytesRead)
position += bytesRead
remaining -= bytesRead
}
buffer.put(bytes, 0, bytesRead)
position += bytesRead
remaining -= bytesRead
}
}

buffer.flip()
buffer
buffer.flip()
buffer
} finally {
lease.close()
}
}

override def createInputStream(): InputStream = {
new MultiSegmentInputStream(segments)
new MultiSegmentInputStream(segments, handleRefs)
}

override def retain(): ManagedBuffer = this
override def retain(): ManagedBuffer = {
val lease = ShuffleHandleLease.acquire(handleRefs)
retainLock.synchronized {
bufferLeases += lease
}
this
}

override def release(): ManagedBuffer = this
override def release(): ManagedBuffer = {
val lease = retainLock.synchronized {
if (bufferLeases.nonEmpty) {
Some(bufferLeases.remove(bufferLeases.size - 1))
} else {
None
}
}
lease.foreach(_.close())
this
}

override def convertToNetty(): AnyRef = {
// Return a custom FileRegion that streams data in chunks, avoiding loading all
// data into memory at once. This addresses concerns about large shuffle blocks.
new MultiSegmentFileRegion(segments)
new MultiSegmentFileRegion(segments, handleRefs)
}

// Spark 4.0+ adds convertToNettyForSsl() abstract method.
Expand All @@ -314,12 +478,20 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu

/**
* An InputStream that reads from multiple partition segments sequentially.
*
* This stream is not thread-safe; callers should create one stream per reading thread and close it
* from that owner thread.
*/
class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStream {
class MultiSegmentInputStream(
segments: Seq[PartitionSegment],
handleRefs: Seq[ShuffleHandleReference]) extends InputStream {

private var currentSegmentIndex: Int = 0
private var currentPosition: Long = if (segments.nonEmpty) segments.head.offset else 0
private var bytesReadInCurrentSegment: Long = 0
// Keeps the partial shuffle file handles open until this stream is closed.
private val lease = ShuffleHandleLease.acquire(handleRefs)
private var closed: Boolean = false

override def read(): Int = {
val buf = new Array[Byte](1)
Expand All @@ -328,6 +500,10 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}

override def read(b: Array[Byte], off: Int, len: Int): Int = {
if (closed) {
throw new IOException("Stream is closed")
}

// Use loop instead of recursion to avoid StackOverflowError with many segments
while (currentSegmentIndex < segments.size) {
val segment = segments(currentSegmentIndex)
Expand Down Expand Up @@ -358,7 +534,7 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}

override def available(): Int = {
if (currentSegmentIndex >= segments.size) {
if (closed || currentSegmentIndex >= segments.size) {
0
} else {
val remaining = segments.drop(currentSegmentIndex).map { seg =>
Expand All @@ -372,13 +548,11 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}
}

/**
* Close is a no-op because the underlying SpillablePartialFileHandle resources
* are managed by MultithreadedShuffleBufferCatalog and will be closed when
* the shuffle is unregistered.
*/
override def close(): Unit = {
// No-op: handles are managed by MultithreadedShuffleBufferCatalog
if (!closed) {
closed = true
lease.close()
}
}
}

Expand All @@ -392,8 +566,12 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
* Spark's MessageWithHeader only accepts ByteBuf or FileRegion. By implementing
* FileRegion, we can provide streaming transfer while remaining compatible with
* Spark's network layer.
*
* Instances are not thread-safe; each transfer should use one FileRegion owned by one Netty write.
*/
class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFileRegion {
class MultiSegmentFileRegion(
segments: Seq[PartitionSegment],
handleRefs: Seq[ShuffleHandleReference]) extends AbstractFileRegion {

private val totalSize: Long = segments.map(_.length).sum
private var totalTransferred: Long = 0
Expand All @@ -407,6 +585,8 @@ class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFi
// Track current position within the logical data stream
private var currentSegmentIndex: Int = 0
private var bytesTransferredInCurrentSegment: Long = 0
// Keeps the partial shuffle file handles open until this file region is deallocated.
private val lease = ShuffleHandleLease.acquire(handleRefs)

override def count(): Long = totalSize

Expand Down Expand Up @@ -475,6 +655,6 @@ class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFi
}

override protected def deallocate(): Unit = {
// No resources to release - handles are managed by MultithreadedShuffleBufferCatalog
lease.close()
}
}
Loading