diff --git a/faststream/confluent/__init__.py b/faststream/confluent/__init__.py index 3644704434..2e97958d43 100644 --- a/faststream/confluent/__init__.py +++ b/faststream/confluent/__init__.py @@ -3,7 +3,7 @@ try: from .annotations import KafkaMessage from .broker import KafkaBroker, KafkaPublisher, KafkaRoute, KafkaRouter - from .response import KafkaPublishCommand, KafkaResponse + from .response import KafkaPublishCommand, KafkaPublishMessage, KafkaResponse from .schemas import TopicPartition from .testing import TestKafkaBroker @@ -19,6 +19,7 @@ "KafkaBroker", "KafkaMessage", "KafkaPublishCommand", + "KafkaPublishMessage", "KafkaPublisher", "KafkaResponse", "KafkaRoute", diff --git a/faststream/confluent/publisher/factory.py b/faststream/confluent/publisher/factory.py index fc22c78e05..06e83ca3c4 100644 --- a/faststream/confluent/publisher/factory.py +++ b/faststream/confluent/publisher/factory.py @@ -2,8 +2,6 @@ from functools import wraps from typing import TYPE_CHECKING, Any -from faststream.exceptions import SetupError - from .config import KafkaPublisherConfig, KafkaPublisherSpecificationConfig from .specification import KafkaPublisherSpecification from .usecase import BatchPublisher, DefaultPublisher @@ -54,10 +52,6 @@ def create_publisher( publisher: BatchPublisher | DefaultPublisher if batch: - if key: - msg = "You can't setup `key` with batch publisher" - raise SetupError(msg) - publisher = BatchPublisher(publisher_config, specification) publish_method = "_basic_publish_batch" diff --git a/faststream/confluent/publisher/producer.py b/faststream/confluent/publisher/producer.py index c3af99518d..6f0b61f58b 100644 --- a/faststream/confluent/publisher/producer.py +++ b/faststream/confluent/publisher/producer.py @@ -157,7 +157,7 @@ async def publish_batch(self, cmd: "KafkaPublishCommand") -> None: headers_to_send = cmd.headers_to_publish() - for msg in cmd.batch_bodies: + for message_position, msg in enumerate(cmd.batch_bodies): message, content_type = encode_message(msg, serializer=self.serializer) if content_type: @@ -169,7 +169,7 @@ async def publish_batch(self, cmd: "KafkaPublishCommand") -> None: final_headers = headers_to_send.copy() batch.append( - key=None, + key=cmd.key_for(message_position), value=message, timestamp=cmd.timestamp_ms, headers=[(i, j.encode()) for i, j in final_headers.items()], diff --git a/faststream/confluent/publisher/usecase.py b/faststream/confluent/publisher/usecase.py index 601fa2895a..e60c980f0a 100644 --- a/faststream/confluent/publisher/usecase.py +++ b/faststream/confluent/publisher/usecase.py @@ -215,11 +215,20 @@ async def request( class BatchPublisher(LogicPublisher): + def __init__( + self, + config: "KafkaPublisherConfig", + specification: "PublisherSpecification[Any, Any]", + ) -> None: + super().__init__(config, specification) + self.key = config.key + @override async def publish( self, *messages: "SendableMessage", topic: str = "", + key: bytes | str | None = None, partition: int | None = None, timestamp_ms: int | None = None, headers: dict[str, str] | None = None, @@ -229,7 +238,7 @@ async def publish( ) -> None: cmd = KafkaPublishCommand( *messages, - key=None, + key=key or self.key, topic=topic or self.topic, partition=partition or self.partition, reply_to=reply_to or self.reply_to, @@ -261,6 +270,7 @@ async def _publish( cmd.reply_to = cmd.reply_to or self.reply_to cmd.partition = cmd.partition or self.partition + cmd.key = cmd.key or self.key await self._basic_publish_batch( cmd, diff --git a/faststream/confluent/response.py b/faststream/confluent/response.py index 32f0e6462d..c15fd3d8d0 100644 --- a/faststream/confluent/response.py +++ b/faststream/confluent/response.py @@ -3,13 +3,28 @@ from typing_extensions import override from faststream.response.publish_type import PublishType -from faststream.response.response import BatchPublishCommand, PublishCommand, Response +from faststream.response.response import ( + BatchPublishCommand, + PublishCommand, + Response, + extract_per_message_keys_and_bodies, + key_for_index, +) if TYPE_CHECKING: from faststream._internal.basic_types import SendableMessage class KafkaResponse(Response): + """Kafka-specific response object for outgoing messages. + + Can be used in two ways: + 1. As a return value from handler to send a response message + 2. Directly in publish_batch() to set per-message attributes (key, headers, etc.) + + For publish operations, consider using the more semantic alias `KafkaPublishMessage`. + """ + def __init__( self, body: "SendableMessage", @@ -17,7 +32,7 @@ def __init__( headers: dict[str, Any] | None = None, correlation_id: str | None = None, timestamp_ms: int | None = None, - key: bytes | str | None = None, + key: bytes | Any | None = None, ) -> None: super().__init__( body=body, @@ -28,6 +43,11 @@ def __init__( self.timestamp_ms = timestamp_ms self.key = key + @override + def get_publish_key(self) -> bytes | Any | None: + """Return the Kafka message key for publishing.""" + return self.key + @override def as_publish_command(self) -> "KafkaPublishCommand": return KafkaPublishCommand( @@ -50,7 +70,7 @@ def __init__( *messages: "SendableMessage", topic: str, _publish_type: PublishType, - key: bytes | str | None = None, + key: bytes | Any | None = None, partition: int | None = None, timestamp_ms: int | None = None, headers: dict[str, str] | None = None, @@ -77,6 +97,12 @@ def __init__( # request option self.timeout = timeout + # per-message keys support + keys, normalized = extract_per_message_keys_and_bodies(self.batch_bodies) + if normalized is not None: + self.batch_bodies = normalized + self._per_message_keys = keys + @classmethod def from_cmd( cls, @@ -100,6 +126,9 @@ def from_cmd( _publish_type=cmd.publish_type, ) + def key_for(self, index: int) -> Any | None: + return key_for_index(self._per_message_keys, self.key, index) + def headers_to_publish(self) -> dict[str, str]: headers = {} @@ -110,3 +139,8 @@ def headers_to_publish(self) -> dict[str, str]: headers["reply_to"] = self.reply_to return headers | self.headers + + +# Semantic alias for publish operations +# More intuitive name when using in publish_batch() rather than as handler return value +KafkaPublishMessage = KafkaResponse diff --git a/faststream/confluent/testing.py b/faststream/confluent/testing.py index 005fc5df47..c1623ad1ed 100644 --- a/faststream/confluent/testing.py +++ b/faststream/confluent/testing.py @@ -154,12 +154,13 @@ async def publish_batch(self, cmd: "KafkaPublishCommand") -> None: topic=cmd.destination, partition=cmd.partition, timestamp_ms=cmd.timestamp_ms, + key=cmd.key_for(message_position), headers=cmd.headers, correlation_id=cmd.correlation_id, reply_to=cmd.reply_to, serializer=self.broker.config.fd_config._serializer, ) - for message in cmd.batch_bodies + for message_position, message in enumerate(cmd.batch_bodies) ) if isinstance(handler, BatchSubscriber): diff --git a/faststream/kafka/__init__.py b/faststream/kafka/__init__.py index 23a3501d4d..277257b139 100644 --- a/faststream/kafka/__init__.py +++ b/faststream/kafka/__init__.py @@ -6,7 +6,7 @@ from .annotations import KafkaMessage from .broker import KafkaBroker, KafkaPublisher, KafkaRoute, KafkaRouter - from .response import KafkaPublishCommand, KafkaResponse + from .response import KafkaPublishCommand, KafkaPublishMessage, KafkaResponse from .testing import TestKafkaBroker except ImportError as e: @@ -22,6 +22,7 @@ "KafkaBroker", "KafkaMessage", "KafkaPublishCommand", + "KafkaPublishMessage", "KafkaPublisher", "KafkaResponse", "KafkaRoute", diff --git a/faststream/kafka/publisher/factory.py b/faststream/kafka/publisher/factory.py index 4b02f5c91c..f8f310db28 100644 --- a/faststream/kafka/publisher/factory.py +++ b/faststream/kafka/publisher/factory.py @@ -5,8 +5,6 @@ Any, ) -from faststream.exceptions import SetupError - from .config import KafkaPublisherConfig, KafkaPublisherSpecificationConfig from .specification import KafkaPublisherSpecification from .usecase import BatchPublisher, DefaultPublisher @@ -56,10 +54,6 @@ def create_publisher( ) if batch: - if key: - msg = "You can't setup `key` with batch publisher" - raise SetupError(msg) - publisher: BatchPublisher | DefaultPublisher = BatchPublisher( publisher_config, specification, diff --git a/faststream/kafka/publisher/producer.py b/faststream/kafka/publisher/producer.py index 8b1b57317a..5799008c58 100644 --- a/faststream/kafka/publisher/producer.py +++ b/faststream/kafka/publisher/producer.py @@ -147,7 +147,7 @@ async def publish_batch( final_headers = headers_to_send.copy() metadata = batch.append( - key=None, + key=cmd.key_for(message_position), value=message, timestamp=cmd.timestamp_ms, headers=[(i, j.encode()) for i, j in final_headers.items()], diff --git a/faststream/kafka/publisher/usecase.py b/faststream/kafka/publisher/usecase.py index 47e1986b7d..c9ea7b5c8d 100644 --- a/faststream/kafka/publisher/usecase.py +++ b/faststream/kafka/publisher/usecase.py @@ -298,11 +298,20 @@ async def request( class BatchPublisher(LogicPublisher): + def __init__( + self, + config: "KafkaPublisherConfig", + specification: "PublisherSpecification[Any, Any]", + ) -> None: + super().__init__(config, specification) + self.key = config.key + @overload async def publish( self, *messages: "SendableMessage", topic: str = "", + key: bytes | Any | None = None, partition: int | None = None, timestamp_ms: int | None = None, headers: dict[str, str] | None = None, @@ -316,6 +325,7 @@ async def publish( self, *messages: "SendableMessage", topic: str = "", + key: bytes | Any | None = None, partition: int | None = None, timestamp_ms: int | None = None, headers: dict[str, str] | None = None, @@ -329,6 +339,7 @@ async def publish( self, *messages: "SendableMessage", topic: str = "", + key: bytes | Any | None = None, partition: int | None = None, timestamp_ms: int | None = None, headers: dict[str, str] | None = None, @@ -342,6 +353,7 @@ async def publish( self, *messages: "SendableMessage", topic: str = "", + key: bytes | Any | None = None, partition: int | None = None, timestamp_ms: int | None = None, headers: dict[str, str] | None = None, @@ -356,6 +368,13 @@ async def publish( Messages bodies to send. topic: Topic where the message will be published. + key: + A single key to associate with every message in this batch. If a + partition is not specified and the producer uses the default + partitioner, messages with the same key will be routed to the + same partition. Must be bytes or serializable to bytes via the + configured key serializer. If omitted, falls back to the + publisher's default key (if configured). partition: Specify a partition. If not set, the partition will be selected using the configured `partitioner` @@ -378,7 +397,7 @@ async def publish( """ cmd = KafkaPublishCommand( *messages, - key=None, + key=key or self.key, topic=topic or self.topic, partition=partition or self.partition, reply_to=reply_to or self.reply_to, @@ -410,6 +429,7 @@ async def _publish( cmd.reply_to = cmd.reply_to or self.reply_to cmd.partition = cmd.partition or self.partition + cmd.key = cmd.key or self.key await self._basic_publish_batch( cmd, diff --git a/faststream/kafka/response.py b/faststream/kafka/response.py index f8ea6db0d3..18da6b404e 100644 --- a/faststream/kafka/response.py +++ b/faststream/kafka/response.py @@ -3,13 +3,28 @@ from typing_extensions import override from faststream.response.publish_type import PublishType -from faststream.response.response import BatchPublishCommand, PublishCommand, Response +from faststream.response.response import ( + BatchPublishCommand, + PublishCommand, + Response, + extract_per_message_keys_and_bodies, + key_for_index, +) if TYPE_CHECKING: from faststream._internal.basic_types import SendableMessage class KafkaResponse(Response): + """Kafka-specific response object for outgoing messages. + + Can be used in two ways: + 1. As a return value from handler to send a response message + 2. Directly in publish_batch() to set per-message attributes (key, headers, etc.) + + For publish operations, consider using the more semantic alias `KafkaPublishMessage`. + """ + def __init__( self, body: "SendableMessage", @@ -28,6 +43,11 @@ def __init__( self.timestamp_ms = timestamp_ms self.key = key + @override + def get_publish_key(self) -> bytes | None: + """Return the Kafka message key for publishing.""" + return self.key + @override def as_publish_command(self) -> "KafkaPublishCommand": return KafkaPublishCommand( @@ -77,6 +97,12 @@ def __init__( # request option self.timeout = timeout + # per-message keys support + keys, normalized = extract_per_message_keys_and_bodies(self.batch_bodies) + if normalized is not None: + self.batch_bodies = normalized + self._per_message_keys = keys + @classmethod def from_cmd( cls, @@ -100,6 +126,9 @@ def from_cmd( _publish_type=cmd.publish_type, ) + def key_for(self, index: int) -> Any | None: + return key_for_index(self._per_message_keys, self.key, index) + def headers_to_publish(self) -> dict[str, str]: headers = {} @@ -110,3 +139,8 @@ def headers_to_publish(self) -> dict[str, str]: headers["reply_to"] = self.reply_to return headers | self.headers + + +# Semantic alias for publish operations +# More intuitive name when using in publish_batch() rather than as handler return value +KafkaPublishMessage = KafkaResponse diff --git a/faststream/kafka/testing.py b/faststream/kafka/testing.py index cdd88cb20c..d85fda9d57 100755 --- a/faststream/kafka/testing.py +++ b/faststream/kafka/testing.py @@ -202,12 +202,13 @@ async def publish_batch( topic=cmd.destination, partition=cmd.partition, timestamp_ms=cmd.timestamp_ms, + key=cmd.key_for(message_position), headers=cmd.headers, correlation_id=cmd.correlation_id, reply_to=cmd.reply_to, serializer=self.broker.config.fd_config._serializer, ) - for message in cmd.batch_bodies + for message_position, message in enumerate(cmd.batch_bodies) ) if isinstance(handler, BatchSubscriber): diff --git a/faststream/response/response.py b/faststream/response/response.py index 18a6ec7a88..0d69e3066d 100644 --- a/faststream/response/response.py +++ b/faststream/response/response.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from functools import singledispatch from typing import Any from typing_extensions import Self @@ -28,6 +29,17 @@ def as_publish_command(self) -> "PublishCommand": _publish_type=PublishType.PUBLISH, ) + def get_publish_key(self) -> Any | None: + """Get the key for publishing this message. + + Override this method in subclasses to provide broker-specific keys. + Default implementation returns None (no key). + + Returns: + The key for publishing, or None if this Response type doesn't use keys. + """ + return None + class PublishCommand(Response): def __init__( @@ -126,3 +138,82 @@ def _parse_bodies(body: Any, *, batch: bool = False) -> tuple[Any, tuple[Any, .. else: body = None return body, tuple(extra_bodies) + + +@singledispatch +def _extract_body_and_key(item: Any) -> tuple[Any, Any | None]: + """Extract body and key from a plain message. + + Default implementation for non-Response objects. + Returns the item as-is for body and None for key. + """ + return item, None + + +@_extract_body_and_key.register +def _(item: Response) -> tuple[Any, Any | None]: + """Extract body and key from a Response object. + + Uses polymorphic get_publish_key() method to retrieve the key. + """ + return item.body, item.get_publish_key() + + +def extract_per_message_keys_and_bodies( + batch_bodies: Sequence[Any], +) -> tuple[tuple[Any | None, ...], tuple[Any, ...] | None]: + """Extract per-message keys and optionally normalized bodies from a batch. + + Returns a pair (keys, normalized_bodies_or_None): + - If no Response objects are present, returns ((), None) + so callers can reuse the original bodies without extra allocations. + - Otherwise returns (keys_tuple, normalized_bodies_tuple), where normalized bodies + contain the extracted 'body' values from Response objects (or the original item). + + Supports passing Response objects (e.g., KafkaResponse) to set per-message keys: + await broker.publish_batch( + KafkaResponse("body1", key=b"key1"), + KafkaResponse("body2", key=b"key2"), + "plain message" # uses default key + ) + + Uses singledispatch for type-based polymorphism without isinstance checks. + """ + if not batch_bodies: + return (), None + + bodies: list[Any] = [] + keys: list[Any | None] = [] + has_key: bool = False + + for item in batch_bodies: + body, key = _extract_body_and_key(item) + bodies.append(body) + keys.append(key) + if key is not None: + has_key = True + + if not has_key: + return (), None + + return tuple(keys), tuple(bodies) + + +def key_for_index( + keys: Sequence[Any | None], default_key: Any | None, index: int +) -> Any | None: + """Return the effective key for a given message index. + + Prefers a per-message key at the given index when it is not None; + otherwise falls back to ``default_key``. If the index is out of bounds + or negative, ``default_key`` is returned. + """ + if index < 0: + return default_key + + try: + k = keys[index] + except IndexError: + return default_key + + return k if k is not None else default_key diff --git a/tests/brokers/confluent/test_publish.py b/tests/brokers/confluent/test_publish.py index 8b561d0325..27b2c02a2b 100644 --- a/tests/brokers/confluent/test_publish.py +++ b/tests/brokers/confluent/test_publish.py @@ -1,10 +1,11 @@ import asyncio +from typing import Any from unittest.mock import MagicMock import pytest from faststream import Context -from faststream.confluent import KafkaResponse +from faststream.confluent import KafkaPublishMessage, KafkaResponse from tests.brokers.base.publish import BrokerPublishTestcase from .basic import ConfluentTestcaseConfig @@ -140,3 +141,131 @@ async def handle_next(msg=Context("message")) -> None: assert event.is_set() mock.assert_called_once_with(body=b"1") + + @pytest.mark.asyncio() + async def test_batch_publisher_manual_with_key(self, queue: str) -> None: + pub_broker = self.get_broker(apply_types=True) + + keys_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=2) + + args, kwargs = self.get_subscriber_params(queue) + + @pub_broker.subscriber(*args, **kwargs) + async def handler(msg=Context("message")) -> None: + await keys_queue.put(msg.raw_message.key()) + + publisher = pub_broker.publisher(queue, batch=True) + + async with self.patch_broker(pub_broker) as br: + await br.start() + + await publisher.publish(1, "hi", key=b"my_key") + + k1 = await asyncio.wait_for(keys_queue.get(), timeout=self.timeout) + k2 = await asyncio.wait_for(keys_queue.get(), timeout=self.timeout) + + assert k1 == b"my_key" + assert k2 == b"my_key" + + @pytest.mark.asyncio() + async def test_batch_publisher_default_key_from_factory(self, queue: str) -> None: + pub_broker = self.get_broker(apply_types=True) + + keys_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=2) + + args, kwargs = self.get_subscriber_params(queue) + + @pub_broker.subscriber(*args, **kwargs) + async def handler(msg=Context("message")) -> None: + await keys_queue.put(msg.raw_message.key()) + + publisher = pub_broker.publisher(queue, batch=True, key=b"default_key") + + async with self.patch_broker(pub_broker) as br: + await br.start() + + await publisher.publish(1, "hi") + + k1 = await asyncio.wait_for(keys_queue.get(), timeout=self.timeout) + k2 = await asyncio.wait_for(keys_queue.get(), timeout=self.timeout) + + assert k1 == b"default_key" + assert k2 == b"default_key" + + @pytest.mark.asyncio() + async def test_batch_publisher_with_kafka_response_per_message_keys( + self, queue: str + ) -> None: + """Test using KafkaResponse to set different keys for each message.""" + pub_broker = self.get_broker(apply_types=True) + + messages_queue: asyncio.Queue[tuple[Any, bytes | None]] = asyncio.Queue(maxsize=3) + + args, kwargs = self.get_subscriber_params(queue) + + @pub_broker.subscriber(*args, **kwargs) + async def handler(msg) -> None: + await messages_queue.put((msg, msg.raw_message.key())) + + publisher = pub_broker.publisher(queue, batch=True) + + async with self.patch_broker(pub_broker) as br: + await br.start() + + # Publish batch with different keys per message + # Using KafkaPublishMessage alias (more semantic for publish operations) + await publisher.publish( + KafkaPublishMessage("message1", key=b"key1"), + KafkaPublishMessage("message2", key=b"key2"), + "message3", # No key, will use default (None) + ) + + msg1, k1 = await asyncio.wait_for(messages_queue.get(), timeout=self.timeout) + msg2, k2 = await asyncio.wait_for(messages_queue.get(), timeout=self.timeout) + msg3, k3 = await asyncio.wait_for(messages_queue.get(), timeout=self.timeout) + + assert msg1 == "message1" + assert k1 == b"key1" + assert msg2 == "message2" + assert k2 == b"key2" + assert msg3 == "message3" + assert k3 is None + + @pytest.mark.asyncio() + async def test_batch_publisher_kafka_response_with_default_key( + self, queue: str + ) -> None: + """Test KafkaResponse with publisher default key fallback.""" + pub_broker = self.get_broker(apply_types=True) + + messages_queue: asyncio.Queue[tuple[Any, bytes | None]] = asyncio.Queue(maxsize=3) + + args, kwargs = self.get_subscriber_params(queue) + + @pub_broker.subscriber(*args, **kwargs) + async def handler(msg) -> None: + await messages_queue.put((msg, msg.raw_message.key())) + + # Publisher with default key + publisher = pub_broker.publisher(queue, batch=True, key=b"default_key") + + async with self.patch_broker(pub_broker) as br: + await br.start() + + # Mix KafkaResponse with explicit keys and plain messages + await publisher.publish( + KafkaResponse("message1", key=b"explicit_key"), + "message2", # Uses default_key + KafkaResponse("message3", key=b"another_key"), + ) + + msg1, k1 = await asyncio.wait_for(messages_queue.get(), timeout=self.timeout) + msg2, k2 = await asyncio.wait_for(messages_queue.get(), timeout=self.timeout) + msg3, k3 = await asyncio.wait_for(messages_queue.get(), timeout=self.timeout) + + assert msg1 == "message1" + assert k1 == b"explicit_key" + assert msg2 == "message2" + assert k2 == b"default_key" + assert msg3 == "message3" + assert k3 == b"another_key" diff --git a/tests/brokers/kafka/test_publish.py b/tests/brokers/kafka/test_publish.py index 5cbdde6d69..3cda3435dc 100644 --- a/tests/brokers/kafka/test_publish.py +++ b/tests/brokers/kafka/test_publish.py @@ -1,11 +1,12 @@ import asyncio +from typing import Any from unittest.mock import MagicMock import pytest from aiokafka.structs import RecordMetadata from faststream import Context -from faststream.kafka import KafkaResponse +from faststream.kafka import KafkaPublishMessage, KafkaResponse from faststream.kafka.exceptions import BatchBufferOverflowException from tests.brokers.base.publish import BrokerPublishTestcase @@ -181,3 +182,123 @@ async def handler(m) -> None: with pytest.raises(BatchBufferOverflowException) as e: await br.publish_batch(1, "Hello, world!", topic=queue, no_confirm=True) assert e.value.message_position == 1 + + @pytest.mark.asyncio() + async def test_batch_publisher_manual_with_key(self, queue: str) -> None: + pub_broker = self.get_broker(apply_types=True) + + keys_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=2) + + @pub_broker.subscriber(queue) + async def handler(msg=Context("message")) -> None: + await keys_queue.put(msg.raw_message.key) + + publisher = pub_broker.publisher(queue, batch=True) + + async with self.patch_broker(pub_broker) as br: + await br.start() + + await publisher.publish(1, "hi", key=b"my_key") + + k1 = await asyncio.wait_for(keys_queue.get(), timeout=3) + k2 = await asyncio.wait_for(keys_queue.get(), timeout=3) + + assert k1 == b"my_key" + assert k2 == b"my_key" + + @pytest.mark.asyncio() + async def test_batch_publisher_default_key_from_factory(self, queue: str) -> None: + pub_broker = self.get_broker(apply_types=True) + + keys_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=2) + + @pub_broker.subscriber(queue) + async def handler(msg=Context("message")) -> None: + await keys_queue.put(msg.raw_message.key) + + publisher = pub_broker.publisher(queue, batch=True, key=b"default_key") + + async with self.patch_broker(pub_broker) as br: + await br.start() + + await publisher.publish(1, "hi") + + k1 = await asyncio.wait_for(keys_queue.get(), timeout=3) + k2 = await asyncio.wait_for(keys_queue.get(), timeout=3) + + assert k1 == b"default_key" + assert k2 == b"default_key" + + @pytest.mark.asyncio() + async def test_batch_publisher_with_kafka_response_per_message_keys( + self, queue: str + ) -> None: + """Test using KafkaResponse to set different keys for each message.""" + pub_broker = self.get_broker(apply_types=True) + + messages_queue: asyncio.Queue[tuple[Any, bytes | None]] = asyncio.Queue(maxsize=3) + + @pub_broker.subscriber(queue) + async def handler(msg) -> None: + await messages_queue.put((msg, msg.raw_message.key)) + + publisher = pub_broker.publisher(queue, batch=True) + + async with self.patch_broker(pub_broker) as br: + await br.start() + + # Publish batch with different keys per message + # Using KafkaPublishMessage alias (more semantic for publish operations) + await publisher.publish( + KafkaPublishMessage("message1", key=b"key1"), + KafkaPublishMessage("message2", key=b"key2"), + "message3", # No key, will use default (None) + ) + + msg1, k1 = await asyncio.wait_for(messages_queue.get(), timeout=3) + msg2, k2 = await asyncio.wait_for(messages_queue.get(), timeout=3) + msg3, k3 = await asyncio.wait_for(messages_queue.get(), timeout=3) + + assert msg1 == "message1" + assert k1 == b"key1" + assert msg2 == "message2" + assert k2 == b"key2" + assert msg3 == "message3" + assert k3 is None + + @pytest.mark.asyncio() + async def test_batch_publisher_kafka_response_with_default_key( + self, queue: str + ) -> None: + """Test KafkaResponse with publisher default key fallback.""" + pub_broker = self.get_broker(apply_types=True) + + messages_queue: asyncio.Queue[tuple[Any, bytes | None]] = asyncio.Queue(maxsize=3) + + @pub_broker.subscriber(queue) + async def handler(msg) -> None: + await messages_queue.put((msg, msg.raw_message.key)) + + # Publisher with default key + publisher = pub_broker.publisher(queue, batch=True, key=b"default_key") + + async with self.patch_broker(pub_broker) as br: + await br.start() + + # Mix KafkaResponse with explicit keys and plain messages + await publisher.publish( + KafkaResponse("message1", key=b"explicit_key"), + "message2", # Uses default_key + KafkaResponse("message3", key=b"another_key"), + ) + + msg1, k1 = await asyncio.wait_for(messages_queue.get(), timeout=3) + msg2, k2 = await asyncio.wait_for(messages_queue.get(), timeout=3) + msg3, k3 = await asyncio.wait_for(messages_queue.get(), timeout=3) + + assert msg1 == "message1" + assert k1 == b"explicit_key" + assert msg2 == "message2" + assert k2 == b"default_key" + assert msg3 == "message3" + assert k3 == b"another_key"