@@ -14,6 +14,7 @@ import "C"
1414import  (
1515	"encoding/json" 
1616	"fmt" 
17+ 	"sort" 
1718	"unsafe" 
1819)
1920
@@ -64,6 +65,9 @@ type Index interface {
6465	ObtainClustersWithDistancesFromIVFIndex (x  []float32 , centroidIDs  []int64 ) (
6566		[]int64 , []float32 , error )
6667
68+ 	// Applicable only to IVF indexes: Returns the top k centroid cardinalities and their vectors 
69+ 	ObtainTopKCentroidCardinalitiesFromIVFIndex (limit  int ) ([]int64 , [][]float32 , error )
70+ 
6771	// Search queries the index with the vectors in x. 
6872	// Returns the IDs of the k nearest neighbors for each query vector and the 
6973	// corresponding distances. 
@@ -214,6 +218,72 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent
214218	return  centroids , centroidDistances , nil 
215219}
216220
221+ func  (idx  * faissIndex ) ObtainTopKCentroidCardinalitiesFromIVFIndex (limit  int ) ([]uint64 , [][]float32 , error ) {
222+ 	nlist  :=  int (C .faiss_IndexIVF_nlist (idx .idx ))
223+ 	if  nlist  ==  0  {
224+ 		return  nil , nil , nil 
225+ 	}
226+ 
227+ 	centroidCardinalities  :=  make ([]C.size_t , nlist )
228+ 
229+ 	// Allocate a flat buffer for all centroids, then slice it per centroid 
230+ 	d  :=  idx .D ()
231+ 	flatCentroids  :=  make ([]float32 , nlist * d )
232+ 
233+ 	// Call the C function to fill centroid vectors and cardinalities 
234+ 	c  :=  C .faiss_IndexIVF_get_centroids_and_cardinality (
235+ 		idx .idx ,
236+ 		(* C .float )(& flatCentroids [0 ]),
237+ 		(* C .size_t )(& centroidCardinalities [0 ]),
238+ 		nil ,
239+ 	)
240+ 	if  c  !=  0  {
241+ 		return  nil , nil , getLastError ()
242+ 	}
243+ 
244+ 	topIndices  :=  getTopIndicesOfTopKCardinalities (centroidCardinalities , limit )
245+ 
246+ 	rvCardinalities  :=  make ([]uint64 , len (topIndices ))
247+ 	rvCentroids  :=  make ([][]float32 , len (topIndices ))
248+ 
249+ 	for  i , idx  :=  range  topIndices  {
250+ 		rvCardinalities [i ] =  uint64 (centroidCardinalities [idx ])
251+ 		rvCentroids [i ] =  flatCentroids [idx * d  : (idx + 1 )* d ]
252+ 	}
253+ 
254+ 	return  rvCardinalities , rvCentroids , nil 
255+ 
256+ }
257+ 
258+ func  getTopIndicesOfTopKCardinalities (cardinalities  []C.size_t , k  int ) []int  {
259+ 	if  k  <=  0  ||  k  >  len (cardinalities ) {
260+ 		return  nil 
261+ 	}
262+ 
263+ 	// Store value and original index 
264+ 	type  pair  struct  {
265+ 		val  C.size_t 
266+ 		idx  int 
267+ 	}
268+ 
269+ 	pairs  :=  make ([]pair , len (cardinalities ))
270+ 	for  i , v  :=  range  cardinalities  {
271+ 		pairs [i ] =  pair {v , i }
272+ 	}
273+ 
274+ 	// Sort pairs by value descending 
275+ 	sort .Slice (pairs , func (i , j  int ) bool  {
276+ 		return  pairs [i ].val  >  pairs [j ].val 
277+ 	})
278+ 
279+ 	// Collect top k indexes 
280+ 	result  :=  make ([]int , k )
281+ 	for  i  :=  0 ; i  <  k ; i ++  {
282+ 		result [i ] =  pairs [i ].idx 
283+ 	}
284+ 	return  result 
285+ }
286+ 
217287func  (idx  * faissIndex ) SearchClustersFromIVFIndex (selector  Selector ,
218288	eligibleCentroidIDs  []int64 , minEligibleCentroids  int , k  int64 , x ,
219289	centroidDis  []float32 , params  json.RawMessage ) ([]float32 , []int64 , error ) {
0 commit comments