Skip to content

Commit 7169eb1

Browse files
committed
feat: add unit test
1 parent d18957e commit 7169eb1

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

langchain/vectorstores/redisearch.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import json
44
import uuid
5-
from typing import Any, Callable, Iterable, List, Optional, Tuple
5+
from typing import Any, Callable, Iterable, List, Mapping, Optional
6+
7+
import numpy as np
8+
from redis.commands.search.query import Query
69

710
from langchain.docstore.document import Document
811
from langchain.embeddings.base import Embeddings
9-
from langchain.vectorstores.base import VectorStore
1012
from langchain.utils import get_from_dict_or_env
11-
from redis.commands.search.query import Query
12-
import numpy as np
13+
from langchain.vectorstores.base import VectorStore
1314

1415

1516
class RediSearch(VectorStore):
@@ -80,7 +81,9 @@ def similarity_search(
8081
.paging(0, k)
8182
.dialect(2)
8283
)
83-
params_dict = {"vector": np.array(embedding).astype(dtype=np.float32).tobytes()}
84+
params_dict: Mapping[str, str] = {
85+
"vector": str(np.array(embedding).astype(dtype=np.float32).tobytes())
86+
}
8487

8588
# perform vector search
8689
results = self.client.ft(self.index_name).search(redis_query, params_dict)
@@ -115,15 +118,14 @@ def from_texts(
115118
redisearch = RediSearch.from_texts(
116119
texts,
117120
embeddings,
118-
redis_url="redis://username:password@localhost:6379"
121+
redisearch_url="redis://username:password@localhost:6379"
119122
)
120123
"""
121124
redisearch_url = get_from_dict_or_env(
122125
kwargs, "redisearch_url", "REDISEARCH_URL"
123126
)
124127
try:
125128
import redis
126-
from redis.commands.search.query import Query
127129
from redis.commands.search.field import TextField, VectorField
128130
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
129131
except ImportError:
@@ -165,7 +167,7 @@ def from_texts(
165167
try:
166168
client.ft(index_name).info()
167169
print("Index already exists")
168-
except:
170+
except: # noqa
169171
# Create RediSearch Index
170172
client.ft(index_name).create_index(
171173
fields=fields,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Test RediSearch functionality."""
2+
3+
from langchain.docstore.document import Document
4+
from langchain.vectorstores.redisearch import RediSearch
5+
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
6+
7+
8+
def test_redisearch() -> None:
9+
"""Test end to end construction and search."""
10+
texts = ["foo", "bar", "baz"]
11+
docsearch = RediSearch.from_texts(
12+
texts, FakeEmbeddings(), redisearch_url="redis://localhost:6379"
13+
)
14+
output = docsearch.similarity_search("foo", k=1)
15+
assert output == [Document(page_content="foo")]

0 commit comments

Comments
 (0)