Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion faststream/_internal/context/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, initial: dict[str, Any] | None = None, /) -> None:
_global_context : a dictionary representing the global context
_scope_context : a dictionary representing the scope context
"""
self._global_context: dict[str, Any] = {"context": self} | (initial or {})
self._global_context: dict[str, Any] = (initial or {}) | {"context": self}
self._scope_context: dict[str, ContextVar[Any]] = {}

@property
Expand Down Expand Up @@ -164,5 +164,21 @@ def resolve(self, argument: str) -> Any:
return v

def clear(self) -> None:
"""Reset global and scope contexts.

Returns:
None
"""
self._global_context = {"context": self}
self._scope_context.clear()

def merge_global(self, other: "ContextRepo") -> None:
"""Merge the global context of another repository into the current one.

Args:
other: Another ContextRepo instance.

Returns:
None
"""
self._global_context |= other._global_context | {"context": self}
1 change: 1 addition & 0 deletions faststream/_internal/di/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _serializer(self) -> Optional["SerializerProto"]:

def __or__(self, value: "FastDependsConfig", /) -> "FastDependsConfig":
use_fd = False if not value.use_fastdepends else self.use_fastdepends
self.context.merge_global(value.context)

return FastDependsConfig(
use_fastdepends=use_fd,
Expand Down
31 changes: 31 additions & 0 deletions tests/application/test_delayed_broker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest

from faststream._internal.application import StartAbleApplication
from faststream._internal.context import ContextRepo
from faststream._internal.di import FastDependsConfig
from faststream.exceptions import SetupError
from faststream.rabbit import RabbitBroker

Expand Down Expand Up @@ -45,3 +47,32 @@ async def test_di_reconfigured() -> None:
app.set_broker(broker)

assert broker.context.get("app") is app


def test_broker_and_app_contexts_merge() -> None:
broker = RabbitBroker(
context=ContextRepo({
"broker_dependency": 1,
"override_dependency": 2,
})
)

config = FastDependsConfig(
context=ContextRepo({
"application_dependency": 3,
"override_dependency": 4,
})
)
app = StartAbleApplication(config=config)
application_context = app.context

# if the broker binds to the application,
# the broker modifies the application context and uses it as its own.
app.set_broker(broker)

assert app.context is application_context
assert broker.context is application_context
assert app.context.get("broker_dependency") == 1
assert app.context.get("application_dependency") == 3
# the broker context overwrites the application context
assert app.context.get("override_dependency") == 2
37 changes: 36 additions & 1 deletion tests/utils/context/test_main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
from typing import Any

import pytest
from fast_depends import ValidationError

from faststream import Context, ContextRepo
from faststream import Context
from faststream._internal.context import ContextRepo
from faststream._internal.utils import apply_types


@pytest.mark.parametrize(
("initial", "expected_context"),
(
pytest.param(None, {}, id="without initial"),
pytest.param({"value": 42}, {"value": 42}, id="basic value"),
pytest.param({"context": "sus"}, {}, id="sus context"),
),
)
def test_context_repo_constructor(
initial: dict[str, Any] | None,
expected_context: dict[str, Any],
) -> None:
repo = ContextRepo(initial)
repo_context = repo.context

assert repo_context.get("context") is repo
repo_context.pop("context")
assert repo_context == expected_context


def test_context_repo_merge_global():
repo_1 = ContextRepo({"value_1": 1, "value_2": 2})
repo_2 = ContextRepo({"value_2": 3, "value_3": 4})

repo_1.merge_global(repo_2)

assert repo_1.get("value_1") == 1
assert repo_1.get("value_2") == 3
assert repo_1.get("value_3") == 4
assert repo_1.get("context") is repo_1


def test_context_getattr(context: ContextRepo) -> None:
a = 1000
context.set_global("key", a)
Expand Down