Skip to content

Commit 97b9b68

Browse files
Add a default for the dtype field of BaseVectorizer (#261)
In 0.3.8, we introduced a new field to `BaseVectorizer`, `dtype`, but did not give it a default value. This broke downstream software written against previous versions of redisvl that lacked this field. Fixes [langchain/langchain-redis#48](langchain-ai/langchain-redis#48). --------- Co-authored-by: Tyler Hutcherson <[email protected]>
1 parent 9f910e4 commit 97b9b68

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

redisvl/schema/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,12 @@ def _make_field(storage_type, **field_inputs) -> BaseField:
170170

171171
@root_validator(pre=True)
172172
@classmethod
173-
def validate_and_create_fields(cls, values):
173+
def validate_and_create_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
174174
"""
175175
Validate uniqueness of field names and create valid field instances.
176176
"""
177177
# Ensure index is a dictionary for validation
178-
index = values.get("index")
178+
index = values.get("index", {})
179179
if not isinstance(index, IndexInfo):
180180
index = IndexInfo(**index)
181181

redisvl/utils/vectorize/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33
from typing import Callable, List, Optional
44

5-
from pydantic.v1 import BaseModel, validator
5+
from pydantic.v1 import BaseModel, Field, validator
66

77
from redisvl.redis.utils import array_to_buffer
88
from redisvl.schema.fields import VectorDataType
@@ -21,7 +21,7 @@ class Vectorizers(Enum):
2121
class BaseVectorizer(BaseModel, ABC):
2222
model: str
2323
dims: int
24-
dtype: str
24+
dtype: str = Field(default="float32")
2525

2626
@property
2727
def type(self) -> str:

tests/unit/test_base_vectorizer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import List
2+
from redisvl.utils.vectorize.base import BaseVectorizer
3+
4+
5+
6+
def test_base_vectorizer_defaults():
7+
"""
8+
Test that the base vectorizer defaults are set correctly, with
9+
a default for dtype. Versions before 0.3.8 did not have this field.
10+
11+
A regression test for langchain-redis/#48
12+
"""
13+
class SimpleVectorizer(BaseVectorizer):
14+
model: str = "simple"
15+
dims: int = 10
16+
17+
def embed(self, text: str, **kwargs) -> List[float]:
18+
return [0.0] * self.dims
19+
20+
async def aembed(self, text: str, **kwargs) -> List[float]:
21+
return [0.0] * self.dims
22+
23+
async def aembed_many(self, texts: List[str], **kwargs) -> List[List[float]]:
24+
return [[0.0] * self.dims] * len(texts)
25+
26+
def embed_many(self, texts: List[str], **kwargs) -> List[List[float]]:
27+
return [[0.0] * self.dims] * len(texts)
28+
29+
vectorizer = SimpleVectorizer()
30+
assert vectorizer.model == "simple"
31+
assert vectorizer.dims == 10
32+
assert vectorizer.dtype == "float32"

0 commit comments

Comments
 (0)