diff --git a/autotune.go b/autotune.go index 4b818d3..15e71ef 100644 --- a/autotune.go +++ b/autotune.go @@ -6,6 +6,7 @@ package faiss */ import "C" import ( + "fmt" "unsafe" ) @@ -30,11 +31,20 @@ func (p *ParameterSpace) SetIndexParameter(idx Index, name string, val float64) C.free(unsafe.Pointer(cname)) }() - c := C.faiss_ParameterSpace_set_index_parameter( - p.ps, idx.cPtr(), cname, C.double(val)) - if c != 0 { - return getLastError() + switch idx.(type) { + case FloatIndex: + idx := idx.(*IndexImpl) + c := C.faiss_ParameterSpace_set_index_parameter( + p.ps, idx.cPtrFloat(), cname, C.double(val)) + if c != 0 { + return getLastError() + } + case BinaryIndex: + return fmt.Errorf("binary indexes not supported for auto-tuning") + default: + return fmt.Errorf("unsupported index type") } + return nil } diff --git a/go.mod b/go.mod index afe720a..f2ef1aa 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/blevesearch/go-faiss -go 1.21 +go 1.22 + +toolchain go1.23.0 diff --git a/index.go b/index.go index 18177fc..9191f3c 100644 --- a/index.go +++ b/index.go @@ -2,13 +2,16 @@ package faiss /* #include +#include #include #include #include -#include -#include +#include +#include #include #include +#include +#include */ import "C" import ( @@ -17,35 +20,80 @@ import ( "unsafe" ) -// Index is a Faiss index. -// -// Note that some index implementations do not support all methods. -// Check the Faiss wiki to see what operations an index supports. +// Index is the common interface for both binary and float vector indexes type Index interface { - // D returns the dimension of the indexed vectors. + // Core index operations D() int - // IsTrained returns true if the index has been trained or does not require - // training. - IsTrained() bool - // Ntotal returns the number of indexed vectors. Ntotal() int64 // MetricType returns the metric type of the index. MetricType() int - // Train trains the index on a representative set of vectors. - Train(x []float32) error + Size() uint64 + + // IVF-specific operations, common to both float and binary IVF indexes + IsIVFIndex() bool + SetNProbe(nprobe int32) + GetNProbe() int32 + GetNlist() int + SetDirectMap(directMapType int) error + + Close() +} + +// BinaryIndex defines methods specific to binary FAISS indexes +type BinaryIndex interface { + Index - // Add adds vectors to the index. - Add(x []float32) error + cPtrBinary() *C.FaissIndexBinary + // Binary-specific operations + TrainBinary(vectors []uint8) error + AddBinary(vectors []uint8) error + AddBinaryWithIDs(vectors []uint8, ids []int64) error + SearchBinary(x []uint8, k int64) ([]int32, []int64, error) + SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) ([]int32, []int64, error) + SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, + labels []int64, err error) + + ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) + ObtainClustersWithDistancesFromIVFIndex(x []uint8, centroidIDs []int64) ( + []int64, []int32, error) + // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs + SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, + minEligibleCentroids int, k int64, x []uint8, centroidDis []int32, + params json.RawMessage) ([]int32, []int64, error) + + BinaryQuantizer() BinaryIndex + SetIsTrained(isTrained bool) +} +// FloatIndex defines methods specific to float-based FAISS indexes +type FloatIndex interface { + Index + + cPtrFloat() *C.FaissIndex + // Float-specific operations + // Train trains the index on a representative set of vectors. + Train(vectors []float32) error + Add(vectors []float32) error // AddWithIDs is like Add, but stores xids instead of sequential IDs. - AddWithIDs(x []float32, xids []int64) error + AddWithIDs(vectors []float32, xids []int64) error + // Search queries the index with the vectors in x. + // Returns the IDs of the k nearest neighbors for each query vector and the + // corresponding distances. + Search(x []float32, k int64) (distances []float32, labels []int64, err error) + // RangeSearch queries the index with the vectors in x. + // Returns all vectors with distance < radius. + RangeSearch(x []float32, radius float32) (*RangeSearchResult, error) + SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) ([]float32, []int64, error) + // SearchWithoutIDs is like Search, but excludes the vectors with IDs in exclude. + SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) ([]float32, []int64, error) + Reconstruct(key int64) (recons []float32, err error) + ReconstructBatch(ids []int64, vectors []float32) ([]float32, error) - // Returns true if the index is an IVF index. - IsIVFIndex() bool + GetCentroids() ([]float32, error) // Applicable only to IVF indexes: Returns a map where the keys // are cluster IDs and the values represent the count of input vectors that belong @@ -64,31 +112,14 @@ type Index interface { ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( []int64, []float32, error) - // Search queries the index with the vectors in x. - // Returns the IDs of the k nearest neighbors for each query vector and the - // corresponding distances. - Search(x []float32, k int64) (distances []float32, labels []int64, err error) - - SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) (distances []float32, - labels []int64, err error) - - SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32, - labels []int64, err error) + DistCompute(queryData []float32, ids []int64, k int, distances []float32) error // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x, centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) - Reconstruct(key int64) ([]float32, error) - - ReconstructBatch(keys []int64, recons []float32) ([]float32, error) - - MergeFrom(other Index, add_id int64) error - - // RangeSearch queries the index with the vectors in x. - // Returns all vectors with distance < radius. - RangeSearch(x []float32, radius float32) (*RangeSearchResult, error) + MergeFrom(other IndexImpl, add_id int64) error // Reset removes all vectors from the index. Reset() error @@ -97,67 +128,393 @@ type Index interface { // Returns the number of elements removed and error. RemoveIDs(sel *IDSelector) (int, error) - // Close frees the memory used by the index. - Close() + Quantizer() *C.FaissIndex +} - // consults the C++ side to get the size of the index - Size() uint64 +// IndexImpl represents a float vector index +type IndexImpl struct { + indexPtr *C.FaissIndex + d int + metric int +} - cPtr() *C.FaissIndex +// BinaryIndexImpl represents a binary vector index +type BinaryIndexImpl struct { + indexPtr *C.FaissIndexBinary + d int + metric int } -type faissIndex struct { - idx *C.FaissIndex +// NewBinaryIndexImpl creates a new binary index implementation +func NewBinaryIndexImpl(d int, description string, metric int) (*BinaryIndexImpl, error) { + idx := &BinaryIndexImpl{ + d: d, + metric: metric, + } + var cDescription *C.char + if description != "" { + cDescription = C.CString(description) + defer C.free(unsafe.Pointer(cDescription)) + } + + var cIdx *C.FaissIndexBinary + if c := C.faiss_index_binary_factory(&cIdx, C.int(idx.d), cDescription); c != 0 { + return nil, getLastError() + } + idx.indexPtr = cIdx + return idx, nil } -func (idx *faissIndex) cPtr() *C.FaissIndex { - return idx.idx +// Core index operations +func (idx *BinaryIndexImpl) Close() { + if idx.indexPtr != nil { + C.faiss_IndexBinary_free(idx.indexPtr) + idx.indexPtr = nil + } } -func (idx *faissIndex) Size() uint64 { - size := C.faiss_Index_size(idx.idx) - return uint64(size) +func (idx *BinaryIndexImpl) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) { + if !idx.IsIVFIndex() { + return nil, fmt.Errorf("index is not an IVF index") + } + clusterIDs := make([]int64, len(vecIDs)) + if c := C.faiss_get_lists_for_keys_binary( + idx.indexPtr, + (*C.idx_t)(unsafe.Pointer(&vecIDs[0])), + (C.size_t)(len(vecIDs)), + (*C.idx_t)(unsafe.Pointer(&clusterIDs[0])), + ); c != 0 { + return nil, getLastError() + } + rv := make(map[int64]int64, len(vecIDs)) + for _, v := range clusterIDs { + rv[v]++ + } + return rv, nil } -func (idx *faissIndex) D() int { - return int(C.faiss_Index_d(idx.idx)) +func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, centroidIDs []int64) ( + []int64, []int32, error) { + // Selector to include only the centroids whose IDs are part of 'centroidIDs'. + includeSelector, err := NewIDSelectorBatch(centroidIDs) + if err != nil { + return nil, nil, err + } + defer includeSelector.Delete() + + params, err := NewSearchParams(idx, json.RawMessage{}, includeSelector.Get(), nil) + if err != nil { + return nil, nil, err + } + defer params.Delete() + + // Populate these with the centroids and their distances. + centroidDistances := make([]int32, len(centroidIDs)) + + n := len(x) / idx.D() + + c := C.faiss_Search_closest_eligible_centroids_binary( + idx.indexPtr, + (C.idx_t)(n), + (*C.uint8_t)(&x[0]), + (C.idx_t)(len(centroidIDs)), + (*C.int32_t)(¢roidDistances[0]), + (*C.idx_t)(¢roidIDs[0]), + params.sp) + if c != 0 { + return nil, nil, getLastError() + } + + return centroidIDs, centroidDistances, nil } -func (idx *faissIndex) IsTrained() bool { - return C.faiss_Index_is_trained(idx.idx) != 0 +func (idx *IndexImpl) GetNlist() int { + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil { + return int(C.faiss_IndexIVF_nlist(ivfIdx)) + } + return 0 } -func (idx *faissIndex) Ntotal() int64 { - return int64(C.faiss_Index_ntotal(idx.idx)) +func (idx *IndexImpl) GetCentroids() ([]float32, error) { + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil { + ivfCentroids := make([]float32, idx.D()*idx.GetNlist()) + C.faiss_IndexIVF_get_centroids(ivfIdx, (*C.float)(&ivfCentroids[0])) + return ivfCentroids, nil + } + return nil, fmt.Errorf("index is not an IVF index") } -func (idx *faissIndex) MetricType() int { - return int(C.faiss_Index_metric_type(idx.idx)) +func (idx *IndexImpl) Quantizer() *C.FaissIndex { + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil { + return C.faiss_IndexIVF_quantizer(ivfIdx) + } + return nil } -func (idx *faissIndex) Train(x []float32) error { - n := len(x) / idx.D() - if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { +func (idx *BinaryIndexImpl) SetIsTrained(isTrained bool) { + if isTrained { + C.faiss_IndexBinaryIVF_set_is_trained((*C.FaissIndexBinaryIVF)(idx.cPtrBinary()), + C.int(1)) + } else { + C.faiss_IndexBinaryIVF_set_is_trained((*C.FaissIndexBinaryIVF)(idx.cPtrBinary()), + C.int(0)) + } +} + +func (idx *BinaryIndexImpl) BinaryQuantizer() BinaryIndex { + if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { + return &BinaryIndexImpl{ + indexPtr: C.faiss_IndexBinaryIVF_quantizer(bivfIdx), + d: idx.d, + metric: idx.metric, + } + } + return nil +} + +func (idx *BinaryIndexImpl) Size() uint64 { + return 0 +} + +func (idx *BinaryIndexImpl) cPtrBinary() *C.FaissIndexBinary { + return idx.indexPtr +} + +func (idx *BinaryIndexImpl) D() int { + return idx.d +} + +func (idx *BinaryIndexImpl) MetricType() int { + return idx.metric +} + +func (idx *BinaryIndexImpl) Ntotal() int64 { + return int64(C.faiss_IndexBinary_ntotal(idx.indexPtr)) +} + +func (idx *BinaryIndexImpl) IsIVFIndex() bool { + return C.faiss_IndexBinaryIVF_cast(idx.indexPtr) != nil +} + +func (idx *BinaryIndexImpl) GetNlist() int { + if ivfIdx := C.faiss_IndexBinaryIVF_cast(idx.indexPtr); ivfIdx != nil { + return int(C.faiss_IndexBinaryIVF_nlist(ivfIdx)) + } + return 0 +} + +// Binary-specific operations +func (idx *BinaryIndexImpl) TrainBinary(vectors []uint8) error { + n := (len(vectors) * 8) / idx.d + if c := C.faiss_IndexBinary_train(idx.indexPtr, C.idx_t(n), (*C.uint8_t)(&vectors[0])); c != 0 { + return getLastError() + } + return nil +} + +func (idx *BinaryIndexImpl) AddBinary(vectors []uint8) error { + n := (len(vectors) * 8) / idx.d + if c := C.faiss_IndexBinary_add(idx.indexPtr, C.idx_t(n), (*C.uint8_t)(&vectors[0])); c != 0 { + return getLastError() + } + return nil +} + +func (idx *BinaryIndexImpl) AddBinaryWithIDs(vectors []uint8, ids []int64) error { + n := (len(vectors) * 8) / idx.d + if c := C.faiss_IndexBinary_add_with_ids(idx.indexPtr, C.idx_t(n), (*C.uint8_t)(&vectors[0]), (*C.idx_t)(&ids[0])); c != 0 { + return getLastError() + } + return nil +} + +func (idx *BinaryIndexImpl) SearchBinary(x []uint8, k int64) ([]int32, []int64, error) { + nq := (len(x) * 8) / idx.d + distances := make([]int32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) + + if c := C.faiss_IndexBinary_search( + idx.indexPtr, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + return nil, nil, getLastError() + } + return distances, labels, nil +} + +func (idx *BinaryIndexImpl) SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) ([]int32, []int64, error) { + nq := (len(x) * 8) / idx.d + distances := make([]int32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) + + includeSelector, err := NewIDSelectorBatch(include) + if err != nil { + return nil, nil, err + } + defer includeSelector.Delete() + + searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_IndexBinary_search_with_params( + idx.indexPtr, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + searchParams.sp, + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + return nil, nil, getLastError() + } + return distances, labels, nil +} + +func (idx *BinaryIndexImpl) Train(vectors []uint8) error { + n := (len(vectors) * 8) / idx.d + if c := C.faiss_IndexBinary_train(idx.indexPtr, C.idx_t(n), (*C.uint8_t)(&vectors[0])); c != 0 { return getLastError() } return nil } -func (idx *faissIndex) Add(x []float32) error { +func (idx *BinaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, labels []int64, err error) { + if len(exclude) == 0 && len(params) == 0 { + return idx.SearchBinary(x, k) + } + + nq := (len(x) * 8) / idx.d + distances = make([]int32, int64(nq)*k) + labels = make([]int64, int64(nq)*k) + + var selector *C.FaissIDSelector + if len(exclude) > 0 { + excludeSelector, err := NewIDSelectorNot(exclude) + if err != nil { + return nil, nil, err + } + selector = excludeSelector.Get() + defer excludeSelector.Delete() + } + + searchParams, err := NewSearchParams(idx, params, selector, nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_IndexBinary_search_with_params( + idx.indexPtr, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + searchParams.sp, + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return distances, labels, err +} + +func (idx *BinaryIndexImpl) SearchClustersFromIVFIndex(selector Selector, + eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x []uint8, + centroidDis []int32, params json.RawMessage) ([]int32, []int64, error) { + tempParams := &defaultSearchParamsIVF{ + Nlist: len(eligibleCentroidIDs), + // Have to override nprobe so that more clusters will be searched for this + // query, if required. + Nprobe: minEligibleCentroids, + } + + searchParams, err := NewSearchParams(idx, params, selector.Get(), tempParams) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + n := (len(x) * 8) / idx.D() + + distances := make([]int32, int64(n)*k) + labels := make([]int64, int64(n)*k) + + effectiveNprobe := getNProbeFromSearchParams(searchParams) + + eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe] + centroidDis = centroidDis[:effectiveNprobe] + + if c := C.faiss_IndexBinaryIVF_search_preassigned_with_params( + idx.indexPtr, + (C.idx_t)(n), + (*C.uint8_t)(&x[0]), + (C.idx_t)(k), + (*C.idx_t)(&eligibleCentroidIDs[0]), + (*C.int32_t)(¢roidDis[0]), + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + (C.int)(0), + searchParams.sp); c != 0 { + return nil, nil, getLastError() + } + + return distances, labels, nil +} + +// Factory functions +func IndexBinaryFactory(d int, description string, metric int) (BinaryIndex, error) { + return NewBinaryIndexImpl(d, description, metric) +} + +// Ensure BinaryIndexImpl implements BinaryIndex interface +var _ BinaryIndex = (*BinaryIndexImpl)(nil) + +func (idx *IndexImpl) searchWithParams(x []float32, k int64, params *C.FaissSearchParameters) (distances []float32, labels []int64, err error) { + n := len(x) / idx.D() + distances = make([]float32, int64(n)*k) + labels = make([]int64, int64(n)*k) + + if c := C.faiss_Index_search_with_params( + idx.indexPtr, + C.idx_t(n), + (*C.float)(&x[0]), + C.idx_t(k), + params, + (*C.float)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return +} + +func (idx *IndexImpl) Size() uint64 { + return uint64(C.faiss_Index_size(idx.cPtrFloat())) +} + +func (idx *IndexImpl) Train(x []float32) error { n := len(x) / idx.D() - if c := C.faiss_Index_add(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { + if c := C.faiss_Index_train(idx.indexPtr, C.idx_t(n), (*C.float)(&x[0])); c != 0 { return getLastError() } return nil } -func (idx *faissIndex) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) { +func (idx *IndexImpl) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) { if !idx.IsIVFIndex() { return nil, fmt.Errorf("index is not an IVF index") } clusterIDs := make([]int64, len(vecIDs)) if c := C.faiss_get_lists_for_keys( - idx.idx, + idx.indexPtr, (*C.idx_t)(unsafe.Pointer(&vecIDs[0])), (C.size_t)(len(vecIDs)), (*C.idx_t)(unsafe.Pointer(&clusterIDs[0])), @@ -171,14 +528,16 @@ func (idx *faissIndex) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (ma return rv, nil } -func (idx *faissIndex) IsIVFIndex() bool { - if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx == nil { - return false +func (idx *IndexImpl) DistCompute(queryData []float32, ids []int64, k int, distances []float32) error { + if c := C.faiss_Index_dist_compute(idx.indexPtr, (*C.float)(&queryData[0]), + (*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])); c != 0 { + return getLastError() } - return true + + return nil } -func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( +func (idx *IndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( []int64, []float32, error) { // Selector to include only the centroids whose IDs are part of 'centroidIDs'. includeSelector, err := NewIDSelectorBatch(centroidIDs) @@ -200,7 +559,7 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent n := len(x) / idx.D() c := C.faiss_Search_closest_eligible_centroids( - idx.idx, + idx.indexPtr, (C.idx_t)(n), (*C.float)(&x[0]), (C.idx_t)(len(centroidIDs)), @@ -214,7 +573,7 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent return centroids, centroidDistances, nil } -func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, +func (idx *IndexImpl) SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x, centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) { @@ -241,7 +600,7 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, centroidDis = centroidDis[:effectiveNprobe] if c := C.faiss_IndexIVF_search_preassigned_with_params( - idx.idx, + idx.indexPtr, (C.idx_t)(n), (*C.float)(&x[0]), (C.idx_t)(k), @@ -257,27 +616,51 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, return distances, labels, nil } -func (idx *faissIndex) AddWithIDs(x []float32, xids []int64) error { - n := len(x) / idx.D() - if c := C.faiss_Index_add_with_ids( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - (*C.idx_t)(&xids[0]), +func (idx *IndexImpl) IsIVFIndex() bool { + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx == nil { + return false + } + return true +} + +// SearchWithIDs performs a search with ID filtering and search parameters +func (idx *IndexImpl) SearchWithIDs(queries []float32, k int64, include []int64, params json.RawMessage) ([]float32, []int64, error) { + nq := len(queries) / idx.d + distances := make([]float32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) + + includeSelector, err := NewIDSelectorBatch(include) + if err != nil { + return nil, nil, err + } + defer includeSelector.Delete() + + searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_Index_search_with_params( + idx.indexPtr, + C.idx_t(nq), + (*C.float)(&queries[0]), + C.idx_t(k), + searchParams.sp, + (*C.float)(&distances[0]), + (*C.idx_t)(&labels[0]), ); c != 0 { - return getLastError() + return nil, nil, getLastError() } - return nil + return distances, labels, nil } -func (idx *faissIndex) Search(x []float32, k int64) ( - distances []float32, labels []int64, err error, -) { +func (idx *IndexImpl) Search(x []float32, k int64) (distances []float32, labels []int64, err error) { n := len(x) / idx.D() distances = make([]float32, int64(n)*k) labels = make([]int64, int64(n)*k) if c := C.faiss_Index_search( - idx.idx, + idx.indexPtr, C.idx_t(n), (*C.float)(&x[0]), C.idx_t(k), @@ -287,16 +670,24 @@ func (idx *faissIndex) Search(x []float32, k int64) ( err = getLastError() } - return + return distances, labels, err } -func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) ( - distances []float32, labels []int64, err error, -) { +func (idx *IndexImpl) Ntotal() int64 { + return int64(C.faiss_Index_ntotal(idx.indexPtr)) +} + +// SearchWithoutIDs performs a search without ID filtering +func (idx *IndexImpl) SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) ( + []float32, []int64, error) { if params == nil && len(exclude) == 0 { return idx.Search(x, k) } + nq := len(x) / idx.d + distances := make([]float32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) + var selector *C.FaissIDSelector if len(exclude) > 0 { excludeSelector, err := NewIDSelectorNot(exclude) @@ -315,81 +706,84 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) - return + return distances, labels, err } -func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, - params json.RawMessage) (distances []float32, labels []int64, err error, -) { - includeSelector, err := NewIDSelectorBatch(include) - if err != nil { - return nil, nil, err - } - defer includeSelector.Delete() - - searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) - if err != nil { - return nil, nil, err - } - defer searchParams.Delete() +// ----------------------------------------------------------------------------- - distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) - return +// RangeSearchResult is the result of a range search. +type RangeSearchResult struct { + rsr *C.FaissRangeSearchResult } -func (idx *faissIndex) Reconstruct(key int64) (recons []float32, err error) { - rv := make([]float32, idx.D()) - if c := C.faiss_Index_reconstruct( - idx.idx, - C.idx_t(key), - (*C.float)(&rv[0]), - ); c != 0 { - err = getLastError() - } +// Nq returns the number of queries. +func (r *RangeSearchResult) Nq() int { + return int(C.faiss_RangeSearchResult_nq(r.rsr)) +} - return rv, err +// Lims returns a slice containing start and end indices for queries in the +// distances and labels slices returned by Labels. +func (r *RangeSearchResult) Lims() []int { + var lims *C.size_t + C.faiss_RangeSearchResult_lims(r.rsr, &lims) + length := r.Nq() + 1 + return (*[1 << 30]int)(unsafe.Pointer(lims))[:length:length] } -func (idx *faissIndex) ReconstructBatch(keys []int64, recons []float32) ([]float32, error) { - var err error - n := int64(len(keys)) - if c := C.faiss_Index_reconstruct_batch( - idx.idx, - C.idx_t(n), - (*C.idx_t)(&keys[0]), - (*C.float)(&recons[0]), - ); c != 0 { - err = getLastError() - } +// Labels returns the unsorted IDs and respective distances for each query. +// The result for query i is labels[lims[i]:lims[i+1]]. +func (r *RangeSearchResult) Labels() (labels []int64, distances []float32) { + lims := r.Lims() + length := lims[len(lims)-1] + var clabels *C.idx_t + var cdist *C.float + C.faiss_RangeSearchResult_labels(r.rsr, &clabels, &cdist) + labels = (*[1 << 30]int64)(unsafe.Pointer(clabels))[:length:length] + distances = (*[1 << 30]float32)(unsafe.Pointer(cdist))[:length:length] + return +} - return recons, err +// Delete frees the memory associated with r. +func (r *RangeSearchResult) Delete() { + C.faiss_RangeSearchResult_free(r.rsr) } -func (i *IndexImpl) MergeFrom(other Index, add_id int64) error { - if impl, ok := other.(*IndexImpl); ok { - return i.Index.MergeFrom(impl.Index, add_id) +// IndexFactory creates a new index using the factory function +func IndexFactory(d int, description string, metric int) (FloatIndex, error) { + var cDescription *C.char + if description != "" { + cDescription = C.CString(description) + defer C.free(unsafe.Pointer(cDescription)) } - return fmt.Errorf("merge not support") -} -func (idx *faissIndex) MergeFrom(other Index, add_id int64) (err error) { - otherIdx, ok := other.(*faissIndex) - if !ok { - return fmt.Errorf("merge api not supported") + var idx *C.FaissIndex + if c := C.faiss_index_factory(&idx, C.int(d), cDescription, C.FaissMetricType(metric)); c != 0 { + return nil, getLastError() } - if c := C.faiss_Index_merge_from( - idx.idx, - otherIdx.idx, - (C.idx_t)(add_id), - ); c != 0 { - err = getLastError() + return &IndexImpl{ + indexPtr: idx, + d: d, + metric: metric, + }, nil +} + +func (idx *IndexImpl) Close() { + if idx.indexPtr != nil { + C.faiss_Index_free(idx.indexPtr) + idx.indexPtr = nil } +} + +func (idx *IndexImpl) D() int { + return idx.d +} - return err +func (idx *IndexImpl) MetricType() int { + return idx.metric } -func (idx *faissIndex) RangeSearch(x []float32, radius float32) ( +func (idx *IndexImpl) RangeSearch(x []float32, radius float32) ( *RangeSearchResult, error, ) { n := len(x) / idx.D() @@ -398,7 +792,7 @@ func (idx *faissIndex) RangeSearch(x []float32, radius float32) ( return nil, getLastError() } if c := C.faiss_Index_range_search( - idx.idx, + idx.indexPtr, C.idx_t(n), (*C.float)(&x[0]), C.float(radius), @@ -409,102 +803,75 @@ func (idx *faissIndex) RangeSearch(x []float32, radius float32) ( return &RangeSearchResult{rsr}, nil } -func (idx *faissIndex) Reset() error { - if c := C.faiss_Index_reset(idx.idx); c != 0 { +func (idx *IndexImpl) Reset() error { + if c := C.faiss_Index_reset(idx.indexPtr); c != 0 { return getLastError() } return nil } -func (idx *faissIndex) RemoveIDs(sel *IDSelector) (int, error) { +func (idx *IndexImpl) RemoveIDs(sel *IDSelector) (int, error) { var nRemoved C.size_t - if c := C.faiss_Index_remove_ids(idx.idx, sel.sel, &nRemoved); c != 0 { + if c := C.faiss_Index_remove_ids(idx.indexPtr, sel.sel, &nRemoved); c != 0 { return 0, getLastError() } return int(nRemoved), nil } -func (idx *faissIndex) Close() { - C.faiss_Index_free(idx.idx) -} - -func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) ( - distances []float32, labels []int64, err error, -) { - n := len(x) / idx.D() - distances = make([]float32, int64(n)*k) - labels = make([]int64, int64(n)*k) - - if c := C.faiss_Index_search_with_params( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - C.idx_t(k), - searchParams, - (*C.float)(&distances[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() +func (idx *IndexImpl) MergeFrom(other IndexImpl, add_id int64) error { + if c := C.faiss_Index_merge_from(idx.indexPtr, other.cPtrFloat(), C.idx_t(add_id)); c != 0 { + return getLastError() } - - return -} - -// ----------------------------------------------------------------------------- - -// RangeSearchResult is the result of a range search. -type RangeSearchResult struct { - rsr *C.FaissRangeSearchResult + return nil } -// Nq returns the number of queries. -func (r *RangeSearchResult) Nq() int { - return int(C.faiss_RangeSearchResult_nq(r.rsr)) +// Float-specific operations +func (idx *IndexImpl) Add(vectors []float32) error { + n := len(vectors) / idx.d + if c := C.faiss_Index_add(idx.indexPtr, C.idx_t(n), (*C.float)(&vectors[0])); c != 0 { + return getLastError() + } + return nil } -// Lims returns a slice containing start and end indices for queries in the -// distances and labels slices returned by Labels. -func (r *RangeSearchResult) Lims() []int { - var lims *C.size_t - C.faiss_RangeSearchResult_lims(r.rsr, &lims) - length := r.Nq() + 1 - return (*[1 << 30]int)(unsafe.Pointer(lims))[:length:length] +func (idx *IndexImpl) cPtrFloat() *C.FaissIndex { + return idx.indexPtr } -// Labels returns the unsorted IDs and respective distances for each query. -// The result for query i is labels[lims[i]:lims[i+1]]. -func (r *RangeSearchResult) Labels() (labels []int64, distances []float32) { - lims := r.Lims() - length := lims[len(lims)-1] - var clabels *C.idx_t - var cdist *C.float - C.faiss_RangeSearchResult_labels(r.rsr, &clabels, &cdist) - labels = (*[1 << 30]int64)(unsafe.Pointer(clabels))[:length:length] - distances = (*[1 << 30]float32)(unsafe.Pointer(cdist))[:length:length] - return +func (idx *IndexImpl) AddWithIDs(vectors []float32, xids []int64) error { + n := len(vectors) / idx.d + if c := C.faiss_Index_add_with_ids(idx.indexPtr, C.idx_t(n), (*C.float)(&vectors[0]), (*C.idx_t)(&xids[0])); c != 0 { + return getLastError() + } + return nil } -// Delete frees the memory associated with r. -func (r *RangeSearchResult) Delete() { - C.faiss_RangeSearchResult_free(r.rsr) -} +func (idx *IndexImpl) Reconstruct(key int64) (recons []float32, err error) { + rv := make([]float32, idx.D()) + if c := C.faiss_Index_reconstruct( + idx.indexPtr, + C.idx_t(key), + (*C.float)(&rv[0]), + ); c != 0 { + err = getLastError() + } -// IndexImpl is an abstract structure for an index. -type IndexImpl struct { - Index + return rv, err } -// IndexFactory builds a composite index. -// description is a comma-separated list of components. -func IndexFactory(d int, description string, metric int) (*IndexImpl, error) { - cdesc := C.CString(description) - defer C.free(unsafe.Pointer(cdesc)) - var idx faissIndex - c := C.faiss_index_factory(&idx.idx, C.int(d), cdesc, C.FaissMetricType(metric)) - if c != 0 { - return nil, getLastError() +func (idx *IndexImpl) ReconstructBatch(keys []int64, recons []float32) ([]float32, error) { + var err error + n := int64(len(keys)) + if c := C.faiss_Index_reconstruct_batch( + idx.indexPtr, + C.idx_t(n), + (*C.idx_t)(&keys[0]), + (*C.float)(&recons[0]), + ); c != 0 { + err = getLastError() } - return &IndexImpl{&idx}, nil + + return recons, err } func SetOMPThreads(n uint) { diff --git a/index_flat.go b/index_flat.go index b8a3c03..27bf0b5 100644 --- a/index_flat.go +++ b/index_flat.go @@ -10,20 +10,20 @@ import "unsafe" // IndexFlat is an index that stores the full vectors and performs exhaustive // search. type IndexFlat struct { - Index + *IndexImpl } // NewIndexFlat creates a new flat index. func NewIndexFlat(d int, metric int) (*IndexFlat, error) { - var idx faissIndex + var idx *C.FaissIndex if c := C.faiss_IndexFlat_new_with( - &idx.idx, + &idx, C.idx_t(d), C.FaissMetricType(metric), ); c != 0 { return nil, getLastError() } - return &IndexFlat{&idx}, nil + return &IndexFlat{&IndexImpl{indexPtr: idx, d: d, metric: metric}}, nil } // NewIndexFlatIP creates a new flat index with the inner product metric type. @@ -41,16 +41,16 @@ func NewIndexFlatL2(d int) (*IndexFlat, error) { func (idx *IndexFlat) Xb() []float32 { var size C.size_t var ptr *C.float - C.faiss_IndexFlat_xb(idx.cPtr(), &ptr, &size) + C.faiss_IndexFlat_xb(idx.cPtrFloat(), &ptr, &size) return (*[1 << 30]float32)(unsafe.Pointer(ptr))[:size:size] } // AsFlat casts idx to a flat index. // AsFlat panics if idx is not a flat index. func (idx *IndexImpl) AsFlat() *IndexFlat { - ptr := C.faiss_IndexFlat_cast(idx.cPtr()) + ptr := C.faiss_IndexFlat_cast(idx.cPtrFloat()) if ptr == nil { panic("index is not a flat index") } - return &IndexFlat{&faissIndex{ptr}} + return &IndexFlat{&IndexImpl{indexPtr: ptr, d: idx.d, metric: idx.metric}} } diff --git a/index_io.go b/index_io.go index 608f4d7..3ac50c3 100644 --- a/index_io.go +++ b/index_io.go @@ -5,29 +5,147 @@ package faiss #include #include #include +#include +#include */ import "C" import ( + "fmt" "unsafe" ) -// WriteIndex writes an index to a file. -func WriteIndex(idx Index, filename string) error { +const ( + IOFlagMmap = C.FAISS_IO_FLAG_MMAP + IOFlagReadOnly = C.FAISS_IO_FLAG_READ_ONLY + IOFlagReadMmap = C.FAISS_IO_FLAG_READ_MMAP | C.FAISS_IO_FLAG_ONDISK_IVF + IOFlagSkipPrefetch = C.FAISS_IO_FLAG_SKIP_PREFETCH +) + +// WriteIndex writes a float index to a file +func WriteIndex(idx FloatIndex, filename string) error { + impl, ok := idx.(*IndexImpl) + if !ok { + return fmt.Errorf("invalid index type for float index serialization") + } + cfname := C.CString(filename) defer C.free(unsafe.Pointer(cfname)) - if c := C.faiss_write_index_fname(idx.cPtr(), cfname); c != 0 { + if c := C.faiss_write_index_fname(impl.cPtrFloat(), cfname); c != 0 { return getLastError() } return nil } -func WriteIndexIntoBuffer(idx Index) ([]byte, error) { +// WriteBinaryIndex writes a binary index to a file +func WriteBinaryIndex(idx BinaryIndex, filename string) error { + impl, ok := idx.(*BinaryIndexImpl) + if !ok { + return fmt.Errorf("invalid index type for binary index serialization") + } + + cfname := C.CString(filename) + defer C.free(unsafe.Pointer(cfname)) + if c := C.faiss_write_index_binary_fname(impl.cPtrBinary(), cfname); c != 0 { + return getLastError() + } + return nil +} + +// ReadIndex reads a float index from a file +func ReadIndex(filename string, ioflags int) (FloatIndex, error) { + cfname := C.CString(filename) + defer C.free(unsafe.Pointer(cfname)) + var idx *C.FaissIndex + if c := C.faiss_read_index_fname(cfname, C.int(ioflags), &idx); c != 0 { + return nil, getLastError() + } + return &IndexImpl{ + indexPtr: idx, + d: int(C.faiss_Index_d(idx)), + metric: int(C.faiss_Index_metric_type(idx)), + }, nil +} + +// ReadBinaryIndex reads a binary index from a file +func ReadBinaryIndex(filename string, ioflags int) (BinaryIndex, error) { + cfname := C.CString(filename) + defer C.free(unsafe.Pointer(cfname)) + var idx *C.FaissIndexBinary + if c := C.faiss_read_index_binary_fname(cfname, C.int(ioflags), &idx); c != 0 { + return nil, getLastError() + } + return &BinaryIndexImpl{ + indexPtr: idx, + d: int(C.faiss_IndexBinary_d(idx)), + metric: int(C.faiss_IndexBinary_metric_type(idx)), + }, nil +} + +func WriteIndexIntoBuffer(idx FloatIndex) ([]byte, error) { // the values to be returned by the faiss APIs tempBuf := (*C.uchar)(nil) bufSize := C.size_t(0) if c := C.faiss_write_index_buf( - idx.cPtr(), + idx.cPtrFloat(), + &bufSize, + &tempBuf, + ); c != 0 { + C.faiss_free_buf(&tempBuf) + return nil, getLastError() + } + + // at this point, the idx has a valid ref count. furthermore, the index is + // something that's present on the C memory space, so not available to go's + // GC. needs to be freed when its of no more use. + + // todo: add checksum. + // the content populated in the tempBuf is converted from *C.uchar to unsafe.Pointer + // and then the pointer is casted into a large byte slice which is then sliced + // to a length and capacity equal to bufSize returned across the cgo interface. + // NOTE: it still points to the C memory though + // the bufSize is of type size_t which is equivalent to a uint in golang, so + // the conversion is safe. + val := unsafe.Slice((*byte)(unsafe.Pointer(tempBuf)), uint(bufSize)) + + // NOTE: This method is compatible with 64-bit systems but may encounter issues on 32-bit systems. + // leading to vector indexing being supported only for 64-bit systems. + // This limitation arises because the maximum allowed length of a slice on 32-bit systems + // is math.MaxInt32 (2^31-1), whereas the maximum value of a size_t in C++ is math.MaxUInt32 + // (4^31-1), exceeding the maximum allowed size of a slice in Go. + // Consequently, the bufSize returned by faiss_write_index_buf might exceed the + // maximum allowed size of a slice in Go, leading to a panic when attempting to + // create the following slice rv. + rv := make([]byte, uint(bufSize)) + // an explicit copy is necessary to free the memory on C heap and then return + // the rv back to the caller which is definitely on goruntime space (which will + // GC'd later on). + // + // an optimization over here - create buffer pool which can be used to make the + // memory allocations cheaper. specifically two separate pools can be utilized, + // one for C pointers and another for goruntime. within the faiss_write_index_buf + // a cheaper calloc rather than malloc can be used to make any extra allocations + // cheaper. + copy(rv, val) + + // safe to free the c memory allocated (tempBuf) while serializing the index (must be done + // within C runtime for it was allocated there); + // rv is from go runtime - so different address space altogether + C.faiss_free_buf(&tempBuf) + + // p.s: no need to free "val" since the underlying memory is same as tempBuf (deferred free) + val = nil + + return rv, nil +} + +func WriteBinaryIndexIntoBuffer(idx BinaryIndex) ([]byte, error) { + // the values to be returned by the faiss APIs + tempBuf := (*C.uchar)(nil) + bufSize := C.size_t(0) + + if c := C.faiss_write_index_binary_buf( + idx.cPtrBinary(), &bufSize, &tempBuf, ); c != 0 { @@ -79,42 +197,48 @@ func WriteIndexIntoBuffer(idx Index) ([]byte, error) { return rv, nil } -func ReadIndexFromBuffer(buf []byte, ioflags int) (*IndexImpl, error) { +// ReadIndexFromBuffer deserializes a float index from a byte buffer +func ReadIndexFromBuffer(buf []byte, ioFlags int) (FloatIndex, error) { ptr := (*C.uchar)(unsafe.Pointer(&buf[0])) size := C.size_t(len(buf)) // the idx var has C.FaissIndex within the struct which is nil as of now. - var idx faissIndex + var idx *C.FaissIndex if c := C.faiss_read_index_buf(ptr, size, - C.int(ioflags), - &idx.idx); c != 0 { + C.int(ioFlags), + &idx); c != 0 { return nil, getLastError() } ptr = nil - // after exiting the faiss_read_index_buf, the ref count to the memory allocated - // for the freshly created faiss::index becomes 1 (held by idx.idx of type C.FaissIndex) - // this is allocated on the C heap, so not available for golang's GC. hence needs - // to be cleaned up after the index is longer being used - to be done at zap layer. - return &IndexImpl{&idx}, nil + return &IndexImpl{ + indexPtr: idx, + d: int(C.faiss_Index_d(idx)), + metric: int(C.faiss_Index_metric_type(idx)), + }, nil } -const ( - IOFlagMmap = C.FAISS_IO_FLAG_MMAP - IOFlagReadOnly = C.FAISS_IO_FLAG_READ_ONLY - IOFlagReadMmap = C.FAISS_IO_FLAG_READ_MMAP | C.FAISS_IO_FLAG_ONDISK_IVF - IOFlagSkipPrefetch = C.FAISS_IO_FLAG_SKIP_PREFETCH -) +// ReadBinaryIndexFromBuffer deserializes a binary index from a byte buffer +func ReadBinaryIndexFromBuffer(buf []byte, ioFlags int) (BinaryIndex, error) { + ptr := (*C.uchar)(unsafe.Pointer(&buf[0])) + size := C.size_t(len(buf)) -// ReadIndex reads an index from a file. -func ReadIndex(filename string, ioflags int) (*IndexImpl, error) { - cfname := C.CString(filename) - defer C.free(unsafe.Pointer(cfname)) - var idx faissIndex - if c := C.faiss_read_index_fname(cfname, C.int(ioflags), &idx.idx); c != 0 { + // the idx var has C.FaissIndex within the struct which is nil as of now. + var idxBinary *C.FaissIndexBinary + if c := C.faiss_read_index_binary_buf(ptr, + size, + C.int(ioFlags), + &idxBinary); c != 0 { return nil, getLastError() } - return &IndexImpl{&idx}, nil + + ptr = nil + + return &BinaryIndexImpl{ + indexPtr: idxBinary, + d: int(C.faiss_IndexBinary_d(idxBinary)), + metric: int(C.faiss_IndexBinary_metric_type(idxBinary)), + }, nil } diff --git a/index_ivf.go b/index_ivf.go index 38f023a..4ecea58 100644 --- a/index_ivf.go +++ b/index_ivf.go @@ -5,6 +5,8 @@ package faiss #include #include #include +#include +#include #include */ import "C" @@ -12,24 +14,64 @@ import ( "fmt" ) -func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { +// IndexIVF represents an IVF index +type IndexIVF struct { + *IndexImpl +} - ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) +func (idx *IndexImpl) GetNProbe() int32 { + ivfPtr := C.faiss_IndexIVF_cast(idx.cPtrFloat()) if ivfPtr == nil { - return fmt.Errorf("index is not of ivf type") + return 0 } - if c := C.faiss_IndexIVF_set_direct_map( - ivfPtr, - C.int(mapType), - ); c != 0 { - err = getLastError() + return int32(C.faiss_IndexIVF_nprobe(ivfPtr)) +} + +func (idx *BinaryIndexImpl) GetNProbe() int32 { + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + if ivfPtrBinary == nil { + return 0 } - return err + return int32(C.faiss_IndexBinaryIVF_nprobe(ivfPtrBinary)) } -func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { +func (idx *BinaryIndexImpl) SetDirectMap(mapType int) (err error) { + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + // If we have a binary IVF index + if ivfPtrBinary != nil { + if c := C.faiss_IndexBinaryIVF_set_direct_map( + ivfPtrBinary, + C.int(mapType), + ); c != 0 { + err = getLastError() + } + return err + } + + return fmt.Errorf("unable to set direct map") +} + +func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { + // Try to get either regular or binary IVF pointer + ivfPtr := C.faiss_IndexIVF_cast(idx.cPtrFloat()) + + // If we have a regular IVF index + if ivfPtr != nil { + if c := C.faiss_IndexIVF_set_direct_map( + ivfPtr, + C.int(mapType), + ); c != 0 { + err = getLastError() + } + return err + } + + // Get index type for better error message + return fmt.Errorf("unable to set direct map") +} - ptr := C.faiss_IndexIDMap2_cast(idx.cPtr()) +func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { + ptr := C.faiss_IndexIDMap2_cast(idx.indexPtr) if ptr == nil { return nil, fmt.Errorf("index is not a id map") } @@ -39,23 +81,23 @@ func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { return nil, fmt.Errorf("couldn't retrieve the sub index") } - return &IndexImpl{&faissIndex{subIdx}}, nil + return &IndexImpl{indexPtr: subIdx}, nil } -// pass nprobe to be set as index time option for IVF indexes only. -// varying nprobe impacts recall but with an increase in latency. -func (idx *IndexImpl) SetNProbe(nprobe int32) { - ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) - if ivfPtr == nil { +func (idx *BinaryIndexImpl) SetNProbe(nprobe int32) { + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + if ivfPtrBinary == nil { return } - C.faiss_IndexIVF_set_nprobe(ivfPtr, C.size_t(nprobe)) + C.faiss_IndexBinaryIVF_set_nprobe(ivfPtrBinary, C.size_t(nprobe)) } -func (idx *IndexImpl) GetNProbe() int32 { - ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) +// pass nprobe to be set as index time option for IVF/BIVF indexes only. +// varying nprobe impacts recall but with an increase in latency. +func (idx *IndexImpl) SetNProbe(nprobe int32) { + ivfPtr := C.faiss_IndexIVF_cast(idx.cPtrFloat()) if ivfPtr == nil { - return 0 + return } - return int32(C.faiss_IndexIVF_nprobe(ivfPtr)) + C.faiss_IndexIVF_set_nprobe(ivfPtr, C.size_t(nprobe)) } diff --git a/search_params.go b/search_params.go index 6086073..bb79888 100644 --- a/search_params.go +++ b/search_params.go @@ -3,6 +3,8 @@ package faiss /* #include #include +#include +#include #include */ import "C" @@ -11,16 +13,18 @@ import ( "fmt" ) +// SearchParams represents search parameters for both float and binary indexes type SearchParams struct { - sp *C.FaissSearchParameters + sp *C.FaissSearchParameters + idx Index } -// Delete frees the memory associated with s. -func (s *SearchParams) Delete() { - if s == nil || s.sp == nil { - return +// Delete frees the search parameters +func (sp *SearchParams) Delete() { + if sp.sp != nil { + C.faiss_SearchParameters_free(sp.sp) + sp.sp = nil } - C.faiss_SearchParameters_free(s.sp) } type searchParamsIVF struct { @@ -63,51 +67,75 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, if c := C.faiss_SearchParameters_new(&rv.sp, sel); c != 0 { return nil, fmt.Errorf("failed to create faiss search params") } - // check if the index is IVF and set the search params - if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx != nil { - rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp) - if len(params) == 0 && sel == nil { - return rv, nil + + if len(params) == 0 && sel == nil { + return rv, nil + } + + if !idx.IsIVFIndex() { + c := C.faiss_SearchParameters_new_with_selector(&rv.sp, sel) + if c != 0 { + rv.Delete() + return nil, fmt.Errorf("failed to create faiss search params") } - var nlist, nprobe, nvecs, maxCodes int - nlist = int(C.faiss_IndexIVF_nlist(ivfIdx)) - nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx)) - nvecs = int(C.faiss_Index_ntotal(idx.cPtr())) - if defaultParams != nil { - if defaultParams.Nlist > 0 { - nlist = defaultParams.Nlist - } - if defaultParams.Nprobe > 0 { - nprobe = defaultParams.Nprobe - } + return rv, nil + } + + var nlist, nprobe, nvecs, maxCodes int + var ivfParams searchParamsIVF + + rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp) + + switch idx.(type) { + case FloatIndex: + ivfIdx := idx.(*IndexImpl) + nlist = int(C.faiss_IndexIVF_nlist(ivfIdx.cPtrFloat())) + nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx.cPtrFloat())) + nvecs = int(C.faiss_Index_ntotal(ivfIdx.cPtrFloat())) + case BinaryIndex: + ivfIdx := idx.(*BinaryIndexImpl) + nlist = int(C.faiss_IndexBinaryIVF_nlist(ivfIdx.cPtrBinary())) + nprobe = int(C.faiss_IndexBinaryIVF_nprobe(ivfIdx.cPtrBinary())) + nvecs = int(C.faiss_IndexBinary_ntotal(ivfIdx.cPtrBinary())) + default: + return nil, fmt.Errorf("unsupported index type") + } + + if defaultParams != nil { + if defaultParams.Nlist > 0 { + nlist = defaultParams.Nlist } - var ivfParams searchParamsIVF - if len(params) > 0 { - if err := json.Unmarshal(params, &ivfParams); err != nil { - rv.Delete() - return nil, fmt.Errorf("failed to unmarshal IVF search params, "+ - "err:%v", err) - } - if err := ivfParams.Validate(); err != nil { - rv.Delete() - return nil, err - } + if defaultParams.Nprobe > 0 { + nprobe = defaultParams.Nprobe } - if ivfParams.NprobePct > 0 { - nprobe = max(int(float32(nlist)*(ivfParams.NprobePct/100)), 1) + } + + if len(params) > 0 { + if err := json.Unmarshal(params, &ivfParams); err != nil { + rv.Delete() + return nil, fmt.Errorf("failed to unmarshal IVF search params, "+ + "err:%v", err) } - if ivfParams.MaxCodesPct > 0 { - maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100)) - } // else, maxCodes will be set to the default value of 0, which means no limit - if c := C.faiss_SearchParametersIVF_new_with( - &rv.sp, - sel, - C.size_t(nprobe), - C.size_t(maxCodes), - ); c != 0 { + if err := ivfParams.Validate(); err != nil { rv.Delete() - return nil, fmt.Errorf("failed to create faiss IVF search params") + return nil, err } } + if ivfParams.NprobePct > 0 { + nprobe = max(int(float32(nlist)*(ivfParams.NprobePct/100)), 1) + } + if ivfParams.MaxCodesPct > 0 { + maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100)) + } // else, maxCodes will be set to the default value of 0, which means no limit + + if c := C.faiss_SearchParametersIVF_new_with( + &rv.sp, + sel, + C.size_t(nprobe), + C.size_t(maxCodes), + ); c != 0 { + rv.Delete() + return nil, fmt.Errorf("failed to create faiss IVF search params") + } return rv, nil }