diff --git a/faststream/_internal/broker/registrator.py b/faststream/_internal/broker/registrator.py index 9fcace078e..b32ca15cc3 100644 --- a/faststream/_internal/broker/registrator.py +++ b/faststream/_internal/broker/registrator.py @@ -37,6 +37,8 @@ def __init__( self.__persistent_subscribers: list[SubscriberUsecase[MsgType]] = [] self.__persistent_publishers: list[PublisherUsecase] = [] + self.__parent: Registrator[MsgType, Any] | None = None + self.include_routers(*routers) @property @@ -93,6 +95,10 @@ def include_router( include_in_schema: bool | None = None, ) -> None: """Includes a router in the current object.""" + if router.parent is self: + return + router.parent = self + if options_config := BrokerConfig( prefix=prefix, include_in_schema=include_in_schema, @@ -104,6 +110,17 @@ def include_router( router.config.add_config(self.config) self.routers.append(router) + @property + def parent(self) -> "Registrator[MsgType, Any] | None": + return self.__parent + + @parent.setter + def parent(self, parent: "Registrator[MsgType, Any]") -> None: + if self.__parent is not None and parent is not self.__parent: + self.__parent.routers.remove(self) + self.config.reset() + self.__parent = parent + def include_routers( self, *routers: "Registrator[MsgType, Any]", diff --git a/faststream/_internal/configs/broker.py b/faststream/_internal/configs/broker.py index 650fd88822..c3f1e9b705 100644 --- a/faststream/_internal/configs/broker.py +++ b/faststream/_internal/configs/broker.py @@ -74,6 +74,9 @@ def __repr__(self) -> str: def add_config(self, config: "ConfigType") -> None: self.configs = (config, *self.configs) + def reset(self) -> None: + self.configs = (self.configs[-1],) + # broker priority options @property def producer(self) -> "ProducerProto[Any]": diff --git a/faststream/_internal/endpoint/subscriber/call_item.py b/faststream/_internal/endpoint/subscriber/call_item.py index 7deb6373c8..1720c959ef 100644 --- a/faststream/_internal/endpoint/subscriber/call_item.py +++ b/faststream/_internal/endpoint/subscriber/call_item.py @@ -81,11 +81,11 @@ def _setup( self.item_parser = parser self.item_decoder = decoder - self.dependant = self.handler.set_wrapped( - dependencies=(*broker_dependencies, *self.dependencies), - _call_decorators=_call_decorators, - config=config, - ) + self.dependant = self.handler.set_wrapped( + dependencies=(*broker_dependencies, *self.dependencies), + _call_decorators=_call_decorators, + config=config, + ) @property def name(self) -> str: diff --git a/tests/brokers/base/include_router.py b/tests/brokers/base/include_router.py index e1108d18dd..cf68cef620 100644 --- a/tests/brokers/base/include_router.py +++ b/tests/brokers/base/include_router.py @@ -145,6 +145,27 @@ def test_complex_router_prefix(self) -> None: assert sub2._outer_config.prefix == "1." assert sub3._outer_config.prefix == "1.4.5." + def test_idempotent_include_twice_on_same_broker(self) -> None: + router = self.get_router() + broker = self.get_broker() + + broker.include_router(router) + broker.include_router(router) + assert len(router.config.configs) == 2 + assert router.parent is broker + + def test_reregister_on_include_in_different_brokers(self) -> None: + router = self.get_router() + broker1 = self.get_broker() + broker2 = self.get_broker() + + broker1.include_router(router) + broker2.include_router(router) + + assert len(router.config.configs) == 2 + assert router.parent is broker2 + assert router not in broker1.routers + class IncludePublisherTestcase(IncludeTestcase): def get_object(self, router: BrokerRouter[Any] | BrokerUsecase[Any, Any]) -> Any: diff --git a/tests/brokers/base/router.py b/tests/brokers/base/router.py index 7eece76306..920007e90e 100644 --- a/tests/brokers/base/router.py +++ b/tests/brokers/base/router.py @@ -4,6 +4,7 @@ import pytest +from faststream import Context from faststream._internal.broker.router import ( ArgsContainer, BrokerRouter, @@ -633,3 +634,20 @@ async def m(m) -> None: await br.start() await br.publish("hello", queue) publisher.mock.assert_called_with("response") + + async def test_func_wrapped_correctly_on_include_in_different_broker(self) -> None: + router = self.get_router() + broker1 = self.get_broker() + broker2 = self.get_broker() + + @router.subscriber("in-queue") + async def handle_msg(broker=Context()) -> str: + return "test" + + broker1.include_router(router) + async with self.patch_broker(broker1) as br: + await br.publish({}, "in-queue") + + broker2.include_router(router) + async with self.patch_broker(broker2) as br: + await br.publish({}, "in-queue") diff --git a/tests/brokers/redis/test_router.py b/tests/brokers/redis/test_router.py index 74fca492e5..970a5991a9 100644 --- a/tests/brokers/redis/test_router.py +++ b/tests/brokers/redis/test_router.py @@ -148,7 +148,6 @@ def response(m) -> None: ), timeout=3, ) - assert event.is_set() async def test_delayed_stream_handlers(