From d01f9496f012ec341a96c57801791ced6590b2b7 Mon Sep 17 00:00:00 2001
From: Taras <voyn1991@gmail.com>
Date: Mon, 3 Jan 2022 12:06:21 +0200
Subject: [PATCH 1/3] Add Producer waitlist for pending send() items

---
 aiokafka/producer/message_accumulator.py | 140 ++++++++++++++++-------
 1 file changed, 96 insertions(+), 44 deletions(-)

diff --git a/aiokafka/producer/message_accumulator.py b/aiokafka/producer/message_accumulator.py
index a20be471..babd5c83 100644
--- a/aiokafka/producer/message_accumulator.py
+++ b/aiokafka/producer/message_accumulator.py
@@ -2,7 +2,10 @@
 import collections
 import copy
 import time
+from dataclasses import dataclass
+from typing import List, Any
 
+import async_timeout
 from aiokafka.errors import (KafkaTimeoutError,
                              NotLeaderForPartitionError,
                              LeaderNotAvailableError,
@@ -241,6 +244,17 @@ def retry_count(self):
         return self._retry_count
 
 
+@dataclass
+class WaitlistHandle():
+
+    attrs: List[Any]
+    # Waitlist items are either pending batches or pending messages
+    is_message: bool
+    # Future exposed to add_message. Is not shielded, so can be cancelled
+    # before resolving.
+    future: asyncio.Future
+
+
 class MessageAccumulator:
     """Accumulator of messages batched by topic-partition
 
@@ -254,6 +268,7 @@ def __init__(
             loop = get_running_loop()
         self._loop = loop
         self._batches = collections.defaultdict(collections.deque)
+        self._waitlist = collections.defaultdict(collections.deque)
         self._pending_batches = set()
         self._cluster = cluster
         self._batch_size = batch_size
@@ -316,31 +331,67 @@ async def add_message(
         If batch is already full this method waits (`timeout` seconds maximum)
         until batch is drained by send task
         """
-        while True:
-            if self._closed:
-                # this can happen when producer is closing but try to send some
-                # messages in async task
-                raise ProducerClosed()
-            if self._exception is not None:
-                raise copy.copy(self._exception)
-
-            pending_batches = self._batches.get(tp)
-            if not pending_batches:
-                builder = self.create_builder()
-                batch = self._append_batch(builder, tp)
-            else:
-                batch = pending_batches[-1]
+        self._check_errors()
 
-            future = batch.append(key, value, timestamp_ms, headers=headers)
+        if not self._waitlist[tp]:
+            future = self._try_add_message(
+                tp, key, value, timestamp_ms, headers)
             if future is not None:
                 return future
-            # Batch is full, can't append data atm,
-            # waiting until batch per topic-partition is drained
-            start = time.monotonic()
-            await batch.wait_drain(timeout)
-            timeout -= time.monotonic() - start
-            if timeout <= 0:
-                raise KafkaTimeoutError()
+
+        # Batch is full, can't append data atm, enqueue data to be sent
+        # after batch for this partition is drained.
+        handle = self._add_to_waitlist(
+            tp, True, key, value, timestamp_ms, headers)
+        try:
+            async with async_timeout.timeout(timeout):
+                return await handle.future
+        except asyncio.TimeoutError:
+            raise KafkaTimeoutError()
+
+    def _check_errors(self):
+        if self._closed:
+            # this can happen when producer is closing but try to send some
+            # messages in async task
+            raise ProducerClosed()
+        if self._exception is not None:
+            raise copy.copy(self._exception)
+
+    def _try_add_message(self, tp, key, value, timestamp_ms, headers):
+        pending_batches = self._batches.get(tp)
+        if not pending_batches:
+            builder = self.create_builder()
+            batch = self._append_batch(builder, tp)
+        else:
+            batch = pending_batches[-1]
+        return batch.append(key, value, timestamp_ms, headers=headers)
+
+    def _add_to_waitlist(self, tp, is_message, *attrs):
+        handle = WaitlistHandle(attrs, is_message, self._loop.create_future())
+        self._waitlist[tp].append(handle)
+        return handle
+
+    def _process_waitlist(self, tp):
+        while self._waitlist.get(tp):
+            handle = self._waitlist[tp].popleft()
+            # We do not send messages that are no longer waited for, just clean
+            # them up.
+            if handle.future.done():
+                continue
+
+            if handle.is_message:
+                future = self._try_add_message(tp, *handle.attrs)
+                if future is not None:
+                    handle.future.set_result(future)
+            else:
+                if not self._batches.get(tp):
+                    builder = handle.attrs[0]
+                    batch = self._append_batch(builder, tp)
+                    handle.future.set_result(batch.future)
+
+            # Return item to waitlist if it was not processed
+            if not handle.future.done():
+                self._waitlist.appendleft(handle)
 
     def data_waiter(self):
         """ Return waiter future that will be resolved when accumulator contain
@@ -370,6 +421,9 @@ def _pop_batch(self, tp):
             def cb(fut, batch=batch, self=self):
                 self._pending_batches.remove(batch)
             batch.future.add_done_callback(cb)
+
+        # Populate next batch based on waitlist items (if any)
+        self._process_waitlist(tp)
         return batch
 
     def reenqueue(self, batch):
@@ -382,6 +436,13 @@ def drain_by_nodes(self, ignore_nodes, muted_partitions=set()):
         """ Group batches by leader to partition nodes. """
         nodes = collections.defaultdict(dict)
         unknown_leaders_exist = False
+
+        # Reset the data waiter before processing batches to allow waitlist
+        # processing to reset it.
+        if not self._wait_data_future.done():
+            self._wait_data_future.set_result(None)
+        self._wait_data_future = self._loop.create_future()
+
         for tp in list(self._batches.keys()):
             # Just ignoring by node is not enough, as leader can change during
             # the cycle
@@ -413,13 +474,6 @@ def drain_by_nodes(self, ignore_nodes, muted_partitions=set()):
                 # delivery future here, no message futures.
                 batch.done_noack()
 
-        # all batches are drained from accumulator
-        # so create "wait data" future again for waiting new data in send
-        # task
-        if not self._wait_data_future.done():
-            self._wait_data_future.set_result(None)
-        self._wait_data_future = self._loop.create_future()
-
         return nodes, unknown_leaders_exist
 
     def create_builder(self):
@@ -467,18 +521,16 @@ async def add_batch(self, builder, tp, timeout):
             aiokafka.errors.KafkaTimeoutError: the batch could not be added
                 within the specified timeout.
         """
-        if self._closed:
-            raise ProducerClosed()
-        if self._exception is not None:
-            raise copy.copy(self._exception)
-
-        start = time.monotonic()
-        while timeout > 0:
-            pending = self._batches.get(tp)
-            if pending:
-                await pending[-1].wait_drain(timeout=timeout)
-                timeout -= time.monotonic() - start
-            else:
-                batch = self._append_batch(builder, tp)
-                return asyncio.shield(batch.future)
-        raise KafkaTimeoutError()
+        self._check_errors()
+
+        pending = self._batches.get(tp)
+        if not pending:
+            batch = self._append_batch(builder, tp)
+            return asyncio.shield(batch.future)
+
+        handle = self._add_to_waitlist(tp, False, builder)
+        try:
+            async with async_timeout.timeout(timeout):
+                return asyncio.shield(await handle.future)
+        except asyncio.TimeoutError:
+            raise KafkaTimeoutError()

From b3b2a0c4ab2188a488a0ff8c1f0abe49b0f1091b Mon Sep 17 00:00:00 2001
From: Taras <voyn1991@gmail.com>
Date: Mon, 3 Jan 2022 12:14:39 +0200
Subject: [PATCH 2/3] Add CHANGES entry

---
 CHANGES/528.feature | 2 ++
 1 file changed, 2 insertions(+)
 create mode 100644 CHANGES/528.feature

diff --git a/CHANGES/528.feature b/CHANGES/528.feature
new file mode 100644
index 00000000..99818538
--- /dev/null
+++ b/CHANGES/528.feature
@@ -0,0 +1,2 @@
+Change internal structure used for waiter in Producer.send() call. Now there should
+be no performance degrade when a large backlog of messages is pending (issue #528)
\ No newline at end of file

From 0dcb78cfde9b92a355e28b36a6cbbb569687bd20 Mon Sep 17 00:00:00 2001
From: Taras <voyn1991@gmail.com>
Date: Mon, 3 Jan 2022 19:02:09 +0200
Subject: [PATCH 3/3] Changed future names to be more readable, added a test
 for waitlist

---
 aiokafka/producer/message_accumulator.py | 123 ++++++++++-------------
 aiokafka/producer/producer.py            |   8 +-
 tests/test_message_accumulator.py        |  67 +++++++++++-
 3 files changed, 123 insertions(+), 75 deletions(-)

diff --git a/aiokafka/producer/message_accumulator.py b/aiokafka/producer/message_accumulator.py
index babd5c83..38c88b55 100644
--- a/aiokafka/producer/message_accumulator.py
+++ b/aiokafka/producer/message_accumulator.py
@@ -3,7 +3,7 @@
 import copy
 import time
 from dataclasses import dataclass
-from typing import List, Any
+from typing import Any, List, Optional, Tuple
 
 import async_timeout
 from aiokafka.errors import (KafkaTimeoutError,
@@ -114,10 +114,8 @@ def __init__(self, tp, builder, ttl):
 
         # Waiters
         # Set when messages are delivered to Kafka based on ACK setting
-        self.future = create_future()
+        self.deliver_future = create_future()
         self._msg_futures = []
-        # Set when sender takes this batch
-        self._drain_waiter = create_future()
         self._retry_count = 0
 
     @property
@@ -158,8 +156,8 @@ def done(self, base_offset, timestamp=None, log_start_offset=None,
             timestamp_type = 1
 
         # Set main batch future
-        if not self.future.done():
-            self.future.set_result(_record_metadata_class(
+        if not self.deliver_future.done():
+            self.deliver_future.set_result(_record_metadata_class(
                 topic, partition, tp, base_offset, timestamp, timestamp_type,
                 log_start_offset))
 
@@ -179,16 +177,16 @@ def done(self, base_offset, timestamp=None, log_start_offset=None,
     def done_noack(self):
         """ Resolve all pending futures to None """
         # Faster resolve for base_offset=None case.
-        if not self.future.done():
-            self.future.set_result(None)
+        if not self.deliver_future.done():
+            self.deliver_future.set_result(None)
         for future, _ in self._msg_futures:
             if future.done():
                 continue
             future.set_result(None)
 
     def failure(self, exception):
-        if not self.future.done():
-            self.future.set_exception(exception)
+        if not self.deliver_future.done():
+            self.deliver_future.set_exception(exception)
         for future, _ in self._msg_futures:
             if future.done():
                 continue
@@ -199,37 +197,13 @@ def failure(self, exception):
         # Consume exception to avoid warnings. We delegate this consumption
         # to user only in case of explicit batch API.
         if self._msg_futures:
-            self.future.exception()
-
-        # In case where sender fails and closes batches all waiters have to be
-        # reset also.
-        if not self._drain_waiter.done():
-            self._drain_waiter.set_exception(exception)
-
-    async def wait_drain(self, timeout=None):
-        """Wait until all message from this batch is processed"""
-        waiter = self._drain_waiter
-        await asyncio.wait([waiter], timeout=timeout)
-        if waiter.done():
-            waiter.result()  # Check for exception
+            self.deliver_future.exception()
 
     def expired(self):
         """Check that batch is expired or not"""
         return (time.monotonic() - self._ctime) > self._ttl
 
-    def drain_ready(self):
-        """Compress batch to be ready for send"""
-        if not self._drain_waiter.done():
-            self._drain_waiter.set_result(None)
-        self._retry_count += 1
-
-    def reset_drain(self):
-        """Reset drain waiter, until we will do another retry"""
-        assert self._drain_waiter.done()
-        self._drain_waiter = create_future()
-
     def set_producer_state(self, producer_id, producer_epoch, base_sequence):
-        assert not self._drain_waiter.done()
         self._builder._set_producer_state(
             producer_id, producer_epoch, base_sequence)
 
@@ -243,16 +217,23 @@ def is_empty(self):
     def retry_count(self):
         return self._retry_count
 
+    def inc_retry_count(self):
+        self._retry_count += 1
+
+
+HeadersType = List[Tuple[str, Any]]
+
 
 @dataclass
 class WaitlistHandle():
 
-    attrs: List[Any]
     # Waitlist items are either pending batches or pending messages
-    is_message: bool
-    # Future exposed to add_message. Is not shielded, so can be cancelled
+    message_attrs: Optional[Tuple[Any, Any, int, HeadersType]]
+    batch_builder: Optional[BatchBuilder]
+
+    # Future exposed to Producer.send(). Is not shielded, so can be cancelled
     # before resolving.
-    future: asyncio.Future
+    send_future: "asyncio.Future[asyncio.Future[RecordMetadata]]"
 
 
 class MessageAccumulator:
@@ -288,9 +269,9 @@ async def flush(self):
         waiters = []
         for batches in self._batches.values():
             for batch in list(batches):
-                waiters.append(batch.future)
+                waiters.append(batch.deliver_future)
         for batch in list(self._pending_batches):
-            waiters.append(batch.future)
+            waiters.append(batch.deliver_future)
         if waiters:
             await asyncio.wait(waiters)
 
@@ -301,9 +282,9 @@ async def flush_for_commit(self):
                 # We force all buffers to close to finalyze the transaction
                 # scope. We should not add anything to this transaction.
                 batch._builder.close()
-                waiters.append(batch.future)
+                waiters.append(batch.deliver_future)
         for batch in self._pending_batches:
-            waiters.append(batch.future)
+            waiters.append(batch.deliver_future)
         # Wait for all waiters to finish. We only wait for the scope we defined
         # above, other batches should not be delivered as part of this
         # transaction
@@ -324,8 +305,7 @@ async def close(self):
         await self.flush()
 
     async def add_message(
-        self, tp, key, value, timeout, timestamp_ms=None,
-        headers=[]
+        self, tp, key, value, timeout, timestamp_ms=None, headers=[]
     ):
         """ Add message to batch by topic-partition
         If batch is already full this method waits (`timeout` seconds maximum)
@@ -333,7 +313,7 @@ async def add_message(
         """
         self._check_errors()
 
-        if not self._waitlist[tp]:
+        if not self._waitlist.get(tp):
             future = self._try_add_message(
                 tp, key, value, timestamp_ms, headers)
             if future is not None:
@@ -341,11 +321,15 @@ async def add_message(
 
         # Batch is full, can't append data atm, enqueue data to be sent
         # after batch for this partition is drained.
-        handle = self._add_to_waitlist(
-            tp, True, key, value, timestamp_ms, headers)
+        handle = WaitlistHandle(
+            message_attrs=(key, value, timestamp_ms, headers),
+            batch_builder=None,
+            send_future=self._loop.create_future())
+        self._waitlist[tp].append(handle)
+
         try:
             async with async_timeout.timeout(timeout):
-                return await handle.future
+                return await handle.send_future
         except asyncio.TimeoutError:
             raise KafkaTimeoutError()
 
@@ -366,32 +350,28 @@ def _try_add_message(self, tp, key, value, timestamp_ms, headers):
             batch = pending_batches[-1]
         return batch.append(key, value, timestamp_ms, headers=headers)
 
-    def _add_to_waitlist(self, tp, is_message, *attrs):
-        handle = WaitlistHandle(attrs, is_message, self._loop.create_future())
-        self._waitlist[tp].append(handle)
-        return handle
-
     def _process_waitlist(self, tp):
         while self._waitlist.get(tp):
             handle = self._waitlist[tp].popleft()
             # We do not send messages that are no longer waited for, just clean
             # them up.
-            if handle.future.done():
+            if handle.send_future.done():
                 continue
 
-            if handle.is_message:
-                future = self._try_add_message(tp, *handle.attrs)
-                if future is not None:
-                    handle.future.set_result(future)
+            if handle.batch_builder is None:
+                msg_future = self._try_add_message(tp, *handle.message_attrs)
+                if msg_future is not None:
+                    handle.send_future.set_result(msg_future)
             else:
                 if not self._batches.get(tp):
-                    builder = handle.attrs[0]
+                    builder = handle.batch_builder
                     batch = self._append_batch(builder, tp)
-                    handle.future.set_result(batch.future)
+                    handle.send_future.set_result(batch.deliver_future)
 
             # Return item to waitlist if it was not processed
-            if not handle.future.done():
-                self._waitlist.appendleft(handle)
+            if not handle.send_future.done():
+                self._waitlist[tp].appendleft(handle)
+                break
 
     def data_waiter(self):
         """ Return waiter future that will be resolved when accumulator contain
@@ -412,7 +392,6 @@ def _pop_batch(self, tp):
                 producer_id=self._txn_manager.producer_id,
                 producer_epoch=self._txn_manager.producer_epoch,
                 base_sequence=seq)
-        batch.drain_ready()
         if len(self._batches[tp]) == 0:
             del self._batches[tp]
         self._pending_batches.add(batch)
@@ -420,8 +399,9 @@ def _pop_batch(self, tp):
         if not_retry:
             def cb(fut, batch=batch, self=self):
                 self._pending_batches.remove(batch)
-            batch.future.add_done_callback(cb)
+            batch.deliver_future.add_done_callback(cb)
 
+        batch.inc_retry_count()
         # Populate next batch based on waitlist items (if any)
         self._process_waitlist(tp)
         return batch
@@ -430,7 +410,6 @@ def reenqueue(self, batch):
         tp = batch.tp
         self._batches[tp].appendleft(batch)
         self._pending_batches.remove(batch)
-        batch.reset_drain()
 
     def drain_by_nodes(self, ignore_nodes, muted_partitions=set()):
         """ Group batches by leader to partition nodes. """
@@ -526,11 +505,17 @@ async def add_batch(self, builder, tp, timeout):
         pending = self._batches.get(tp)
         if not pending:
             batch = self._append_batch(builder, tp)
-            return asyncio.shield(batch.future)
+            return asyncio.shield(batch.deliver_future)
 
-        handle = self._add_to_waitlist(tp, False, builder)
+        # Delay the send until there is no pending batches
+        handle = WaitlistHandle(
+            message_attrs=None,
+            batch_builder=builder,
+            send_future=self._loop.create_future())
+        self._waitlist[tp].append(handle)
         try:
             async with async_timeout.timeout(timeout):
-                return asyncio.shield(await handle.future)
+                batch_deliver_future = await handle.send_future
+                return asyncio.shield(batch_deliver_future)
         except asyncio.TimeoutError:
             raise KafkaTimeoutError()
diff --git a/aiokafka/producer/producer.py b/aiokafka/producer/producer.py
index 98b10fb7..031dc10b 100644
--- a/aiokafka/producer/producer.py
+++ b/aiokafka/producer/producer.py
@@ -464,10 +464,10 @@ async def send(
         tp = TopicPartition(topic, partition)
         log.debug("Sending (key=%s value=%s) to %s", key, value, tp)
 
-        fut = await self._message_accumulator.add_message(
+        deliver_future = await self._message_accumulator.add_message(
             tp, key_bytes, value_bytes, self._request_timeout_ms / 1000,
             timestamp_ms=timestamp_ms, headers=headers)
-        return fut
+        return deliver_future
 
     async def send_and_wait(
         self, topic, value=None, key=None, partition=None,
@@ -515,9 +515,9 @@ async def send_batch(self, batch, topic, *, partition):
 
         tp = TopicPartition(topic, partition)
         log.debug("Sending batch to %s", tp)
-        future = await self._message_accumulator.add_batch(
+        deliver_future = await self._message_accumulator.add_batch(
             batch, tp, self._request_timeout_ms / 1000)
-        return future
+        return deliver_future
 
     def _ensure_transactional(self):
         if self._txn_manager is None or \
diff --git a/tests/test_message_accumulator.py b/tests/test_message_accumulator.py
index 406526f1..5cd79d17 100644
--- a/tests/test_message_accumulator.py
+++ b/tests/test_message_accumulator.py
@@ -316,11 +316,74 @@ def mocked_leader_for_partition(tp):
         self.assertFalse(ma._batches)
         self.assertFalse(fut1.done())
 
-        if hasattr(batch.future, "_callbacks"):  # Vanilla asyncio
-            self.assertEqual(len(batch.future._callbacks), 1)
+        if hasattr(batch.deliver_future, "_callbacks"):  # Vanilla asyncio
+            self.assertEqual(len(batch.deliver_future._callbacks), 1)
 
         batch.done_noack()
         await asyncio.sleep(0.01)
         self.assertEqual(batch.retry_count, 3)
         self.assertFalse(ma._pending_batches)
         self.assertFalse(ma._batches)
+
+    @run_until_complete
+    async def test_waitlist_message(self):
+        cluster = ClusterMetadata(metadata_max_age_ms=10000)
+        # Use small batch_size to force waitlist for messages
+        ma = MessageAccumulator(
+            cluster, batch_size=100, compression_type=0, batch_ttl=30)
+
+        tp0 = TopicPartition("test-topic", 0)
+        tp1 = TopicPartition("test-topic", 1)
+        # 1st message will be added as size limit allows it
+        await ma.add_message(tp0, b'key', b'm'*35, timeout=2)
+        await ma.add_message(tp1, b'key_tp1', b'y'*20, timeout=2)
+
+        # 2nd message will be waitlisted
+        task = create_task(ma.add_message(tp0, b'key1', b'm'*100, timeout=2))
+        done, _ = await asyncio.wait([task], timeout=0.2)
+        self.assertFalse(bool(done))
+        # 3rd message also waitlisted
+        task2 = create_task(ma.add_message(tp0, b'key2', b'm'*100, timeout=2))
+
+        data_waiter = asyncio.ensure_future(ma.data_waiter())
+        done, _ = await asyncio.wait([data_waiter], timeout=0.2)
+        self.assertTrue(bool(done))  # data available for drain
+
+        def mocked_leader_for_partition(tp):
+            if tp == tp0:
+                return 0
+            if tp == tp1:
+                return 1
+            return -1
+
+        cluster.leader_for_partition = mock.MagicMock()
+        cluster.leader_for_partition.side_effect = mocked_leader_for_partition
+        batches, unknown_leaders_exist = ma.drain_by_nodes(ignore_nodes=[])
+
+        self.assertEqual(batches[0][tp0].record_count, 1)
+        self.assertEqual(batches[1][tp1].record_count, 1)
+
+        data_waiter = asyncio.ensure_future(ma.data_waiter())
+        done, _ = await asyncio.wait([data_waiter], timeout=0.2)
+        # data still available as waitlist items were processed
+        self.assertTrue(bool(done))
+
+        # Waitlist was processed, so send() should also finish
+        self.assertTrue(task.done())
+        deliver_fut = task.result()
+        self.assertTrue(isinstance(deliver_fut, asyncio.Future))
+        self.assertFalse(deliver_fut.done())
+
+        # 3rd message should be retained in the waitlist
+        self.assertFalse(task2.done())
+
+        batches, unknown_leaders_exist = ma.drain_by_nodes(ignore_nodes=[])
+        self.assertEqual(len(batches), 1)
+        self.assertEqual(batches[0][tp0].record_count, 1)
+        await asyncio.wait([task2], timeout=0.2)
+        # 3rd message is now also submitted for execution
+        self.assertTrue(task2.done())
+
+        batches[0][tp0].done_noack()
+        done, _ = await asyncio.wait([deliver_fut], timeout=0.1)
+        self.assertTrue(bool(done))  # waitlisted message delivered