From 0331aa9299d209d6d4d938b2ce7a2f883bf896d9 Mon Sep 17 00:00:00 2001
From: "Ahmed Hussein (amahussein)"
Date: Thu, 11 Jun 2026 14:10:57 -0500
Subject: [PATCH 1/2] Fix skip-merge shuffle handle lifetime
Signed-off-by: Ahmed Hussein (amahussein)
Fixes #15018
Retain partial shuffle handles while managed buffers, input streams, and Netty file regions are active so `unregisterShuffle` cannot close data under in-flight readers.
Harden lease cleanup and partial-file close handling so releases are exception-safe and interrupted close waits still finish resource cleanup before restoring the interrupt flag.
Add regression coverage for retained-buffer reads, Netty file-region release, and interrupted partial-file cleanup.
---
.../MultithreadedShuffleBufferCatalog.scala | 276 +++++++++++++++---
.../spill/SpillablePartialFileHandle.scala | 107 ++++---
...ltithreadedShuffleBufferCatalogSuite.scala | 274 +++++++++++++++++
.../SpillablePartialFileHandleSuite.scala | 81 ++++-
4 files changed, 642 insertions(+), 96 deletions(-)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala
index 35fd66d93ac..b5fc3e100ac 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala
@@ -24,6 +24,7 @@ import java.util.HashSet
import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
import _root_.io.netty.handler.stream.ChunkedStream
import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle
@@ -39,11 +40,143 @@ import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId}
* @param handle the partial file handle containing the data
* @param offset starting offset within the handle
* @param length number of bytes in this segment
+ * @param handleRef lifecycle reference for the partial file handle
*/
case class PartitionSegment(
handle: SpillablePartialFileHandle,
offset: Long,
- length: Long)
+ length: Long,
+ handleRef: ShuffleHandleReference)
+
+/**
+ * Reference-counted lifecycle wrapper for a partial shuffle file handle.
+ *
+ * Shuffle cleanup can request a close while retained buffers, streams, or file regions are still
+ * reading the handle. The physical close is deferred until all active leases release, allowing
+ * catalog metadata to be removed without closing data under active consumers.
+ */
+private[rapids] final class ShuffleHandleReference(handle: SpillablePartialFileHandle)
+ extends Logging {
+
+ private var refCount: Int = 0
+ private var closeRequested: Boolean = false
+ private var closed: Boolean = false
+ private var closeShuffleId: Int = -1
+
+ def retain(): Unit = synchronized {
+ if (closed) {
+ throw new IllegalStateException("Cannot retain a closed shuffle handle")
+ }
+ refCount += 1
+ }
+
+ def release(): Unit = {
+ val shuffleIdToClose = synchronized {
+ if (refCount <= 0) {
+ throw new IllegalStateException("release() without matching retain()")
+ }
+ refCount -= 1
+ closeShuffleIdIfReady()
+ }
+ shuffleIdToClose.foreach(closeHandle)
+ }
+
+ def requestClose(shuffleId: Int): Unit = {
+ val shuffleIdToClose = synchronized {
+ if (!closeRequested) {
+ closeRequested = true
+ closeShuffleId = shuffleId
+ }
+ closeShuffleIdIfReady()
+ }
+ shuffleIdToClose.foreach(closeHandle)
+ }
+
+ private def closeShuffleIdIfReady(): Option[Int] = {
+ if (closeRequested && refCount == 0 && !closed) {
+ closed = true
+ Some(closeShuffleId)
+ } else {
+ None
+ }
+ }
+
+ private def closeHandle(shuffleId: Int): Unit = {
+ try {
+ handle.close()
+ } catch {
+ case e: InterruptedException =>
+ Thread.currentThread().interrupt()
+ logError(s"Interrupted while closing handle for shuffle $shuffleId", e)
+ case NonFatal(e) =>
+ logError(s"Failed to close handle for shuffle $shuffleId", e)
+ }
+ }
+}
+
+/**
+ * Owns a temporary retain on one or more partial shuffle file handles.
+ *
+ * A lease is used while a buffer, stream, or file region may still read the handles. Closing the
+ * lease releases all retained handles exactly once.
+ */
+private[rapids] final class ShuffleHandleLease(handleRefs: Seq[ShuffleHandleReference])
+ extends AutoCloseable {
+ private var released: Boolean = false
+
+ override def close(): Unit = {
+ val refsToRelease = synchronized {
+ if (released) {
+ Seq.empty
+ } else {
+ released = true
+ handleRefs
+ }
+ }
+ ShuffleHandleLease.releaseAll(refsToRelease.reverseIterator)
+ }
+}
+
+private[rapids] object ShuffleHandleLease {
+ def acquire(handleRefs: Seq[ShuffleHandleReference]): ShuffleHandleLease = {
+ val retained = new ArrayBuffer[ShuffleHandleReference](handleRefs.size)
+ try {
+ handleRefs.foreach { ref =>
+ ref.retain()
+ retained += ref
+ }
+ new ShuffleHandleLease(retained.toSeq)
+ } catch {
+ case t: Throwable =>
+ try {
+ releaseAll(retained.reverseIterator)
+ } catch {
+ case releaseFailure: Throwable =>
+ t.addSuppressed(releaseFailure)
+ }
+ throw t
+ }
+ }
+
+ private def releaseAll(handleRefs: Iterator[ShuffleHandleReference]): Unit = {
+ var firstFailure: Throwable = null
+ handleRefs.foreach { ref =>
+ try {
+ ref.release()
+ } catch {
+ case t: Throwable =>
+ if (firstFailure == null) {
+ firstFailure = t
+ } else {
+ firstFailure.addSuppressed(t)
+ }
+ }
+ }
+ if (firstFailure != null) {
+ throw firstFailure
+ }
+ }
+}
/**
* Catalog for managing shuffle data in MULTITHREADED mode without merging.
@@ -69,6 +202,10 @@ class MultithreadedShuffleBufferCatalog extends Logging {
/** Track active shuffles for cleanup */
private val activeShuffles = new ConcurrentHashMap[Int, JBoolean]()
+ /** Track each unique partial file handle so cleanup can be deferred while readers are active. */
+ private val handleRefs =
+ new ConcurrentHashMap[SpillablePartialFileHandle, ShuffleHandleReference]()
+
/**
* Register a shuffle as active.
* Must be called before adding any partitions for this shuffle.
@@ -99,7 +236,8 @@ class MultithreadedShuffleBufferCatalog extends Logging {
}
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
- val segment = PartitionSegment(handle, offset, length)
+ val handleRef = handleRefs.computeIfAbsent(handle, h => new ShuffleHandleReference(h))
+ val segment = PartitionSegment(handle, offset, length, handleRef)
partitionSegments.compute(blockId, (_, existing) => {
val segments = if (existing == null) new ArrayBuffer[PartitionSegment]() else existing
@@ -217,19 +355,16 @@ class MultithreadedShuffleBufferCatalog extends Logging {
numForcedFileOnly += 1
}
- try {
- handle.close()
- } catch {
- case e: Exception =>
- logError(s"Failed to close handle for shuffle $shuffleId", e)
- }
+ // Drop catalog ownership; retained buffers/streams keep the reference alive.
+ handleRefs.remove(handle, segment.handleRef)
+ segment.handleRef.requestClose(shuffleId)
}
}
}
}
- logDebug(s"Unregistered shuffle $shuffleId: closed ${closedHandles.size()} handles, " +
- s"bytesFromMemory=$bytesFromMemory, bytesFromDisk=$bytesFromDisk, " +
+ logDebug(s"Unregistered shuffle $shuffleId: cleanup requested for ${closedHandles.size()} " +
+ s"handles, bytesFromMemory=$bytesFromMemory, bytesFromDisk=$bytesFromDisk, " +
s"numExpansions=$numExpansions, numSpills=$numSpills, numForcedFileOnly=$numForcedFileOnly")
// Return statistics if we had any data
@@ -252,50 +387,79 @@ class MultithreadedShuffleBufferCatalog extends Logging {
*/
class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBuffer {
+ private val handleRefs: Seq[ShuffleHandleReference] = segments.map(_.handleRef).distinct
+
+ /** Guards bufferLeases while retain()/release() can be called from different threads. */
+ private val retainLock = new Object
+
+ /** Leases that keep this buffer's partial shuffle file handles open after retain(). */
+ private val bufferLeases = new ArrayBuffer[ShuffleHandleLease]()
+
override def size(): Long = segments.map(_.length).sum
override def nioByteBuffer(): ByteBuffer = {
- // This method loads all data into memory. It's required by the ManagedBuffer interface
- // but is NOT used in the network transfer path - Spark's network layer uses
- // convertToNetty() which returns our streaming MultiSegmentFileRegion.
- // This method may be called by other code paths (e.g., local block reading).
- val totalSize = size().toInt
- val buffer = ByteBuffer.allocate(totalSize)
- val bytes = new Array[Byte](8192) // Read buffer
-
- segments.foreach { segment =>
- var remaining = segment.length
- var position = segment.offset
- while (remaining > 0) {
- val toRead = math.min(remaining, bytes.length).toInt
- val bytesRead = segment.handle.readAt(position, bytes, 0, toRead)
- if (bytesRead <= 0) {
- throw new IOException(
- s"Unexpected EOF reading segment at position $position, " +
- s"expected ${segment.length} bytes")
+ val lease = ShuffleHandleLease.acquire(handleRefs)
+ try {
+ // This method loads all data into memory. It's required by the ManagedBuffer interface
+ // but is NOT used in the network transfer path - Spark's network layer uses
+ // convertToNetty() which returns our streaming MultiSegmentFileRegion.
+ // This method may be called by other code paths (e.g., local block reading).
+ val totalSize = size().toInt
+ val buffer = ByteBuffer.allocate(totalSize)
+ val bytes = new Array[Byte](8192) // Read buffer
+
+ segments.foreach { segment =>
+ var remaining = segment.length
+ var position = segment.offset
+ while (remaining > 0) {
+ val toRead = math.min(remaining, bytes.length).toInt
+ val bytesRead = segment.handle.readAt(position, bytes, 0, toRead)
+ if (bytesRead <= 0) {
+ throw new IOException(
+ s"Unexpected EOF reading segment at position $position, " +
+ s"expected ${segment.length} bytes")
+ }
+ buffer.put(bytes, 0, bytesRead)
+ position += bytesRead
+ remaining -= bytesRead
}
- buffer.put(bytes, 0, bytesRead)
- position += bytesRead
- remaining -= bytesRead
}
- }
- buffer.flip()
- buffer
+ buffer.flip()
+ buffer
+ } finally {
+ lease.close()
+ }
}
override def createInputStream(): InputStream = {
- new MultiSegmentInputStream(segments)
+ new MultiSegmentInputStream(segments, handleRefs)
}
- override def retain(): ManagedBuffer = this
+ override def retain(): ManagedBuffer = {
+ val lease = ShuffleHandleLease.acquire(handleRefs)
+ retainLock.synchronized {
+ bufferLeases += lease
+ }
+ this
+ }
- override def release(): ManagedBuffer = this
+ override def release(): ManagedBuffer = {
+ val lease = retainLock.synchronized {
+ if (bufferLeases.nonEmpty) {
+ Some(bufferLeases.remove(bufferLeases.size - 1))
+ } else {
+ None
+ }
+ }
+ lease.foreach(_.close())
+ this
+ }
override def convertToNetty(): AnyRef = {
// Return a custom FileRegion that streams data in chunks, avoiding loading all
// data into memory at once. This addresses concerns about large shuffle blocks.
- new MultiSegmentFileRegion(segments)
+ new MultiSegmentFileRegion(segments, handleRefs)
}
// Spark 4.0+ adds convertToNettyForSsl() abstract method.
@@ -314,12 +478,20 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu
/**
* An InputStream that reads from multiple partition segments sequentially.
+ *
+ * This stream is not thread-safe; callers should create one stream per reading thread and close it
+ * from that owner thread.
*/
-class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStream {
+class MultiSegmentInputStream(
+ segments: Seq[PartitionSegment],
+ handleRefs: Seq[ShuffleHandleReference]) extends InputStream {
private var currentSegmentIndex: Int = 0
private var currentPosition: Long = if (segments.nonEmpty) segments.head.offset else 0
private var bytesReadInCurrentSegment: Long = 0
+ // Keeps the partial shuffle file handles open until this stream is closed.
+ private val lease = ShuffleHandleLease.acquire(handleRefs)
+ private var closed: Boolean = false
override def read(): Int = {
val buf = new Array[Byte](1)
@@ -328,6 +500,10 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}
override def read(b: Array[Byte], off: Int, len: Int): Int = {
+ if (closed) {
+ throw new IOException("Stream is closed")
+ }
+
// Use loop instead of recursion to avoid StackOverflowError with many segments
while (currentSegmentIndex < segments.size) {
val segment = segments(currentSegmentIndex)
@@ -358,7 +534,7 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}
override def available(): Int = {
- if (currentSegmentIndex >= segments.size) {
+ if (closed || currentSegmentIndex >= segments.size) {
0
} else {
val remaining = segments.drop(currentSegmentIndex).map { seg =>
@@ -372,13 +548,11 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
}
}
- /**
- * Close is a no-op because the underlying SpillablePartialFileHandle resources
- * are managed by MultithreadedShuffleBufferCatalog and will be closed when
- * the shuffle is unregistered.
- */
override def close(): Unit = {
- // No-op: handles are managed by MultithreadedShuffleBufferCatalog
+ if (!closed) {
+ closed = true
+ lease.close()
+ }
}
}
@@ -392,8 +566,12 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre
* Spark's MessageWithHeader only accepts ByteBuf or FileRegion. By implementing
* FileRegion, we can provide streaming transfer while remaining compatible with
* Spark's network layer.
+ *
+ * Instances are not thread-safe; each transfer should use one FileRegion owned by one Netty write.
*/
-class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFileRegion {
+class MultiSegmentFileRegion(
+ segments: Seq[PartitionSegment],
+ handleRefs: Seq[ShuffleHandleReference]) extends AbstractFileRegion {
private val totalSize: Long = segments.map(_.length).sum
private var totalTransferred: Long = 0
@@ -407,6 +585,8 @@ class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFi
// Track current position within the logical data stream
private var currentSegmentIndex: Int = 0
private var bytesTransferredInCurrentSegment: Long = 0
+ // Keeps the partial shuffle file handles open until this file region is deallocated.
+ private val lease = ShuffleHandleLease.acquire(handleRefs)
override def count(): Long = totalSize
@@ -475,6 +655,6 @@ class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFi
}
override protected def deallocate(): Unit = {
- // No resources to release - handles are managed by MultithreadedShuffleBufferCatalog
+ lease.close()
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala
index 52dfd286773..99c76cd930a 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala
@@ -731,58 +731,72 @@ class SpillablePartialFileHandle private (
* throughout the entire lifecycle (write phase + read phase), we count it as saved.
*/
override private[spill] def doClose(): Unit = {
- // Collect resources to close under lock, then close them outside lock
- val (bos, fos, bis, fis, fc, raf) = synchronized {
- // Wait for any in-progress spill to complete before closing buffer
- while (spillInProgress) {
- wait()
- }
+ var interrupted = false
+ try {
+ // Collect resources to close under lock, then close them outside lock
+ val (bos, fos, bis, fis, fc, raf) = synchronized {
+ // Wait for any in-progress spill to complete before closing buffer.
+ // close() has already committed this handle's closed state, so cleanup must finish even
+ // if this thread is interrupted while waiting. Restore the interrupt flag after cleanup.
+ while (spillInProgress) {
+ try {
+ wait()
+ } catch {
+ case _: InterruptedException =>
+ interrupted = true
+ }
+ }
- // Record disk write savings for ESS + multi-batch merge scenario.
- // When ESS is enabled with multiple batches, partial files are merged into
- // a final file. If a partial file stayed in memory (not spilled), it avoided
- // an intermediate disk write. The task is still running during merge, so
- // GpuTaskMetrics.get is valid.
- if (storageMode == PartialFileStorageMode.MEMORY_WITH_SPILL &&
- !spilledToDisk && totalBytesWritten > 0 && !diskWriteSavingsRecorded) {
- GpuTaskMetrics.get.addDiskWriteSaved(totalBytesWritten)
- diskWriteSavingsRecorded = true
- logDebug(s"Recorded disk write savings in doClose: $totalBytesWritten bytes")
- }
+ // Record disk write savings for ESS + multi-batch merge scenario.
+ // When ESS is enabled with multiple batches, partial files are merged into
+ // a final file. If a partial file stayed in memory (not spilled), it avoided
+ // an intermediate disk write. The task is still running during merge, so
+ // GpuTaskMetrics.get is valid.
+ if (storageMode == PartialFileStorageMode.MEMORY_WITH_SPILL &&
+ !spilledToDisk && totalBytesWritten > 0 && !diskWriteSavingsRecorded) {
+ GpuTaskMetrics.get.addDiskWriteSaved(totalBytesWritten)
+ diskWriteSavingsRecorded = true
+ logDebug(s"Recorded disk write savings in doClose: $totalBytesWritten bytes")
+ }
- // Collect streams/channels to close
- val result = (bufferedOutputStream, fileOutputStream,
- bufferedInputStream, fileInputStream, fileChannel, randomAccessFile)
+ // Collect streams/channels to close
+ val result = (bufferedOutputStream, fileOutputStream,
+ bufferedInputStream, fileInputStream, fileChannel, randomAccessFile)
- // Clear references
- bufferedOutputStream = None
- fileOutputStream = None
- bufferedInputStream = None
- fileInputStream = None
- fileChannel = None
- randomAccessFile = None
+ // Clear references
+ bufferedOutputStream = None
+ fileOutputStream = None
+ bufferedInputStream = None
+ fileInputStream = None
+ fileChannel = None
+ randomAccessFile = None
- // Release host buffer (removes from SpillFramework tracking and closes buffer)
- releaseHostResource()
+ // Release host buffer (removes from SpillFramework tracking and closes buffer)
+ releaseHostResource()
- result
- }
+ result
+ }
- // Close streams outside lock (IO operations can be slow)
- tryClose(bos, "bufferedOutputStream")
- tryClose(fos, "fileOutputStream")
- tryClose(bis, "bufferedInputStream")
- tryClose(fis, "fileInputStream")
- tryClose(fc, "fileChannel")
- tryClose(raf, "randomAccessFile")
-
- // Delete file if it exists
- if (file != null && file.exists()) {
- try {
- file.delete()
- } catch {
- case e: Exception =>
- logWarning(s"Failed to delete file ${file.getAbsolutePath}", e)
+ // Close streams outside lock (IO operations can be slow)
+ tryClose(bos, "bufferedOutputStream")
+ tryClose(fos, "fileOutputStream")
+ tryClose(bis, "bufferedInputStream")
+ tryClose(fis, "fileInputStream")
+ tryClose(fc, "fileChannel")
+ tryClose(raf, "randomAccessFile")
+
+ // Delete file if it exists
+ if (file != null && file.exists()) {
+ try {
+ file.delete()
+ } catch {
+ case e: Exception =>
+ logWarning(s"Failed to delete file ${file.getAbsolutePath}", e)
+ }
+ }
+ } finally {
+ if (interrupted) {
+ Thread.currentThread().interrupt()
}
}
}
@@ -860,4 +874,3 @@ object SpillablePartialFileHandle extends Logging {
capacityHintProvider = capacityHintProvider)
}
}
-
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala
index 4266654c9f4..7ff036563e7 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala
@@ -16,6 +16,15 @@
package com.nvidia.spark.rapids
+import java.io.{File, InputStream}
+import java.nio.channels.ClosedChannelException
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
+
+import _root_.io.netty.channel.FileRegion
import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito._
@@ -23,6 +32,7 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.scalatestplus.mockito.MockitoSugar
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId}
class MultithreadedShuffleBufferCatalogSuite
@@ -188,6 +198,72 @@ class MultithreadedShuffleBufferCatalogSuite
catalog.unregisterShuffle(1)
}
+ // ------------------------------------------------------------------------------------------
+ // Regression test: a retained file-backed (FILE_ONLY) skip-merge shuffle buffer must stay
+ // readable while shuffle cleanup unregisters the catalog entry. The test retains a buffer,
+ // starts concurrent stream readers, unregisters the shuffle while reads are in flight, and
+ // verifies that the active readers do not observe a closed backing channel.
+ //
+ // The race depends on cleanup landing while a reader is using the cached FileChannel, so the
+ // test oversubscribes readers and uses small, frequent reads with bounded retries.
+
+ test("retained skip-merge buffer stays readable across concurrent unregisterShuffle") {
+ val race = MultithreadedShuffleBufferCatalogSuite.RetainedBufferReadRace
+ var bug: Option[Throwable] = None
+ var otherError: Option[Throwable] = None
+ var sawReadsInFlight = false
+ var sawReadsAfterUnregister = false
+ var leakedReaders = 0
+ var iteration = 0
+ while (bug.isEmpty && iteration < race.MaxIterations) {
+ iteration += 1
+ val result = race.attempt(iteration)
+ bug = result.target
+ if (otherError.isEmpty) otherError = result.otherError
+ sawReadsInFlight ||= result.startedOk && result.readsBeforeUnregister > 0
+ sawReadsAfterUnregister ||= result.readsAfterUnregister > 0
+ leakedReaders += result.leakedReaders
+ }
+
+ bug match {
+ case Some(closed) =>
+ // BUG: an active read was closed underneath the reader. Surface the real exception so the
+ // failure is the ClosedChannelException from the read path, not a synthetic assertion.
+ // Expected to FAIL on unmodified `main`.
+ throw closed
+ case None =>
+ // Post-fix expectation: confirm the scenario actually ran and that reads SURVIVED
+ // unregisterShuffle, so a future change cannot make this test pass for the wrong reason.
+ assert(sawReadsInFlight,
+ "reproducer never observed reads before unregisterShuffle; scenario did not run")
+ otherError.foreach(e => throw e)
+ assert(sawReadsAfterUnregister,
+ "no reads completed after unregisterShuffle; the retained buffer was not readable")
+ assert(leakedReaders == 0, s"$leakedReaders reader thread(s) did not stop after the test")
+ info(s"retained reads survived unregisterShuffle across $iteration iterations")
+ }
+ }
+
+ test("convertToNetty release closes retained handle once") {
+ val catalog = new MultithreadedShuffleBufferCatalog()
+ val handle = createMockHandle()
+
+ catalog.registerShuffle(1)
+ catalog.addPartition(1, 0L, 0, handle, 0, 100)
+
+ val buffer = catalog.getMergedBuffer(ShuffleBlockId(1, 0L, 0))
+ val region = buffer.convertToNetty().asInstanceOf[FileRegion]
+
+ catalog.unregisterShuffle(1)
+ verify(handle, never()).close()
+
+ assert(region.release())
+ verify(handle, times(1)).close()
+
+ buffer.release()
+ verify(handle, times(1)).close()
+ }
+
private def createMockHandle(): SpillablePartialFileHandle = {
val handle = mock[SpillablePartialFileHandle]
handle
@@ -213,3 +289,201 @@ class MultithreadedShuffleBufferCatalogSuite
}
}
+object MultithreadedShuffleBufferCatalogSuite {
+ private object RetainedBufferReadRace {
+ // Oversubscribe readers (bounded) so the close reliably lands in a reader's channel-read path.
+ private val ReaderThreads: Int =
+ math.min(128, math.max(64, Runtime.getRuntime.availableProcessors() * 4))
+ private val ReadBufferBytes: Int = 4 * 1024 // small reads => very frequent readAt calls
+ private val BackingFileBytes: Int = 2 * 1024 * 1024 // 2 MB single-segment backing file
+ private val AwaitSeconds: Long = 30L
+ private val PostUnregisterMillis: Long = 200L
+
+ val MaxIterations: Int = 20
+
+ case class Result(
+ target: Option[Throwable],
+ otherError: Option[Throwable],
+ readsBeforeUnregister: Long,
+ readsAfterUnregister: Long,
+ startedOk: Boolean,
+ leakedReaders: Int)
+
+ /** Runs one race iteration and reports what the readers saw. */
+ def attempt(iteration: Int): Result = {
+ val shuffleId = iteration
+ val mapId = 0L
+ val reduceId = 0
+ val backingFile = File.createTempFile("skipmerge-catalog-repro-", ".data")
+
+ val catalog = new MultithreadedShuffleBufferCatalog()
+ val handle = SpillablePartialFileHandle.createFileOnly(backingFile)
+ val stop = new AtomicBoolean(false)
+ val afterUnregister = new AtomicBoolean(false)
+ val started = new CountDownLatch(ReaderThreads)
+ val readsBefore = new AtomicLong(0L)
+ val readsAfter = new AtomicLong(0L)
+ val readerErrors = new ConcurrentLinkedQueue[Throwable]()
+ val readers = new ArrayBuffer[Thread](ReaderThreads)
+ var buffer: ManagedBuffer = null
+ var startedOk = false
+ var leakedReaders = 0
+ var targetBeforeCleanup: Option[Throwable] = None
+ var otherErrorBeforeCleanup: Option[Throwable] = None
+
+ try {
+ // Publish a multi-MB file-only handle as a single whole-file segment, like the writer does.
+ writeBackingData(handle, BackingFileBytes)
+ handle.finishWrite()
+ catalog.registerShuffle(shuffleId)
+ catalog.addPartition(shuffleId, mapId, reduceId, handle, 0L, BackingFileBytes.toLong)
+
+ // A reducer retains the buffer before handing it to readers; this retained lease is the
+ // lifecycle guarantee the regression test protects.
+ buffer = catalog.getMergedBuffer(ShuffleBlockId(shuffleId, mapId, reduceId))
+ buffer.retain()
+ val readBuffer = buffer
+
+ (0 until ReaderThreads).foreach { idx =>
+ // A Runnable (not a Thread subclass) so the local `stop` flag is not shadowed by the
+ // inherited Thread.stop() member.
+ val body = new Runnable {
+ override def run(): Unit = {
+ var in: InputStream = readBuffer.createInputStream()
+ val buf = new Array[Byte](ReadBufferBytes)
+ started.countDown()
+ try {
+ while (!stop.get()) {
+ val n = in.read(buf, 0, buf.length)
+ if (n < 0) {
+ in.close()
+ in = readBuffer.createInputStream()
+ } else if (afterUnregister.get()) {
+ readsAfter.incrementAndGet()
+ } else {
+ readsBefore.incrementAndGet()
+ }
+ }
+ } catch {
+ case NonFatal(t) => readerErrors.add(t)
+ } finally {
+ try { in.close() } catch { case NonFatal(_) => () }
+ }
+ }
+ }
+ val reader = new Thread(body, s"skipmerge-catalog-reader-$iteration-$idx")
+ reader.setDaemon(true)
+ readers += reader
+ reader.start()
+ }
+
+ // Make sure reads are flowing before the close, so it lands in a reader's read path.
+ startedOk = started.await(AwaitSeconds, TimeUnit.SECONDS) &&
+ awaitReadsInFlight(readsBefore, ReaderThreads.toLong)
+
+ // Cleanup thread closes the handle underneath the active readers; then give readers a
+ // short window to read the retained buffer post-unregister (on `main` they hit the
+ // closed channel; once fixed they keep reading, counted in readsAfter).
+ catalog.unregisterShuffle(shuffleId)
+ afterUnregister.set(true)
+ awaitReadsAfterUnregister(readsAfter, readers, ReaderThreads.toLong)
+ } finally {
+ stop.set(true)
+ leakedReaders = joinAll(readers)
+ targetBeforeCleanup = firstMatching(readerErrors, isClosedChannelFromReadPath)
+ otherErrorBeforeCleanup =
+ firstMatching(readerErrors, t => !isClosedChannelFromReadPath(t))
+ // Release the retained buffer before the last-resort handle.close() guard, giving the
+ // catalog's deferred close path the first chance to close the handle.
+ if (buffer != null) {
+ try { buffer.release() } catch { case NonFatal(_) => () }
+ }
+ try { handle.close() } catch { case NonFatal(_) => () }
+ if (backingFile.exists()) {
+ backingFile.delete()
+ }
+ }
+
+ Result(
+ target = targetBeforeCleanup,
+ otherError = otherErrorBeforeCleanup,
+ readsBeforeUnregister = readsBefore.get(),
+ readsAfterUnregister = readsAfter.get(),
+ startedOk = startedOk,
+ leakedReaders = leakedReaders)
+ }
+
+ private def writeBackingData(handle: SpillablePartialFileHandle, totalBytes: Int): Unit = {
+ val chunk = new Array[Byte](1024 * 1024)
+ var i = 0
+ while (i < chunk.length) {
+ chunk(i) = (i & 0xFF).toByte
+ i += 1
+ }
+ var written = 0
+ while (written < totalBytes) {
+ val toWrite = math.min(chunk.length, totalBytes - written)
+ handle.write(chunk, 0, toWrite)
+ written += toWrite
+ }
+ }
+
+ /** Waits until reads are flowing; returns true if the target read count was reached in time. */
+ private def awaitReadsInFlight(reads: AtomicLong, target: Long): Boolean = {
+ val deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(AwaitSeconds)
+ while (reads.get() < target && System.nanoTime() < deadline) {
+ Thread.sleep(1L)
+ }
+ reads.get() >= target
+ }
+
+ /** Lets readers attempt reads after unregister; stops early once enough succeed or all exit. */
+ private def awaitReadsAfterUnregister(
+ readsAfter: AtomicLong, readers: ArrayBuffer[Thread], target: Long): Unit = {
+ val deadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(PostUnregisterMillis)
+ while (readsAfter.get() < target &&
+ System.nanoTime() < deadline && readers.exists(_.isAlive)) {
+ Thread.sleep(1L)
+ }
+ }
+
+ /** Stops/joins all readers under one shared deadline; returns the count still alive. */
+ private def joinAll(readers: ArrayBuffer[Thread]): Int = {
+ val deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(AwaitSeconds)
+ readers.foreach { reader =>
+ val remainingMs = TimeUnit.NANOSECONDS.toMillis(deadline - System.nanoTime())
+ if (remainingMs > 0) {
+ try {
+ reader.join(remainingMs)
+ } catch {
+ case _: InterruptedException => Thread.currentThread().interrupt()
+ }
+ }
+ }
+ readers.count(_.isAlive)
+ }
+
+ private def firstMatching(
+ errors: ConcurrentLinkedQueue[Throwable], p: Throwable => Boolean): Option[Throwable] = {
+ var found: Option[Throwable] = None
+ val it = errors.iterator()
+ while (found.isEmpty && it.hasNext) {
+ val t = it.next()
+ if (p(t)) {
+ found = Some(t)
+ }
+ }
+ found
+ }
+
+ private def isClosedChannelFromReadPath(t: Throwable): Boolean = {
+ // AsynchronousCloseException (channel closed while a read is in flight) is a subclass of
+ // ClosedChannelException, so this covers both.
+ t.isInstanceOf[ClosedChannelException] && t.getStackTrace.exists { frame =>
+ frame.getClassName.endsWith("SpillablePartialFileHandle") &&
+ (frame.getMethodName.contains("readFromFileChannel") ||
+ frame.getMethodName.contains("readAt"))
+ }
+ }
+ }
+}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala
index 7f0b59ca059..0c082a70bac 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala
@@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.spill
import java.io.File
import java.util.Arrays
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsConf
@@ -28,6 +29,11 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac
// Use 1GB max buffer size for tests to avoid memory issues on test machines
private val testMaxBufferSize = 1L * 1024 * 1024 * 1024
+ private val spillInProgressField = {
+ val field = classOf[SpillablePartialFileHandle].getDeclaredField("spillInProgress")
+ field.setAccessible(true)
+ field
+ }
override def beforeEach(): Unit = {
super.beforeEach()
@@ -42,6 +48,18 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac
super.afterEach()
}
+ private def setSpillInProgress(handle: SpillablePartialFileHandle, value: Boolean): Unit = {
+ spillInProgressField.setBoolean(handle, value)
+ }
+
+ private def waitUntil(condition: => Boolean, clue: String): Unit = {
+ val deadline = System.currentTimeMillis() + 5000
+ while (!condition && System.currentTimeMillis() < deadline) {
+ Thread.sleep(10)
+ }
+ assert(condition, clue)
+ }
+
test("FILE_ONLY mode: write and read") {
val tempFile = File.createTempFile("test-file-only-", ".tmp")
@@ -222,6 +240,68 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac
}
}
+ test("MEMORY_WITH_SPILL mode: interrupted close still completes cleanup") {
+ val tempFile = File.createTempFile("test-interrupted-close-", ".tmp")
+ val handle = SpillablePartialFileHandle.createMemoryWithSpill(
+ initialCapacity = 1024,
+ maxBufferSize = testMaxBufferSize,
+ memoryThreshold = 0.5,
+ spillFile = tempFile)
+ var closeThread: Thread = null
+
+ try {
+ val testData = "data retained until close".getBytes("UTF-8")
+ handle.write(testData, 0, testData.length)
+ handle.finishWrite()
+ assert(handle.host.nonEmpty)
+ assert(tempFile.exists())
+
+ handle.synchronized {
+ setSpillInProgress(handle, true)
+ }
+
+ val closeReturnedWithInterrupt = new AtomicBoolean(false)
+ val closeError = new AtomicReference[Throwable]()
+ closeThread = new Thread("test-interrupted-partial-file-close") {
+ override def run(): Unit = {
+ try {
+ handle.close()
+ closeReturnedWithInterrupt.set(Thread.currentThread().isInterrupted)
+ } catch {
+ case t: Throwable => closeError.set(t)
+ }
+ }
+ }
+
+ closeThread.start()
+ waitUntil(closeThread.getState == Thread.State.WAITING,
+ "close thread did not wait for spillInProgress")
+ closeThread.interrupt()
+
+ handle.synchronized {
+ setSpillInProgress(handle, false)
+ handle.notifyAll()
+ }
+ closeThread.join(5000)
+
+ assert(!closeThread.isAlive)
+ assert(closeError.get() == null)
+ assert(closeReturnedWithInterrupt.get())
+ assert(handle.host.isEmpty)
+ assert(!tempFile.exists())
+ } finally {
+ handle.synchronized {
+ setSpillInProgress(handle, false)
+ handle.notifyAll()
+ }
+ if (closeThread != null && closeThread.isAlive) {
+ closeThread.join(5000)
+ }
+ handle.close()
+ tempFile.delete()
+ }
+ }
+
test("MEMORY_WITH_SPILL mode: sequential write with single bytes") {
val tempFile = File.createTempFile("test-single-bytes-", ".tmp")
@@ -573,4 +653,3 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac
}
}
}
-
From 6b755648376c18df55051a6828acd757fcac33af Mon Sep 17 00:00:00 2001
From: "Ahmed Hussein (amahussein)"
Date: Wed, 24 Jun 2026 13:22:31 -0500
Subject: [PATCH 2/2] Refactor the code moving handler inside
Signed-off-by: Ahmed Hussein (amahussein)
Fixes #15018
Reference-count partial shuffle handles on the handle itself so that while managed buffers,
input streams, and Netty file regions are active, unregisterShuffle (and any other close
caller) defers the physical close instead of freeing data under in-flight readers. close()
is an idempotent close-request; acquireRead/releaseRead track active readers and the last
release performs the deferred close.
Release is exception-safe, and an interrupted close while waiting on an in-progress spill
still finishes resource cleanup before restoring the interrupt flag.
Add regression coverage for the handle read-lease lifecycle (deferred close, repeated close
while a lease is held, immediate close with no leases), retained-buffer reads across
concurrent unregisterShuffle, Netty file-region release, and interrupted partial-file cleanup.
---
.../MultithreadedShuffleBufferCatalog.scala | 140 +++++-------------
.../spill/SpillablePartialFileHandle.scala | 99 +++++++++++++
...ltithreadedShuffleBufferCatalogSuite.scala | 48 +++---
.../SpillablePartialFileHandleSuite.scala | 87 +++++++++++
4 files changed, 255 insertions(+), 119 deletions(-)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala
index b5fc3e100ac..ad7f24cdf6d 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala
@@ -40,110 +40,44 @@ import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId}
* @param handle the partial file handle containing the data
* @param offset starting offset within the handle
* @param length number of bytes in this segment
- * @param handleRef lifecycle reference for the partial file handle
*/
case class PartitionSegment(
handle: SpillablePartialFileHandle,
offset: Long,
- length: Long,
- handleRef: ShuffleHandleReference)
+ length: Long)
/**
- * Reference-counted lifecycle wrapper for a partial shuffle file handle.
+ * Owns a temporary read lease on one or more partial shuffle file handles.
*
- * Shuffle cleanup can request a close while retained buffers, streams, or file regions are still
- * reading the handle. The physical close is deferred until all active leases release, allowing
- * catalog metadata to be removed without closing data under active consumers.
+ * A lease is held while a buffer, stream, or file region may still read the handles. Closing the
+ * lease releases every handle exactly once; each handle defers its physical close until its last
+ * lease is released (see `SpillablePartialFileHandle.acquireRead`/`releaseRead`).
*/
-private[rapids] final class ShuffleHandleReference(handle: SpillablePartialFileHandle)
- extends Logging {
-
- private var refCount: Int = 0
- private var closeRequested: Boolean = false
- private var closed: Boolean = false
- private var closeShuffleId: Int = -1
-
- def retain(): Unit = synchronized {
- if (closed) {
- throw new IllegalStateException("Cannot retain a closed shuffle handle")
- }
- refCount += 1
- }
-
- def release(): Unit = {
- val shuffleIdToClose = synchronized {
- if (refCount <= 0) {
- throw new IllegalStateException("release() without matching retain()")
- }
- refCount -= 1
- closeShuffleIdIfReady()
- }
- shuffleIdToClose.foreach(closeHandle)
- }
-
- def requestClose(shuffleId: Int): Unit = {
- val shuffleIdToClose = synchronized {
- if (!closeRequested) {
- closeRequested = true
- closeShuffleId = shuffleId
- }
- closeShuffleIdIfReady()
- }
- shuffleIdToClose.foreach(closeHandle)
- }
-
- private def closeShuffleIdIfReady(): Option[Int] = {
- if (closeRequested && refCount == 0 && !closed) {
- closed = true
- Some(closeShuffleId)
- } else {
- None
- }
- }
-
- private def closeHandle(shuffleId: Int): Unit = {
- try {
- handle.close()
- } catch {
- case e: InterruptedException =>
- Thread.currentThread().interrupt()
- logError(s"Interrupted while closing handle for shuffle $shuffleId", e)
- case NonFatal(e) =>
- logError(s"Failed to close handle for shuffle $shuffleId", e)
- }
- }
-}
-
-/**
- * Owns a temporary retain on one or more partial shuffle file handles.
- *
- * A lease is used while a buffer, stream, or file region may still read the handles. Closing the
- * lease releases all retained handles exactly once.
- */
-private[rapids] final class ShuffleHandleLease(handleRefs: Seq[ShuffleHandleReference])
+private[rapids] final class ShuffleHandleLease(handles: Seq[SpillablePartialFileHandle])
extends AutoCloseable {
private var released: Boolean = false
override def close(): Unit = {
- val refsToRelease = synchronized {
+ val handlesToRelease = synchronized {
if (released) {
Seq.empty
} else {
released = true
- handleRefs
+ handles
}
}
- ShuffleHandleLease.releaseAll(refsToRelease.reverseIterator)
+ ShuffleHandleLease.releaseAll(handlesToRelease.reverseIterator)
}
}
private[rapids] object ShuffleHandleLease {
- def acquire(handleRefs: Seq[ShuffleHandleReference]): ShuffleHandleLease = {
- val retained = new ArrayBuffer[ShuffleHandleReference](handleRefs.size)
+ /** Acquire a read lease on every handle, rolling back the partial set if any acquire fails. */
+ def acquire(handles: Seq[SpillablePartialFileHandle]): ShuffleHandleLease = {
+ val retained = new ArrayBuffer[SpillablePartialFileHandle](handles.size)
try {
- handleRefs.foreach { ref =>
- ref.retain()
- retained += ref
+ handles.foreach { handle =>
+ handle.acquireRead()
+ retained += handle
}
new ShuffleHandleLease(retained.toSeq)
} catch {
@@ -158,11 +92,11 @@ private[rapids] object ShuffleHandleLease {
}
}
- private def releaseAll(handleRefs: Iterator[ShuffleHandleReference]): Unit = {
+ private def releaseAll(handles: Iterator[SpillablePartialFileHandle]): Unit = {
var firstFailure: Throwable = null
- handleRefs.foreach { ref =>
+ handles.foreach { handle =>
try {
- ref.release()
+ handle.releaseRead()
} catch {
case t: Throwable =>
if (firstFailure == null) {
@@ -202,10 +136,6 @@ class MultithreadedShuffleBufferCatalog extends Logging {
/** Track active shuffles for cleanup */
private val activeShuffles = new ConcurrentHashMap[Int, JBoolean]()
- /** Track each unique partial file handle so cleanup can be deferred while readers are active. */
- private val handleRefs =
- new ConcurrentHashMap[SpillablePartialFileHandle, ShuffleHandleReference]()
-
/**
* Register a shuffle as active.
* Must be called before adding any partitions for this shuffle.
@@ -236,8 +166,7 @@ class MultithreadedShuffleBufferCatalog extends Logging {
}
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
- val handleRef = handleRefs.computeIfAbsent(handle, h => new ShuffleHandleReference(h))
- val segment = PartitionSegment(handle, offset, length, handleRef)
+ val segment = PartitionSegment(handle, offset, length)
partitionSegments.compute(blockId, (_, existing) => {
val segments = if (existing == null) new ArrayBuffer[PartitionSegment]() else existing
@@ -355,9 +284,16 @@ class MultithreadedShuffleBufferCatalog extends Logging {
numForcedFileOnly += 1
}
- // Drop catalog ownership; retained buffers/streams keep the reference alive.
- handleRefs.remove(handle, segment.handleRef)
- segment.handleRef.requestClose(shuffleId)
+ // Drop catalog ownership; retained buffers, streams, and file regions keep the handle
+ // alive through their read leases, so the physical close is deferred until the last
+ // lease is released. close() propagates failures, so catch here so one bad handle
+ // does not abort cleanup of the rest.
+ try {
+ handle.close()
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Failed to request close of handle for shuffle $shuffleId", e)
+ }
}
}
}
@@ -387,7 +323,7 @@ class MultithreadedShuffleBufferCatalog extends Logging {
*/
class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBuffer {
- private val handleRefs: Seq[ShuffleHandleReference] = segments.map(_.handleRef).distinct
+ private val handles: Seq[SpillablePartialFileHandle] = segments.map(_.handle).distinct
/** Guards bufferLeases while retain()/release() can be called from different threads. */
private val retainLock = new Object
@@ -398,7 +334,7 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu
override def size(): Long = segments.map(_.length).sum
override def nioByteBuffer(): ByteBuffer = {
- val lease = ShuffleHandleLease.acquire(handleRefs)
+ val lease = ShuffleHandleLease.acquire(handles)
try {
// This method loads all data into memory. It's required by the ManagedBuffer interface
// but is NOT used in the network transfer path - Spark's network layer uses
@@ -433,11 +369,11 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu
}
override def createInputStream(): InputStream = {
- new MultiSegmentInputStream(segments, handleRefs)
+ new MultiSegmentInputStream(segments, handles)
}
override def retain(): ManagedBuffer = {
- val lease = ShuffleHandleLease.acquire(handleRefs)
+ val lease = ShuffleHandleLease.acquire(handles)
retainLock.synchronized {
bufferLeases += lease
}
@@ -459,7 +395,7 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu
override def convertToNetty(): AnyRef = {
// Return a custom FileRegion that streams data in chunks, avoiding loading all
// data into memory at once. This addresses concerns about large shuffle blocks.
- new MultiSegmentFileRegion(segments, handleRefs)
+ new MultiSegmentFileRegion(segments, handles)
}
// Spark 4.0+ adds convertToNettyForSsl() abstract method.
@@ -484,13 +420,13 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu
*/
class MultiSegmentInputStream(
segments: Seq[PartitionSegment],
- handleRefs: Seq[ShuffleHandleReference]) extends InputStream {
+ handles: Seq[SpillablePartialFileHandle]) extends InputStream {
private var currentSegmentIndex: Int = 0
private var currentPosition: Long = if (segments.nonEmpty) segments.head.offset else 0
private var bytesReadInCurrentSegment: Long = 0
// Keeps the partial shuffle file handles open until this stream is closed.
- private val lease = ShuffleHandleLease.acquire(handleRefs)
+ private val lease = ShuffleHandleLease.acquire(handles)
private var closed: Boolean = false
override def read(): Int = {
@@ -571,7 +507,7 @@ class MultiSegmentInputStream(
*/
class MultiSegmentFileRegion(
segments: Seq[PartitionSegment],
- handleRefs: Seq[ShuffleHandleReference]) extends AbstractFileRegion {
+ handles: Seq[SpillablePartialFileHandle]) extends AbstractFileRegion {
private val totalSize: Long = segments.map(_.length).sum
private var totalTransferred: Long = 0
@@ -586,7 +522,7 @@ class MultiSegmentFileRegion(
private var currentSegmentIndex: Int = 0
private var bytesTransferredInCurrentSegment: Long = 0
// Keeps the partial shuffle file handles open until this file region is deallocated.
- private val lease = ShuffleHandleLease.acquire(handleRefs)
+ private val lease = ShuffleHandleLease.acquire(handles)
override def count(): Long = totalSize
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala
index 99c76cd930a..f4584339cc5 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala
@@ -20,6 +20,8 @@ import java.io.{BufferedInputStream, BufferedOutputStream, File, FileInputStream
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
+import scala.util.control.NonFatal
+
import ai.rapids.cudf.HostMemoryBuffer
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.HostAlloc
@@ -95,6 +97,12 @@ class SpillablePartialFileHandle private (
@volatile private var expansionCount: Int = 0
@volatile private var spillCount: Int = 0
+ // Shuffle read-lease reference counting. Active consumers (retained buffers, input streams,
+ // Netty file regions) hold leases so the physical close can be deferred until they finish.
+ // Guarded by this handle's monitor; `closed` (from StoreHandle) marks the physical close.
+ private var readRefCount: Int = 0
+ private var closeRequested: Boolean = false
+
// Write state
private var writePosition: Long = 0L
private var fileOutputStream: Option[FileOutputStream] = None
@@ -725,6 +733,97 @@ class SpillablePartialFileHandle private (
}
}
+ // ----- Shuffle read-lease lifecycle ----------------------------------------------------------
+ // Shuffle cleanup (`MultithreadedShuffleBufferCatalog.unregisterShuffle`) can request a close
+ // while retained buffers, input streams, or Netty file regions are still reading this handle.
+ // `acquireRead`/`releaseRead` reference-count those active consumers, and `close()` requests the
+ // physical close but defers it until the last lease is released. The physical close is
+ // `doClose()`, which coordinates with an in-progress spill via `spillInProgress` (this handle
+ // does not use the base `spilling` flag).
+
+ /**
+ * Acquire a read lease. Throws if the handle has already been physically closed.
+ */
+ private[rapids] def acquireRead(): Unit = synchronized {
+ if (closed) {
+ throw new IllegalStateException(
+ "Cannot acquire a read lease on a closed partial file handle")
+ }
+ readRefCount += 1
+ }
+
+ /**
+ * Release a read lease, performing the deferred physical close if this was the last one. This
+ * runs on consumer cleanup paths (Netty file-region deallocation, stream close, finally blocks),
+ * so a close failure is logged and swallowed rather than propagated.
+ */
+ private[rapids] def releaseRead(): Unit = {
+ val performClose = synchronized {
+ if (readRefCount <= 0) {
+ throw new IllegalStateException("releaseRead() without a matching acquireRead()")
+ }
+ readRefCount -= 1
+ markCloseIfReady()
+ }
+ if (performClose) {
+ closeQuietly()
+ }
+ }
+
+ /**
+ * True once the handle has been physically closed (distinct from a pending close request).
+ */
+ private[rapids] def isPhysicallyClosed: Boolean = synchronized { closed }
+
+ /**
+ * Request close of this handle. If no read leases are active it closes immediately; otherwise
+ * the physical close is deferred until the last `releaseRead()`. A close failure here
+ * propagates to the caller, per the normal `AutoCloseable` convention (the catalog and
+ * writer/merge call sites wrap it). The quiet path is `releaseRead()`, used on consumer cleanup.
+ */
+ override def close(): Unit = {
+ val performClose = synchronized {
+ closeRequested = true
+ markCloseIfReady()
+ }
+ if (performClose) {
+ doClose()
+ }
+ }
+
+ /**
+ * Decide, under the handle monitor, whether the physical close should run now, setting `closed`
+ * atomically with the decision so a concurrent `acquireRead` cannot slip in after we commit to
+ * closing. Returns true iff the caller should run the physical close outside the lock.
+ */
+ private def markCloseIfReady(): Boolean = {
+ if (closeRequested && readRefCount == 0 && !closed) {
+ closed = true
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Runs the physical close, swallowing failures so a failed close cannot break a consumer's
+ * release path (Netty file-region deallocation, stream close, finally blocks). Used by
+ * `releaseRead()`; `close()` propagates instead.
+ */
+ private def closeQuietly(): Unit = {
+ try {
+ doClose()
+ } catch {
+ case e: InterruptedException =>
+ // Log while the interrupt flag is clear, then restore it as the last action so the
+ // logging framework's own (possibly interruptible) work isn't disrupted.
+ logError("Interrupted while closing partial file handle", e)
+ Thread.currentThread().interrupt()
+ case NonFatal(e) =>
+ logError("Failed to close partial file handle", e)
+ }
+ }
+
/**
* Close and cleanup resources.
* This is where we record disk write savings: only if data was never spilled to disk
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala
index 7ff036563e7..66e7bf6bd7a 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala
@@ -244,24 +244,38 @@ class MultithreadedShuffleBufferCatalogSuite
}
}
- test("convertToNetty release closes retained handle once") {
+ test("convertToNetty release closes retained handle exactly once") {
+ // The handle owns the close deferral, so a mock can't reproduce it: use a real FILE_ONLY
+ // handle and observe the physical close via `isPhysicallyClosed`.
val catalog = new MultithreadedShuffleBufferCatalog()
- val handle = createMockHandle()
-
- catalog.registerShuffle(1)
- catalog.addPartition(1, 0L, 0, handle, 0, 100)
-
- val buffer = catalog.getMergedBuffer(ShuffleBlockId(1, 0L, 0))
- val region = buffer.convertToNetty().asInstanceOf[FileRegion]
-
- catalog.unregisterShuffle(1)
- verify(handle, never()).close()
-
- assert(region.release())
- verify(handle, times(1)).close()
-
- buffer.release()
- verify(handle, times(1)).close()
+ val backingFile = File.createTempFile("skipmerge-region-", ".data")
+ val handle = SpillablePartialFileHandle.createFileOnly(backingFile)
+ handle.write(Array.fill[Byte](100)(7.toByte), 0, 100)
+ handle.finishWrite()
+ try {
+ catalog.registerShuffle(1)
+ catalog.addPartition(1, 0L, 0, handle, 0, 100)
+
+ val buffer = catalog.getMergedBuffer(ShuffleBlockId(1, 0L, 0))
+ val region = buffer.convertToNetty().asInstanceOf[FileRegion]
+
+ // Cleanup requests close, but the file region still holds a read lease: close is deferred.
+ catalog.unregisterShuffle(1)
+ assert(!handle.isPhysicallyClosed,
+ "handle must stay open while the file region holds a lease")
+
+ // Releasing the region drops the last lease and runs the deferred physical close.
+ assert(region.release())
+ assert(handle.isPhysicallyClosed, "handle must close once the file region releases its lease")
+
+ // No retained buffer lease here, so releasing the buffer is a no-op and must not re-close.
+ buffer.release()
+ assert(handle.isPhysicallyClosed)
+ } finally {
+ if (backingFile.exists()) {
+ backingFile.delete()
+ }
+ }
}
private def createMockHandle(): SpillablePartialFileHandle = {
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala
index 0c082a70bac..0edbb1b7a38 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala
@@ -60,6 +60,26 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac
assert(condition, clue)
}
+ /**
+ * Runs `body` against a finished FILE_ONLY handle backed by a fresh temp file, then closes the
+ * handle and deletes the file. The temp file is exposed so tests can assert on its deletion.
+ */
+ private def withFinishedFileOnlyHandle(data: Array[Byte] = "lease".getBytes("UTF-8"))(
+ body: (SpillablePartialFileHandle, File) => Unit): Unit = {
+ val tempFile = File.createTempFile("test-lease-", ".tmp")
+ val handle = SpillablePartialFileHandle.createFileOnly(tempFile)
+ try {
+ handle.write(data, 0, data.length)
+ handle.finishWrite()
+ body(handle, tempFile)
+ } finally {
+ handle.close()
+ if (tempFile.exists()) {
+ tempFile.delete()
+ }
+ }
+ }
+
test("FILE_ONLY mode: write and read") {
val tempFile = File.createTempFile("test-file-only-", ".tmp")
@@ -652,4 +672,71 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac
assert(handle.isSpilled)
}
}
+
+ // ----- Shuffle read-lease lifecycle: acquireRead / releaseRead / deferred close --------------
+
+ test("read lease: close defers until the last lease is released") {
+ withFinishedFileOnlyHandle() { (handle, tempFile) =>
+ handle.acquireRead()
+ handle.acquireRead()
+
+ // Close requested while two leases are held: must defer.
+ handle.close()
+ assert(!handle.isPhysicallyClosed, "handle must stay open while leases are held")
+ assert(tempFile.exists(), "backing file must not be deleted while leases are held")
+
+ handle.releaseRead()
+ assert(!handle.isPhysicallyClosed, "handle must stay open while one lease remains")
+
+ handle.releaseRead()
+ assert(handle.isPhysicallyClosed, "handle must close once the last lease is released")
+ assert(!tempFile.exists(), "FILE_ONLY close must delete the backing file")
+ }
+ }
+
+ test("read lease: repeated close() while a lease is held stays deferred") {
+ withFinishedFileOnlyHandle() { (handle, tempFile) =>
+ handle.acquireRead()
+
+ // Multiple owners can request close while a reader is active: the catalog's
+ // unregisterShuffle and the spill store's shutdown both call handle.close(). close() is a
+ // close-request, not a refcount decrement, so repeated calls stay idempotent and must not
+ // free the handle while the lease is outstanding.
+ handle.close()
+ handle.close()
+ assert(!handle.isPhysicallyClosed,
+ "repeated close() must not free the handle while a lease is held")
+ assert(tempFile.exists(), "backing file must survive repeated close() while a lease is held")
+
+ handle.releaseRead()
+ assert(handle.isPhysicallyClosed, "handle closes once the lease is released")
+ assert(!tempFile.exists(), "FILE_ONLY close must delete the backing file")
+ }
+ }
+
+ test("read lease: close with no active leases closes immediately") {
+ withFinishedFileOnlyHandle() { (handle, tempFile) =>
+ handle.close()
+ assert(handle.isPhysicallyClosed, "close with no leases must close immediately")
+ assert(!tempFile.exists())
+ }
+ }
+
+ test("read lease: acquireRead after close throws") {
+ withFinishedFileOnlyHandle() { (handle, _) =>
+ handle.close()
+ assert(handle.isPhysicallyClosed)
+ assertThrows[IllegalStateException] {
+ handle.acquireRead()
+ }
+ }
+ }
+
+ test("read lease: releaseRead without acquireRead throws") {
+ withFinishedFileOnlyHandle() { (handle, _) =>
+ assertThrows[IllegalStateException] {
+ handle.releaseRead()
+ }
+ }
+ }
}