@@ -37,6 +37,7 @@ type Index interface {
3737 IsIVFIndex () bool
3838 SetNProbe (nprobe int32 )
3939 GetNProbe () int32
40+ GetNlist () int
4041 SetDirectMap (directMapType int ) error
4142
4243 Close ()
@@ -63,6 +64,9 @@ type BinaryIndex interface {
6364 SearchClustersFromIVFIndex (selector Selector , eligibleCentroidIDs []int64 ,
6465 minEligibleCentroids int , k int64 , x []uint8 , centroidDis []int32 ,
6566 params json.RawMessage ) ([]int32 , []int64 , error )
67+
68+ BinaryQuantizer () BinaryIndex
69+ SetIsTrained (isTrained bool )
6670}
6771
6872// FloatIndex defines methods specific to float-based FAISS indexes
@@ -89,6 +93,8 @@ type FloatIndex interface {
8993 Reconstruct (key int64 ) (recons []float32 , err error )
9094 ReconstructBatch (ids []int64 , vectors []float32 ) ([]float32 , error )
9195
96+ GetCentroids () ([]float32 , error )
97+
9298 // Applicable only to IVF indexes: Returns a map where the keys
9399 // are cluster IDs and the values represent the count of input vectors that belong
94100 // to each cluster.
@@ -121,6 +127,8 @@ type FloatIndex interface {
121127 // RemoveIDs removes the vectors specified by sel from the index.
122128 // Returns the number of elements removed and error.
123129 RemoveIDs (sel * IDSelector ) (int , error )
130+
131+ Quantizer () * C.FaissIndex
124132}
125133
126134// IndexImpl represents a float vector index
@@ -220,6 +228,50 @@ func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, c
220228 return centroidIDs , centroidDistances , nil
221229}
222230
231+ func (idx * IndexImpl ) GetNlist () int {
232+ if ivfIdx := C .faiss_IndexIVF_cast (idx .cPtrFloat ()); ivfIdx != nil {
233+ return int (C .faiss_IndexIVF_nlist (ivfIdx ))
234+ }
235+ return 0
236+ }
237+
238+ func (idx * IndexImpl ) GetCentroids () ([]float32 , error ) {
239+ if ivfIdx := C .faiss_IndexIVF_cast (idx .cPtrFloat ()); ivfIdx != nil {
240+ ivfCentroids := make ([]float32 , idx .D ()* idx .GetNlist ())
241+ C .faiss_IndexIVF_get_centroids (ivfIdx , (* C .float )(& ivfCentroids [0 ]))
242+ return ivfCentroids , nil
243+ }
244+ return nil , fmt .Errorf ("index is not an IVF index" )
245+ }
246+
247+ func (idx * IndexImpl ) Quantizer () * C.FaissIndex {
248+ if ivfIdx := C .faiss_IndexIVF_cast (idx .cPtrFloat ()); ivfIdx != nil {
249+ return C .faiss_IndexIVF_quantizer (ivfIdx )
250+ }
251+ return nil
252+ }
253+
254+ func (idx * BinaryIndexImpl ) SetIsTrained (isTrained bool ) {
255+ if isTrained {
256+ C .faiss_IndexBinaryIVF_set_is_trained ((* C .FaissIndexBinaryIVF )(idx .cPtrBinary ()),
257+ C .int (1 ))
258+ } else {
259+ C .faiss_IndexBinaryIVF_set_is_trained ((* C .FaissIndexBinaryIVF )(idx .cPtrBinary ()),
260+ C .int (0 ))
261+ }
262+ }
263+
264+ func (idx * BinaryIndexImpl ) BinaryQuantizer () BinaryIndex {
265+ if bivfIdx := C .faiss_IndexBinaryIVF_cast (idx .cPtrBinary ()); bivfIdx != nil {
266+ return & BinaryIndexImpl {
267+ indexPtr : C .faiss_IndexBinaryIVF_quantizer (bivfIdx ),
268+ d : idx .d ,
269+ metric : idx .metric ,
270+ }
271+ }
272+ return nil
273+ }
274+
223275func (idx * BinaryIndexImpl ) Size () uint64 {
224276 return 0
225277}
@@ -244,6 +296,13 @@ func (idx *BinaryIndexImpl) IsIVFIndex() bool {
244296 return C .faiss_IndexBinaryIVF_cast (idx .indexPtr ) != nil
245297}
246298
299+ func (idx * BinaryIndexImpl ) GetNlist () int {
300+ if ivfIdx := C .faiss_IndexBinaryIVF_cast (idx .indexPtr ); ivfIdx != nil {
301+ return int (C .faiss_IndexBinaryIVF_nlist (ivfIdx ))
302+ }
303+ return 0
304+ }
305+
247306// Binary-specific operations
248307func (idx * BinaryIndexImpl ) TrainBinary (vectors []uint8 ) error {
249308 n := (len (vectors ) * 8 ) / idx .d
0 commit comments