diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index dd1cc508c..56ad2075a 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -538,7 +538,7 @@ def _estimate_size_in_bytes(self, key, value, headers=[]): return LegacyRecordBatchBuilder.estimate_size_in_bytes( magic, self.config['compression_type'], key, value) - def send(self, topic, value=None, key=None, headers=None, partition=None, timestamp_ms=None): + def send(self, topic, value=None, key=None, headers=None, partition=None, timestamp_ms=None, chain_future=None): """Publish a message to a topic. Arguments: @@ -563,6 +563,7 @@ def send(self, topic, value=None, key=None, headers=None, partition=None, timest are tuples of str key and bytes value. timestamp_ms (int, optional): epoch milliseconds (from Jan 1 1970 UTC) to use as the message timestamp. Defaults to current time. + chain_future (Future, optional): chained success and failure method Returns: FutureRecordMetadata: resolves to RecordMetadata @@ -603,7 +604,8 @@ def send(self, topic, value=None, key=None, headers=None, partition=None, timest result = self._accumulator.append(tp, timestamp_ms, key_bytes, value_bytes, headers, self.config['max_block_ms'], - estimated_size=message_size) + estimated_size=message_size, + chain_future=chain_future) future, batch_is_full, new_batch_created = result if batch_is_full or new_batch_created: log.debug("Waking up the sender since %s is either full or" diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index a2aa0e8ec..b5db887ff 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -7,6 +7,7 @@ import time import kafka.errors as Errors +from kafka.future import Future from kafka.producer.buffer import SimpleBufferPool from kafka.producer.future import FutureRecordMetadata, FutureProduceResult from kafka.record.memory_records import MemoryRecordsBuilder @@ -198,7 +199,7 @@ def __init__(self, **configs): self._drain_index = 0 def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms, - estimated_size=0): + estimated_size=0, chain_future=None): """Add a record to the accumulator, return the append result. The append result will contain the future metadata, and flag for @@ -213,12 +214,14 @@ def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms, headers (List[Tuple[str, bytes]]): The header fields for the record max_time_to_block_ms (int): The maximum time in milliseconds to block for buffer memory to be available - + chain_future (Future): chain future Returns: tuple: (future, batch_is_full, new_batch_created) """ assert isinstance(tp, TopicPartition), 'not TopicPartition' assert not self._closed, 'RecordAccumulator is closed' + if chain_future is not None: + assert isinstance(chain_future, Future), 'not Future' # We keep track of the number of appending thread to make sure we do # not miss batches in abortIncompleteBatches(). self._appends_in_progress.increment() @@ -235,6 +238,8 @@ def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms, last = dq[-1] future = last.try_append(timestamp_ms, key, value, headers) if future is not None: + if chain_future: + future.chain(chain_future) batch_is_full = len(dq) > 1 or last.records.is_full() return future, batch_is_full, False @@ -253,6 +258,8 @@ def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms, # Somebody else found us a batch, return the one we # waited for! Hopefully this doesn't happen often... self._free.deallocate(buf) + if chain_future: + future.chain(chain_future) batch_is_full = len(dq) > 1 or last.records.is_full() return future, batch_is_full, False @@ -269,6 +276,8 @@ def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms, dq.append(batch) self._incomplete.add(batch) + if chain_future: + future.chain(chain_future) batch_is_full = len(dq) > 1 or batch.records.is_full() return future, batch_is_full, True finally: