Skip to content

Commit 5774cac

Browse files
utils to re-use centroids from ivf index
1 parent b3fff7e commit 5774cac

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

index.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type Index interface {
3737
IsIVFIndex() bool
3838
SetNProbe(nprobe int32)
3939
GetNProbe() int32
40+
GetNlist() int
4041
SetDirectMap(directMapType int) error
4142

4243
Close()
@@ -63,6 +64,9 @@ type BinaryIndex interface {
6364
SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64,
6465
minEligibleCentroids int, k int64, x []uint8, centroidDis []int32,
6566
params json.RawMessage) ([]int32, []int64, error)
67+
68+
BinaryQuantizer() BinaryIndex
69+
SetIsTrained(isTrained bool)
6670
}
6771

6872
// FloatIndex defines methods specific to float-based FAISS indexes
@@ -89,6 +93,8 @@ type FloatIndex interface {
8993
Reconstruct(key int64) (recons []float32, err error)
9094
ReconstructBatch(ids []int64, vectors []float32) ([]float32, error)
9195

96+
GetCentroids() ([]float32, error)
97+
9298
// Applicable only to IVF indexes: Returns a map where the keys
9399
// are cluster IDs and the values represent the count of input vectors that belong
94100
// to each cluster.
@@ -121,6 +127,8 @@ type FloatIndex interface {
121127
// RemoveIDs removes the vectors specified by sel from the index.
122128
// Returns the number of elements removed and error.
123129
RemoveIDs(sel *IDSelector) (int, error)
130+
131+
Quantizer() *C.FaissIndex
124132
}
125133

126134
// IndexImpl represents a float vector index
@@ -220,6 +228,50 @@ func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, c
220228
return centroidIDs, centroidDistances, nil
221229
}
222230

231+
func (idx *IndexImpl) GetNlist() int {
232+
if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil {
233+
return int(C.faiss_IndexIVF_nlist(ivfIdx))
234+
}
235+
return 0
236+
}
237+
238+
func (idx *IndexImpl) GetCentroids() ([]float32, error) {
239+
if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil {
240+
ivfCentroids := make([]float32, idx.D()*idx.GetNlist())
241+
C.faiss_IndexIVF_get_centroids(ivfIdx, (*C.float)(&ivfCentroids[0]))
242+
return ivfCentroids, nil
243+
}
244+
return nil, fmt.Errorf("index is not an IVF index")
245+
}
246+
247+
func (idx *IndexImpl) Quantizer() *C.FaissIndex {
248+
if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil {
249+
return C.faiss_IndexIVF_quantizer(ivfIdx)
250+
}
251+
return nil
252+
}
253+
254+
func (idx *BinaryIndexImpl) SetIsTrained(isTrained bool) {
255+
if isTrained {
256+
C.faiss_IndexBinaryIVF_set_is_trained((*C.FaissIndexBinaryIVF)(idx.cPtrBinary()),
257+
C.int(1))
258+
} else {
259+
C.faiss_IndexBinaryIVF_set_is_trained((*C.FaissIndexBinaryIVF)(idx.cPtrBinary()),
260+
C.int(0))
261+
}
262+
}
263+
264+
func (idx *BinaryIndexImpl) BinaryQuantizer() BinaryIndex {
265+
if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil {
266+
return &BinaryIndexImpl{
267+
indexPtr: C.faiss_IndexBinaryIVF_quantizer(bivfIdx),
268+
d: idx.d,
269+
metric: idx.metric,
270+
}
271+
}
272+
return nil
273+
}
274+
223275
func (idx *BinaryIndexImpl) Size() uint64 {
224276
return 0
225277
}
@@ -244,6 +296,13 @@ func (idx *BinaryIndexImpl) IsIVFIndex() bool {
244296
return C.faiss_IndexBinaryIVF_cast(idx.indexPtr) != nil
245297
}
246298

299+
func (idx *BinaryIndexImpl) GetNlist() int {
300+
if ivfIdx := C.faiss_IndexBinaryIVF_cast(idx.indexPtr); ivfIdx != nil {
301+
return int(C.faiss_IndexBinaryIVF_nlist(ivfIdx))
302+
}
303+
return 0
304+
}
305+
247306
// Binary-specific operations
248307
func (idx *BinaryIndexImpl) TrainBinary(vectors []uint8) error {
249308
n := (len(vectors) * 8) / idx.d

search_params.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package faiss
44
#include <faiss/c_api/Index_c.h>
55
#include <faiss/c_api/IndexIVF_c.h>
66
#include <faiss/c_api/IndexBinary_c.h>
7+
#include <faiss/c_api/IndexBinaryIVF_c.h>
78
#include <faiss/c_api/impl/AuxIndexStructures_c.h>
89
*/
910
import "C"

0 commit comments

Comments
 (0)