Skip to content

Commit 57c6d65

Browse files
MB-66396: New IndexIVF API: ObtainTopKCentroidCardinalitiesFromIVFIndex
1 parent 371fb38 commit 57c6d65

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

index.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import "C"
1414
import (
1515
"encoding/json"
1616
"fmt"
17+
"sort"
1718
"unsafe"
1819
)
1920

@@ -64,6 +65,9 @@ type Index interface {
6465
ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) (
6566
[]int64, []float32, error)
6667

68+
// Applicable only to IVF indexes: Returns the top k centroid cardinalities and their vectors
69+
ObtainTopKCentroidCardinalitiesFromIVFIndex(limit int) ([]uint64, [][]float32, error)
70+
6771
// Search queries the index with the vectors in x.
6872
// Returns the IDs of the k nearest neighbors for each query vector and the
6973
// corresponding distances.
@@ -214,6 +218,72 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent
214218
return centroids, centroidDistances, nil
215219
}
216220

221+
func (idx *faissIndex) ObtainTopKCentroidCardinalitiesFromIVFIndex(limit int) ([]uint64, [][]float32, error) {
222+
nlist := int(C.faiss_IndexIVF_nlist(idx.idx))
223+
if nlist == 0 {
224+
return nil, nil, nil
225+
}
226+
227+
centroidCardinalities := make([]C.size_t, nlist)
228+
229+
// Allocate a flat buffer for all centroids, then slice it per centroid
230+
d := idx.D()
231+
flatCentroids := make([]float32, nlist*d)
232+
233+
// Call the C function to fill centroid vectors and cardinalities
234+
c := C.faiss_IndexIVF_get_centroids_and_cardinality(
235+
idx.idx,
236+
(*C.float)(&flatCentroids[0]),
237+
(*C.size_t)(&centroidCardinalities[0]),
238+
nil,
239+
)
240+
if c != 0 {
241+
return nil, nil, getLastError()
242+
}
243+
244+
topIndices := getTopIndicesOfTopKCardinalities(centroidCardinalities, limit)
245+
246+
rvCardinalities := make([]uint64, len(topIndices))
247+
rvCentroids := make([][]float32, len(topIndices))
248+
249+
for i, idx := range topIndices {
250+
rvCardinalities[i] = uint64(centroidCardinalities[idx])
251+
rvCentroids[i] = flatCentroids[idx*d : (idx+1)*d]
252+
}
253+
254+
return rvCardinalities, rvCentroids, nil
255+
256+
}
257+
258+
func getTopIndicesOfTopKCardinalities(cardinalities []C.size_t, k int) []int {
259+
if k <= 0 || k > len(cardinalities) {
260+
return nil
261+
}
262+
263+
// Store value and original index
264+
type pair struct {
265+
val C.size_t
266+
idx int
267+
}
268+
269+
pairs := make([]pair, len(cardinalities))
270+
for i, v := range cardinalities {
271+
pairs[i] = pair{v, i}
272+
}
273+
274+
// Sort pairs by value descending
275+
sort.Slice(pairs, func(i, j int) bool {
276+
return pairs[i].val > pairs[j].val
277+
})
278+
279+
// Collect top k indexes
280+
result := make([]int, k)
281+
for i := 0; i < k; i++ {
282+
result[i] = pairs[i].idx
283+
}
284+
return result
285+
}
286+
217287
func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector,
218288
eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x,
219289
centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) {

0 commit comments

Comments
 (0)