diff --git a/src/main/scala/aserralle/akka/stream/kcl/scaladsl/KinesisWorkerSource.scala b/src/main/scala/aserralle/akka/stream/kcl/scaladsl/KinesisWorkerSource.scala index 29dfcfd..cecf093 100644 --- a/src/main/scala/aserralle/akka/stream/kcl/scaladsl/KinesisWorkerSource.scala +++ b/src/main/scala/aserralle/akka/stream/kcl/scaladsl/KinesisWorkerSource.scala @@ -4,6 +4,8 @@ package aserralle.akka.stream.kcl.scaladsl +import java.util.concurrent.Semaphore + import akka.stream.Supervision.{Resume, Stop} import akka.stream._ import akka.stream.scaladsl.{Flow, GraphDSL, Keep, Sink, Source, Zip} @@ -41,15 +43,19 @@ object KinesisWorkerSource { .watchTermination()(Keep.both) .mapMaterializedValue { case (queue, watch) => + val semaphore = new Semaphore(1, true) val worker = workerBuilder( new IRecordProcessorFactory { override def createProcessor(): IRecordProcessor = new IRecordProcessor( - record => + record => { + semaphore.acquire(1) (Exception.nonFatalCatch either Await.result( queue.offer(record), settings.backpressureTimeout) left) - .foreach(_ => queue.fail(BackpressureTimeout)), + .foreach(_ => queue.fail(BackpressureTimeout)) + semaphore.release() + }, settings.terminateStreamGracePeriod ) } diff --git a/src/test/scala/aserralle/akka/stream/kcl/KinesisWorkerSourceSourceSpec.scala b/src/test/scala/aserralle/akka/stream/kcl/KinesisWorkerSourceSourceSpec.scala index 2f7dfcf..67443f0 100644 --- a/src/test/scala/aserralle/akka/stream/kcl/KinesisWorkerSourceSourceSpec.scala +++ b/src/test/scala/aserralle/akka/stream/kcl/KinesisWorkerSourceSourceSpec.scala @@ -6,9 +6,9 @@ package aserralle.akka.stream.kcl import java.nio.ByteBuffer import java.util.Date -import java.util.concurrent.Semaphore -import aserralle.akka.stream.kcl.Errors.BackpressureTimeout +import java.util.concurrent.{CountDownLatch, Semaphore} +import aserralle.akka.stream.kcl.Errors.BackpressureTimeout import akka.stream.KillSwitches import akka.stream.scaladsl.Keep import akka.stream.testkit.scaladsl.{TestSink, TestSource} @@ -144,6 +144,51 @@ class KinesisWorkerSourceSourceSpec sinkProbe.expectComplete() } + "not drop messages in case of back-pressure with multiple shard workers" in new KinesisWorkerContext + with TestData { + recordProcessor.initialize(initializationInput) + recordProcessor2.initialize(initializationInput.withShardId("shard2")) + + for (i <- 1 to 5) { // 10 is a buffer size + val record = org.mockito.Mockito.mock(classOf[Record]) + when(record.getSequenceNumber).thenReturn(i.toString) + recordProcessor.processRecords( + recordsInput.withRecords(List(record).asJava)) + recordProcessor2.processRecords( + recordsInput.withRecords(List(record).asJava)) + } + + //expect to consume all 10 across both shards + for (_ <- 1 to 10) sinkProbe.requestNext() + + // Each shard is assigned its own worker thread, so we get messages + // from each thread simultaneously. + def simulateWorkerThread(rp: v2.IRecordProcessor): Future[Unit] = { + Future { + for (i <- 1 to 25) { // 10 is a buffer size + val record = org.mockito.Mockito.mock(classOf[Record]) + when(record.getSequenceNumber).thenReturn(i.toString) + rp.processRecords(recordsInput.withRecords(List(record).asJava)) + } + } + } + + //send another batch to exceed the queue size - this is shard 1 + simulateWorkerThread(recordProcessor) + + //send another batch to exceed the queue size - this is shard 2 + simulateWorkerThread(recordProcessor2) + + //expect to consume all 50 with slow consumer + for (_ <- 1 to 50) { + sinkProbe.requestNext() + Thread.sleep(100) + } + + killSwitch.shutdown() + sinkProbe.expectComplete() + } + "stop the stream when back pressure timeout elapsed" in new KinesisWorkerContext( backpressureTimeout = 100.milliseconds) with TestData { recordProcessor.initialize(initializationInput) @@ -177,9 +222,11 @@ class KinesisWorkerSourceSourceSpec var recordProcessorFactory: IRecordProcessorFactory = _ var recordProcessor: v2.IRecordProcessor = _ + var recordProcessor2: v2.IRecordProcessor = _ val workerBuilder = { x: IRecordProcessorFactory => recordProcessorFactory = x recordProcessor = x.createProcessor() + recordProcessor2 = x.createProcessor() semaphore.release() worker }