Skip to content

Commit 54642f2

Browse files
committed
Handle cancellation of concurrent async generators
Signed-off-by: Waldemar Quevedo <[email protected]>
1 parent 2dc2031 commit 54642f2

File tree

3 files changed

+193
-1
lines changed

3 files changed

+193
-1
lines changed

nats/src/nats/aio/client.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,8 @@ async def _process_msg(
16971697
return
16981698

16991699
sub._received += 1
1700-
if sub._max_msgs > 0 and sub._received >= sub._max_msgs:
1700+
max_msgs_reached = sub._max_msgs > 0 and sub._received >= sub._max_msgs
1701+
if max_msgs_reached:
17011702
# Enough messages so can throwaway subscription now, the
17021703
# pending messages will still be in the subscription
17031704
# internal queue and the task will finish once the last
@@ -1800,6 +1801,16 @@ async def _process_msg(
18001801
if sub._jsi:
18011802
await sub._jsi.check_for_sequence_mismatch(msg)
18021803

1804+
# Send sentinel after reaching max messages for non-callback subscriptions.
1805+
if max_msgs_reached and not sub._cb and sub._active_generators > 0:
1806+
# Send one sentinel per active generator to unblock them all.
1807+
for _ in range(sub._active_generators):
1808+
try:
1809+
sub._pending_queue.put_nowait(None)
1810+
except Exception:
1811+
# Queue might be full or closed, that's ok
1812+
break
1813+
18031814
def _build_message(
18041815
self,
18051816
sid: int,

nats/src/nats/aio/subscription.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ def delivered(self) -> int:
196196
"""
197197
return self._received
198198

199+
@property
200+
def is_closed(self) -> bool:
201+
"""
202+
Returns True if the subscription is closed, False otherwise.
203+
"""
204+
return self._closed
205+
199206
async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
200207
"""
201208
:params timeout: Time in seconds to wait for next message before timing out.

nats/tests/test_client.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,180 @@ async def test_subscribe_async_generator(self):
600600

601601
await nc.close()
602602

603+
@async_test
604+
async def test_subscribe_concurrent_async_generators(self):
605+
"""Test multiple concurrent async generators on the same subscription"""
606+
nc = NATS()
607+
await nc.connect()
608+
609+
sub = await nc.subscribe("test.concurrent")
610+
611+
# Publish messages
612+
num_msgs = 12
613+
for i in range(num_msgs):
614+
await nc.publish("test.concurrent", f"msg-{i}".encode())
615+
await nc.flush()
616+
617+
# Track results from each consumer
618+
consumer_results = {}
619+
620+
async def consumer_task(consumer_id: str, max_messages: int = None):
621+
"""Consumer task that processes messages"""
622+
import random
623+
624+
received = []
625+
try:
626+
async for msg in sub.messages:
627+
received.append(msg.data.decode())
628+
# Add random processing delay to simulate real work.
629+
await asyncio.sleep(random.uniform(0.01, 0.05))
630+
if max_messages and len(received) >= max_messages:
631+
break
632+
except Exception as e:
633+
# Store the exception for later inspection
634+
consumer_results[consumer_id] = f"Error: {e}"
635+
return
636+
consumer_results[consumer_id] = received
637+
638+
# Start multiple concurrent consumers.
639+
tasks = [
640+
asyncio.create_task(consumer_task("consumer_A", 3)),
641+
asyncio.create_task(consumer_task("consumer_B", 5)),
642+
asyncio.create_task(consumer_task("consumer_C", 4)),
643+
]
644+
645+
# Wait for all consumers to finish.
646+
await asyncio.gather(*tasks)
647+
648+
# Verify results
649+
consumer_A_msgs = consumer_results.get("consumer_A", [])
650+
consumer_B_msgs = consumer_results.get("consumer_B", [])
651+
consumer_C_msgs = consumer_results.get("consumer_C", [])
652+
653+
# Each consumer should get the expected number of messages
654+
self.assertEqual(len(consumer_A_msgs), 3)
655+
self.assertEqual(len(consumer_B_msgs), 5)
656+
self.assertEqual(len(consumer_C_msgs), 4)
657+
658+
# All messages should be unique (no duplicates across consumers)
659+
all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
660+
self.assertEqual(len(all_received), len(set(all_received)))
661+
662+
# All received messages should be from our published set
663+
expected_msgs = {f"msg-{i}" for i in range(num_msgs)}
664+
received_msgs = set(all_received)
665+
self.assertTrue(received_msgs.issubset(expected_msgs))
666+
667+
# Verify we got exactly 12 unique messages total
668+
self.assertEqual(len(received_msgs), 12)
669+
670+
await nc.close()
671+
672+
@async_test
673+
async def test_subscribe_async_generator_with_unsubscribe_limit(self):
674+
"""Test async generator respects unsubscribe max_msgs limit automatically"""
675+
nc = NATS()
676+
await nc.connect()
677+
678+
sub = await nc.subscribe("test.unsub.limit")
679+
await sub.unsubscribe(limit=5)
680+
681+
# Publish more messages than the limit
682+
num_msgs = 10
683+
for i in range(num_msgs):
684+
await nc.publish("test.unsub.limit", f"msg-{i}".encode())
685+
await nc.flush()
686+
687+
received_msgs = []
688+
async for msg in sub.messages:
689+
received_msgs.append(msg.data.decode())
690+
# Add small delay to ensure we don't race with the unsubscribe.
691+
await asyncio.sleep(0.01)
692+
693+
# Should have received exactly 5 messages due to unsubscribe limit.
694+
self.assertEqual(len(received_msgs), 5, f"Expected 5 messages, got {len(received_msgs)}: {received_msgs}")
695+
696+
# Messages should be the first 5 published.
697+
for i in range(5):
698+
self.assertIn(f"msg-{i}", received_msgs)
699+
700+
# Verify the subscription received the expected number.
701+
self.assertEqual(sub._received, 5)
702+
703+
# The generator should have stopped due to max_msgs limit being reached.
704+
self.assertEqual(sub._max_msgs, 5)
705+
706+
await nc.close()
707+
708+
@async_test
709+
async def test_subscribe_concurrent_async_generators_auto_unsubscribe(self):
710+
"""Test multiple concurrent async generators on the same subscription"""
711+
nc = NATS()
712+
await nc.connect()
713+
714+
sub = await nc.subscribe("test.concurrent")
715+
await sub.unsubscribe(5)
716+
717+
# Publish messages over the max msgs limit.
718+
num_msgs = 12
719+
for i in range(num_msgs):
720+
await nc.publish("test.concurrent", f"msg-{i}".encode())
721+
await nc.flush()
722+
723+
# Track results from each consumer
724+
consumer_results = {}
725+
726+
async def consumer_task(consumer_id: str, max_messages: int = None):
727+
"""Consumer task that processes messages"""
728+
import random
729+
730+
received = []
731+
try:
732+
async for msg in sub.messages:
733+
received.append(msg.data.decode())
734+
# Add random processing delay to simulate real work
735+
await asyncio.sleep(random.uniform(0.01, 0.05))
736+
if max_messages and len(received) >= max_messages:
737+
break
738+
739+
# Once subscription reached max number of messages, it should unblock.
740+
except Exception as e:
741+
# Store the exception for later inspection
742+
consumer_results[consumer_id] = f"Error: {e}"
743+
return
744+
consumer_results[consumer_id] = received
745+
746+
# Start multiple concurrent consumers.
747+
tasks = [
748+
asyncio.create_task(consumer_task("consumer_A", 3)),
749+
asyncio.create_task(consumer_task("consumer_B", 5)),
750+
asyncio.create_task(consumer_task("consumer_C", 4)),
751+
]
752+
753+
# Wait for all consumers to finish.
754+
await asyncio.gather(*tasks)
755+
756+
# Verify results
757+
consumer_A_msgs = consumer_results.get("consumer_A", [])
758+
consumer_B_msgs = consumer_results.get("consumer_B", [])
759+
consumer_C_msgs = consumer_results.get("consumer_C", [])
760+
761+
# Each consumer should get the expected number of messages.
762+
total = len(consumer_A_msgs) + len(consumer_B_msgs) + len(consumer_C_msgs)
763+
self.assertEqual(total, 5)
764+
765+
# All messages should be unique (no duplicates across consumers)
766+
all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
767+
self.assertEqual(len(all_received), len(set(all_received)))
768+
769+
# All received messages should be from our published set.
770+
expected_msgs = {f"msg-{i}" for i in range(num_msgs)}
771+
received_msgs = set(all_received)
772+
self.assertTrue(received_msgs.issubset(expected_msgs))
773+
self.assertEqual(len(received_msgs), 5)
774+
775+
await nc.close()
776+
603777
@async_test
604778
async def test_subscribe_async_generator_with_drain(self):
605779
"""Test async generator with drain functionality"""

0 commit comments

Comments
 (0)