diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala index 35fd66d93ac..ad7f24cdf6d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalog.scala @@ -24,6 +24,7 @@ import java.util.HashSet import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal import _root_.io.netty.handler.stream.ChunkedStream import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle @@ -45,6 +46,72 @@ case class PartitionSegment( offset: Long, length: Long) +/** + * Owns a temporary read lease on one or more partial shuffle file handles. + * + * A lease is held while a buffer, stream, or file region may still read the handles. Closing the + * lease releases every handle exactly once; each handle defers its physical close until its last + * lease is released (see `SpillablePartialFileHandle.acquireRead`/`releaseRead`). + */ +private[rapids] final class ShuffleHandleLease(handles: Seq[SpillablePartialFileHandle]) + extends AutoCloseable { + private var released: Boolean = false + + override def close(): Unit = { + val handlesToRelease = synchronized { + if (released) { + Seq.empty + } else { + released = true + handles + } + } + ShuffleHandleLease.releaseAll(handlesToRelease.reverseIterator) + } +} + +private[rapids] object ShuffleHandleLease { + /** Acquire a read lease on every handle, rolling back the partial set if any acquire fails. */ + def acquire(handles: Seq[SpillablePartialFileHandle]): ShuffleHandleLease = { + val retained = new ArrayBuffer[SpillablePartialFileHandle](handles.size) + try { + handles.foreach { handle => + handle.acquireRead() + retained += handle + } + new ShuffleHandleLease(retained.toSeq) + } catch { + case t: Throwable => + try { + releaseAll(retained.reverseIterator) + } catch { + case releaseFailure: Throwable => + t.addSuppressed(releaseFailure) + } + throw t + } + } + + private def releaseAll(handles: Iterator[SpillablePartialFileHandle]): Unit = { + var firstFailure: Throwable = null + handles.foreach { handle => + try { + handle.releaseRead() + } catch { + case t: Throwable => + if (firstFailure == null) { + firstFailure = t + } else { + firstFailure.addSuppressed(t) + } + } + } + if (firstFailure != null) { + throw firstFailure + } + } +} + /** * Catalog for managing shuffle data in MULTITHREADED mode without merging. * @@ -217,19 +284,23 @@ class MultithreadedShuffleBufferCatalog extends Logging { numForcedFileOnly += 1 } + // Drop catalog ownership; retained buffers, streams, and file regions keep the handle + // alive through their read leases, so the physical close is deferred until the last + // lease is released. close() propagates failures, so catch here so one bad handle + // does not abort cleanup of the rest. try { handle.close() } catch { - case e: Exception => - logError(s"Failed to close handle for shuffle $shuffleId", e) + case NonFatal(e) => + logWarning(s"Failed to request close of handle for shuffle $shuffleId", e) } } } } } - logDebug(s"Unregistered shuffle $shuffleId: closed ${closedHandles.size()} handles, " + - s"bytesFromMemory=$bytesFromMemory, bytesFromDisk=$bytesFromDisk, " + + logDebug(s"Unregistered shuffle $shuffleId: cleanup requested for ${closedHandles.size()} " + + s"handles, bytesFromMemory=$bytesFromMemory, bytesFromDisk=$bytesFromDisk, " + s"numExpansions=$numExpansions, numSpills=$numSpills, numForcedFileOnly=$numForcedFileOnly") // Return statistics if we had any data @@ -252,50 +323,79 @@ class MultithreadedShuffleBufferCatalog extends Logging { */ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBuffer { + private val handles: Seq[SpillablePartialFileHandle] = segments.map(_.handle).distinct + + /** Guards bufferLeases while retain()/release() can be called from different threads. */ + private val retainLock = new Object + + /** Leases that keep this buffer's partial shuffle file handles open after retain(). */ + private val bufferLeases = new ArrayBuffer[ShuffleHandleLease]() + override def size(): Long = segments.map(_.length).sum override def nioByteBuffer(): ByteBuffer = { - // This method loads all data into memory. It's required by the ManagedBuffer interface - // but is NOT used in the network transfer path - Spark's network layer uses - // convertToNetty() which returns our streaming MultiSegmentFileRegion. - // This method may be called by other code paths (e.g., local block reading). - val totalSize = size().toInt - val buffer = ByteBuffer.allocate(totalSize) - val bytes = new Array[Byte](8192) // Read buffer - - segments.foreach { segment => - var remaining = segment.length - var position = segment.offset - while (remaining > 0) { - val toRead = math.min(remaining, bytes.length).toInt - val bytesRead = segment.handle.readAt(position, bytes, 0, toRead) - if (bytesRead <= 0) { - throw new IOException( - s"Unexpected EOF reading segment at position $position, " + - s"expected ${segment.length} bytes") + val lease = ShuffleHandleLease.acquire(handles) + try { + // This method loads all data into memory. It's required by the ManagedBuffer interface + // but is NOT used in the network transfer path - Spark's network layer uses + // convertToNetty() which returns our streaming MultiSegmentFileRegion. + // This method may be called by other code paths (e.g., local block reading). + val totalSize = size().toInt + val buffer = ByteBuffer.allocate(totalSize) + val bytes = new Array[Byte](8192) // Read buffer + + segments.foreach { segment => + var remaining = segment.length + var position = segment.offset + while (remaining > 0) { + val toRead = math.min(remaining, bytes.length).toInt + val bytesRead = segment.handle.readAt(position, bytes, 0, toRead) + if (bytesRead <= 0) { + throw new IOException( + s"Unexpected EOF reading segment at position $position, " + + s"expected ${segment.length} bytes") + } + buffer.put(bytes, 0, bytesRead) + position += bytesRead + remaining -= bytesRead } - buffer.put(bytes, 0, bytesRead) - position += bytesRead - remaining -= bytesRead } - } - buffer.flip() - buffer + buffer.flip() + buffer + } finally { + lease.close() + } } override def createInputStream(): InputStream = { - new MultiSegmentInputStream(segments) + new MultiSegmentInputStream(segments, handles) } - override def retain(): ManagedBuffer = this + override def retain(): ManagedBuffer = { + val lease = ShuffleHandleLease.acquire(handles) + retainLock.synchronized { + bufferLeases += lease + } + this + } - override def release(): ManagedBuffer = this + override def release(): ManagedBuffer = { + val lease = retainLock.synchronized { + if (bufferLeases.nonEmpty) { + Some(bufferLeases.remove(bufferLeases.size - 1)) + } else { + None + } + } + lease.foreach(_.close()) + this + } override def convertToNetty(): AnyRef = { // Return a custom FileRegion that streams data in chunks, avoiding loading all // data into memory at once. This addresses concerns about large shuffle blocks. - new MultiSegmentFileRegion(segments) + new MultiSegmentFileRegion(segments, handles) } // Spark 4.0+ adds convertToNettyForSsl() abstract method. @@ -314,12 +414,20 @@ class MultiBatchManagedBuffer(segments: Seq[PartitionSegment]) extends ManagedBu /** * An InputStream that reads from multiple partition segments sequentially. + * + * This stream is not thread-safe; callers should create one stream per reading thread and close it + * from that owner thread. */ -class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStream { +class MultiSegmentInputStream( + segments: Seq[PartitionSegment], + handles: Seq[SpillablePartialFileHandle]) extends InputStream { private var currentSegmentIndex: Int = 0 private var currentPosition: Long = if (segments.nonEmpty) segments.head.offset else 0 private var bytesReadInCurrentSegment: Long = 0 + // Keeps the partial shuffle file handles open until this stream is closed. + private val lease = ShuffleHandleLease.acquire(handles) + private var closed: Boolean = false override def read(): Int = { val buf = new Array[Byte](1) @@ -328,6 +436,10 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre } override def read(b: Array[Byte], off: Int, len: Int): Int = { + if (closed) { + throw new IOException("Stream is closed") + } + // Use loop instead of recursion to avoid StackOverflowError with many segments while (currentSegmentIndex < segments.size) { val segment = segments(currentSegmentIndex) @@ -358,7 +470,7 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre } override def available(): Int = { - if (currentSegmentIndex >= segments.size) { + if (closed || currentSegmentIndex >= segments.size) { 0 } else { val remaining = segments.drop(currentSegmentIndex).map { seg => @@ -372,13 +484,11 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre } } - /** - * Close is a no-op because the underlying SpillablePartialFileHandle resources - * are managed by MultithreadedShuffleBufferCatalog and will be closed when - * the shuffle is unregistered. - */ override def close(): Unit = { - // No-op: handles are managed by MultithreadedShuffleBufferCatalog + if (!closed) { + closed = true + lease.close() + } } } @@ -392,8 +502,12 @@ class MultiSegmentInputStream(segments: Seq[PartitionSegment]) extends InputStre * Spark's MessageWithHeader only accepts ByteBuf or FileRegion. By implementing * FileRegion, we can provide streaming transfer while remaining compatible with * Spark's network layer. + * + * Instances are not thread-safe; each transfer should use one FileRegion owned by one Netty write. */ -class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFileRegion { +class MultiSegmentFileRegion( + segments: Seq[PartitionSegment], + handles: Seq[SpillablePartialFileHandle]) extends AbstractFileRegion { private val totalSize: Long = segments.map(_.length).sum private var totalTransferred: Long = 0 @@ -407,6 +521,8 @@ class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFi // Track current position within the logical data stream private var currentSegmentIndex: Int = 0 private var bytesTransferredInCurrentSegment: Long = 0 + // Keeps the partial shuffle file handles open until this file region is deallocated. + private val lease = ShuffleHandleLease.acquire(handles) override def count(): Long = totalSize @@ -475,6 +591,6 @@ class MultiSegmentFileRegion(segments: Seq[PartitionSegment]) extends AbstractFi } override protected def deallocate(): Unit = { - // No resources to release - handles are managed by MultithreadedShuffleBufferCatalog + lease.close() } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala index 52dfd286773..f4584339cc5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala @@ -20,6 +20,8 @@ import java.io.{BufferedInputStream, BufferedOutputStream, File, FileInputStream import java.nio.ByteBuffer import java.nio.channels.FileChannel +import scala.util.control.NonFatal + import ai.rapids.cudf.HostMemoryBuffer import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.HostAlloc @@ -95,6 +97,12 @@ class SpillablePartialFileHandle private ( @volatile private var expansionCount: Int = 0 @volatile private var spillCount: Int = 0 + // Shuffle read-lease reference counting. Active consumers (retained buffers, input streams, + // Netty file regions) hold leases so the physical close can be deferred until they finish. + // Guarded by this handle's monitor; `closed` (from StoreHandle) marks the physical close. + private var readRefCount: Int = 0 + private var closeRequested: Boolean = false + // Write state private var writePosition: Long = 0L private var fileOutputStream: Option[FileOutputStream] = None @@ -725,64 +733,169 @@ class SpillablePartialFileHandle private ( } } + // ----- Shuffle read-lease lifecycle ---------------------------------------------------------- + // Shuffle cleanup (`MultithreadedShuffleBufferCatalog.unregisterShuffle`) can request a close + // while retained buffers, input streams, or Netty file regions are still reading this handle. + // `acquireRead`/`releaseRead` reference-count those active consumers, and `close()` requests the + // physical close but defers it until the last lease is released. The physical close is + // `doClose()`, which coordinates with an in-progress spill via `spillInProgress` (this handle + // does not use the base `spilling` flag). + + /** + * Acquire a read lease. Throws if the handle has already been physically closed. + */ + private[rapids] def acquireRead(): Unit = synchronized { + if (closed) { + throw new IllegalStateException( + "Cannot acquire a read lease on a closed partial file handle") + } + readRefCount += 1 + } + + /** + * Release a read lease, performing the deferred physical close if this was the last one. This + * runs on consumer cleanup paths (Netty file-region deallocation, stream close, finally blocks), + * so a close failure is logged and swallowed rather than propagated. + */ + private[rapids] def releaseRead(): Unit = { + val performClose = synchronized { + if (readRefCount <= 0) { + throw new IllegalStateException("releaseRead() without a matching acquireRead()") + } + readRefCount -= 1 + markCloseIfReady() + } + if (performClose) { + closeQuietly() + } + } + + /** + * True once the handle has been physically closed (distinct from a pending close request). + */ + private[rapids] def isPhysicallyClosed: Boolean = synchronized { closed } + + /** + * Request close of this handle. If no read leases are active it closes immediately; otherwise + * the physical close is deferred until the last `releaseRead()`. A close failure here + * propagates to the caller, per the normal `AutoCloseable` convention (the catalog and + * writer/merge call sites wrap it). The quiet path is `releaseRead()`, used on consumer cleanup. + */ + override def close(): Unit = { + val performClose = synchronized { + closeRequested = true + markCloseIfReady() + } + if (performClose) { + doClose() + } + } + + /** + * Decide, under the handle monitor, whether the physical close should run now, setting `closed` + * atomically with the decision so a concurrent `acquireRead` cannot slip in after we commit to + * closing. Returns true iff the caller should run the physical close outside the lock. + */ + private def markCloseIfReady(): Boolean = { + if (closeRequested && readRefCount == 0 && !closed) { + closed = true + true + } else { + false + } + } + + /** + * Runs the physical close, swallowing failures so a failed close cannot break a consumer's + * release path (Netty file-region deallocation, stream close, finally blocks). Used by + * `releaseRead()`; `close()` propagates instead. + */ + private def closeQuietly(): Unit = { + try { + doClose() + } catch { + case e: InterruptedException => + // Log while the interrupt flag is clear, then restore it as the last action so the + // logging framework's own (possibly interruptible) work isn't disrupted. + logError("Interrupted while closing partial file handle", e) + Thread.currentThread().interrupt() + case NonFatal(e) => + logError("Failed to close partial file handle", e) + } + } + /** * Close and cleanup resources. * This is where we record disk write savings: only if data was never spilled to disk * throughout the entire lifecycle (write phase + read phase), we count it as saved. */ override private[spill] def doClose(): Unit = { - // Collect resources to close under lock, then close them outside lock - val (bos, fos, bis, fis, fc, raf) = synchronized { - // Wait for any in-progress spill to complete before closing buffer - while (spillInProgress) { - wait() - } + var interrupted = false + try { + // Collect resources to close under lock, then close them outside lock + val (bos, fos, bis, fis, fc, raf) = synchronized { + // Wait for any in-progress spill to complete before closing buffer. + // close() has already committed this handle's closed state, so cleanup must finish even + // if this thread is interrupted while waiting. Restore the interrupt flag after cleanup. + while (spillInProgress) { + try { + wait() + } catch { + case _: InterruptedException => + interrupted = true + } + } - // Record disk write savings for ESS + multi-batch merge scenario. - // When ESS is enabled with multiple batches, partial files are merged into - // a final file. If a partial file stayed in memory (not spilled), it avoided - // an intermediate disk write. The task is still running during merge, so - // GpuTaskMetrics.get is valid. - if (storageMode == PartialFileStorageMode.MEMORY_WITH_SPILL && - !spilledToDisk && totalBytesWritten > 0 && !diskWriteSavingsRecorded) { - GpuTaskMetrics.get.addDiskWriteSaved(totalBytesWritten) - diskWriteSavingsRecorded = true - logDebug(s"Recorded disk write savings in doClose: $totalBytesWritten bytes") - } + // Record disk write savings for ESS + multi-batch merge scenario. + // When ESS is enabled with multiple batches, partial files are merged into + // a final file. If a partial file stayed in memory (not spilled), it avoided + // an intermediate disk write. The task is still running during merge, so + // GpuTaskMetrics.get is valid. + if (storageMode == PartialFileStorageMode.MEMORY_WITH_SPILL && + !spilledToDisk && totalBytesWritten > 0 && !diskWriteSavingsRecorded) { + GpuTaskMetrics.get.addDiskWriteSaved(totalBytesWritten) + diskWriteSavingsRecorded = true + logDebug(s"Recorded disk write savings in doClose: $totalBytesWritten bytes") + } - // Collect streams/channels to close - val result = (bufferedOutputStream, fileOutputStream, - bufferedInputStream, fileInputStream, fileChannel, randomAccessFile) + // Collect streams/channels to close + val result = (bufferedOutputStream, fileOutputStream, + bufferedInputStream, fileInputStream, fileChannel, randomAccessFile) - // Clear references - bufferedOutputStream = None - fileOutputStream = None - bufferedInputStream = None - fileInputStream = None - fileChannel = None - randomAccessFile = None + // Clear references + bufferedOutputStream = None + fileOutputStream = None + bufferedInputStream = None + fileInputStream = None + fileChannel = None + randomAccessFile = None - // Release host buffer (removes from SpillFramework tracking and closes buffer) - releaseHostResource() + // Release host buffer (removes from SpillFramework tracking and closes buffer) + releaseHostResource() - result - } + result + } - // Close streams outside lock (IO operations can be slow) - tryClose(bos, "bufferedOutputStream") - tryClose(fos, "fileOutputStream") - tryClose(bis, "bufferedInputStream") - tryClose(fis, "fileInputStream") - tryClose(fc, "fileChannel") - tryClose(raf, "randomAccessFile") - - // Delete file if it exists - if (file != null && file.exists()) { - try { - file.delete() - } catch { - case e: Exception => - logWarning(s"Failed to delete file ${file.getAbsolutePath}", e) + // Close streams outside lock (IO operations can be slow) + tryClose(bos, "bufferedOutputStream") + tryClose(fos, "fileOutputStream") + tryClose(bis, "bufferedInputStream") + tryClose(fis, "fileInputStream") + tryClose(fc, "fileChannel") + tryClose(raf, "randomAccessFile") + + // Delete file if it exists + if (file != null && file.exists()) { + try { + file.delete() + } catch { + case e: Exception => + logWarning(s"Failed to delete file ${file.getAbsolutePath}", e) + } + } + } finally { + if (interrupted) { + Thread.currentThread().interrupt() } } } @@ -860,4 +973,3 @@ object SpillablePartialFileHandle extends Logging { capacityHintProvider = capacityHintProvider) } } - diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala index 4266654c9f4..66e7bf6bd7a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/MultithreadedShuffleBufferCatalogSuite.scala @@ -16,6 +16,15 @@ package com.nvidia.spark.rapids +import java.io.{File, InputStream} +import java.nio.channels.ClosedChannelException +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal + +import _root_.io.netty.channel.FileRegion import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ @@ -23,6 +32,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.{ShuffleBlockBatchId, ShuffleBlockId} class MultithreadedShuffleBufferCatalogSuite @@ -188,6 +198,86 @@ class MultithreadedShuffleBufferCatalogSuite catalog.unregisterShuffle(1) } + // ------------------------------------------------------------------------------------------ + // Regression test: a retained file-backed (FILE_ONLY) skip-merge shuffle buffer must stay + // readable while shuffle cleanup unregisters the catalog entry. The test retains a buffer, + // starts concurrent stream readers, unregisters the shuffle while reads are in flight, and + // verifies that the active readers do not observe a closed backing channel. + // + // The race depends on cleanup landing while a reader is using the cached FileChannel, so the + // test oversubscribes readers and uses small, frequent reads with bounded retries. + + test("retained skip-merge buffer stays readable across concurrent unregisterShuffle") { + val race = MultithreadedShuffleBufferCatalogSuite.RetainedBufferReadRace + var bug: Option[Throwable] = None + var otherError: Option[Throwable] = None + var sawReadsInFlight = false + var sawReadsAfterUnregister = false + var leakedReaders = 0 + var iteration = 0 + while (bug.isEmpty && iteration < race.MaxIterations) { + iteration += 1 + val result = race.attempt(iteration) + bug = result.target + if (otherError.isEmpty) otherError = result.otherError + sawReadsInFlight ||= result.startedOk && result.readsBeforeUnregister > 0 + sawReadsAfterUnregister ||= result.readsAfterUnregister > 0 + leakedReaders += result.leakedReaders + } + + bug match { + case Some(closed) => + // BUG: an active read was closed underneath the reader. Surface the real exception so the + // failure is the ClosedChannelException from the read path, not a synthetic assertion. + // Expected to FAIL on unmodified `main`. + throw closed + case None => + // Post-fix expectation: confirm the scenario actually ran and that reads SURVIVED + // unregisterShuffle, so a future change cannot make this test pass for the wrong reason. + assert(sawReadsInFlight, + "reproducer never observed reads before unregisterShuffle; scenario did not run") + otherError.foreach(e => throw e) + assert(sawReadsAfterUnregister, + "no reads completed after unregisterShuffle; the retained buffer was not readable") + assert(leakedReaders == 0, s"$leakedReaders reader thread(s) did not stop after the test") + info(s"retained reads survived unregisterShuffle across $iteration iterations") + } + } + + test("convertToNetty release closes retained handle exactly once") { + // The handle owns the close deferral, so a mock can't reproduce it: use a real FILE_ONLY + // handle and observe the physical close via `isPhysicallyClosed`. + val catalog = new MultithreadedShuffleBufferCatalog() + val backingFile = File.createTempFile("skipmerge-region-", ".data") + val handle = SpillablePartialFileHandle.createFileOnly(backingFile) + handle.write(Array.fill[Byte](100)(7.toByte), 0, 100) + handle.finishWrite() + try { + catalog.registerShuffle(1) + catalog.addPartition(1, 0L, 0, handle, 0, 100) + + val buffer = catalog.getMergedBuffer(ShuffleBlockId(1, 0L, 0)) + val region = buffer.convertToNetty().asInstanceOf[FileRegion] + + // Cleanup requests close, but the file region still holds a read lease: close is deferred. + catalog.unregisterShuffle(1) + assert(!handle.isPhysicallyClosed, + "handle must stay open while the file region holds a lease") + + // Releasing the region drops the last lease and runs the deferred physical close. + assert(region.release()) + assert(handle.isPhysicallyClosed, "handle must close once the file region releases its lease") + + // No retained buffer lease here, so releasing the buffer is a no-op and must not re-close. + buffer.release() + assert(handle.isPhysicallyClosed) + } finally { + if (backingFile.exists()) { + backingFile.delete() + } + } + } + private def createMockHandle(): SpillablePartialFileHandle = { val handle = mock[SpillablePartialFileHandle] handle @@ -213,3 +303,201 @@ class MultithreadedShuffleBufferCatalogSuite } } +object MultithreadedShuffleBufferCatalogSuite { + private object RetainedBufferReadRace { + // Oversubscribe readers (bounded) so the close reliably lands in a reader's channel-read path. + private val ReaderThreads: Int = + math.min(128, math.max(64, Runtime.getRuntime.availableProcessors() * 4)) + private val ReadBufferBytes: Int = 4 * 1024 // small reads => very frequent readAt calls + private val BackingFileBytes: Int = 2 * 1024 * 1024 // 2 MB single-segment backing file + private val AwaitSeconds: Long = 30L + private val PostUnregisterMillis: Long = 200L + + val MaxIterations: Int = 20 + + case class Result( + target: Option[Throwable], + otherError: Option[Throwable], + readsBeforeUnregister: Long, + readsAfterUnregister: Long, + startedOk: Boolean, + leakedReaders: Int) + + /** Runs one race iteration and reports what the readers saw. */ + def attempt(iteration: Int): Result = { + val shuffleId = iteration + val mapId = 0L + val reduceId = 0 + val backingFile = File.createTempFile("skipmerge-catalog-repro-", ".data") + + val catalog = new MultithreadedShuffleBufferCatalog() + val handle = SpillablePartialFileHandle.createFileOnly(backingFile) + val stop = new AtomicBoolean(false) + val afterUnregister = new AtomicBoolean(false) + val started = new CountDownLatch(ReaderThreads) + val readsBefore = new AtomicLong(0L) + val readsAfter = new AtomicLong(0L) + val readerErrors = new ConcurrentLinkedQueue[Throwable]() + val readers = new ArrayBuffer[Thread](ReaderThreads) + var buffer: ManagedBuffer = null + var startedOk = false + var leakedReaders = 0 + var targetBeforeCleanup: Option[Throwable] = None + var otherErrorBeforeCleanup: Option[Throwable] = None + + try { + // Publish a multi-MB file-only handle as a single whole-file segment, like the writer does. + writeBackingData(handle, BackingFileBytes) + handle.finishWrite() + catalog.registerShuffle(shuffleId) + catalog.addPartition(shuffleId, mapId, reduceId, handle, 0L, BackingFileBytes.toLong) + + // A reducer retains the buffer before handing it to readers; this retained lease is the + // lifecycle guarantee the regression test protects. + buffer = catalog.getMergedBuffer(ShuffleBlockId(shuffleId, mapId, reduceId)) + buffer.retain() + val readBuffer = buffer + + (0 until ReaderThreads).foreach { idx => + // A Runnable (not a Thread subclass) so the local `stop` flag is not shadowed by the + // inherited Thread.stop() member. + val body = new Runnable { + override def run(): Unit = { + var in: InputStream = readBuffer.createInputStream() + val buf = new Array[Byte](ReadBufferBytes) + started.countDown() + try { + while (!stop.get()) { + val n = in.read(buf, 0, buf.length) + if (n < 0) { + in.close() + in = readBuffer.createInputStream() + } else if (afterUnregister.get()) { + readsAfter.incrementAndGet() + } else { + readsBefore.incrementAndGet() + } + } + } catch { + case NonFatal(t) => readerErrors.add(t) + } finally { + try { in.close() } catch { case NonFatal(_) => () } + } + } + } + val reader = new Thread(body, s"skipmerge-catalog-reader-$iteration-$idx") + reader.setDaemon(true) + readers += reader + reader.start() + } + + // Make sure reads are flowing before the close, so it lands in a reader's read path. + startedOk = started.await(AwaitSeconds, TimeUnit.SECONDS) && + awaitReadsInFlight(readsBefore, ReaderThreads.toLong) + + // Cleanup thread closes the handle underneath the active readers; then give readers a + // short window to read the retained buffer post-unregister (on `main` they hit the + // closed channel; once fixed they keep reading, counted in readsAfter). + catalog.unregisterShuffle(shuffleId) + afterUnregister.set(true) + awaitReadsAfterUnregister(readsAfter, readers, ReaderThreads.toLong) + } finally { + stop.set(true) + leakedReaders = joinAll(readers) + targetBeforeCleanup = firstMatching(readerErrors, isClosedChannelFromReadPath) + otherErrorBeforeCleanup = + firstMatching(readerErrors, t => !isClosedChannelFromReadPath(t)) + // Release the retained buffer before the last-resort handle.close() guard, giving the + // catalog's deferred close path the first chance to close the handle. + if (buffer != null) { + try { buffer.release() } catch { case NonFatal(_) => () } + } + try { handle.close() } catch { case NonFatal(_) => () } + if (backingFile.exists()) { + backingFile.delete() + } + } + + Result( + target = targetBeforeCleanup, + otherError = otherErrorBeforeCleanup, + readsBeforeUnregister = readsBefore.get(), + readsAfterUnregister = readsAfter.get(), + startedOk = startedOk, + leakedReaders = leakedReaders) + } + + private def writeBackingData(handle: SpillablePartialFileHandle, totalBytes: Int): Unit = { + val chunk = new Array[Byte](1024 * 1024) + var i = 0 + while (i < chunk.length) { + chunk(i) = (i & 0xFF).toByte + i += 1 + } + var written = 0 + while (written < totalBytes) { + val toWrite = math.min(chunk.length, totalBytes - written) + handle.write(chunk, 0, toWrite) + written += toWrite + } + } + + /** Waits until reads are flowing; returns true if the target read count was reached in time. */ + private def awaitReadsInFlight(reads: AtomicLong, target: Long): Boolean = { + val deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(AwaitSeconds) + while (reads.get() < target && System.nanoTime() < deadline) { + Thread.sleep(1L) + } + reads.get() >= target + } + + /** Lets readers attempt reads after unregister; stops early once enough succeed or all exit. */ + private def awaitReadsAfterUnregister( + readsAfter: AtomicLong, readers: ArrayBuffer[Thread], target: Long): Unit = { + val deadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(PostUnregisterMillis) + while (readsAfter.get() < target && + System.nanoTime() < deadline && readers.exists(_.isAlive)) { + Thread.sleep(1L) + } + } + + /** Stops/joins all readers under one shared deadline; returns the count still alive. */ + private def joinAll(readers: ArrayBuffer[Thread]): Int = { + val deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(AwaitSeconds) + readers.foreach { reader => + val remainingMs = TimeUnit.NANOSECONDS.toMillis(deadline - System.nanoTime()) + if (remainingMs > 0) { + try { + reader.join(remainingMs) + } catch { + case _: InterruptedException => Thread.currentThread().interrupt() + } + } + } + readers.count(_.isAlive) + } + + private def firstMatching( + errors: ConcurrentLinkedQueue[Throwable], p: Throwable => Boolean): Option[Throwable] = { + var found: Option[Throwable] = None + val it = errors.iterator() + while (found.isEmpty && it.hasNext) { + val t = it.next() + if (p(t)) { + found = Some(t) + } + } + found + } + + private def isClosedChannelFromReadPath(t: Throwable): Boolean = { + // AsynchronousCloseException (channel closed while a read is in flight) is a subclass of + // ClosedChannelException, so this covers both. + t.isInstanceOf[ClosedChannelException] && t.getStackTrace.exists { frame => + frame.getClassName.endsWith("SpillablePartialFileHandle") && + (frame.getMethodName.contains("readFromFileChannel") || + frame.getMethodName.contains("readAt")) + } + } + } +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala index 7f0b59ca059..0edbb1b7a38 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.spill import java.io.File import java.util.Arrays +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsConf @@ -28,6 +29,11 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac // Use 1GB max buffer size for tests to avoid memory issues on test machines private val testMaxBufferSize = 1L * 1024 * 1024 * 1024 + private val spillInProgressField = { + val field = classOf[SpillablePartialFileHandle].getDeclaredField("spillInProgress") + field.setAccessible(true) + field + } override def beforeEach(): Unit = { super.beforeEach() @@ -42,6 +48,38 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac super.afterEach() } + private def setSpillInProgress(handle: SpillablePartialFileHandle, value: Boolean): Unit = { + spillInProgressField.setBoolean(handle, value) + } + + private def waitUntil(condition: => Boolean, clue: String): Unit = { + val deadline = System.currentTimeMillis() + 5000 + while (!condition && System.currentTimeMillis() < deadline) { + Thread.sleep(10) + } + assert(condition, clue) + } + + /** + * Runs `body` against a finished FILE_ONLY handle backed by a fresh temp file, then closes the + * handle and deletes the file. The temp file is exposed so tests can assert on its deletion. + */ + private def withFinishedFileOnlyHandle(data: Array[Byte] = "lease".getBytes("UTF-8"))( + body: (SpillablePartialFileHandle, File) => Unit): Unit = { + val tempFile = File.createTempFile("test-lease-", ".tmp") + val handle = SpillablePartialFileHandle.createFileOnly(tempFile) + try { + handle.write(data, 0, data.length) + handle.finishWrite() + body(handle, tempFile) + } finally { + handle.close() + if (tempFile.exists()) { + tempFile.delete() + } + } + } + test("FILE_ONLY mode: write and read") { val tempFile = File.createTempFile("test-file-only-", ".tmp") @@ -222,6 +260,68 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac } } + test("MEMORY_WITH_SPILL mode: interrupted close still completes cleanup") { + val tempFile = File.createTempFile("test-interrupted-close-", ".tmp") + val handle = SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 1024, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile) + var closeThread: Thread = null + + try { + val testData = "data retained until close".getBytes("UTF-8") + handle.write(testData, 0, testData.length) + handle.finishWrite() + assert(handle.host.nonEmpty) + assert(tempFile.exists()) + + handle.synchronized { + setSpillInProgress(handle, true) + } + + val closeReturnedWithInterrupt = new AtomicBoolean(false) + val closeError = new AtomicReference[Throwable]() + closeThread = new Thread("test-interrupted-partial-file-close") { + override def run(): Unit = { + try { + handle.close() + closeReturnedWithInterrupt.set(Thread.currentThread().isInterrupted) + } catch { + case t: Throwable => closeError.set(t) + } + } + } + + closeThread.start() + waitUntil(closeThread.getState == Thread.State.WAITING, + "close thread did not wait for spillInProgress") + closeThread.interrupt() + + handle.synchronized { + setSpillInProgress(handle, false) + handle.notifyAll() + } + closeThread.join(5000) + + assert(!closeThread.isAlive) + assert(closeError.get() == null) + assert(closeReturnedWithInterrupt.get()) + assert(handle.host.isEmpty) + assert(!tempFile.exists()) + } finally { + handle.synchronized { + setSpillInProgress(handle, false) + handle.notifyAll() + } + if (closeThread != null && closeThread.isAlive) { + closeThread.join(5000) + } + handle.close() + tempFile.delete() + } + } + test("MEMORY_WITH_SPILL mode: sequential write with single bytes") { val tempFile = File.createTempFile("test-single-bytes-", ".tmp") @@ -572,5 +672,71 @@ class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEac assert(handle.isSpilled) } } -} + // ----- Shuffle read-lease lifecycle: acquireRead / releaseRead / deferred close -------------- + + test("read lease: close defers until the last lease is released") { + withFinishedFileOnlyHandle() { (handle, tempFile) => + handle.acquireRead() + handle.acquireRead() + + // Close requested while two leases are held: must defer. + handle.close() + assert(!handle.isPhysicallyClosed, "handle must stay open while leases are held") + assert(tempFile.exists(), "backing file must not be deleted while leases are held") + + handle.releaseRead() + assert(!handle.isPhysicallyClosed, "handle must stay open while one lease remains") + + handle.releaseRead() + assert(handle.isPhysicallyClosed, "handle must close once the last lease is released") + assert(!tempFile.exists(), "FILE_ONLY close must delete the backing file") + } + } + + test("read lease: repeated close() while a lease is held stays deferred") { + withFinishedFileOnlyHandle() { (handle, tempFile) => + handle.acquireRead() + + // Multiple owners can request close while a reader is active: the catalog's + // unregisterShuffle and the spill store's shutdown both call handle.close(). close() is a + // close-request, not a refcount decrement, so repeated calls stay idempotent and must not + // free the handle while the lease is outstanding. + handle.close() + handle.close() + assert(!handle.isPhysicallyClosed, + "repeated close() must not free the handle while a lease is held") + assert(tempFile.exists(), "backing file must survive repeated close() while a lease is held") + + handle.releaseRead() + assert(handle.isPhysicallyClosed, "handle closes once the lease is released") + assert(!tempFile.exists(), "FILE_ONLY close must delete the backing file") + } + } + + test("read lease: close with no active leases closes immediately") { + withFinishedFileOnlyHandle() { (handle, tempFile) => + handle.close() + assert(handle.isPhysicallyClosed, "close with no leases must close immediately") + assert(!tempFile.exists()) + } + } + + test("read lease: acquireRead after close throws") { + withFinishedFileOnlyHandle() { (handle, _) => + handle.close() + assert(handle.isPhysicallyClosed) + assertThrows[IllegalStateException] { + handle.acquireRead() + } + } + } + + test("read lease: releaseRead without acquireRead throws") { + withFinishedFileOnlyHandle() { (handle, _) => + assertThrows[IllegalStateException] { + handle.releaseRead() + } + } + } +}