Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ package akka.stream

import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit

import scala.concurrent.Await
import scala.concurrent.{ Await, Future }
import scala.concurrent.duration._

import com.typesafe.config.ConfigFactory
import org.openjdk.jmh.annotations._

Expand Down Expand Up @@ -60,31 +58,91 @@ class FlatMapConcatBenchmark {
@OperationsPerInvocation(OperationsPerInvocation)
def sourceDotSingle(): Unit = {
val latch = new CountDownLatch(1)

testSource.flatMapConcat(Source.single).runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def sourceDotSingleP1(): Unit = {
val latch = new CountDownLatch(1)
testSource.flatMapConcat(1, Source.single).runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def internalSingleSource(): Unit = {
val latch = new CountDownLatch(1)

testSource
.flatMapConcat(elem => new GraphStages.SingleSource(elem))
.runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def internalSingleSourceP1(): Unit = {
val latch = new CountDownLatch(1)
testSource
.flatMapConcat(1, elem => new GraphStages.SingleSource(elem))
.runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def oneElementList(): Unit = {
val latch = new CountDownLatch(1)

testSource.flatMapConcat(n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def oneElementListP1(): Unit = {
val latch = new CountDownLatch(1)
testSource.flatMapConcat(1, n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def completedFuture(): Unit = {
val latch = new CountDownLatch(1)
testSource
.flatMapConcat(n => Source.future(Future.successful(n)))
.runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def completedFutureP1(): Unit = {
val latch = new CountDownLatch(1)
testSource
.flatMapConcat(1, n => Source.future(Future.successful(n)))
.runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def normalFuture(): Unit = {
val latch = new CountDownLatch(1)
testSource
.flatMapConcat(n => Source.future(Future(n)(system.dispatcher)))
.runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def normalFutureP1(): Unit = {
val latch = new CountDownLatch(1)
testSource
.flatMapConcat(1, n => Source.future(Future(n)(system.dispatcher)))
.runWith(new LatchSink(OperationsPerInvocation, latch))
awaitLatch(latch)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright (C) 2014-2024 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.stream.scaladsl

import akka.NotUsed
import akka.pattern.FutureTimeoutSupport
import akka.stream.OverflowStrategy
import akka.stream.testkit._
import akka.stream.testkit.scaladsl.TestSink

import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.switch
import scala.concurrent.{ ExecutionContext, Future }
import scala.concurrent.duration.DurationInt
import scala.util.control.NoStackTrace

class FlowFlatMapConcatParallelismSpec extends StreamSpec("""
akka.stream.materializer.initial-input-buffer-size = 2
""") with ScriptedTest with FutureTimeoutSupport {
val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right)

class BoomException extends RuntimeException("BOOM~~") with NoStackTrace
"A flatMapConcat" must {

for (i <- 1 until 129) {
s"work with value presented sources with parallelism: $i" in {
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.successful(5)),
Source.lazyFuture(() => Future.successful(6)),
Source.future(after(1.millis)(Future.successful(7)))))
.flatMapConcat(i, identity)
.runWith(toSeq)
.futureValue should ===(1 to 7)
}
}

def generateRandomValuePresentedSources(nums: Int): (Int, Seq[Source[Int, NotUsed]]) = {
val seq = Seq.tabulate(nums) { _ =>
val random = ThreadLocalRandom.current().nextInt(1, 10)
(random: @switch) match {
case 1 => Source.single(1)
case 2 => Source(List(1))
case 3 => Source.fromJavaStream(() => java.util.stream.Stream.of(1))
case 4 => Source.future(Future.successful(1))
case 5 => Source.future(after(1.millis)(Future.successful(1)))
case _ => Source.empty[Int]
}
}
val sum = seq.filterNot(_.eq(Source.empty[Int])).size
(sum, seq)
}

def generateSequencedValuePresentedSources(nums: Int): (Int, Seq[Source[Int, NotUsed]]) = {
val seq = Seq.tabulate(nums) { index =>
val random = ThreadLocalRandom.current().nextInt(1, 6)
(random: @switch) match {
case 1 => Source.single(index)
case 2 => Source(List(index))
case 3 => Source.fromJavaStream(() => java.util.stream.Stream.of(index))
case 4 => Source.future(Future.successful(index))
case 5 => Source.future(after(1.millis)(Future.successful(index)))
case _ => throw new IllegalStateException("unexpected")
}
}
val sum = (0 until nums).sum
(sum, seq)
}

for (i <- 1 until 129) {
s"work with generated value presented sources with parallelism: $i " in {
val (sum, sources) = generateRandomValuePresentedSources(100000)
Source(sources)
.flatMapConcat(i, identity)
.runWith(Sink.seq)
.map(_.sum)(ExecutionContext.parasitic)
.futureValue shouldBe sum
}
}

for (i <- 1 until 129) {
s"work with generated value sequenced sources with parallelism: $i " in {
val (sum, sources) = generateSequencedValuePresentedSources(100000)
Source(sources)
.flatMapConcat(i, identity)
//check the order
.statefulMap(() => -1)((pre, current) => {
if (pre + 1 != current) {
throw new IllegalStateException(s"expected $pre + 1 == $current")
}
(current, current)
}, _ => None)
Comment on lines +94 to +99
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the ordering

.runWith(Sink.seq)
.map(_.sum)(ExecutionContext.parasitic)
.futureValue shouldBe sum
}
}

"work with value presented failed sources" in {
val ex = new BoomException
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.failed(ex)),
Source.lazyFuture(() => Future.successful(5))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.onErrorComplete[BoomException]()
.runWith(toSeq)
.futureValue should ===(1 to 4)
}

"work with value presented sources when demands slow" in {
val prob = Source(
List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.runWith(TestSink())

prob.request(1)
prob.expectNext(1)
prob.expectNoMessage(1.seconds)
prob.request(2)
prob.expectNext(2, 3)
prob.expectNoMessage(1.seconds)
prob.request(2)
prob.expectNext(4, 5)
prob.expectComplete()
}

"can do pre materialization when parallelism > 1" in {
val materializationCounter = new AtomicInteger(0)
val randomParallelism = ThreadLocalRandom.current().nextInt(4, 65)
val prob = Source(1 to (randomParallelism * 3))
.flatMapConcat(
randomParallelism,
value => {
Source
.lazySingle(() => {
materializationCounter.incrementAndGet()
value
})
.buffer(1, overflowStrategy = OverflowStrategy.backpressure)
})
.runWith(TestSink())

expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 0

prob.request(1)
prob.expectNext(1.seconds, 1)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe (randomParallelism + 1)
materializationCounter.set(0)

prob.request(2)
prob.expectNextN(List(2, 3))
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 2
materializationCounter.set(0)

prob.request(randomParallelism - 3)
prob.expectNextN(4 to randomParallelism)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe (randomParallelism - 3)
materializationCounter.set(0)

prob.request(randomParallelism)
prob.expectNextN(randomParallelism + 1 to randomParallelism * 2)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe randomParallelism
materializationCounter.set(0)

prob.request(randomParallelism)
prob.expectNextN(randomParallelism * 2 + 1 to randomParallelism * 3)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 0
prob.expectComplete()
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
/**
* INTERNAL API
*/
@InternalApi private[akka] final class FailedSource[T](failure: Throwable) extends GraphStage[SourceShape[T]] {
@InternalApi private[akka] final class FailedSource[T](val failure: Throwable) extends GraphStage[SourceShape[T]] {
val out = Outlet[T]("FailedSource.out")
override val shape = SourceShape(out)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import java.util.function.Consumer

/** INTERNAL API */
@InternalApi private[stream] final class JavaStreamSource[T, S <: java.util.stream.BaseStream[T, S]](
open: () => java.util.stream.BaseStream[T, S])
val open: () => java.util.stream.BaseStream[T, S])
extends GraphStage[SourceShape[T]] {

val out: Outlet[T] = Outlet("JavaStreamSource")
Expand Down
1 change: 1 addition & 0 deletions akka-stream/src/main/scala/akka/stream/impl/Stages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ import akka.stream.Attributes._
val mergePreferred = name("mergePreferred")
val mergePrioritized = name("mergePrioritized")
val flattenMerge = name("flattenMerge")
val flattenConcat = name("flattenConcat")
val recoverWith = name("recoverWith")
val onErrorComplete = name("onErrorComplete")
val broadcast = name("broadcast")
Expand Down
42 changes: 40 additions & 2 deletions akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ package akka.stream.impl

import scala.collection.immutable.Map.Map1
import scala.language.existentials

import akka.annotation.{ DoNotInherit, InternalApi }
import akka.stream._
import akka.stream.impl.StreamLayout.AtomicModule
import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 }
import akka.stream.impl.fusing.GraphStageModule
import akka.stream.impl.fusing.GraphStages.SingleSource
import akka.stream.impl.fusing.GraphStages.{ FutureSource, IterableSource, SingleSource }
import akka.stream.scaladsl.Keep
import akka.util.OptionVal

Expand Down Expand Up @@ -369,12 +368,51 @@ import akka.util.OptionVal
}
}

/**
* Try to find `SingleSource` or wrapped such. This is used as a
* performance optimization in FlattenConcat and possibly other places.
*/
def getValuePresentedSource[A >: Null](graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = {
def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match {
case _: SingleSource[_] | _: FutureSource[_] | _: IterableSource[_] | _: JavaStreamSource[_, _] |
_: FailedSource[_] =>
true
case maybeEmpty if isEmptySource(maybeEmpty) => true
case _ => false
}
graph match {
case _ if isValuePresentedSource(graph) => OptionVal.Some(graph)
case _ =>
graph.traversalBuilder match {
case l: LinearTraversalBuilder =>
l.pendingBuilder match {
case OptionVal.Some(a: AtomicTraversalBuilder) =>
a.module match {
case m: GraphStageModule[_, _] =>
m.stage match {
case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) =>
// It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize.
if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync)
OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]])
else OptionVal.None
case _ => OptionVal.None
}
case _ => OptionVal.None
}
case _ => OptionVal.None
}
case _ => OptionVal.None
}
}
}

/**
* Test if a Graph is an empty Source.
* */
def isEmptySource(graph: Graph[SourceShape[_], _]): Boolean = graph match {
case source: scaladsl.Source[_, _] if source eq scaladsl.Source.empty => true
case source: javadsl.Source[_, _] if source eq javadsl.Source.empty() => true
case EmptySource => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@ private[stream] object Collect {
*/
@InternalApi private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In => Future[Out])
extends GraphStage[FlowShape[In, Out]] {
require(parallelism >= 1, "parallelism should >= 1")

import MapAsync._

Expand Down
Loading