From bc334e92974700b04ede7e51ad7403ec94671dde Mon Sep 17 00:00:00 2001 From: Dimitri Ho Date: Wed, 20 Nov 2024 19:55:12 +0100 Subject: [PATCH 1/3] Prevent deadlock when publisher is canceled --- .../main/scala/fs2/concurrent/Channel.scala | 2 +- .../src/main/scala/fs2/concurrent/Topic.scala | 2 +- .../scala/fs2/concurrent/ChannelSuite.scala | 15 +++++++++++++ .../scala/fs2/concurrent/TopicSuite.scala | 21 +++++++++++++++++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/core/shared/src/main/scala/fs2/concurrent/Channel.scala b/core/shared/src/main/scala/fs2/concurrent/Channel.scala index df2f17b54a..0de4c6f56c 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Channel.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Channel.scala @@ -151,7 +151,7 @@ object Channel { new Channel[F, A] { def sendAll: Pipe[F, A, Nothing] = { in => - (in ++ Stream.exec(close.void)) + in.onFinalize(close.void) .evalMap(send) .takeWhile(_.isRight) .drain diff --git a/core/shared/src/main/scala/fs2/concurrent/Topic.scala b/core/shared/src/main/scala/fs2/concurrent/Topic.scala index a7d9bf12c7..2061070b16 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Topic.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Topic.scala @@ -208,7 +208,7 @@ object Topic { } def publish: Pipe[F, A, Nothing] = { in => - (in ++ Stream.exec(close.void)) + in.onFinalize(close.void) .evalMap(publish1) .takeWhile(_.isRight) .drain diff --git a/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala b/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala index 6e15fffa5e..8e8f8dbf52 100644 --- a/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala @@ -323,4 +323,19 @@ class ChannelSuite extends Fs2Suite { racingSendOperations(channel) } + test("stream should terminate when sendAll is interrupted") { + val program = + Channel + .bounded[IO, Unit](1) + .flatMap { ch => + val producer = + Stream + .eval(IO.canceled) + .through(ch.sendAll) + + ch.stream.concurrently(producer).compile.drain + } + + TestControl.executeEmbed(program) // will fail if program is deadlocked + } } diff --git a/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala b/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala index 6f731d41eb..c26fd73dd7 100644 --- a/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala @@ -185,4 +185,25 @@ class TopicSuite extends Fs2Suite { TestControl.executeEmbed(program) // will fail if program is deadlocked } + + test("publisher cancellation does not deadlock") { + val program = + Topic[IO, String] + .flatMap { topic => + val publisher = + Stream + .constant("1") + .covary[IO] + .evalTap(_ => IO.canceled) + .through(topic.publish) + + Stream + .resource(topic.subscribeAwait(1)) + .flatMap(subscriber => subscriber.concurrently(publisher)) + .compile + .drain + } + + TestControl.executeEmbed(program) // will fail if program is deadlocked + } } From 09f86e3ee1b4c5fa722bc986648d6608d2513730 Mon Sep 17 00:00:00 2001 From: Dimitri Ho Date: Sun, 1 Dec 2024 02:08:44 +0100 Subject: [PATCH 2/3] Handle error propagation / cancelation in Channel --- .../main/scala/fs2/concurrent/Channel.scala | 50 +++++++++++-------- .../scala/fs2/concurrent/ChannelSuite.scala | 6 ++- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/core/shared/src/main/scala/fs2/concurrent/Channel.scala b/core/shared/src/main/scala/fs2/concurrent/Channel.scala index 0de4c6f56c..68e71bcaa9 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Channel.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Channel.scala @@ -24,6 +24,7 @@ package concurrent import cats.effect._ import cats.effect.implicits._ +import cats.effect.Resource.ExitCase import cats.syntax.all._ /** Stream aware, multiple producer, single consumer closeable channel. @@ -138,62 +139,62 @@ object Channel { size: Int, waiting: Option[Deferred[F, Unit]], producers: List[(A, Deferred[F, Unit])], - closed: Boolean + closed: Option[ExitCase] ) - val open = State(List.empty, 0, None, List.empty, closed = false) + val open = State(List.empty, 0, None, List.empty, closed = None) - def empty(isClosed: Boolean): State = - if (isClosed) State(List.empty, 0, None, List.empty, closed = true) + def empty(close: Option[ExitCase]): State = + if (close.nonEmpty) State(List.empty, 0, None, List.empty, closed = close) else open (F.ref(open), F.deferred[Unit]).mapN { (state, closedGate) => new Channel[F, A] { def sendAll: Pipe[F, A, Nothing] = { in => - in.onFinalize(close.void) + in.onFinalizeCase(closeWithExitCase(_).void) .evalMap(send) .takeWhile(_.isRight) .drain } - def sendImpl(a: A, close: Boolean) = + def sendImpl(a: A, close: Option[ExitCase]) = F.deferred[Unit].flatMap { producer => state.flatModifyFull { case (poll, state) => state match { - case s @ State(_, _, _, _, closed @ true) => + case s @ State(_, _, _, _, Some(_)) => (s, Channel.closed[Unit].pure[F]) - case State(values, size, waiting, producers, closed @ false) => + case State(values, size, waiting, producers, None) => if (size < capacity) ( State(a :: values, size + 1, None, producers, close), - signalClosure.whenA(close) *> notifyStream(waiting).as(rightUnit) + signalClosure.whenA(close.nonEmpty) *> notifyStream(waiting).as(rightUnit) ) else ( State(values, size, None, (a, producer) :: producers, close), - signalClosure.whenA(close) *> + signalClosure.whenA(close.nonEmpty) *> notifyStream(waiting).as(rightUnit) <* - waitOnBound(producer, poll).unlessA(close) + waitOnBound(producer, poll).unlessA(close.nonEmpty) ) } } } - def send(a: A) = sendImpl(a, false) + def send(a: A) = sendImpl(a, None) - def closeWithElement(a: A) = sendImpl(a, true) + def closeWithElement(a: A) = sendImpl(a, Some(ExitCase.Succeeded)) def trySend(a: A) = state.flatModify { - case s @ State(_, _, _, _, closed @ true) => + case s @ State(_, _, _, _, Some(_)) => (s, Channel.closed[Boolean].pure[F]) - case s @ State(values, size, waiting, producers, closed @ false) => + case s @ State(values, size, waiting, producers, None) => if (size < capacity) ( - State(a :: values, size + 1, None, producers, false), + State(a :: values, size + 1, None, producers, None), notifyStream(waiting).as(rightTrue) ) else @@ -201,13 +202,16 @@ object Channel { } def close = + closeWithExitCase(ExitCase.Succeeded) + + def closeWithExitCase(exitCase: ExitCase): F[Either[Closed, Unit]] = state.flatModify { - case s @ State(_, _, _, _, closed @ true) => + case s @ State(_, _, _, _, Some(_)) => (s, Channel.closed[Unit].pure[F]) - case State(values, size, waiting, producers, closed @ false) => + case State(values, size, waiting, producers, None) => ( - State(values, size, None, producers, true), + State(values, size, None, producers, Some(exitCase)), notifyStream(waiting).as(rightUnit) <* signalClosure ) } @@ -250,8 +254,12 @@ object Channel { unblock.as(Pull.output(toEmit) >> consumeLoop) } else { F.pure( - if (closed) Pull.done - else Pull.eval(waiting.get) >> consumeLoop + closed match { + case Some(ExitCase.Succeeded) => Pull.done + case Some(ExitCase.Errored(e)) => Pull.raiseError(e) + case Some(ExitCase.Canceled) => Pull.eval(F.canceled) + case None => Pull.eval(waiting.get) >> consumeLoop + } ) } } diff --git a/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala b/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala index 8e8f8dbf52..c19fdf99a7 100644 --- a/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala @@ -29,6 +29,8 @@ import scala.concurrent.duration._ import org.scalacheck.effect.PropF.forAllF +import scala.concurrent.CancellationException + class ChannelSuite extends Fs2Suite { test("receives some simple elements above capacity and closes") { @@ -336,6 +338,8 @@ class ChannelSuite extends Fs2Suite { ch.stream.concurrently(producer).compile.drain } - TestControl.executeEmbed(program) // will fail if program is deadlocked + TestControl + .executeEmbed(program) + .intercept[CancellationException] } } From cf8799fa3ce0239ec95dcfd3c8424691baaf1749 Mon Sep 17 00:00:00 2001 From: Dimitri Ho Date: Thu, 19 Dec 2024 20:25:58 +0100 Subject: [PATCH 3/3] Handle error propagation / cancelation in Topic --- .../main/scala/fs2/concurrent/Channel.scala | 18 ++++++++++++++++++ .../src/main/scala/fs2/concurrent/Topic.scala | 18 ++++++++++++++++-- .../test/scala/fs2/concurrent/TopicSuite.scala | 6 +++++- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/core/shared/src/main/scala/fs2/concurrent/Channel.scala b/core/shared/src/main/scala/fs2/concurrent/Channel.scala index 68e71bcaa9..9807a525c4 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Channel.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Channel.scala @@ -117,6 +117,18 @@ sealed trait Channel[F[_], A] { */ def closeWithElement(a: A): F[Either[Channel.Closed, Unit]] + /** Raises an error, closing the channel with an error state. + * + * No-op if the channel is closed, see [[close]] for further info. + */ + def raiseError(e: Throwable): F[Either[Channel.Closed, Unit]] + + /** Cancels the channel, closing it with a canceled state. + * + * No-op if the channel is closed, see [[close]] for further info. + */ + def cancel: F[Either[Channel.Closed, Unit]] + /** Returns true if this channel is closed */ def isClosed: F[Boolean] @@ -216,6 +228,12 @@ object Channel { ) } + def raiseError(e: Throwable): F[Either[Closed, Unit]] = + closeWithExitCase(ExitCase.Errored(e)) + + def cancel: F[Either[Closed, Unit]] = + closeWithExitCase(ExitCase.Canceled) + def isClosed = closedGate.tryGet.map(_.isDefined) def closed = closedGate.get diff --git a/core/shared/src/main/scala/fs2/concurrent/Topic.scala b/core/shared/src/main/scala/fs2/concurrent/Topic.scala index 2061070b16..21245be41d 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Topic.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Topic.scala @@ -23,6 +23,7 @@ package fs2 package concurrent import cats.effect._ +import cats.effect.Resource.ExitCase import cats.effect.implicits._ import cats.syntax.all._ import scala.collection.immutable.LongMap @@ -208,7 +209,8 @@ object Topic { } def publish: Pipe[F, A, Nothing] = { in => - in.onFinalize(close.void) + in + .onFinalizeCase(closeWithExitCase(_).void) .evalMap(publish1) .takeWhile(_.isRight) .drain @@ -223,13 +225,25 @@ object Topic { def subscribers: Stream[F, Int] = subscriberCount.discrete def close: F[Either[Topic.Closed, Unit]] = + closeWithExitCase(ExitCase.Succeeded) + + def closeWithExitCase(exitCase: ExitCase): F[Either[Closed, Unit]] = signalClosure .complete(()) .flatMap { completedNow => val result = if (completedNow) Topic.rightUnit else Topic.closed + val closeChannel = + (channel: Channel[F, A]) => + exitCase match { + case ExitCase.Succeeded => channel.close.void + case ExitCase.Errored(e) => channel.raiseError(e).void + case ExitCase.Canceled => channel.cancel.void + } state.get - .flatMap { case (subs, _) => foreach(subs)(_.close.void) } + .flatMap { case (subs, _) => + foreach(subs)(closeChannel) + } .as(result) } .uncancelable diff --git a/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala b/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala index c26fd73dd7..2784d23bba 100644 --- a/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala @@ -25,6 +25,8 @@ package concurrent import cats.syntax.all._ import cats.effect.IO import scala.concurrent.duration._ +import scala.concurrent.CancellationException + import cats.effect.testkit.TestControl class TopicSuite extends Fs2Suite { @@ -204,6 +206,8 @@ class TopicSuite extends Fs2Suite { .drain } - TestControl.executeEmbed(program) // will fail if program is deadlocked + TestControl + .executeEmbed(program) + .intercept[CancellationException] } }