Skip to content

Commit cb630fa

Browse files
authored
Merge pull request #1 from rishiraj/development
Development
2 parents 11786e3 + 6674a2b commit cb630fa

File tree

3 files changed

+143
-2
lines changed

3 files changed

+143
-2
lines changed

Diff for: README.md

+92
Original file line numberDiff line numberDiff line change
@@ -1 +1,93 @@
11
# spanking 🍑👋
2+
3+
To use the 🍑👋 `VectorDB` class, you can follow these steps:
4+
5+
1. Create an instance of the 🍑👋 `VectorDB` class:
6+
```python
7+
from spanking import VectorDB
8+
vector_db = VectorDB(model_name='BAAI/bge-base-en-v1.5')
9+
```
10+
You can optionally specify a different pre-trained sentence transformer model by passing its name to the constructor.
11+
12+
2. Add texts to the database:
13+
```python
14+
texts = ["i eat pizza", "i play chess", "i drive bus"]
15+
vector_db.add_texts(texts)
16+
```
17+
This will encode the texts into embeddings and store them in the database.
18+
19+
3. Search for similar texts:
20+
```python
21+
query = "we play football"
22+
top_results = vector_db.search(query, top_k=3)
23+
print(top_results)
24+
```
25+
This will retrieve the top-3 most similar texts to the query based on cosine similarity. The `search` method returns a list of tuples, where each tuple contains the text and its similarity score.
26+
27+
4. Delete a text from the database:
28+
```python
29+
index = 1
30+
vector_db.delete_text(index)
31+
```
32+
This will remove the text and its corresponding embedding at the specified index.
33+
34+
5. Update a text in the database:
35+
```python
36+
index = 0
37+
new_text = "i enjoy eating pizza"
38+
vector_db.update_text(index, new_text)
39+
```
40+
This will update the text and its corresponding embedding at the specified index with the new text.
41+
42+
6. Iterate over the stored texts:
43+
```python
44+
for text in vector_db:
45+
print(text)
46+
```
47+
This will iterate over all the texts stored in the database.
48+
49+
7. Access individual texts by index:
50+
```python
51+
index = 2
52+
text = vector_db[index]
53+
print(text)
54+
```
55+
This will retrieve the text at the specified index.
56+
57+
8. Get the number of texts in the database:
58+
```python
59+
num_texts = len(vector_db)
60+
print(num_texts)
61+
```
62+
This will return the number of texts currently stored in the database.
63+
64+
Here's an example usage of the 🍑👋 `VectorDB` class:
65+
66+
```python
67+
from spanking import VectorDB
68+
vector_db = VectorDB()
69+
70+
# Add texts to the database
71+
texts = ["i eat pizza", "i play chess", "i drive bus"]
72+
vector_db.add_texts(texts)
73+
74+
# Search for similar texts
75+
query = "we play football"
76+
top_results = vector_db.search(query, top_k=2)
77+
print("Top results:")
78+
for text, similarity in top_results:
79+
print(f"Text: {text}, Similarity: {similarity}")
80+
81+
# Update a text
82+
vector_db.update_text(1, "i enjoy playing chess")
83+
84+
# Delete a text
85+
vector_db.delete_text(2)
86+
87+
# Iterate over the stored texts
88+
print("\nStored texts:")
89+
for text in vector_db:
90+
print(text)
91+
```
92+
93+
This example demonstrates how to create a 🍑👋 `VectorDB` instance, add texts, search for similar texts, update and delete texts, and iterate over the stored texts.

Diff for: spanking/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .main import main
1+
from .main import VectorDB
2+
from .main import main

Diff for: spanking/main.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,50 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from sentence_transformers import SentenceTransformer
4+
5+
class VectorDB:
6+
def __init__(self, model_name='BAAI/bge-base-en-v1.5'):
7+
self.model = SentenceTransformer(model_name)
8+
self.texts = []
9+
self.embeddings = None
10+
11+
def add_texts(self, texts):
12+
new_embeddings = self.model.encode(texts, normalize_embeddings=True)
13+
if self.embeddings is None:
14+
self.embeddings = new_embeddings
15+
else:
16+
self.embeddings = jnp.concatenate((self.embeddings, new_embeddings), axis=0)
17+
self.texts.extend(texts)
18+
19+
def delete_text(self, index):
20+
if 0 <= index < len(self.texts):
21+
self.texts.pop(index)
22+
self.embeddings = jnp.delete(self.embeddings, index, axis=0)
23+
else:
24+
raise IndexError("Invalid index")
25+
26+
def update_text(self, index, new_text):
27+
if 0 <= index < len(self.texts):
28+
self.texts[index] = new_text
29+
new_embedding = self.model.encode([new_text], normalize_embeddings=True)
30+
self.embeddings = jax.ops.index_update(self.embeddings, index, new_embedding)
31+
else:
32+
raise IndexError("Invalid index")
33+
34+
def search(self, query, top_k=5):
35+
query_embedding = self.model.encode([query], normalize_embeddings=True)
36+
similarities = jnp.dot(self.embeddings, query_embedding.T).squeeze()
37+
top_indices = jnp.argsort(similarities)[-top_k:][::-1]
38+
return [(self.texts[i], similarities[i]) for i in top_indices]
39+
40+
def __len__(self):
41+
return len(self.texts)
42+
43+
def __getitem__(self, index):
44+
return self.texts[index]
45+
46+
def __iter__(self):
47+
return iter(self.texts)
48+
149
def main():
2-
print("🍑👋")
50+
print("🍑👋")

0 commit comments

Comments
 (0)