From 9ed93c56116b40262c799349583e86f5958d46dd Mon Sep 17 00:00:00 2001 From: kerr Date: Sun, 30 Jul 2023 19:51:33 +0800 Subject: [PATCH] feat: Add flatmapConcat with parallelism. --- .../akka/stream/FlatMapConcatBenchmark.scala | 70 +++- .../FlowFlatMapConcatParallelismSpec.scala | 191 ++++++++++ .../scala/akka/stream/impl/FailedSource.scala | 2 +- .../akka/stream/impl/JavaStreamSource.scala | 2 +- .../main/scala/akka/stream/impl/Stages.scala | 1 + .../akka/stream/impl/TraversalBuilder.scala | 42 ++- .../scala/akka/stream/impl/fusing/Ops.scala | 1 + .../stream/impl/fusing/StreamOfStreams.scala | 349 +++++++++++++++++- .../main/scala/akka/stream/javadsl/Flow.scala | 19 + .../scala/akka/stream/javadsl/Source.scala | 17 + .../scala/akka/stream/javadsl/SubFlow.scala | 20 + .../scala/akka/stream/javadsl/SubSource.scala | 20 + .../scala/akka/stream/scaladsl/Flow.scala | 17 + 13 files changed, 727 insertions(+), 24 deletions(-) create mode 100644 akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlatMapConcatParallelismSpec.scala diff --git a/akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala index 34b6a7928a3..df31e2bc25d 100644 --- a/akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala +++ b/akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala @@ -6,10 +6,8 @@ package akka.stream import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit - -import scala.concurrent.Await +import scala.concurrent.{ Await, Future } import scala.concurrent.duration._ - import com.typesafe.config.ConfigFactory import org.openjdk.jmh.annotations._ @@ -60,9 +58,15 @@ class FlatMapConcatBenchmark { @OperationsPerInvocation(OperationsPerInvocation) def sourceDotSingle(): Unit = { val latch = new CountDownLatch(1) - testSource.flatMapConcat(Source.single).runWith(new LatchSink(OperationsPerInvocation, latch)) + awaitLatch(latch) + } + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def sourceDotSingleP1(): Unit = { + val latch = new CountDownLatch(1) + testSource.flatMapConcat(1, Source.single).runWith(new LatchSink(OperationsPerInvocation, latch)) awaitLatch(latch) } @@ -70,11 +74,19 @@ class FlatMapConcatBenchmark { @OperationsPerInvocation(OperationsPerInvocation) def internalSingleSource(): Unit = { val latch = new CountDownLatch(1) - testSource .flatMapConcat(elem => new GraphStages.SingleSource(elem)) .runWith(new LatchSink(OperationsPerInvocation, latch)) + awaitLatch(latch) + } + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def internalSingleSourceP1(): Unit = { + val latch = new CountDownLatch(1) + testSource + .flatMapConcat(1, elem => new GraphStages.SingleSource(elem)) + .runWith(new LatchSink(OperationsPerInvocation, latch)) awaitLatch(latch) } @@ -82,9 +94,55 @@ class FlatMapConcatBenchmark { @OperationsPerInvocation(OperationsPerInvocation) def oneElementList(): Unit = { val latch = new CountDownLatch(1) - testSource.flatMapConcat(n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch)) + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def oneElementListP1(): Unit = { + val latch = new CountDownLatch(1) + testSource.flatMapConcat(1, n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch)) + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def completedFuture(): Unit = { + val latch = new CountDownLatch(1) + testSource + .flatMapConcat(n => Source.future(Future.successful(n))) + .runWith(new LatchSink(OperationsPerInvocation, latch)) + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def completedFutureP1(): Unit = { + val latch = new CountDownLatch(1) + testSource + .flatMapConcat(1, n => Source.future(Future.successful(n))) + .runWith(new LatchSink(OperationsPerInvocation, latch)) + awaitLatch(latch) + } + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def normalFuture(): Unit = { + val latch = new CountDownLatch(1) + testSource + .flatMapConcat(n => Source.future(Future(n)(system.dispatcher))) + .runWith(new LatchSink(OperationsPerInvocation, latch)) + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def normalFutureP1(): Unit = { + val latch = new CountDownLatch(1) + testSource + .flatMapConcat(1, n => Source.future(Future(n)(system.dispatcher))) + .runWith(new LatchSink(OperationsPerInvocation, latch)) awaitLatch(latch) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlatMapConcatParallelismSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlatMapConcatParallelismSpec.scala new file mode 100644 index 00000000000..8f832a44499 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlatMapConcatParallelismSpec.scala @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2014-2024 Lightbend Inc. + */ + +package akka.stream.scaladsl + +import akka.NotUsed +import akka.pattern.FutureTimeoutSupport +import akka.stream.OverflowStrategy +import akka.stream.testkit._ +import akka.stream.testkit.scaladsl.TestSink + +import java.util.concurrent.ThreadLocalRandom +import java.util.concurrent.atomic.AtomicInteger +import scala.annotation.switch +import scala.concurrent.{ ExecutionContext, Future } +import scala.concurrent.duration.DurationInt +import scala.util.control.NoStackTrace + +class FlowFlatMapConcatParallelismSpec extends StreamSpec(""" + akka.stream.materializer.initial-input-buffer-size = 2 + """) with ScriptedTest with FutureTimeoutSupport { + val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right) + + class BoomException extends RuntimeException("BOOM~~") with NoStackTrace + "A flatMapConcat" must { + + for (i <- 1 until 129) { + s"work with value presented sources with parallelism: $i" in { + Source( + List( + Source.empty[Int], + Source.single(1), + Source.empty[Int], + Source(List(2, 3, 4)), + Source.future(Future.successful(5)), + Source.lazyFuture(() => Future.successful(6)), + Source.future(after(1.millis)(Future.successful(7))))) + .flatMapConcat(i, identity) + .runWith(toSeq) + .futureValue should ===(1 to 7) + } + } + + def generateRandomValuePresentedSources(nums: Int): (Int, Seq[Source[Int, NotUsed]]) = { + val seq = Seq.tabulate(nums) { _ => + val random = ThreadLocalRandom.current().nextInt(1, 10) + (random: @switch) match { + case 1 => Source.single(1) + case 2 => Source(List(1)) + case 3 => Source.fromJavaStream(() => java.util.stream.Stream.of(1)) + case 4 => Source.future(Future.successful(1)) + case 5 => Source.future(after(1.millis)(Future.successful(1))) + case _ => Source.empty[Int] + } + } + val sum = seq.filterNot(_.eq(Source.empty[Int])).size + (sum, seq) + } + + def generateSequencedValuePresentedSources(nums: Int): (Int, Seq[Source[Int, NotUsed]]) = { + val seq = Seq.tabulate(nums) { index => + val random = ThreadLocalRandom.current().nextInt(1, 6) + (random: @switch) match { + case 1 => Source.single(index) + case 2 => Source(List(index)) + case 3 => Source.fromJavaStream(() => java.util.stream.Stream.of(index)) + case 4 => Source.future(Future.successful(index)) + case 5 => Source.future(after(1.millis)(Future.successful(index))) + case _ => throw new IllegalStateException("unexpected") + } + } + val sum = (0 until nums).sum + (sum, seq) + } + + for (i <- 1 until 129) { + s"work with generated value presented sources with parallelism: $i " in { + val (sum, sources) = generateRandomValuePresentedSources(100000) + Source(sources) + .flatMapConcat(i, identity) + .runWith(Sink.seq) + .map(_.sum)(ExecutionContext.parasitic) + .futureValue shouldBe sum + } + } + + for (i <- 1 until 129) { + s"work with generated value sequenced sources with parallelism: $i " in { + val (sum, sources) = generateSequencedValuePresentedSources(100000) + Source(sources) + .flatMapConcat(i, identity) + //check the order + .statefulMap(() => -1)((pre, current) => { + if (pre + 1 != current) { + throw new IllegalStateException(s"expected $pre + 1 == $current") + } + (current, current) + }, _ => None) + .runWith(Sink.seq) + .map(_.sum)(ExecutionContext.parasitic) + .futureValue shouldBe sum + } + } + + "work with value presented failed sources" in { + val ex = new BoomException + Source( + List( + Source.empty[Int], + Source.single(1), + Source.empty[Int], + Source(List(2, 3, 4)), + Source.future(Future.failed(ex)), + Source.lazyFuture(() => Future.successful(5)))) + .flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity) + .onErrorComplete[BoomException]() + .runWith(toSeq) + .futureValue should ===(1 to 4) + } + + "work with value presented sources when demands slow" in { + val prob = Source( + List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5)))) + .flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity) + .runWith(TestSink()) + + prob.request(1) + prob.expectNext(1) + prob.expectNoMessage(1.seconds) + prob.request(2) + prob.expectNext(2, 3) + prob.expectNoMessage(1.seconds) + prob.request(2) + prob.expectNext(4, 5) + prob.expectComplete() + } + + "can do pre materialization when parallelism > 1" in { + val materializationCounter = new AtomicInteger(0) + val randomParallelism = ThreadLocalRandom.current().nextInt(4, 65) + val prob = Source(1 to (randomParallelism * 3)) + .flatMapConcat( + randomParallelism, + value => { + Source + .lazySingle(() => { + materializationCounter.incrementAndGet() + value + }) + .buffer(1, overflowStrategy = OverflowStrategy.backpressure) + }) + .runWith(TestSink()) + + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe 0 + + prob.request(1) + prob.expectNext(1.seconds, 1) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe (randomParallelism + 1) + materializationCounter.set(0) + + prob.request(2) + prob.expectNextN(List(2, 3)) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe 2 + materializationCounter.set(0) + + prob.request(randomParallelism - 3) + prob.expectNextN(4 to randomParallelism) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe (randomParallelism - 3) + materializationCounter.set(0) + + prob.request(randomParallelism) + prob.expectNextN(randomParallelism + 1 to randomParallelism * 2) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe randomParallelism + materializationCounter.set(0) + + prob.request(randomParallelism) + prob.expectNextN(randomParallelism * 2 + 1 to randomParallelism * 3) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe 0 + prob.expectComplete() + } + + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/FailedSource.scala b/akka-stream/src/main/scala/akka/stream/impl/FailedSource.scala index 96a4a9fb361..898dad7bca5 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FailedSource.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FailedSource.scala @@ -12,7 +12,7 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler } /** * INTERNAL API */ -@InternalApi private[akka] final class FailedSource[T](failure: Throwable) extends GraphStage[SourceShape[T]] { +@InternalApi private[akka] final class FailedSource[T](val failure: Throwable) extends GraphStage[SourceShape[T]] { val out = Outlet[T]("FailedSource.out") override val shape = SourceShape(out) diff --git a/akka-stream/src/main/scala/akka/stream/impl/JavaStreamSource.scala b/akka-stream/src/main/scala/akka/stream/impl/JavaStreamSource.scala index 59bc096c492..5327d6ba634 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/JavaStreamSource.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/JavaStreamSource.scala @@ -12,7 +12,7 @@ import java.util.function.Consumer /** INTERNAL API */ @InternalApi private[stream] final class JavaStreamSource[T, S <: java.util.stream.BaseStream[T, S]]( - open: () => java.util.stream.BaseStream[T, S]) + val open: () => java.util.stream.BaseStream[T, S]) extends GraphStage[SourceShape[T]] { val out: Outlet[T] = Outlet("JavaStreamSource") diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index d5ab3c45553..c5dec4c784f 100755 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -79,6 +79,7 @@ import akka.stream.Attributes._ val mergePreferred = name("mergePreferred") val mergePrioritized = name("mergePrioritized") val flattenMerge = name("flattenMerge") + val flattenConcat = name("flattenConcat") val recoverWith = name("recoverWith") val onErrorComplete = name("onErrorComplete") val broadcast = name("broadcast") diff --git a/akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala b/akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala index 9c1c4065e5e..a42ffbe1a11 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala @@ -6,13 +6,12 @@ package akka.stream.impl import scala.collection.immutable.Map.Map1 import scala.language.existentials - import akka.annotation.{ DoNotInherit, InternalApi } import akka.stream._ import akka.stream.impl.StreamLayout.AtomicModule import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 } import akka.stream.impl.fusing.GraphStageModule -import akka.stream.impl.fusing.GraphStages.SingleSource +import akka.stream.impl.fusing.GraphStages.{ FutureSource, IterableSource, SingleSource } import akka.stream.scaladsl.Keep import akka.util.OptionVal @@ -369,12 +368,51 @@ import akka.util.OptionVal } } + /** + * Try to find `SingleSource` or wrapped such. This is used as a + * performance optimization in FlattenConcat and possibly other places. + */ + def getValuePresentedSource[A >: Null](graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = { + def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match { + case _: SingleSource[_] | _: FutureSource[_] | _: IterableSource[_] | _: JavaStreamSource[_, _] | + _: FailedSource[_] => + true + case maybeEmpty if isEmptySource(maybeEmpty) => true + case _ => false + } + graph match { + case _ if isValuePresentedSource(graph) => OptionVal.Some(graph) + case _ => + graph.traversalBuilder match { + case l: LinearTraversalBuilder => + l.pendingBuilder match { + case OptionVal.Some(a: AtomicTraversalBuilder) => + a.module match { + case m: GraphStageModule[_, _] => + m.stage match { + case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) => + // It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize. + if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync) + OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) + else OptionVal.None + case _ => OptionVal.None + } + case _ => OptionVal.None + } + case _ => OptionVal.None + } + case _ => OptionVal.None + } + } + } + /** * Test if a Graph is an empty Source. * */ def isEmptySource(graph: Graph[SourceShape[_], _]): Boolean = graph match { case source: scaladsl.Source[_, _] if source eq scaladsl.Source.empty => true case source: javadsl.Source[_, _] if source eq javadsl.Source.empty() => true + case EmptySource => true case _ => false } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 9286d6d4778..d40a3e67080 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -1268,6 +1268,7 @@ private[stream] object Collect { */ @InternalApi private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In => Future[Out]) extends GraphStage[FlowShape[In, Out]] { + require(parallelism >= 1, "parallelism should >= 1") import MapAsync._ diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala index bf0855af7e4..9c8d440af31 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala @@ -4,37 +4,43 @@ package akka.stream.impl.fusing -import java.util.Collections -import java.util.concurrent.atomic.AtomicReference - -import scala.annotation.tailrec -import scala.collection.immutable -import scala.concurrent.duration.FiniteDuration -import scala.util.control.NonFatal - import akka.NotUsed import akka.annotation.InternalApi -import akka.stream._ import akka.stream.ActorAttributes.StreamSubscriptionTimeout import akka.stream.ActorAttributes.SupervisionStrategy import akka.stream.Attributes.SourceLocation -import akka.stream.impl.{ Buffer => BufferImpl } -import akka.stream.impl.ActorSubscriberMessage +import akka.stream._ +import akka.stream.impl.{ + ActorSubscriberMessage, + FailedSource, + JavaStreamSource, + SubscriptionTimeoutException, + TraversalBuilder, + Buffer => BufferImpl +} import akka.stream.impl.ActorSubscriberMessage.OnError import akka.stream.impl.Stages.DefaultAttributes -import akka.stream.impl.SubscriptionTimeoutException -import akka.stream.impl.TraversalBuilder -import akka.stream.impl.fusing.GraphStages.SingleSource +import akka.stream.impl.fusing.GraphStages.{ FutureSource, IterableSource, SingleSource } import akka.stream.scaladsl._ import akka.stream.stage._ import akka.util.OptionVal + import scala.jdk.CollectionConverters._ +import java.util.Collections +import java.util.concurrent.atomic.AtomicReference +import scala.annotation.{ nowarn, tailrec } +import scala.collection.immutable +import scala.concurrent.{ ExecutionContext, Future } +import scala.concurrent.duration.FiniteDuration +import scala.util.{ Failure, Try } +import scala.util.control.NonFatal /** * INTERNAL API */ @InternalApi private[akka] final class FlattenMerge[T, M](val breadth: Int) extends GraphStage[FlowShape[Graph[SourceShape[T], M], T]] { + require(breadth >= 1, "breadth should >= 1") private val in = Inlet[Graph[SourceShape[T], M]]("flatten.in") private val out = Outlet[T]("flatten.out") @@ -137,6 +143,321 @@ import scala.jdk.CollectionConverters._ override def toString: String = s"FlattenMerge($breadth)" } +/** + * INTERNAL API + */ +@InternalApi +private[akka] object FlattenConcat { + private sealed abstract class InflightSource[T] { + def hasNext: Boolean + def next(): T + def tryPull(): Unit + def cancel(cause: Throwable): Unit + def isClosed: Boolean + def hasFailed: Boolean = failure.isDefined + def failure: Option[Throwable] = None + def materialize(): Unit = () + } + + private final class InflightIteratorSource[T](iterator: Iterator[T]) extends InflightSource[T] { + override def hasNext: Boolean = iterator.hasNext + override def next(): T = iterator.next() + override def tryPull(): Unit = () + override def cancel(cause: Throwable): Unit = () + override def isClosed: Boolean = !hasNext + } + + private final class InflightCompletedFutureSource[T](result: Try[T]) extends InflightSource[T] { + private var _hasNext = result.isSuccess + override def hasNext: Boolean = _hasNext + override def next(): T = { + if (_hasNext) { + _hasNext = false + result.get + } else throw new NoSuchElementException("next called after completion") + } + override def hasFailed: Boolean = result.isFailure + override def failure: Option[Throwable] = result.failed.toOption + override def tryPull(): Unit = () + override def cancel(cause: Throwable): Unit = () + override def isClosed: Boolean = true + } + + private final class InflightPendingFutureSource[T](cb: InflightSource[T] => Unit) + extends InflightSource[T] + with (Try[T] => Unit) { + private var result: Try[T] = MapAsync.NotYetThere + private var consumed = false + override def apply(result: Try[T]): Unit = { + this.result = result + cb(this) + } + override def hasNext: Boolean = (result ne MapAsync.NotYetThere) && !consumed && result.isSuccess + override def next(): T = { + if (!consumed) { + consumed = true + result.get + } else throw new NoSuchElementException("next called after completion") + } + override def hasFailed: Boolean = (result ne MapAsync.NotYetThere) && result.isFailure + override def failure: Option[Throwable] = if (result eq MapAsync.NotYetThere) None else result.failed.toOption + override def tryPull(): Unit = () + override def cancel(cause: Throwable): Unit = () + override def isClosed: Boolean = consumed || hasFailed + } +} + +/** + * INTERNAL API + */ +@InternalApi +private[akka] final class FlattenConcat[T, M](parallelism: Int) + extends GraphStage[FlowShape[Graph[SourceShape[T], M], T]] { + require(parallelism >= 1, "parallelism should >= 1") + private val in = Inlet[Graph[SourceShape[T], M]]("flattenConcat.in") + private val out = Outlet[T]("flattenConcat.out") + + override def initialAttributes: Attributes = DefaultAttributes.flattenConcat + override val shape: FlowShape[Graph[SourceShape[T], M], T] = FlowShape(in, out) + override def createLogic(enclosingAttributes: Attributes) = { + final object FlattenConcatLogic extends GraphStageLogic(shape) with InHandler with OutHandler { + import FlattenConcat._ + // InflightSource[T] or SingleSource[T] + // AnyRef here to avoid lift the SingleSource[T] to InflightSource[T] + private var queue: BufferImpl[AnyRef] = _ + private val invokeCb: InflightSource[T] => Unit = + getAsyncCallback[InflightSource[T]](futureSourceCompleted).invoke + + override def preStart(): Unit = queue = BufferImpl(parallelism, enclosingAttributes) + + private def futureSourceCompleted(futureSource: InflightSource[T]): Unit = { + if (queue.peek() eq futureSource) { + if (isAvailable(out) && futureSource.hasNext) { + push(out, futureSource.next()) //TODO should filter out the `null` here? + if (futureSource.isClosed) { + handleCurrentSourceClosed(futureSource) + } + } else if (futureSource.isClosed) { + handleCurrentSourceClosed(futureSource) + } + } // else just ignore, it will be picked up by onPull + } + + override def onPush(): Unit = { + addSource(grab(in)) + //must try pull after addSource to avoid queue overflow + if (!queue.isFull) { // try to keep the maximum parallelism + tryPull(in) + } + } + + override def onUpstreamFinish(): Unit = if (queue.isEmpty) completeStage() + + override def onUpstreamFailure(ex: Throwable): Unit = { + super.onUpstreamFailure(ex) + cancelInflightSources(SubscriptionWithCancelException.NoMoreElementsNeeded) + } + + override def onPull(): Unit = { + //purge if possible + queue.peek() match { + case src: SingleSource[T] @unchecked => + push(out, src.elem) + removeSource() + case src: InflightSource[T] @unchecked => pushOut(src) + case null => //queue is empty + if (!hasBeenPulled(in)) { + tryPull(in) + } else if (isClosed(in)) { + completeStage() + } + case _ => throw new IllegalStateException("Should not reach here.") + } + } + + private def pushOut(src: InflightSource[T]): Unit = { + if (src.hasNext) { + push(out, src.next()) + if (src.isClosed) { + handleCurrentSourceClosed(src) + } + } else if (src.isClosed) { + handleCurrentSourceClosed(src) + } else { + src.tryPull() + } + } + + private def handleCurrentSourceClosed(source: InflightSource[T]): Unit = { + source.failure match { + case Some(cause) => onUpstreamFailure(cause) + case None => removeSource(source) + } + } + + override def onDownstreamFinish(cause: Throwable): Unit = { + super.onDownstreamFinish(cause) + cancelInflightSources(cause) + } + + private def cancelInflightSources(cause: Throwable): Unit = { + if (queue.nonEmpty) { + var source = queue.dequeue() + while ((source ne null) && (source.isInstanceOf[InflightSource[T] @unchecked])) { + source.asInstanceOf[InflightSource[T]].cancel(cause) + source = queue.dequeue() + } + } + } + + private def addSource(singleSource: SingleSource[T]): Unit = { + if (isAvailable(out) && queue.isEmpty) { + push(out, singleSource.elem) + } else { + queue.enqueue(singleSource) + } + } + + private def addSourceElements(iterator: Iterator[T]): Unit = { + val inflightSource = new InflightIteratorSource[T](iterator) + if (isAvailable(out) && queue.isEmpty) { + if (inflightSource.hasNext) { + push(out, inflightSource.next()) + if (inflightSource.hasNext) { + queue.enqueue(inflightSource) + } + } + } else { + queue.enqueue(inflightSource) + } + } + + private def addCompletedFutureElem(elem: Try[T]): Unit = { + if (isAvailable(out) && queue.isEmpty) { + elem match { + case scala.util.Success(value) => push(out, value) + case scala.util.Failure(ex) => onUpstreamFailure(ex) + } + } else { + queue.enqueue(new InflightCompletedFutureSource(elem)) + } + } + + private def addPendingFutureElem(future: Future[T]): Unit = { + val inflightSource = new InflightPendingFutureSource[T](invokeCb) + future.onComplete(inflightSource)(ExecutionContext.parasitic) + queue.enqueue(inflightSource) + } + + private def attachAndMaterializeSource(source: Graph[SourceShape[T], M]): Unit = { + object inflightSource extends InflightSource[T] { self => + private val sinkIn = new SubSinkInlet[T]("FlattenConcatSink") + private var upstreamFailure = Option.empty[Throwable] + sinkIn.setHandler(new InHandler { + override def onPush(): Unit = { + if (isAvailable(out) && (queue.peek() eq self)) { + push(out, sinkIn.grab()) + } + } + override def onUpstreamFinish(): Unit = if (!sinkIn.isAvailable) removeSource(self) + override def onUpstreamFailure(ex: Throwable): Unit = { + upstreamFailure = Some(ex) + // if it's the current emitting source, fail the stage + if (queue.peek() eq self) { + super.onUpstreamFailure(ex) + } // else just mark the source as failed + } + }) + + final override def materialize(): Unit = { + val graph = Source.fromGraph(source).to(sinkIn.sink) + interpreter.subFusingMaterializer.materialize(graph, defaultAttributes = enclosingAttributes) + } + final override def cancel(cause: Throwable): Unit = sinkIn.cancel(cause) + final override def hasNext: Boolean = sinkIn.isAvailable + final override def isClosed: Boolean = sinkIn.isClosed + final override def failure: Option[Throwable] = upstreamFailure + final override def next(): T = sinkIn.grab() + final override def tryPull(): Unit = if (!sinkIn.isClosed && !sinkIn.hasBeenPulled) sinkIn.pull() + } + if (isAvailable(out) && queue.isEmpty) { + //this is the first one, pull + inflightSource.tryPull() + } + queue.enqueue(inflightSource) + inflightSource.materialize() + } + + private def addSource(source: Graph[SourceShape[T], M]): Unit = { + (TraversalBuilder.getValuePresentedSource(source): @nowarn("cat=lint-infer-any")) match { + case OptionVal.Some(graph) => + graph match { + case single: SingleSource[T] @unchecked => addSource(single) + case futureSource: FutureSource[T] @unchecked => + val future = futureSource.future + future.value match { + case Some(elem) => addCompletedFutureElem(elem) + case None => addPendingFutureElem(future) + } + case iterable: IterableSource[T] @unchecked => addSourceElements(iterable.elements.iterator) + case javaStream: JavaStreamSource[T, _] @unchecked => + addSourceElements(javaStream.open().iterator.asScala) + case failed: FailedSource[T] @unchecked => addCompletedFutureElem(Failure(failed.failure)) + case maybeEmpty if TraversalBuilder.isEmptySource(maybeEmpty) => //Empty source is discarded + case _ => attachAndMaterializeSource(source) + } + case _ => attachAndMaterializeSource(source) + } + + } + + private def removeSource(): Unit = { + queue.dequeue() + pullIfNeeded() + } + + private def removeSource(source: InflightSource[T]): Unit = { + if (source eq queue.peek()) { + //only dequeue if it's the current emitting source + queue.dequeue() + pullIfNeeded() + } //not the head source, just ignore + } + + private def pullIfNeeded(): Unit = { + if (isClosed(in)) { + if (queue.isEmpty) { + completeStage() + } else { + tryPullNextSourceInQueue() + } + } else { + if (queue.nonEmpty) { + tryPullNextSourceInQueue() + } + if (!hasBeenPulled(in)) { + tryPull(in) + } + } + } + + private def tryPullNextSourceInQueue(): Unit = { + //pull the new emitting source + val nextSource = queue.peek() + if (nextSource.isInstanceOf[InflightSource[T] @unchecked]) { + nextSource.asInstanceOf[InflightSource[T]].tryPull() + } + } + + setHandlers(in, out, this) + } + + FlattenConcatLogic + } + + override def toString: String = s"FlattenConcat($parallelism)" +} + /** * INTERNAL API */ diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index 214b9b95721..f22db6f36b7 100755 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -2435,6 +2435,25 @@ final class Flow[In, Out, Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends Gr def flatMapConcat[T, M](f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): Flow[In, T, Mat] = new Flow(delegate.flatMapConcat[T, M](x => f(x))) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + */ + def flatMapConcat[T, M]( + parallelism: Int, + f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): Flow[In, T, Mat] = + new Flow(delegate.flatMapConcat[T, M](parallelism, x => f(x))) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth` diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala index 36a4485821a..d7137845c8b 100755 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala @@ -3885,6 +3885,23 @@ final class Source[Out, Mat](delegate: scaladsl.Source[Out, Mat]) extends Graph[ def flatMapConcat[T, M](f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): Source[T, Mat] = new Source(delegate.flatMapConcat[T, M](x => f(x))) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + */ + def flatMapConcat[T, M](parallelism: Int, f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): Source[T, Mat] = + new Source(delegate.flatMapConcat[T, M](parallelism, x => f(x))) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth` diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala index fdc70e26e90..fa6e9ec5d19 100755 --- a/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala @@ -1495,6 +1495,26 @@ final class SubFlow[In, Out, Mat]( def flatMapConcat[T, M](f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): SubFlow[In, T, Mat] = new SubFlow(delegate.flatMapConcat(x => f(x))) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + * + */ + def flatMapConcat[T, M]( + parallelism: Int, + f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): SubFlow[In, T, Mat] = + new SubFlow(delegate.flatMapConcat(parallelism, x => f(x))) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth` diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala b/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala index a0c42229f38..85ac9500bef 100755 --- a/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala @@ -1473,6 +1473,26 @@ final class SubSource[Out, Mat]( def flatMapConcat[T, M](f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): SubSource[T, Mat] = new SubSource(delegate.flatMapConcat(x => f(x))) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + * + */ + def flatMapConcat[T, M]( + parallelism: Int, + f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): SubSource[T, Mat] = + new SubSource(delegate.flatMapConcat(parallelism, x => f(x))) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth` diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index c1aa53aac7f..52f3956c58d 100755 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -2489,6 +2489,23 @@ trait FlowOps[+Out, +Mat] { */ def flatMapConcat[T, M](f: Out => Graph[SourceShape[T], M]): Repr[T] = map(f).via(new FlattenMerge[T, M](1)) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + */ + def flatMapConcat[T, M](parallelism: Int, f: Out => Graph[SourceShape[T], M]): Repr[T] = + map(f).via(new FlattenConcat[T, M](parallelism)) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth`