Skip to content

Commit 4e91fa0

Browse files
authored
Merge pull request #8 from rishiraj/staging
Staging
2 parents 5dff898 + 6fc2e74 commit 4e91fa0

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

README.md

+24-6
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,34 @@ vector_db.update_text(index, new_text)
3939
```
4040
This will update the text and its corresponding embedding at the specified index with the new text.
4141

42-
6. Iterate over the stored texts:
42+
6. Save the database to a file:
43+
```python
44+
vector_db.save('vector_db.pkl')
45+
```
46+
This will save the current state of the `VectorDB` instance to a file named 'vector_db.pkl'.
47+
48+
7. Load the database from a file:
49+
```python
50+
vector_db = VectorDB.load('vector_db.pkl')
51+
```
52+
This will load the `VectorDB` instance from the file named 'vector_db.pkl' and return it.
53+
54+
8. Iterate over the stored texts:
4355
```python
4456
for text in vector_db:
4557
print(text)
4658
```
4759
This will iterate over all the texts stored in the database.
4860

49-
7. Access individual texts by index:
61+
9. Access individual texts by index:
5062
```python
5163
index = 2
5264
text = vector_db[index]
5365
print(text)
5466
```
5567
This will retrieve the text at the specified index.
5668

57-
8. Get the number of texts in the database:
69+
10. Get the number of texts in the database:
5870
```python
5971
num_texts = len(vector_db)
6072
print(num_texts)
@@ -84,9 +96,15 @@ vector_db.update_text(1, "i enjoy playing chess")
8496
# Delete a text
8597
vector_db.delete_text(2)
8698

87-
# Iterate over the stored texts
88-
print("\nStored texts:")
89-
for text in vector_db:
99+
# Save the database
100+
vector_db.save('vector_db.pkl')
101+
102+
# Load the database
103+
loaded_vector_db = VectorDB.load('vector_db.pkl')
104+
105+
# Iterate over the stored texts in the loaded database
106+
print("\nStored texts in the loaded database:")
107+
for text in loaded_vector_db:
90108
print(text)
91109
```
92110

spanking/main.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,36 @@ def __init__(self, model_name='BAAI/bge-base-en-v1.5'):
88
self.model = SentenceTransformer(model_name)
99
self.texts = []
1010
self.embeddings = None
11-
11+
1212
def add_texts(self, texts):
13-
new_embeddings = self.model.encode(texts, normalize_embeddings=True)
13+
new_embeddings = jnp.array(self.model.encode(texts, normalize_embeddings=True))
1414
if self.embeddings is None:
1515
self.embeddings = new_embeddings
1616
else:
1717
self.embeddings = jnp.concatenate((self.embeddings, new_embeddings), axis=0)
1818
self.texts.extend(texts)
19-
19+
2020
def delete_text(self, index):
2121
if 0 <= index < len(self.texts):
2222
self.texts.pop(index)
23-
self.embeddings = jnp.delete(self.embeddings, index, axis=0)
23+
self.embeddings = self.embeddings.at[index].delete()
2424
else:
2525
raise IndexError("Invalid index")
26-
26+
2727
def update_text(self, index, new_text):
2828
if 0 <= index < len(self.texts):
2929
self.texts[index] = new_text
30-
new_embedding = self.model.encode([new_text], normalize_embeddings=True).squeeze()
31-
self.embeddings = (self.embeddings).at[index].set(new_embedding)
30+
new_embedding = jnp.array(self.model.encode([new_text], normalize_embeddings=True)).squeeze()
31+
self.embeddings = self.embeddings.at[index].set(new_embedding)
3232
else:
3333
raise IndexError("Invalid index")
34-
34+
3535
def search(self, query, top_k=5):
36-
query_embedding = self.model.encode([query], normalize_embeddings=True)
36+
query_embedding = jnp.array(self.model.encode([query], normalize_embeddings=True))
3737
similarities = jnp.dot(self.embeddings, query_embedding.T).squeeze()
3838
top_indices = jnp.argsort(similarities)[-top_k:][::-1]
3939
return [(self.texts[i], similarities[i]) for i in top_indices]
40-
40+
4141
def save(self, file_path):
4242
with open(file_path, 'wb') as file:
4343
pickle.dump(self, file)
@@ -46,13 +46,13 @@ def save(self, file_path):
4646
def load(file_path):
4747
with open(file_path, 'rb') as file:
4848
return pickle.load(file)
49-
49+
5050
def __len__(self):
5151
return len(self.texts)
52-
52+
5353
def __getitem__(self, index):
5454
return self.texts[index]
55-
55+
5656
def __iter__(self):
5757
return iter(self.texts)
5858

0 commit comments

Comments
 (0)