Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.HashSet
import java.util.concurrent.ConcurrentHashMap

import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import _root_.io.netty.handler.stream.ChunkedStream
import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle
Expand All @@ -45,6 +46,72 @@ case class PartitionSegment(
offset: Long,
length: Long)

/**
* Owns a temporary read lease on one or more partial shuffle file handles.
*
* A lease is held while a buffer, stream, or file region may still read the handles. Closing the
* lease releases every handle exactly once; each handle defers its physical close until its last
* lease is released (see `SpillablePartialFileHandle.acquireRead`/`releaseRead`).
*/
private[rapids] final class ShuffleHandleLease(handles: Seq[SpillablePartialFileHandle])
extends AutoCloseable {
private var released: Boolean = false

override def close(): Unit = {
val handlesToRelease = synchronized {
if (released) {
Seq.empty
} else {
released = true
handles
}
}
ShuffleHandleLease.releaseAll(handlesToRelease.reverseIterator)
}
}

private[rapids] object ShuffleHandleLease {
/** Acquire a read lease on every handle, rolling back the partial set if any acquire fails. */
def acquire(handles: Seq[SpillablePartialFileHandle]): ShuffleHandleLease = {
val retained = new ArrayBuffer[SpillablePartialFileHandle](handles.size)
try {
handles.foreach { handle =>
handle.acquireRead()
retained += handle
}
new ShuffleHandleLease(retained.toSeq)
} catch {
case t: Throwable =>
try {
releaseAll(retained.reverseIterator)
} catch {
case releaseFailure: Throwable =>
t.addSuppressed(releaseFailure)
}
throw t
}
}

private def releaseAll(handles: Iterator[SpillablePartialFileHandle]): Unit = {
var firstFailure: Throwable = null
handles.foreach { handle =>
try {
handle.releaseRead()
} catch {
case t: Throwable =>
if (firstFailure == null) {
firstFailure = t
} else {
firstFailure.addSuppressed(t)
}
}
}
if (firstFailure != null) {
throw firstFailure
}
}
}

/**
* Catalog for managing shuffle data in MULTITHREADED mode without merging.
*
Expand Down Expand Up @@ -217,19 +284,23 @@ class MultithreadedShuffleBufferCatalog extends Logging {
numForcedFileOnly += 1
}

// Drop catalog ownership; retained buffers, streams, and file regions keep the handle
// alive through their read leases, so the physical close is deferred until the last
// lease is released. close() propagates failures, so catch here so one bad handle
// does not abort cleanup of the rest.
try {
handle.close()
} catch {
case e: Exception =>
logError(s"Failed to close handle for shuffle $shuffleId", e)
case NonFatal(e) =>
logWarning(s"Failed to request close of handle for shuffle $shuffleId", e)
}
}
}
}
}

logDebug(s"Unregistered shuffle $shuffleId: closed ${closedHandles.size()} handles, " +
s"bytesFromMemory=$bytesFromMemory, bytesFromDisk=$bytesFromDisk, " +
logDebug(s"Unregistered shuffle $shuffleId: cleanup requested for ${closedHandles.size()} " +
s"handles, bytesFromMemory=$bytesFromMemory, bytesFromDisk=$bytesFromDisk, " +
s"numExpansions=$numExpansions, numSpills=$numSpills, numForcedFileOnly=$numForcedFileOnly")

// Return statistics if we had any data
Expand All @@ -252,50 +323,79 @@ class MultithreadedShuffleBufferCatalog extends Logging {
*/
class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBuffer {

private val handles: Seq[SpillablePartialFileHandle] = segments.map(_.handle).distinct

/** Guards bufferLeases while retain()/release() can be called from different threads. */
private val retainLock = new Object

/** Leases that keep this buffer's partial shuffle file handles open after retain(). */
private val bufferLeases = new ArrayBuffer[ShuffleHandleLease]()

override def size(): Long = segments.map(_.length).sum

override def nioByteBuffer(): ByteBuffer = {
// This method loads all data into memory. It's required by the ManagedBuffer interface
// but is NOT used in the network transfer path - Spark's network layer uses
// convertToNetty() which returns our streaming MultiSegmentFileRegion.
// This method may be called by other code paths (e.g., local block reading).
val totalSize = size().toInt
val buffer = ByteBuffer.allocate(totalSize)
val bytes = new Array[Byte](8192) // Read buffer

segments.foreach { segment =>
var remaining = segment.length
var position = segment.offset
while (remaining > 0) {
val toRead = math.min(remaining, bytes.length).toInt
val bytesRead = segment.handle.readAt(position, bytes, 0, toRead)
if (bytesRead <= 0) {
throw new IOException(
s"Unexpected EOF reading segment at position $position, " +
s"expected ${segment.length} bytes")
val lease = ShuffleHandleLease.acquire(handles)
try {
// This method loads all data into memory. It's required by the ManagedBuffer interface
// but is NOT used in the network transfer path - Spark's network layer uses
// convertToNetty() which returns our streaming MultiSegmentFileRegion.
// This method may be called by other code paths (e.g., local block reading).
val totalSize = size().toInt
val buffer = ByteBuffer.allocate(totalSize)
val bytes = new Array[Byte](8192) // Read buffer

segments.foreach { segment =>
var remaining = segment.length
var position = segment.offset
while (remaining > 0) {
val toRead = math.min(remaining, bytes.length).toInt
val bytesRead = segment.handle.readAt(position, bytes, 0, toRead)
if (bytesRead <= 0) {
throw new IOException(
s"Unexpected EOF reading segment at position $position, " +
s"expected ${segment.length} bytes")
}
buffer.put(bytes, 0, bytesRead)
position += bytesRead
remaining -= bytesRead
}
buffer.put(bytes, 0, bytesRead)
position += bytesRead
remaining -= bytesRead
}
}

buffer.flip()
buffer
buffer.flip()
buffer
} finally {
lease.close()
}
}

override def createInputStream(): InputStream = {
new MultiSegmentInputStream(segments)
new MultiSegmentInputStream(segments, handles)
}

override def retain(): ManagedBuffer = this
override def retain(): ManagedBuffer = {
val lease = ShuffleHandleLease.acquire(handles)
retainLock.synchronized {
bufferLeases += lease
}
this
}

override def release(): ManagedBuffer = this
override def release(): ManagedBuffer = {
val lease = retainLock.synchronized {
if (bufferLeases.nonEmpty) {
Some(bufferLeases.remove(bufferLeases.size - 1))
} else {
None
}
}
lease.foreach(_.close())
this
}

override def convertToNetty(): AnyRef = {
// Return a custom FileRegion that streams data in chunks, avoiding loading all
// data into memory at once. This addresses concerns about large shuffle blocks.
new MultiSegmentFileRegion(segments)
new MultiSegmentFileRegion(segments, handles)
}

// Spark 4.0+ adds convertToNettyForSsl() abstract method.
Expand All @@ -314,12 +414,20 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu

/**
* An InputStream that reads from multiple partition segments sequentially.
*
* This stream is not thread-safe; callers should create one stream per reading thread and close it
* from that owner thread.
*/
class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStream {
class MultiSegmentInputStream(
segments: Seq[PartitionSegment],
handles: Seq[SpillablePartialFileHandle]) extends InputStream {

private var currentSegmentIndex: Int = 0
private var currentPosition: Long = if (segments.nonEmpty) segments.head.offset else 0
private var bytesReadInCurrentSegment: Long = 0
// Keeps the partial shuffle file handles open until this stream is closed.
private val lease = ShuffleHandleLease.acquire(handles)
private var closed: Boolean = false

override def read(): Int = {
val buf = new Array[Byte](1)
Expand All @@ -328,6 +436,10 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}

override def read(b: Array[Byte], off: Int, len: Int): Int = {
if (closed) {
throw new IOException("Stream is closed")
}

// Use loop instead of recursion to avoid StackOverflowError with many segments
while (currentSegmentIndex < segments.size) {
val segment = segments(currentSegmentIndex)
Expand Down Expand Up @@ -358,7 +470,7 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}

override def available(): Int = {
if (currentSegmentIndex >= segments.size) {
if (closed || currentSegmentIndex >= segments.size) {
0
} else {
val remaining = segments.drop(currentSegmentIndex).map { seg =>
Expand All @@ -372,13 +484,11 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}
}

/**
* Close is a no-op because the underlying SpillablePartialFileHandle resources
* are managed by MultithreadedShuffleBufferCatalog and will be closed when
* the shuffle is unregistered.
*/
override def close(): Unit = {
// No-op: handles are managed by MultithreadedShuffleBufferCatalog
if (!closed) {
closed = true
lease.close()
}
}
}

Expand All @@ -392,8 +502,12 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
* Spark's MessageWithHeader only accepts ByteBuf or FileRegion. By implementing
* FileRegion, we can provide streaming transfer while remaining compatible with
* Spark's network layer.
*
* Instances are not thread-safe; each transfer should use one FileRegion owned by one Netty write.
*/
class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFileRegion {
class MultiSegmentFileRegion(
segments: Seq[PartitionSegment],
handles: Seq[SpillablePartialFileHandle]) extends AbstractFileRegion {

private val totalSize: Long = segments.map(_.length).sum
private var totalTransferred: Long = 0
Expand All @@ -407,6 +521,8 @@ class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFi
// Track current position within the logical data stream
private var currentSegmentIndex: Int = 0
private var bytesTransferredInCurrentSegment: Long = 0
// Keeps the partial shuffle file handles open until this file region is deallocated.
private val lease = ShuffleHandleLease.acquire(handles)

override def count(): Long = totalSize

Expand Down Expand Up @@ -475,6 +591,6 @@ class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFi
}

override protected def deallocate(): Unit = {
// No resources to release - handles are managed by MultithreadedShuffleBufferCatalog
lease.close()
}
}
Loading
Loading