Skip to content

Commit fc02479

Browse files
authored
Improvements to cached_result (#46)
* Improvements to cached_result - Omit 'self' argument from cache key generation - Raise warnings for potentially unsafe cache key arguments - Raise error when used on anything else than a coroutine - Raise error when no arguments are available for cache key generation * Bump version to 1.10.0
1 parent 3036a32 commit fc02479

File tree

4 files changed

+195
-12
lines changed

4 files changed

+195
-12
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 1.9.2
2+
current_version = 1.10.0
33
commit = False
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?

nwastdlib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#
1414
"""The NWA-stdlib module."""
1515

16-
__version__ = "1.9.2"
16+
__version__ = "1.10.0"
1717

1818
from nwastdlib.f import const, identity
1919

nwastdlib/asyncio_cache.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
#
14+
import datetime
1415
import hashlib
1516
import hmac
17+
import inspect
1618
import pickle # noqa: S403
1719
import sys
20+
import types
21+
import typing
22+
import warnings
1823
from collections.abc import Callable
1924
from functools import wraps
20-
from typing import Any, Protocol, runtime_checkable
25+
from typing import Any, Protocol, get_args, get_origin, runtime_checkable
26+
from uuid import UUID
2127

2228
import structlog
2329
from redis.asyncio import Redis as AIORedis
@@ -111,6 +117,80 @@ async def get_signed_cache_value(pool: AIORedis, secret: str, cache_key: str, se
111117
return _deserialize(pickled_value, serializer)
112118

113119

120+
def _generate_cache_key_suffix(*, skip_first: bool, args: tuple, kwargs: dict) -> str:
121+
# Auto generate cache key suffix based on the arguments
122+
# Note: this makes no attempt to handle non-hashable values like lists and sets or other complex objects
123+
filtered_args = args[int(skip_first) :]
124+
filtered_kwargs = frozenset(kwargs.items())
125+
if not filtered_args and not filtered_kwargs:
126+
raise ValueError("Cannot generate cache key without args/kwargs")
127+
args_and_kwargs_string = (filtered_args, filtered_kwargs)
128+
return str(args_and_kwargs_string)
129+
130+
131+
SAFE_CACHED_RESULT_TYPES = (
132+
int,
133+
str,
134+
float,
135+
datetime.datetime,
136+
UUID,
137+
)
138+
139+
140+
def _unwrap_type(type_: Any) -> Any:
141+
origin, args = get_origin(type_), get_args(type_)
142+
# 'str'
143+
if not origin:
144+
return type_
145+
146+
# 'str | None' or 'Optional[str]'
147+
if origin in (types.UnionType, typing.Union) and types.NoneType in args:
148+
return args[0]
149+
150+
# For more advanced type handling, see https://github.com/workfloworchestrator/nwa-stdlib/issues/45
151+
return type_
152+
153+
154+
def _format_warning(func: Callable, name: str, type_: Any) -> str:
155+
safe_types = (t.__name__ for t in SAFE_CACHED_RESULT_TYPES)
156+
return (
157+
f"{cached_result.__name__}() applied to function {func.__qualname__} which has parameter '{name}' "
158+
f"of unsafe type '{type_.__name__}'. "
159+
f"This can lead to duplicate keys and thus cache misses. "
160+
f"To resolve this, either set a static keyname or only use parameters of the type {safe_types}. "
161+
f"If you understand the risks you can suppress/ignore this warning. "
162+
f"For background and feedback see https://github.com/workfloworchestrator/nwa-stdlib/issues/45"
163+
)
164+
165+
166+
def _validate_signature(func: Callable) -> bool:
167+
"""Validate the function's signature and return a bool whether to skip the first argument.
168+
169+
Raises warnings for potentially unsafe cache key arguments.
170+
"""
171+
func_params = inspect.signature(func).parameters
172+
is_nested_function = "." in func.__qualname__
173+
174+
skip_first_arg = False
175+
for idx, (name, param) in enumerate(func_params.items()):
176+
if idx == 0 and name == "self" and is_nested_function:
177+
# This will falsely recognize a closure function with 'self'
178+
# as first arg as a method. Nothing we can do about that..
179+
skip_first_arg = True
180+
continue
181+
182+
param_type = _unwrap_type(param.annotation)
183+
if param_type not in SAFE_CACHED_RESULT_TYPES:
184+
warnings.warn(_format_warning(func, name, param.annotation), stacklevel=2)
185+
return skip_first_arg
186+
187+
188+
def _validate_coroutine(func: Callable) -> None:
189+
"""Validate that the callable is a coroutine."""
190+
if not inspect.iscoroutinefunction(func):
191+
raise TypeError(f"Can't apply {cached_result.__name__}() to {func.__name__}: not a coroutine")
192+
193+
114194
def cached_result(
115195
pool: AIORedis,
116196
prefix: str,
@@ -157,21 +237,23 @@ def my_other_function...
157237
decorator function
158238
159239
"""
240+
python_major, python_minor = sys.version_info[:2]
241+
prefix_version = f"{prefix}:{python_major}.{python_minor}"
242+
static_cache_key: str | None = f"{prefix_version}:{key_name}" if key_name else None
160243

161244
def cache_decorator(func: Callable) -> Callable:
245+
_validate_coroutine(func)
246+
skip_first = _validate_signature(func)
247+
162248
@wraps(func)
163249
async def func_wrapper(*args: tuple[Any], **kwargs: dict[str, Any]) -> Any:
164250
from_cache = (not revalidate_fn(*args, **kwargs)) if revalidate_fn else True
165251

166-
python_major, python_minor = sys.version_info[:2]
167-
if key_name:
168-
cache_key = f"{prefix}:{python_major}.{python_minor}:{key_name}"
252+
if static_cache_key:
253+
cache_key = static_cache_key
169254
else:
170-
# Auto generate a cache key name based on function_name and a hash of the arguments
171-
# Note: this makes no attempt to handle non-hashable values like lists and sets or other complex objects
172-
args_and_kwargs_string = (args, frozenset(kwargs.items()))
173-
cache_key = f"{prefix}:{python_major}.{python_minor}:{func.__name__}{args_and_kwargs_string}"
174-
logger.debug("Autogenerated a cache key", cache_key=cache_key)
255+
suffix = _generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs)
256+
cache_key = f"{prefix_version}:{func.__name__}:{suffix}"
175257

176258
if from_cache:
177259
logger.debug("Cache called with wrapper func", func_name=func.__name__, cache_key=cache_key)

tests/test_asyncio_cache.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
22
import sys
33
from copy import copy
4+
from datetime import datetime
5+
from typing import Any, Optional, Union
6+
from uuid import UUID
47

58
import pytest
69
from fakeredis.aioredis import FakeRedis
710

8-
from nwastdlib.asyncio_cache import cached_result
11+
from nwastdlib.asyncio_cache import _generate_cache_key_suffix, cached_result
912

1013

1114
@pytest.fixture(autouse=True)
@@ -220,3 +223,101 @@ async def slow_function(revalidate_cache: bool):
220223
# A new call should serve 1: as it is not cached now
221224
result = await slow_function(revalidate_cache=True)
222225
assert result == 1
226+
227+
228+
# Test the validation
229+
230+
231+
@pytest.mark.parametrize(
232+
"type_",
233+
[
234+
Any,
235+
tuple,
236+
Union[str, int],
237+
],
238+
)
239+
def test_validate_signature_warn_unsafe(type_):
240+
with pytest.warns(UserWarning, match="unsafe type"):
241+
242+
@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
243+
async def foo(param: type_):
244+
return f"{param}-{param}"
245+
246+
247+
@pytest.mark.parametrize(
248+
"type_",
249+
[
250+
int,
251+
int | None,
252+
Optional[int],
253+
],
254+
)
255+
def test_validate_signature_safe(recwarn, type_):
256+
@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
257+
async def foo(param: type_):
258+
return f"{param}-{param}"
259+
260+
assert not [w.message for w in recwarn]
261+
262+
263+
def test_type_error_on_function():
264+
with pytest.raises(TypeError, match="foo: not a coroutine"):
265+
266+
@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
267+
def foo(param: str):
268+
return f"{param}-{param}"
269+
270+
271+
def test_type_error_on_generatorfunction():
272+
with pytest.raises(TypeError, match="foo: not a coroutine"):
273+
274+
@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
275+
def foo(param: str):
276+
yield f"{param}-{param}"
277+
278+
279+
def test_type_error_on_asyncgeneratorfunction():
280+
with pytest.raises(TypeError, match="foo: not a coroutine"):
281+
282+
@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
283+
async def foo(param: str):
284+
yield f"{param}-{param}"
285+
286+
287+
# Test key generation
288+
289+
version = f"{sys.version_info.major}.{sys.version_info.minor}"
290+
cache_prefix = "test"
291+
cache_key_start = f"{cache_prefix}:{version}"
292+
293+
294+
@pytest.mark.parametrize(
295+
("skip_first", "args", "kwargs", "expected_key"),
296+
[
297+
(True, (1, 2), {}, "((2,), frozenset())"),
298+
(False, (1, 2), {}, "((1, 2), frozenset())"),
299+
(False, (1, "a"), {}, "((1, 'a'), frozenset())"),
300+
(False, (), {"foo": "bar"}, "((), frozenset({('foo', 'bar')}))"),
301+
(False, (1.234567,), {}, "((1.234567,), frozenset())"),
302+
(False, (datetime(year=2025, month=4, day=14),), {}, "((datetime.datetime(2025, 4, 14, 0, 0),), frozenset())"),
303+
(
304+
False,
305+
(UUID("12345678-0000-1111-2222-0123456789ab"),),
306+
{},
307+
"((UUID('12345678-0000-1111-2222-0123456789ab'),), frozenset())",
308+
),
309+
],
310+
)
311+
def test_generate_cache_key_suffix(skip_first, args, kwargs, expected_key):
312+
assert _generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs) == expected_key
313+
314+
315+
@pytest.mark.parametrize(
316+
("skip_first", "args", "kwargs", "expected_exception"),
317+
[
318+
(False, (), {}, ValueError),
319+
],
320+
)
321+
def test_generate_cache_key_errors(skip_first, args, kwargs, expected_exception):
322+
with pytest.raises(expected_exception):
323+
_generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs)

0 commit comments

Comments
 (0)