Skip to content

Commit

Permalink
Merge pull request #4264 from armanbilge/bug/blocking-starvation
Browse files Browse the repository at this point in the history
Transfer tick state when replacing a blocked thread
  • Loading branch information
djspiewak authored Feb 5, 2025
2 parents 6766177 + 3ce8193 commit 65c5ba8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
val thread =
new WorkerThread(
index,
0,
queue,
parkedSignal,
externalQueue,
Expand All @@ -174,6 +175,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
system,
poller,
metrics,
new WorkerThread.TransferState,
this)

workerThreads.set(i, thread)
Expand Down
47 changes: 29 additions & 18 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import java.lang.Long.MIN_VALUE
import java.util.concurrent.{ArrayBlockingQueue, ThreadLocalRandom}
import java.util.concurrent.atomic.AtomicBoolean

import WorkerThread.Metrics
import WorkerThread.{Metrics, TransferState}

/**
* Implementation of the worker thread at the heart of the [[WorkStealingThreadPool]].
Expand All @@ -45,6 +45,7 @@ import WorkerThread.Metrics
*/
private[effect] final class WorkerThread[P <: AnyRef](
idx: Int,
private[this] var tick: Int,
// Local queue instance with exclusive write access.
private[this] var queue: LocalQueue,
// The state of the `WorkerThread` (parked/unparked).
Expand All @@ -58,6 +59,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
private[this] val system: PollingSystem.WithPoller[P],
private[this] var _poller: P,
private[this] var metrics: Metrics,
private[this] var transferState: TransferState,
// Reference to the `WorkStealingThreadPool` in which this thread operates.
pool: WorkStealingThreadPool[P])
extends Thread
Expand Down Expand Up @@ -109,7 +111,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
*/
private[this] var _active: Runnable = _

private val indexTransfer: ArrayBlockingQueue[Integer] = new ArrayBlockingQueue(1)
private val stateTransfer: ArrayBlockingQueue[TransferState] = new ArrayBlockingQueue(1)
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[effect] var currentIOFiber: IOFiber[?] = _
Expand Down Expand Up @@ -313,8 +315,6 @@ private[effect] final class WorkerThread[P <: AnyRef](
random = ThreadLocalRandom.current()
val rnd = random

var iteration = 0

val done = pool.done

/*
Expand Down Expand Up @@ -726,15 +726,16 @@ private[effect] final class WorkerThread[P <: AnyRef](
_active = null
_poller = null.asInstanceOf[P]
metrics = null
transferState = null

// Add this thread to the cached threads data structure, to be picked up
// by another thread in the future.
pool.cachedThreads.add(this)
try {
val len = runtimeBlockingExpiration.length
val unit = runtimeBlockingExpiration.unit
var newIdx: Integer = indexTransfer.poll(len, unit)
if (newIdx eq null) {
var newState = stateTransfer.poll(len, unit)
if (newState eq null) {
// The timeout elapsed and no one woke up this thread. Try to remove
// the thread from the cached threads data structure.
if (pool.cachedThreads.remove(this)) {
Expand All @@ -745,12 +746,12 @@ private[effect] final class WorkerThread[P <: AnyRef](
// Someone else concurrently stole this thread from the cached
// data structure and will transfer the data soon. Time to wait
// for it again.
newIdx = indexTransfer.take()
init(newIdx)
newState = stateTransfer.take()
init(newState)
}
} else {
// Some other thread woke up this thread. Time to take its place.
init(newIdx)
init(newState)
}
} catch {
case _: InterruptedException =>
Expand All @@ -759,13 +760,9 @@ private[effect] final class WorkerThread[P <: AnyRef](
// exit.
return
}

// Reset the state of the thread for resumption.
blocking = false
iteration = 1
}

((iteration & ExternalQueueTicksMask): @switch) match {
((tick & ExternalQueueTicksMask): @switch) match {
case 0 =>
if (pool.blockedThreadDetectionEnabled) {
// TODO prefetch pool.workerThread or Thread.State.BLOCKED ?
Expand Down Expand Up @@ -875,7 +872,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
// Continue executing fibers from the local queue.
}

iteration += 1
tick += 1
}
}

Expand Down Expand Up @@ -948,7 +945,9 @@ private[effect] final class WorkerThread[P <: AnyRef](
val idx = index
pool.replaceWorker(idx, cached)
// Transfer the data structures to the cached thread and wake it up.
val _ = cached.indexTransfer.offer(idx)
transferState.index = idx
transferState.tick = tick + 1
val _ = cached.stateTransfer.offer(transferState)
} else {
// Spawn a new `WorkerThread`, a literal clone of this one. It is safe to
// transfer ownership of the local queue and the parked signal to the new
Expand All @@ -964,6 +963,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
val clone =
new WorkerThread(
idx,
tick + 1,
queue,
parked,
external,
Expand All @@ -972,6 +972,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
system,
_poller,
metrics,
transferState,
pool)
// Make sure the clone gets our old name:
val clonePrefix = pool.threadPrefix
Expand All @@ -995,18 +996,23 @@ private[effect] final class WorkerThread[P <: AnyRef](
thunk
}

private[this] def init(newIdx: Int): Unit = {
private[this] def init(newState: TransferState): Unit = {
val newIdx = newState.index
_index = newIdx
tick = newState.tick
queue = pool.localQueues(newIdx)
sleepers = pool.sleepers(newIdx)
parked = pool.parkedSignals(newIdx)
fiberBag = pool.fiberBags(newIdx)
_poller = pool.pollers(newIdx)
metrics = pool.metrices(newIdx)
transferState = newState

// Reset the name of the thread to the regular prefix.
val prefix = pool.threadPrefix
setName(s"$prefix-$newIdx")
setName(s"$prefix-${_index}")

blocking = false
}

/**
Expand All @@ -1026,6 +1032,11 @@ private[effect] final class WorkerThread[P <: AnyRef](

private[effect] object WorkerThread {

private[unsafe] final class TransferState {
var index: Int = _
var tick: Int = _
}

final class Metrics {
private[this] var idleTime: Long = 0
def getIdleTime(): Long = idleTime
Expand Down
23 changes: 23 additions & 0 deletions tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,29 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala
}
}

"blocking work does not starve poll" in {

val (pool, poller, shutdown) = IORuntime.createWorkStealingComputeThreadPool(
threads = 1,
pollingSystem = DummySystem)

implicit val runtime: IORuntime =
IORuntime.builder().setCompute(pool, shutdown).addPoller(poller, () => ()).build()

try {
def mkBlockingWork: IO[Unit] = IO.defer(mkBlockingWork.start) *> IO.blocking(())

val test = mkBlockingWork *>
IO.pollers.map(_.head.asInstanceOf[DummyPoller]).flatMap { poller =>
poller.poll.replicateA_(100).as(true)
}

test.unsafeRunTimed(1.second) must beSome(beTrue)
} finally {
runtime.shutdown()
}
}

if (javaMajorVersion >= 21)
"block in-place on virtual threads" in real {
val loomExec = classOf[Executors]
Expand Down

0 comments on commit 65c5ba8

Please sign in to comment.