Skip to content

Commit 9c2dfab

Browse files
authored
Expand Sender test coverage (#2586)
1 parent 3962d67 commit 9c2dfab

File tree

1 file changed

+183
-9
lines changed

1 file changed

+183
-9
lines changed

Diff for: test/test_sender.py

+183-9
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
# pylint: skip-file
22
from __future__ import absolute_import
33

4-
import pytest
4+
import collections
55
import io
6+
import time
7+
8+
import pytest
9+
from unittest.mock import call
10+
11+
from kafka.vendor import six
612

713
from kafka.client_async import KafkaClient
14+
import kafka.errors as Errors
815
from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS
916
from kafka.producer.kafka import KafkaProducer
1017
from kafka.protocol.produce import ProduceRequest
1118
from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch
1219
from kafka.producer.sender import Sender
20+
from kafka.producer.transaction_state import TransactionState
1321
from kafka.record.memory_records import MemoryRecordsBuilder
1422
from kafka.structs import TopicPartition
1523

@@ -20,8 +28,18 @@ def accumulator():
2028

2129

2230
@pytest.fixture
23-
def sender(client, accumulator, metrics, mocker):
24-
return Sender(client, client.cluster, accumulator, metrics=metrics)
31+
def sender(client, accumulator):
32+
return Sender(client, client.cluster, accumulator)
33+
34+
35+
def producer_batch(topic='foo', partition=0, magic=2):
36+
tp = TopicPartition(topic, partition)
37+
records = MemoryRecordsBuilder(
38+
magic=magic, compression_type=0, batch_size=100000)
39+
batch = ProducerBatch(tp, records)
40+
batch.try_append(0, None, b'msg', [])
41+
batch.records.close()
42+
return batch
2543

2644

2745
@pytest.mark.parametrize(("api_version", "produce_version"), [
@@ -30,13 +48,169 @@ def sender(client, accumulator, metrics, mocker):
3048
((0, 9), 1),
3149
((0, 8, 0), 0)
3250
])
33-
def test_produce_request(sender, mocker, api_version, produce_version):
51+
def test_produce_request(sender, api_version, produce_version):
3452
sender._client._api_versions = BROKER_API_VERSIONS[api_version]
35-
tp = TopicPartition('foo', 0)
3653
magic = KafkaProducer.max_usable_produce_magic(api_version)
37-
records = MemoryRecordsBuilder(
38-
magic=1, compression_type=0, batch_size=100000)
39-
batch = ProducerBatch(tp, records)
40-
records.close()
54+
batch = producer_batch(magic=magic)
4155
produce_request = sender._produce_request(0, 0, 0, [batch])
4256
assert isinstance(produce_request, ProduceRequest[produce_version])
57+
58+
59+
@pytest.mark.parametrize(("api_version", "produce_version"), [
60+
((2, 1), 7),
61+
])
62+
def test_create_produce_requests(sender, api_version, produce_version):
63+
sender._client._api_versions = BROKER_API_VERSIONS[api_version]
64+
tp = TopicPartition('foo', 0)
65+
magic = KafkaProducer.max_usable_produce_magic(api_version)
66+
batches_by_node = collections.defaultdict(list)
67+
for node in range(3):
68+
for _ in range(5):
69+
batches_by_node[node].append(producer_batch(magic=magic))
70+
produce_requests_by_node = sender._create_produce_requests(batches_by_node)
71+
assert len(produce_requests_by_node) == 3
72+
for node in range(3):
73+
assert isinstance(produce_requests_by_node[node], ProduceRequest[produce_version])
74+
75+
76+
def test_complete_batch_success(sender):
77+
batch = producer_batch()
78+
assert not batch.produce_future.is_done
79+
80+
# No error, base_offset 0
81+
sender._complete_batch(batch, None, 0, timestamp_ms=123, log_start_offset=456)
82+
assert batch.is_done
83+
assert batch.produce_future.is_done
84+
assert batch.produce_future.succeeded()
85+
assert batch.produce_future.value == (0, 123, 456)
86+
87+
88+
def test_complete_batch_transaction(sender):
89+
sender._transaction_state = TransactionState()
90+
batch = producer_batch()
91+
assert sender._transaction_state.sequence_number(batch.topic_partition) == 0
92+
assert sender._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id
93+
94+
# No error, base_offset 0
95+
sender._complete_batch(batch, None, 0)
96+
assert batch.is_done
97+
assert sender._transaction_state.sequence_number(batch.topic_partition) == batch.record_count
98+
99+
100+
@pytest.mark.parametrize(("error", "refresh_metadata"), [
101+
(Errors.KafkaConnectionError, True),
102+
(Errors.CorruptRecordError, False),
103+
(Errors.UnknownTopicOrPartitionError, True),
104+
(Errors.NotLeaderForPartitionError, True),
105+
(Errors.MessageSizeTooLargeError, False),
106+
(Errors.InvalidTopicError, False),
107+
(Errors.RecordListTooLargeError, False),
108+
(Errors.NotEnoughReplicasError, False),
109+
(Errors.NotEnoughReplicasAfterAppendError, False),
110+
(Errors.InvalidRequiredAcksError, False),
111+
(Errors.TopicAuthorizationFailedError, False),
112+
(Errors.UnsupportedForMessageFormatError, False),
113+
(Errors.InvalidProducerEpochError, False),
114+
(Errors.ClusterAuthorizationFailedError, False),
115+
(Errors.TransactionalIdAuthorizationFailedError, False),
116+
])
117+
def test_complete_batch_error(sender, error, refresh_metadata):
118+
sender._client.cluster._last_successful_refresh_ms = (time.time() - 10) * 1000
119+
sender._client.cluster._need_update = False
120+
assert sender._client.cluster.ttl() > 0
121+
batch = producer_batch()
122+
sender._complete_batch(batch, error, -1)
123+
if refresh_metadata:
124+
assert sender._client.cluster.ttl() == 0
125+
else:
126+
assert sender._client.cluster.ttl() > 0
127+
assert batch.is_done
128+
assert batch.produce_future.failed()
129+
assert isinstance(batch.produce_future.exception, error)
130+
131+
132+
@pytest.mark.parametrize(("error", "retry"), [
133+
(Errors.KafkaConnectionError, True),
134+
(Errors.CorruptRecordError, False),
135+
(Errors.UnknownTopicOrPartitionError, True),
136+
(Errors.NotLeaderForPartitionError, True),
137+
(Errors.MessageSizeTooLargeError, False),
138+
(Errors.InvalidTopicError, False),
139+
(Errors.RecordListTooLargeError, False),
140+
(Errors.NotEnoughReplicasError, True),
141+
(Errors.NotEnoughReplicasAfterAppendError, True),
142+
(Errors.InvalidRequiredAcksError, False),
143+
(Errors.TopicAuthorizationFailedError, False),
144+
(Errors.UnsupportedForMessageFormatError, False),
145+
(Errors.InvalidProducerEpochError, False),
146+
(Errors.ClusterAuthorizationFailedError, False),
147+
(Errors.TransactionalIdAuthorizationFailedError, False),
148+
])
149+
def test_complete_batch_retry(sender, accumulator, mocker, error, retry):
150+
sender.config['retries'] = 1
151+
mocker.spy(sender, '_fail_batch')
152+
mocker.patch.object(accumulator, 'reenqueue')
153+
batch = producer_batch()
154+
sender._complete_batch(batch, error, -1)
155+
if retry:
156+
assert not batch.is_done
157+
accumulator.reenqueue.assert_called_with(batch)
158+
batch.attempts += 1 # normally handled by accumulator.reenqueue, but it's mocked
159+
sender._complete_batch(batch, error, -1)
160+
assert batch.is_done
161+
assert isinstance(batch.produce_future.exception, error)
162+
else:
163+
assert batch.is_done
164+
assert isinstance(batch.produce_future.exception, error)
165+
166+
167+
def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, mocker):
168+
sender._transaction_state = TransactionState()
169+
sender.config['retries'] = 1
170+
mocker.spy(sender, '_fail_batch')
171+
mocker.patch.object(accumulator, 'reenqueue')
172+
error = Errors.NotLeaderForPartitionError
173+
batch = producer_batch()
174+
sender._complete_batch(batch, error, -1)
175+
assert not batch.is_done
176+
accumulator.reenqueue.assert_called_with(batch)
177+
batch.records._producer_id = 123 # simulate different producer_id
178+
assert batch.producer_id != sender._transaction_state.producer_id_and_epoch.producer_id
179+
sender._complete_batch(batch, error, -1)
180+
assert batch.is_done
181+
assert isinstance(batch.produce_future.exception, error)
182+
183+
184+
def test_fail_batch(sender, accumulator, mocker):
185+
sender._transaction_state = TransactionState()
186+
mocker.patch.object(TransactionState, 'reset_producer_id')
187+
batch = producer_batch()
188+
mocker.patch.object(batch, 'done')
189+
assert sender._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id
190+
error = Exception('error')
191+
sender._fail_batch(batch, base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None)
192+
sender._transaction_state.reset_producer_id.assert_called_once()
193+
batch.done.assert_called_with(base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None)
194+
195+
196+
def test_handle_produce_response():
197+
pass
198+
199+
200+
def test_failed_produce(sender, mocker):
201+
mocker.patch.object(sender, '_complete_batch')
202+
mock_batches = ['foo', 'bar', 'fizzbuzz']
203+
sender._failed_produce(mock_batches, 0, 'error')
204+
sender._complete_batch.assert_has_calls([
205+
call('foo', 'error', -1),
206+
call('bar', 'error', -1),
207+
call('fizzbuzz', 'error', -1),
208+
])
209+
210+
211+
def test_maybe_wait_for_producer_id():
212+
pass
213+
214+
215+
def test_run_once():
216+
pass

0 commit comments

Comments
 (0)