Skip to content

Commit 61c13fa

Browse files
committed
Add mechanism to override providers
1 parent 612ecd6 commit 61c13fa

File tree

2 files changed

+157
-6
lines changed

2 files changed

+157
-6
lines changed

pif/providers.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,62 @@
88
#
99
# https://github.com/scottzach1/python-injector-framework
1010

11+
from __future__ import annotations
12+
1113
import abc
1214
import functools
13-
from typing import Callable
15+
from typing import Callable, Self
1416

1517

1618
class Provider[T](abc.ABC):
1719
"""
1820
Signposts something that can be injected.
1921
"""
2022

21-
@abc.abstractmethod
23+
_override: Provider | None = None
24+
2225
def __call__(self, *args, **kwargs) -> T:
2326
"""
24-
Evaluate the value to provide.
27+
Evaluate the provider, will select override if present.
28+
"""
29+
if self._override:
30+
return self._override()
31+
32+
return self._evaluate()
33+
34+
@abc.abstractmethod
35+
def _evaluate(self) -> T:
36+
"""
37+
Define the behavior to evaluate the provided value.
38+
"""
39+
return self()
40+
41+
def override[U: Provider | None](self, provider: U) -> Override[U]:
2542
"""
43+
Override the current providers value with another provider.
44+
"""
45+
return Override(self, provider)
46+
47+
48+
class Override[ProviderT: Provider]:
49+
"""
50+
A context manager to implement overrides for providers.
51+
"""
52+
53+
__slots__ = ("_base", "_override", "_before")
54+
55+
def __init__(self, base: Provider, override: ProviderT | None = None):
56+
# noinspection PyProtectedMember
57+
self._before = base._override
58+
self._base = base
59+
self._override = override
60+
base._override = override
61+
62+
def __enter__(self) -> Self:
63+
yield self
64+
65+
def __exit__(self, exc_type, exc_val, exc_tb):
66+
self._base._override = self._before
2667

2768

2869
class ExistingSingleton[T](Provider):
@@ -35,7 +76,7 @@ class ExistingSingleton[T](Provider):
3576
def __init__(self, t: T):
3677
self.t = t
3778

38-
def __call__(self) -> T:
79+
def _evaluate(self) -> T:
3980
return self.t
4081

4182

@@ -53,7 +94,7 @@ def __init__(self, func: Callable[[...], T], *args, **kwargs):
5394
self._func = functools.partial(func, *args, **kwargs)
5495
self._result = UNSET
5596

56-
def __call__(self) -> T:
97+
def _evaluate(self) -> T:
5798
if self._result is UNSET:
5899
self._result = self._func()
59100
return self._result
@@ -69,5 +110,5 @@ class Factory[T](Provider):
69110
def __init__(self, func: Callable[[...], T], *args, **kwargs):
70111
self._func = functools.partial(func, *args, **kwargs)
71112

72-
def __call__(self) -> T:
113+
def _evaluate(self) -> T:
73114
return self._func()

tests/test_providers.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from pif import providers
2+
3+
4+
def test_override_standard_shallow():
5+
"""
6+
Testing basic override logic for providers.
7+
"""
8+
provide_a = providers.Factory[str](lambda: "a")
9+
provide_b = providers.Factory[str](lambda: "b")
10+
assert provide_a() == "a"
11+
assert provide_b() == "b"
12+
13+
provide_a.override(provide_b)
14+
assert provide_a() == "b"
15+
assert provide_b() == "b"
16+
17+
provide_a.override(None)
18+
assert provide_a() == "a"
19+
assert provide_b() == "b"
20+
21+
provide_b.override(provide_a)
22+
assert provide_a() == "a"
23+
assert provide_b() == "a"
24+
25+
provide_b.override(None)
26+
assert provide_a() == "a"
27+
assert provide_b() == "b"
28+
29+
30+
def test_override_contextmanager_shallow():
31+
"""
32+
Testing basic override logic for providers with contextmanager.
33+
"""
34+
provide_a = providers.Factory[str](lambda: "a")
35+
provide_b = providers.Factory[str](lambda: "b")
36+
37+
assert provide_a() == "a"
38+
assert provide_b() == "b"
39+
40+
with provide_a.override(provide_b):
41+
assert provide_a() == "b"
42+
assert provide_b() == "b"
43+
44+
assert provide_a() == "a"
45+
assert provide_b() == "b"
46+
47+
with provide_b.override(provide_a):
48+
assert provide_a() == "a"
49+
assert provide_b() == "a"
50+
51+
52+
def test_override_standard_nested():
53+
"""
54+
Testing nested override logic for providers.
55+
"""
56+
provide_a = providers.Factory[str](lambda: "a")
57+
provide_b = providers.Factory[str](lambda: "b")
58+
provide_c = providers.Factory[str](lambda: "c")
59+
60+
assert provide_a() == "a"
61+
assert provide_b() == "b"
62+
assert provide_c() == "c"
63+
64+
provide_a.override(provide_b)
65+
provide_b.override(provide_c)
66+
67+
assert provide_a() == "c"
68+
assert provide_b() == "c"
69+
assert provide_c() == "c"
70+
71+
provide_b.override(None)
72+
assert provide_a() == "b"
73+
assert provide_b() == "b"
74+
assert provide_c() == "c"
75+
76+
provide_a.override(None)
77+
assert provide_a() == "a"
78+
assert provide_b() == "b"
79+
assert provide_c() == "c"
80+
81+
82+
def test_override_contextmanager_nested():
83+
"""
84+
Testing nested override logic for providers with contextmanager.
85+
"""
86+
provide_a = providers.Factory[str](lambda: "a")
87+
provide_b = providers.Factory[str](lambda: "b")
88+
provide_c = providers.Factory[str](lambda: "c")
89+
90+
assert provide_a() == "a"
91+
assert provide_b() == "b"
92+
assert provide_c() == "c"
93+
94+
with provide_a.override(provide_b):
95+
assert provide_a() == "b"
96+
assert provide_b() == "b"
97+
assert provide_c() == "c"
98+
99+
with provide_b.override(provide_c):
100+
assert provide_a() == "c"
101+
assert provide_b() == "c"
102+
assert provide_c() == "c"
103+
104+
assert provide_a() == "b"
105+
assert provide_b() == "b"
106+
assert provide_c() == "c"
107+
108+
assert provide_a() == "a"
109+
assert provide_b() == "b"
110+
assert provide_c() == "c"

0 commit comments

Comments
 (0)