@@ -14,6 +14,7 @@ import "C"
14
14
import (
15
15
"encoding/json"
16
16
"fmt"
17
+ "sort"
17
18
"unsafe"
18
19
)
19
20
@@ -64,6 +65,9 @@ type Index interface {
64
65
ObtainClustersWithDistancesFromIVFIndex (x []float32 , centroidIDs []int64 ) (
65
66
[]int64 , []float32 , error )
66
67
68
+ // Applicable only to IVF indexes: Returns the top k centroid cardinalities and their vectors
69
+ ObtainTopKCentroidCardinalitiesFromIVFIndex (limit int ) ([]uint64 , [][]float32 , error )
70
+
67
71
// Search queries the index with the vectors in x.
68
72
// Returns the IDs of the k nearest neighbors for each query vector and the
69
73
// corresponding distances.
@@ -214,6 +218,72 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent
214
218
return centroids , centroidDistances , nil
215
219
}
216
220
221
+ func (idx * faissIndex ) ObtainTopKCentroidCardinalitiesFromIVFIndex (limit int ) ([]uint64 , [][]float32 , error ) {
222
+ nlist := int (C .faiss_IndexIVF_nlist (idx .idx ))
223
+ if nlist == 0 {
224
+ return nil , nil , nil
225
+ }
226
+
227
+ centroidCardinalities := make ([]C.size_t , nlist )
228
+
229
+ // Allocate a flat buffer for all centroids, then slice it per centroid
230
+ d := idx .D ()
231
+ flatCentroids := make ([]float32 , nlist * d )
232
+
233
+ // Call the C function to fill centroid vectors and cardinalities
234
+ c := C .faiss_IndexIVF_get_centroids_and_cardinality (
235
+ idx .idx ,
236
+ (* C .float )(& flatCentroids [0 ]),
237
+ (* C .size_t )(& centroidCardinalities [0 ]),
238
+ nil ,
239
+ )
240
+ if c != 0 {
241
+ return nil , nil , getLastError ()
242
+ }
243
+
244
+ topIndices := getTopIndicesOfTopKCardinalities (centroidCardinalities , limit )
245
+
246
+ rvCardinalities := make ([]uint64 , len (topIndices ))
247
+ rvCentroids := make ([][]float32 , len (topIndices ))
248
+
249
+ for i , idx := range topIndices {
250
+ rvCardinalities [i ] = uint64 (centroidCardinalities [idx ])
251
+ rvCentroids [i ] = flatCentroids [idx * d : (idx + 1 )* d ]
252
+ }
253
+
254
+ return rvCardinalities , rvCentroids , nil
255
+
256
+ }
257
+
258
+ func getTopIndicesOfTopKCardinalities (cardinalities []C.size_t , k int ) []int {
259
+ if k <= 0 || k > len (cardinalities ) {
260
+ return nil
261
+ }
262
+
263
+ // Store value and original index
264
+ type pair struct {
265
+ val C.size_t
266
+ idx int
267
+ }
268
+
269
+ pairs := make ([]pair , len (cardinalities ))
270
+ for i , v := range cardinalities {
271
+ pairs [i ] = pair {v , i }
272
+ }
273
+
274
+ // Sort pairs by value descending
275
+ sort .Slice (pairs , func (i , j int ) bool {
276
+ return pairs [i ].val > pairs [j ].val
277
+ })
278
+
279
+ // Collect top k indexes
280
+ result := make ([]int , k )
281
+ for i := 0 ; i < k ; i ++ {
282
+ result [i ] = pairs [i ].idx
283
+ }
284
+ return result
285
+ }
286
+
217
287
func (idx * faissIndex ) SearchClustersFromIVFIndex (selector Selector ,
218
288
eligibleCentroidIDs []int64 , minEligibleCentroids int , k int64 , x ,
219
289
centroidDis []float32 , params json.RawMessage ) ([]float32 , []int64 , error ) {
0 commit comments