Skip to content

Commit 44d1cac

Browse files
committed
Add Redis cache option
1 parent 2b5a10c commit 44d1cac

File tree

18 files changed

+500
-206
lines changed

18 files changed

+500
-206
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2525
- Deprecate "free" queries
2626
- Create simpler developer contract for databackend
2727
- Snowflake native data-backend implementation
28+
- Add redis cache inside `db.metadata` for quick multi-process loading
2829

2930
#### Bug Fixes
3031

plugins/redis/pyproject.toml

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
[build-system]
2+
requires = ["setuptools>=61.0"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "superduper_redis"
7+
readme = "README.md"
8+
description = "superduper allows users to work with arbitrary sklearn estimators, with additional support for pre-, post-processing and input/ output data-types."
9+
license = {file = "LICENSE"}
10+
maintainers = [{name = "superduper.io, Inc.", email = "[email protected]"}]
11+
keywords = [
12+
"databases",
13+
"mongodb",
14+
"data-science",
15+
"machine-learning",
16+
"mlops",
17+
"vector-database",
18+
"ai",
19+
]
20+
requires-python = ">=3.10"
21+
dynamic = ["version"]
22+
dependencies = [
23+
"redis"
24+
]
25+
26+
[project.optional-dependencies]
27+
test = [
28+
# Annotation plugin dependencies will be installed in CI
29+
# :CI: plugins/mongodb
30+
]
31+
32+
[project.urls]
33+
homepage = "https://superduper.io"
34+
documentation = "https://docs.superduper.io/docs/intro"
35+
source = "https://github.com/superduper-io/superduper"
36+
37+
[tool.setuptools.packages.find]
38+
include = ["superduper_redis*"]
39+
40+
[tool.setuptools.dynamic]
41+
version = {attr = "superduper_redis.__version__"}
42+
43+
[tool.black]
44+
skip-string-normalization = true
45+
target-version = ["py38"]
46+
47+
[tool.mypy]
48+
ignore_missing_imports = true
49+
no_implicit_optional = true
50+
warn_unused_ignores = true
51+
disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg"]
52+
53+
[tool.pytest.ini_options]
54+
addopts = "-W ignore"
55+
56+
[tool.ruff.lint]
57+
extend-select = [
58+
"I", # Missing required import (auto-fixable)
59+
"F", # PyFlakes
60+
#"W", # PyCode Warning
61+
"E", # PyCode Error
62+
#"N", # pep8-naming
63+
"D", # pydocstyle
64+
]
65+
ignore = [
66+
"D100", # Missing docstring in public module
67+
"D104", # Missing docstring in public package
68+
"D107", # Missing docstring in __init__
69+
"D105", # Missing docstring in magic method
70+
"D203", # 1 blank line required before class docstring
71+
"D212", # Multi-line docstring summary should start at the first line
72+
"D213", # Multi-line docstring summary should start at the second line
73+
"D401",
74+
"E402",
75+
]
76+
77+
[tool.ruff.lint.isort]
78+
combine-as-imports = true
79+
80+
[tool.ruff.lint.per-file-ignores]
81+
"test/**" = ["D"]
82+
"plugin_test/**" = ["D"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .cache import RedisCache as Cache
2+
3+
__version__ = '0.6.0'
4+
5+
__all__ = ['Cache']
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
import re
3+
4+
import redis
5+
from superduper import logging
6+
from superduper.backends.base.cache import Cache
7+
8+
9+
class RedisCache(Cache):
10+
"""Local cache for caching components.
11+
12+
:param init_cache: Initialize cache
13+
"""
14+
15+
def __init__(self, uri: str = 'redis://localhost:6379/0'):
16+
logging.info('Using Redis cache')
17+
logging.info(f'Connecting to Redis cache at {uri}')
18+
self.redis = redis.Redis.from_url(uri, decode_responses=True)
19+
logging.info(f'Connecting to Redis cache at {uri}... DONE')
20+
21+
def __delitem__(self, item):
22+
self.redis.delete(':'.join(item))
23+
24+
def __setitem__(self, key, value):
25+
key = ':'.join(key)
26+
self.redis.set(key, json.dumps(value))
27+
28+
def keys(self, *pattern):
29+
"""Get keys from the cache.
30+
31+
:param pattern: The pattern to search for.
32+
"""
33+
pattern = ':'.join(pattern)
34+
strings = list(self.redis.keys(pattern))
35+
return [tuple(re.split(':', string)) for string in strings]
36+
37+
def __getitem__(self, item):
38+
out = self.redis.get(':'.join(item))
39+
if out is None:
40+
raise KeyError(item)
41+
return json.loads(out)
42+
43+
def initialize(self):
44+
"""Initialize the cache."""
45+
pass
46+
47+
def drop(self, force: bool = False):
48+
"""Drop component from the cache.
49+
50+
:param uuid: Component uuid.
51+
"""
52+
self.redis.flushdb()
53+
54+
@property
55+
def db(self):
56+
"""Get the ``db``."""
57+
return self._db
58+
59+
@db.setter
60+
def db(self, value):
61+
"""Set the ``db``.
62+
63+
:param value: The value to set the ``db`` to.
64+
"""
65+
self._db = value

superduper/backends/base/cache.py

+88-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,101 @@
1+
import typing as t
12
from abc import abstractmethod
23

3-
from superduper.backends.base.backends import BaseBackend
44
from superduper.components.component import Component
55

66

7-
class Cache(BaseBackend):
7+
class Cache:
88
"""Cache object for caching components.
99
1010
# noqa
1111
"""
1212

1313
@abstractmethod
14-
def __getitem__(self, *item) -> Component:
14+
def __getitem__(self, *item) -> t.Dict | t.List:
1515
"""Get a component from the cache."""
1616
pass
17+
18+
@abstractmethod
19+
def keys(self, *pattern) -> t.List[str]:
20+
"""Get the keys from the cache.
21+
22+
:param pattern: The pattern to match.
23+
24+
>>> cache.keys('*', '*', '*')
25+
>>> cache.keys('Model', '*', '*')
26+
>>> cache.keys('Model', 'my_model', '*')
27+
>>> cache.keys('*', '*', '1234567890')
28+
"""
29+
30+
def get_with_uuid(self, uuid: str):
31+
"""Get a component from the cache with a specific uuid.
32+
33+
:param uuid: The uuid of the component to get.
34+
"""
35+
key = self.keys('*', '*', uuid)
36+
if not key:
37+
return None
38+
else:
39+
key = key[0]
40+
41+
try:
42+
return self[key]
43+
except KeyError:
44+
return
45+
46+
def get_with_component(self, component: str):
47+
"""Get all components from the cache of a certain type.
48+
49+
:param component: The component to get.
50+
"""
51+
keys = self.keys(component, '*', '*')
52+
return [self[k] for k in keys]
53+
54+
def get_with_component_identifier(self, component: str, identifier: str):
55+
"""Get a component from the cache with a specific identifier.
56+
57+
:param component: The component to get.
58+
:param identifier: The identifier of the component to
59+
"""
60+
keys = self.keys(component, identifier, '*')
61+
out = [self[k] for k in keys]
62+
if not out:
63+
return None
64+
out = max(out, key=lambda x: x['version']) # type: ignore[arg-type, call-overload]
65+
return out
66+
67+
def get_with_component_identifier_version(
68+
self, component: str, identifier: str, version: int
69+
):
70+
"""Get a component from the cache with a specific version.
71+
72+
:param component: The component to get.
73+
:param identifier: The identifier of the component to get.
74+
:param version: The version of the component to get.
75+
"""
76+
keys = self.keys(component, identifier, '*')
77+
out = [self[k] for k in keys]
78+
try:
79+
return next(r for r in out if r['version'] == version) # type: ignore[call-overload]
80+
except StopIteration:
81+
return
82+
83+
def __contains__(self, key: str) -> bool:
84+
return key in self.keys()
85+
86+
@abstractmethod
87+
def __setitem__(self, key: t.Tuple[str, str, str], value: t.Dict) -> None:
88+
pass
89+
90+
def delete_uuid(self, uuid: str):
91+
"""Delete a component from the cache.
92+
93+
:param uuid: The uuid of the component to delete.
94+
"""
95+
keys = self.keys('*', '*', uuid)
96+
for key in keys:
97+
del self[key] # type: ignore[arg-type]
98+
99+
@abstractmethod
100+
def __delitem__(self, key: t.Tuple[str, str, str]):
101+
pass

superduper/backends/base/cluster.py

-10
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,3 @@ def initialize(self, with_compute: bool = False):
111111
self.cdc.initialize()
112112

113113
logging.info(f"Cluster initialized in {time.time() - start:.2f} seconds.")
114-
115-
def drop_component(self, uuid: str):
116-
"""Drop component and its services rom the cluster.
117-
118-
:param uuid: Component uuid.
119-
"""
120-
try:
121-
del self.cache[uuid]
122-
except KeyError:
123-
pass

0 commit comments

Comments
 (0)