Skip to content

Commit fe8ae8c

Browse files
authored
Update main.py
1 parent 50d2ac0 commit fe8ae8c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

spanking/main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import pandas as pd
99

1010
class VectorDB:
11-
def __init__(self, model_name='BAAI/bge-base-en-v1.5'):
12-
self.model = SentenceTransformer(model_name)
11+
def __init__(self, model_name='dunzhang/stella_en_400M_v5'):
12+
self.model = SentenceTransformer(model_name, trust_remote_code=True)
1313
self.image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384")
1414
self.texts = []
1515
self.embeddings = None
@@ -45,7 +45,7 @@ def search(self, query, top_k=5, type='text'):
4545
if isinstance(query, str):
4646
query = Image.open(requests.get(query, stream=True).raw)
4747
outputs = self.image_classifier(query, candidate_labels=self.texts)
48-
similarities = jnp.array([output['score'] for output in outputs])
48+
similarities = jnp.array([round(output["score"], 4) for output in outputs])
4949
else:
5050
raise ValueError("Invalid search type. Supported types are 'text' and 'image'.")
5151
top_indices = jnp.argsort(similarities)[-top_k:][::-1]

0 commit comments

Comments
 (0)