diff --git a/index.go b/index.go index 18177fc..91b6681 100644 --- a/index.go +++ b/index.go @@ -14,6 +14,7 @@ import "C" import ( "encoding/json" "fmt" + "sort" "unsafe" ) @@ -64,6 +65,10 @@ type Index interface { ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( []int64, []float32, error) + // Applicable only to IVF indexes: Returns the top k centroid cardinalities and + // their vectors in chosen order (descending or ascending) + ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) ([]uint64, [][]float32, error) + // Search queries the index with the vectors in x. // Returns the IDs of the k nearest neighbors for each query vector and the // corresponding distances. @@ -214,6 +219,79 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent return centroids, centroidDistances, nil } +func (idx *faissIndex) ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) ( + []uint64, [][]float32, error) { + nlist := int(C.faiss_IndexIVF_nlist(idx.idx)) + if nlist == 0 { + return nil, nil, nil + } + + centroidCardinalities := make([]C.size_t, nlist) + + // Allocate a flat buffer for all centroids, then slice it per centroid + d := idx.D() + flatCentroids := make([]float32, nlist*d) + + // Call the C function to fill centroid vectors and cardinalities + c := C.faiss_IndexIVF_get_centroids_and_cardinality( + idx.idx, + (*C.float)(&flatCentroids[0]), + (*C.size_t)(¢roidCardinalities[0]), + nil, + ) + if c != 0 { + return nil, nil, getLastError() + } + + topIndices := getIndicesOfKCentroidCardinalities(centroidCardinalities, limit, descending) + + rvCardinalities := make([]uint64, len(topIndices)) + rvCentroids := make([][]float32, len(topIndices)) + + for i, idx := range topIndices { + rvCardinalities[i] = uint64(centroidCardinalities[idx]) + rvCentroids[i] = flatCentroids[idx*d : (idx+1)*d] + } + + return rvCardinalities, rvCentroids, nil + +} + +func getIndicesOfKCentroidCardinalities(cardinalities []C.size_t, k int, descending bool) []int { + if k <= 0 || k > len(cardinalities) { + return nil + } + + // Store value and original index + type pair struct { + val C.size_t + idx int + } + + pairs := make([]pair, len(cardinalities)) + for i, v := range cardinalities { + pairs[i] = pair{v, i} + } + + // Sort pairs by value descending if descending is true, otherwise ascending + if descending { + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].val > pairs[j].val + }) + } else { + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].val < pairs[j].val + }) + } + + // Collect top k indexes + result := make([]int, k) + for i := 0; i < k; i++ { + result[i] = pairs[i].idx + } + return result +} + func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x, centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) {