@@ -8,36 +8,36 @@ def __init__(self, model_name='BAAI/bge-base-en-v1.5'):
8
8
self .model = SentenceTransformer (model_name )
9
9
self .texts = []
10
10
self .embeddings = None
11
-
11
+
12
12
def add_texts (self , texts ):
13
- new_embeddings = self .model .encode (texts , normalize_embeddings = True )
13
+ new_embeddings = jnp . array ( self .model .encode (texts , normalize_embeddings = True ) )
14
14
if self .embeddings is None :
15
15
self .embeddings = new_embeddings
16
16
else :
17
17
self .embeddings = jnp .concatenate ((self .embeddings , new_embeddings ), axis = 0 )
18
18
self .texts .extend (texts )
19
-
19
+
20
20
def delete_text (self , index ):
21
21
if 0 <= index < len (self .texts ):
22
22
self .texts .pop (index )
23
- self .embeddings = jnp . delete ( self .embeddings , index , axis = 0 )
23
+ self .embeddings = self .embeddings . at [ index ]. delete ( )
24
24
else :
25
25
raise IndexError ("Invalid index" )
26
-
26
+
27
27
def update_text (self , index , new_text ):
28
28
if 0 <= index < len (self .texts ):
29
29
self .texts [index ] = new_text
30
- new_embedding = self .model .encode ([new_text ], normalize_embeddings = True ).squeeze ()
31
- self .embeddings = ( self .embeddings ) .at [index ].set (new_embedding )
30
+ new_embedding = jnp . array ( self .model .encode ([new_text ], normalize_embeddings = True ) ).squeeze ()
31
+ self .embeddings = self .embeddings .at [index ].set (new_embedding )
32
32
else :
33
33
raise IndexError ("Invalid index" )
34
-
34
+
35
35
def search (self , query , top_k = 5 ):
36
- query_embedding = self .model .encode ([query ], normalize_embeddings = True )
36
+ query_embedding = jnp . array ( self .model .encode ([query ], normalize_embeddings = True ) )
37
37
similarities = jnp .dot (self .embeddings , query_embedding .T ).squeeze ()
38
38
top_indices = jnp .argsort (similarities )[- top_k :][::- 1 ]
39
39
return [(self .texts [i ], similarities [i ]) for i in top_indices ]
40
-
40
+
41
41
def save (self , file_path ):
42
42
with open (file_path , 'wb' ) as file :
43
43
pickle .dump (self , file )
@@ -46,13 +46,13 @@ def save(self, file_path):
46
46
def load (file_path ):
47
47
with open (file_path , 'rb' ) as file :
48
48
return pickle .load (file )
49
-
49
+
50
50
def __len__ (self ):
51
51
return len (self .texts )
52
-
52
+
53
53
def __getitem__ (self , index ):
54
54
return self .texts [index ]
55
-
55
+
56
56
def __iter__ (self ):
57
57
return iter (self .texts )
58
58
0 commit comments