-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathutils.py
235 lines (186 loc) · 7.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import asyncio
import inspect
import json
import logging
import warnings
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from time import time
from typing import Any, Callable, Coroutine, Dict, Optional, Sequence
from warnings import warn
from pydantic import BaseModel
from redis import Redis
from ulid import ULID
def create_ulid() -> str:
"""Generate a unique identifier to group related Redis documents."""
return str(ULID())
def current_timestamp() -> float:
"""Generate a unix epoch timestamp to assign to Redis documents."""
return time()
def model_to_dict(model: BaseModel) -> Dict[str, Any]:
"""
Custom serialization function that converts a Pydantic model to a dict,
serializing Enum fields to their values, and handling nested models and lists.
"""
def serialize_item(item):
if isinstance(item, Enum):
return item.value.lower()
elif isinstance(item, dict):
return {key: serialize_item(value) for key, value in item.items()}
elif isinstance(item, list):
return [serialize_item(element) for element in item]
else:
return item
serialized_data = model.model_dump(exclude_none=True)
for key, value in serialized_data.items():
serialized_data[key] = serialize_item(value)
return serialized_data
def validate_vector_dims(v1: int, v2: int) -> None:
"""Check the equality of vector dimensions."""
if v1 != v2:
raise ValueError(
"Invalid vector dimensions! " f"Vector has dims defined as {v1}",
f"Vector field has dims defined as {v2}",
"Vector dims must be equal in order to perform similarity search.",
)
def serialize(data: Any) -> str:
"""Serlize the input into a string."""
return json.dumps(data)
def deserialize(data: str) -> Any:
"""Deserialize the input from a string."""
return json.loads(data)
def deprecated_argument(argument: str, replacement: Optional[str] = None) -> Callable:
"""
Decorator to warn if a deprecated argument is passed.
When the wrapped function is called, the decorator will warn if the
deprecated argument is passed as an argument or keyword argument.
NOTE: The @deprecated_argument decorator should not fall "outside"
of the @classmethod decorator, but instead should come between
it and the method definition. For example:
class MyClass:
@classmethod
@deprecated_argument("old_arg", "new_arg")
@other_decorator
def test_method(cls, old_arg=None, new_arg=None):
pass
"""
message = f"Argument {argument} is deprecated and will be removed in the next major release."
if replacement:
message += f" Use {replacement} instead."
def decorator(func):
# Check if the function is a classmethod or staticmethod
if isinstance(func, (classmethod, staticmethod)):
underlying = func.__func__
@wraps(underlying)
def inner_wrapped(*args, **kwargs):
if argument in kwargs:
warn(message, DeprecationWarning, stacklevel=2)
else:
sig = inspect.signature(underlying)
bound_args = sig.bind(*args, **kwargs)
if argument in bound_args.arguments:
warn(message, DeprecationWarning, stacklevel=2)
return underlying(*args, **kwargs)
if isinstance(func, classmethod):
return classmethod(inner_wrapped)
else:
return staticmethod(inner_wrapped)
else:
@wraps(func)
def inner_normal(*args, **kwargs):
if argument in kwargs:
warn(message, DeprecationWarning, stacklevel=2)
else:
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
if argument in bound_args.arguments:
warn(message, DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)
return inner_normal
return decorator
@contextmanager
def assert_no_warnings():
"""
Assert that a function does not emit any warnings when called.
"""
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
def deprecated_function(name: Optional[str] = None, replacement: Optional[str] = None):
"""
Decorator to mark a function as deprecated.
When the wrapped function is called, the decorator will log a deprecation
warning.
"""
def decorator(func):
fn_name = name or func.__name__
warning_message = (
f"Function {fn_name} is deprecated and will be "
"removed in the next major release. "
)
if replacement:
warning_message += replacement
@wraps(func)
def wrapper(*args, **kwargs):
warn(warning_message, category=DeprecationWarning, stacklevel=3)
return func(*args, **kwargs)
return wrapper
return decorator
def sync_wrapper(fn: Callable[[], Coroutine[Any, Any, Any]]) -> Callable[[], None]:
def wrapper():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
try:
if loop is None or not loop.is_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
task = loop.create_task(fn())
loop.run_until_complete(task)
except RuntimeError:
# This could happen if an object stored an event loop and now
# that event loop is closed. There's nothing we can do other than
# advise the user to use explicit cleanup methods.
#
# Uses logging module instead of get_logger() to avoid I/O errors
# if the wrapped function is called as a finalizer.
logging.info(
f"Could not run the async function {fn.__name__} because the event loop is closed. "
"This usually means the object was not properly cleaned up. Please use explicit "
"cleanup methods (e.g., disconnect(), close()) or use the object as an async "
"context manager.",
)
return
return wrapper
def norm_cosine_distance(value: float) -> float:
"""
Normalize a cosine distance to a similarity score between 0 and 1.
"""
return max((2 - value) / 2, 0)
def denorm_cosine_distance(value: float) -> float:
"""
Denormalize a similarity score between 0 and 1 to a cosine distance between
0 and 2.
"""
return max(2 - 2 * value, 0)
def norm_l2_distance(value: float) -> float:
"""
Normalize the L2 distance.
"""
return 1 / (1 + value)
def scan_by_pattern(
redis_client: Redis,
pattern: str,
) -> Sequence[str]:
"""
Scan the Redis database for keys matching a specific pattern.
Args:
redis (Redis): The Redis client instance.
pattern (str): The pattern to match keys against.
Returns:
List[str]: A dictionary containing the keys and their values.
"""
from redisvl.redis.utils import convert_bytes
return convert_bytes(list(redis_client.scan_iter(match=pattern)))