@@ -600,6 +600,180 @@ async def test_subscribe_async_generator(self):
600600
601601 await nc .close ()
602602
603+ @async_test
604+ async def test_subscribe_concurrent_async_generators (self ):
605+ """Test multiple concurrent async generators on the same subscription"""
606+ nc = NATS ()
607+ await nc .connect ()
608+
609+ sub = await nc .subscribe ("test.concurrent" )
610+
611+ # Publish messages
612+ num_msgs = 12
613+ for i in range (num_msgs ):
614+ await nc .publish ("test.concurrent" , f"msg-{ i } " .encode ())
615+ await nc .flush ()
616+
617+ # Track results from each consumer
618+ consumer_results = {}
619+
620+ async def consumer_task (consumer_id : str , max_messages : int = None ):
621+ """Consumer task that processes messages"""
622+ import random
623+
624+ received = []
625+ try :
626+ async for msg in sub .messages :
627+ received .append (msg .data .decode ())
628+ # Add random processing delay to simulate real work.
629+ await asyncio .sleep (random .uniform (0.01 , 0.05 ))
630+ if max_messages and len (received ) >= max_messages :
631+ break
632+ except Exception as e :
633+ # Store the exception for later inspection
634+ consumer_results [consumer_id ] = f"Error: { e } "
635+ return
636+ consumer_results [consumer_id ] = received
637+
638+ # Start multiple concurrent consumers.
639+ tasks = [
640+ asyncio .create_task (consumer_task ("consumer_A" , 3 )),
641+ asyncio .create_task (consumer_task ("consumer_B" , 5 )),
642+ asyncio .create_task (consumer_task ("consumer_C" , 4 )),
643+ ]
644+
645+ # Wait for all consumers to finish.
646+ await asyncio .gather (* tasks )
647+
648+ # Verify results
649+ consumer_A_msgs = consumer_results .get ("consumer_A" , [])
650+ consumer_B_msgs = consumer_results .get ("consumer_B" , [])
651+ consumer_C_msgs = consumer_results .get ("consumer_C" , [])
652+
653+ # Each consumer should get the expected number of messages
654+ self .assertEqual (len (consumer_A_msgs ), 3 )
655+ self .assertEqual (len (consumer_B_msgs ), 5 )
656+ self .assertEqual (len (consumer_C_msgs ), 4 )
657+
658+ # All messages should be unique (no duplicates across consumers)
659+ all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
660+ self .assertEqual (len (all_received ), len (set (all_received )))
661+
662+ # All received messages should be from our published set
663+ expected_msgs = {f"msg-{ i } " for i in range (num_msgs )}
664+ received_msgs = set (all_received )
665+ self .assertTrue (received_msgs .issubset (expected_msgs ))
666+
667+ # Verify we got exactly 12 unique messages total
668+ self .assertEqual (len (received_msgs ), 12 )
669+
670+ await nc .close ()
671+
672+ @async_test
673+ async def test_subscribe_async_generator_with_unsubscribe_limit (self ):
674+ """Test async generator respects unsubscribe max_msgs limit automatically"""
675+ nc = NATS ()
676+ await nc .connect ()
677+
678+ sub = await nc .subscribe ("test.unsub.limit" )
679+ await sub .unsubscribe (limit = 5 )
680+
681+ # Publish more messages than the limit
682+ num_msgs = 10
683+ for i in range (num_msgs ):
684+ await nc .publish ("test.unsub.limit" , f"msg-{ i } " .encode ())
685+ await nc .flush ()
686+
687+ received_msgs = []
688+ async for msg in sub .messages :
689+ received_msgs .append (msg .data .decode ())
690+ # Add small delay to ensure we don't race with the unsubscribe.
691+ await asyncio .sleep (0.01 )
692+
693+ # Should have received exactly 5 messages due to unsubscribe limit.
694+ self .assertEqual (len (received_msgs ), 5 , f"Expected 5 messages, got { len (received_msgs )} : { received_msgs } " )
695+
696+ # Messages should be the first 5 published.
697+ for i in range (5 ):
698+ self .assertIn (f"msg-{ i } " , received_msgs )
699+
700+ # Verify the subscription received the expected number.
701+ self .assertEqual (sub ._received , 5 )
702+
703+ # The generator should have stopped due to max_msgs limit being reached.
704+ self .assertEqual (sub ._max_msgs , 5 )
705+
706+ await nc .close ()
707+
708+ @async_test
709+ async def test_subscribe_concurrent_async_generators_auto_unsubscribe (self ):
710+ """Test multiple concurrent async generators on the same subscription"""
711+ nc = NATS ()
712+ await nc .connect ()
713+
714+ sub = await nc .subscribe ("test.concurrent" )
715+ await sub .unsubscribe (5 )
716+
717+ # Publish messages over the max msgs limit.
718+ num_msgs = 12
719+ for i in range (num_msgs ):
720+ await nc .publish ("test.concurrent" , f"msg-{ i } " .encode ())
721+ await nc .flush ()
722+
723+ # Track results from each consumer
724+ consumer_results = {}
725+
726+ async def consumer_task (consumer_id : str , max_messages : int = None ):
727+ """Consumer task that processes messages"""
728+ import random
729+
730+ received = []
731+ try :
732+ async for msg in sub .messages :
733+ received .append (msg .data .decode ())
734+ # Add random processing delay to simulate real work
735+ await asyncio .sleep (random .uniform (0.01 , 0.05 ))
736+ if max_messages and len (received ) >= max_messages :
737+ break
738+
739+ # Once subscription reached max number of messages, it should unblock.
740+ except Exception as e :
741+ # Store the exception for later inspection
742+ consumer_results [consumer_id ] = f"Error: { e } "
743+ return
744+ consumer_results [consumer_id ] = received
745+
746+ # Start multiple concurrent consumers.
747+ tasks = [
748+ asyncio .create_task (consumer_task ("consumer_A" , 3 )),
749+ asyncio .create_task (consumer_task ("consumer_B" , 5 )),
750+ asyncio .create_task (consumer_task ("consumer_C" , 4 )),
751+ ]
752+
753+ # Wait for all consumers to finish.
754+ await asyncio .gather (* tasks )
755+
756+ # Verify results
757+ consumer_A_msgs = consumer_results .get ("consumer_A" , [])
758+ consumer_B_msgs = consumer_results .get ("consumer_B" , [])
759+ consumer_C_msgs = consumer_results .get ("consumer_C" , [])
760+
761+ # Each consumer should get the expected number of messages.
762+ total = len (consumer_A_msgs ) + len (consumer_B_msgs ) + len (consumer_C_msgs )
763+ self .assertEqual (total , 5 )
764+
765+ # All messages should be unique (no duplicates across consumers)
766+ all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
767+ self .assertEqual (len (all_received ), len (set (all_received )))
768+
769+ # All received messages should be from our published set.
770+ expected_msgs = {f"msg-{ i } " for i in range (num_msgs )}
771+ received_msgs = set (all_received )
772+ self .assertTrue (received_msgs .issubset (expected_msgs ))
773+ self .assertEqual (len (received_msgs ), 5 )
774+
775+ await nc .close ()
776+
603777 @async_test
604778 async def test_subscribe_async_generator_with_drain (self ):
605779 """Test async generator with drain functionality"""
0 commit comments