diff --git a/faststream/_internal/context/repository.py b/faststream/_internal/context/repository.py index 8a3a57516f..c0418bdeca 100644 --- a/faststream/_internal/context/repository.py +++ b/faststream/_internal/context/repository.py @@ -17,7 +17,7 @@ def __init__(self, initial: dict[str, Any] | None = None, /) -> None: _global_context : a dictionary representing the global context _scope_context : a dictionary representing the scope context """ - self._global_context: dict[str, Any] = {"context": self} | (initial or {}) + self._global_context: dict[str, Any] = (initial or {}) | {"context": self} self._scope_context: dict[str, ContextVar[Any]] = {} @property @@ -164,5 +164,21 @@ def resolve(self, argument: str) -> Any: return v def clear(self) -> None: + """Reset global and scope contexts. + + Returns: + None + """ self._global_context = {"context": self} self._scope_context.clear() + + def merge_global(self, other: "ContextRepo") -> None: + """Merge the global context of another repository into the current one. + + Args: + other: Another ContextRepo instance. + + Returns: + None + """ + self._global_context |= other._global_context | {"context": self} diff --git a/faststream/_internal/di/config.py b/faststream/_internal/di/config.py index 6a841fc023..1e7f49cd6c 100644 --- a/faststream/_internal/di/config.py +++ b/faststream/_internal/di/config.py @@ -50,6 +50,7 @@ def _serializer(self) -> Optional["SerializerProto"]: def __or__(self, value: "FastDependsConfig", /) -> "FastDependsConfig": use_fd = False if not value.use_fastdepends else self.use_fastdepends + self.context.merge_global(value.context) return FastDependsConfig( use_fastdepends=use_fd, diff --git a/tests/application/test_delayed_broker.py b/tests/application/test_delayed_broker.py index 2f5596f317..b507368324 100644 --- a/tests/application/test_delayed_broker.py +++ b/tests/application/test_delayed_broker.py @@ -1,6 +1,8 @@ import pytest from faststream._internal.application import StartAbleApplication +from faststream._internal.context import ContextRepo +from faststream._internal.di import FastDependsConfig from faststream.exceptions import SetupError from faststream.rabbit import RabbitBroker @@ -45,3 +47,32 @@ async def test_di_reconfigured() -> None: app.set_broker(broker) assert broker.context.get("app") is app + + +def test_broker_and_app_contexts_merge() -> None: + broker = RabbitBroker( + context=ContextRepo({ + "broker_dependency": 1, + "override_dependency": 2, + }) + ) + + config = FastDependsConfig( + context=ContextRepo({ + "application_dependency": 3, + "override_dependency": 4, + }) + ) + app = StartAbleApplication(config=config) + application_context = app.context + + # if the broker binds to the application, + # the broker modifies the application context and uses it as its own. + app.set_broker(broker) + + assert app.context is application_context + assert broker.context is application_context + assert app.context.get("broker_dependency") == 1 + assert app.context.get("application_dependency") == 3 + # the broker context overwrites the application context + assert app.context.get("override_dependency") == 2 diff --git a/tests/utils/context/test_main.py b/tests/utils/context/test_main.py index 6f0f783738..206d17241b 100644 --- a/tests/utils/context/test_main.py +++ b/tests/utils/context/test_main.py @@ -1,10 +1,45 @@ +from typing import Any + import pytest from fast_depends import ValidationError -from faststream import Context, ContextRepo +from faststream import Context +from faststream._internal.context import ContextRepo from faststream._internal.utils import apply_types +@pytest.mark.parametrize( + ("initial", "expected_context"), + ( + pytest.param(None, {}, id="without initial"), + pytest.param({"value": 42}, {"value": 42}, id="basic value"), + pytest.param({"context": "sus"}, {}, id="sus context"), + ), +) +def test_context_repo_constructor( + initial: dict[str, Any] | None, + expected_context: dict[str, Any], +) -> None: + repo = ContextRepo(initial) + repo_context = repo.context + + assert repo_context.get("context") is repo + repo_context.pop("context") + assert repo_context == expected_context + + +def test_context_repo_merge_global(): + repo_1 = ContextRepo({"value_1": 1, "value_2": 2}) + repo_2 = ContextRepo({"value_2": 3, "value_3": 4}) + + repo_1.merge_global(repo_2) + + assert repo_1.get("value_1") == 1 + assert repo_1.get("value_2") == 3 + assert repo_1.get("value_3") == 4 + assert repo_1.get("context") is repo_1 + + def test_context_getattr(context: ContextRepo) -> None: a = 1000 context.set_global("key", a)