Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions index.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import "C"
import (
"encoding/json"
"fmt"
"sort"
"unsafe"
)

Expand Down Expand Up @@ -64,6 +65,10 @@ type Index interface {
ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) (
[]int64, []float32, error)

// Applicable only to IVF indexes: Returns the top k centroid cardinalities and
// their vectors in chosen order (descending or ascending)
ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) ([]uint64, [][]float32, error)

// Search queries the index with the vectors in x.
// Returns the IDs of the k nearest neighbors for each query vector and the
// corresponding distances.
Expand Down Expand Up @@ -214,6 +219,79 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent
return centroids, centroidDistances, nil
}

func (idx *faissIndex) ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) (
[]uint64, [][]float32, error) {
nlist := int(C.faiss_IndexIVF_nlist(idx.idx))
if nlist == 0 {
return nil, nil, nil
}

centroidCardinalities := make([]C.size_t, nlist)

// Allocate a flat buffer for all centroids, then slice it per centroid
d := idx.D()
flatCentroids := make([]float32, nlist*d)

// Call the C function to fill centroid vectors and cardinalities
c := C.faiss_IndexIVF_get_centroids_and_cardinality(
idx.idx,
(*C.float)(&flatCentroids[0]),
(*C.size_t)(&centroidCardinalities[0]),
nil,
)
if c != 0 {
return nil, nil, getLastError()
}

topIndices := getIndicesOfKCentroidCardinalities(centroidCardinalities, limit, descending)

rvCardinalities := make([]uint64, len(topIndices))
rvCentroids := make([][]float32, len(topIndices))

for i, idx := range topIndices {
rvCardinalities[i] = uint64(centroidCardinalities[idx])
rvCentroids[i] = flatCentroids[idx*d : (idx+1)*d]
}

return rvCardinalities, rvCentroids, nil

}

func getIndicesOfKCentroidCardinalities(cardinalities []C.size_t, k int, descending bool) []int {
if k <= 0 || k > len(cardinalities) {
return nil
}

// Store value and original index
type pair struct {
val C.size_t
idx int
}

pairs := make([]pair, len(cardinalities))
for i, v := range cardinalities {
pairs[i] = pair{v, i}
}

// Sort pairs by value descending if descending is true, otherwise ascending
if descending {
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].val > pairs[j].val
})
} else {
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].val < pairs[j].val
})
}

// Collect top k indexes
result := make([]int, k)
for i := 0; i < k; i++ {
result[i] = pairs[i].idx
}
return result
}

func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector,
eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x,
centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) {
Expand Down