From f7c89085295916a9cc7179fa367f139363d07026 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Fri, 9 May 2025 17:53:24 +0530 Subject: [PATCH 01/13] support for binary indexes - wip --- go.mod | 4 +- index.go | 195 ++++++++++++++++++++++++++++++++++++++++++++------ index_flat.go | 2 +- index_io.go | 80 +++++++++++++++++++++ index_ivf.go | 68 +++++++++++++----- 5 files changed, 307 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index afe720a..f2ef1aa 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/blevesearch/go-faiss -go 1.21 +go 1.22 + +toolchain go1.23.0 diff --git a/index.go b/index.go index 18177fc..f0a7999 100644 --- a/index.go +++ b/index.go @@ -2,13 +2,16 @@ package faiss /* #include +#include #include #include +#include #include #include #include #include #include +#include */ import "C" import ( @@ -36,13 +39,13 @@ type Index interface { MetricType() int // Train trains the index on a representative set of vectors. - Train(x []float32) error + Train(x interface{}) error // Add adds vectors to the index. - Add(x []float32) error + Add(x interface{}) error // AddWithIDs is like Add, but stores xids instead of sequential IDs. - AddWithIDs(x []float32, xids []int64) error + AddWithIDs(x interface{}, xids []int64) error // Returns true if the index is an IVF index. IsIVFIndex() bool @@ -75,6 +78,12 @@ type Index interface { SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32, labels []int64, err error) + SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32, + labels []int64, err error) + + SearchBinary(x []uint8, k int64) (distances []int32, + labels []int64, err error) + // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x, centroidDis []float32, @@ -104,23 +113,40 @@ type Index interface { Size() uint64 cPtr() *C.FaissIndex + + cPtrBinary() *C.FaissIndexBinary + + IVFDistCompute(queryData []float32, ids []int64, k int, distances []float32) } type faissIndex struct { - idx *C.FaissIndex + idx *C.FaissIndex + idxBinary *C.FaissIndexBinary } func (idx *faissIndex) cPtr() *C.FaissIndex { return idx.idx } +func (idx *faissIndex) IVFDistCompute(queryData []float32, ids []int64, k int, distances []float32) { + C.faiss_IndexIVF_dist_compute(idx.idx, (*C.float)(&queryData[0]), + (*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])) +} + +func (idx *faissIndex) cPtrBinary() *C.FaissIndexBinary { + return idx.idxBinary +} + func (idx *faissIndex) Size() uint64 { size := C.faiss_Index_size(idx.idx) return uint64(size) } func (idx *faissIndex) D() int { - return int(C.faiss_Index_d(idx.idx)) + if idx.idx != nil { + return int(C.faiss_Index_d(idx.idx)) + } + return int(C.faiss_IndexBinary_d(idx.idxBinary)) } func (idx *faissIndex) IsTrained() bool { @@ -128,6 +154,9 @@ func (idx *faissIndex) IsTrained() bool { } func (idx *faissIndex) Ntotal() int64 { + if idx.idxBinary != nil { + return int64(C.faiss_IndexBinary_ntotal(idx.idxBinary)) + } return int64(C.faiss_Index_ntotal(idx.idx)) } @@ -135,19 +164,50 @@ func (idx *faissIndex) MetricType() int { return int(C.faiss_Index_metric_type(idx.idx)) } -func (idx *faissIndex) Train(x []float32) error { - n := len(x) / idx.D() - if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { - return getLastError() +func (idx *faissIndex) Train(x interface{}) error { + floatVec, ok := x.([]float32) + if ok { + n := len(floatVec) / idx.D() + if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&floatVec[0])); c != 0 { + return getLastError() + } + } else { + c, ok := x.([]uint8) + if ok { + n := (len(c) * 8) / idx.D() + if c := C.faiss_IndexBinary_train(idx.idxBinary, C.idx_t(n), (*C.uint8_t)(&c[0])); c != 0 { + return getLastError() + } + } } return nil } -func (idx *faissIndex) Add(x []float32) error { - n := len(x) / idx.D() - if c := C.faiss_Index_add(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { - return getLastError() +func (idx *faissIndex) Add(x interface{}) error { + floatVec, ok := x.([]float32) + if ok { + n := len(floatVec) / idx.D() + if c := C.faiss_Index_add( + idx.idx, + C.idx_t(n), + (*C.float)(&floatVec[0]), + ); c != 0 { + return getLastError() + } + } else { + c, ok := x.([]uint8) + if ok { + n := (len(c) * 8) / idx.D() + if c := C.faiss_IndexBinary_add( + idx.idxBinary, + C.idx_t(n), + (*C.uint8_t)(&c[0]), + ); c != 0 { + return getLastError() + } + } } + return nil } @@ -257,16 +317,50 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, return distances, labels, nil } -func (idx *faissIndex) AddWithIDs(x []float32, xids []int64) error { - n := len(x) / idx.D() - if c := C.faiss_Index_add_with_ids( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - (*C.idx_t)(&xids[0]), - ); c != 0 { - return getLastError() +func packBits(bits []uint8) []uint8 { + n := (len(bits) + 7) / 8 + result := make([]uint8, n) + for i := 0; i < len(bits); i++ { + // Determine the index in the result slice + byteIndex := i / 8 + // Determine the bit position in the byte + bitPosition := uint(7 - (i % 8)) + // If the bit is 1, set the corresponding bit in the uint8 value + if bits[i] == 1 { + result[byteIndex] |= (1 << bitPosition) + } + } + + return result +} + +func (idx *faissIndex) AddWithIDs(x interface{}, xids []int64) error { + floatVec, ok := x.([]float32) + if ok { + n := len(floatVec) / idx.D() + if c := C.faiss_Index_add_with_ids( + idx.idx, + C.idx_t(n), + (*C.float)(&floatVec[0]), + (*C.idx_t)(&xids[0]), + ); c != 0 { + return getLastError() + } + } else { + c, ok := x.([]uint8) + if ok { + n := (len(c) * 8) / idx.D() + if c := C.faiss_IndexBinary_add_with_ids( + idx.idxBinary, + C.idx_t(n), + (*C.uint8_t)(&c[0]), + (*C.idx_t)(&xids[0]), + ); c != 0 { + return getLastError() + } + } } + return nil } @@ -318,6 +412,51 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p return } +func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, + params json.RawMessage) (distances []int32, labels []int64, err error, +) { + d := idx.D() + nq := (len(x) * 8) / d + + distances = make([]int32, int64(nq)*k) + labels = make([]int64, int64(nq)*k) + + if c := C.faiss_IndexBinary_search( + idx.idxBinary, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return distances, labels, nil +} + +func (idx *faissIndex) SearchBinary(x []uint8, k int64) (distances []int32, labels []int64, err error, +) { + d := idx.D() + nq := (len(x) * 8) / d + + distances = make([]int32, int64(nq)*k) + labels = make([]int64, int64(nq)*k) + + if c := C.faiss_IndexBinary_search( + idx.idxBinary, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return distances, labels, nil +} + func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32, labels []int64, err error, ) { @@ -426,6 +565,7 @@ func (idx *faissIndex) RemoveIDs(sel *IDSelector) (int, error) { func (idx *faissIndex) Close() { C.faiss_Index_free(idx.idx) + C.faiss_IndexBinary_free(idx.idxBinary) } func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) ( @@ -507,6 +647,17 @@ func IndexFactory(d int, description string, metric int) (*IndexImpl, error) { return &IndexImpl{&idx}, nil } +func IndexBinaryFactory(d int, description string, metric int) (*IndexImpl, error) { + cdesc := C.CString(description) + defer C.free(unsafe.Pointer(cdesc)) + var idx faissIndex + c := C.faiss_index_binary_factory(&idx.idxBinary, C.int(d), cdesc) + if c != 0 { + return nil, getLastError() + } + return &IndexImpl{&idx}, nil +} + func SetOMPThreads(n uint) { C.faiss_set_omp_threads(C.uint(n)) } diff --git a/index_flat.go b/index_flat.go index b8a3c03..a97d6f8 100644 --- a/index_flat.go +++ b/index_flat.go @@ -52,5 +52,5 @@ func (idx *IndexImpl) AsFlat() *IndexFlat { if ptr == nil { panic("index is not a flat index") } - return &IndexFlat{&faissIndex{ptr}} + return &IndexFlat{&faissIndex{idx: ptr}} } diff --git a/index_io.go b/index_io.go index 608f4d7..88425c1 100644 --- a/index_io.go +++ b/index_io.go @@ -79,6 +79,64 @@ func WriteIndexIntoBuffer(idx Index) ([]byte, error) { return rv, nil } +func WriteBinaryIndexIntoBuffer(idx Index) ([]byte, error) { + // the values to be returned by the faiss APIs + tempBuf := (*C.uchar)(nil) + bufSize := C.size_t(0) + + if c := C.faiss_write_index_binary_buf( + idx.cPtrBinary(), + &bufSize, + &tempBuf, + ); c != 0 { + C.faiss_free_buf(&tempBuf) + return nil, getLastError() + } + + // at this point, the idx has a valid ref count. furthermore, the index is + // something that's present on the C memory space, so not available to go's + // GC. needs to be freed when its of no more use. + + // todo: add checksum. + // the content populated in the tempBuf is converted from *C.uchar to unsafe.Pointer + // and then the pointer is casted into a large byte slice which is then sliced + // to a length and capacity equal to bufSize returned across the cgo interface. + // NOTE: it still points to the C memory though + // the bufSize is of type size_t which is equivalent to a uint in golang, so + // the conversion is safe. + val := unsafe.Slice((*byte)(unsafe.Pointer(tempBuf)), uint(bufSize)) + + // NOTE: This method is compatible with 64-bit systems but may encounter issues on 32-bit systems. + // leading to vector indexing being supported only for 64-bit systems. + // This limitation arises because the maximum allowed length of a slice on 32-bit systems + // is math.MaxInt32 (2^31-1), whereas the maximum value of a size_t in C++ is math.MaxUInt32 + // (4^31-1), exceeding the maximum allowed size of a slice in Go. + // Consequently, the bufSize returned by faiss_write_index_buf might exceed the + // maximum allowed size of a slice in Go, leading to a panic when attempting to + // create the following slice rv. + rv := make([]byte, uint(bufSize)) + // an explicit copy is necessary to free the memory on C heap and then return + // the rv back to the caller which is definitely on goruntime space (which will + // GC'd later on). + // + // an optimization over here - create buffer pool which can be used to make the + // memory allocations cheaper. specifically two separate pools can be utilized, + // one for C pointers and another for goruntime. within the faiss_write_index_buf + // a cheaper calloc rather than malloc can be used to make any extra allocations + // cheaper. + copy(rv, val) + + // safe to free the c memory allocated (tempBuf) while serializing the index (must be done + // within C runtime for it was allocated there); + // rv is from go runtime - so different address space altogether + C.faiss_free_buf(&tempBuf) + + // p.s: no need to free "val" since the underlying memory is same as tempBuf (deferred free) + val = nil + + return rv, nil +} + func ReadIndexFromBuffer(buf []byte, ioflags int) (*IndexImpl, error) { ptr := (*C.uchar)(unsafe.Pointer(&buf[0])) size := C.size_t(len(buf)) @@ -101,6 +159,28 @@ func ReadIndexFromBuffer(buf []byte, ioflags int) (*IndexImpl, error) { return &IndexImpl{&idx}, nil } +func ReadBinaryIndexFromBuffer(buf []byte, ioflags int) (*IndexImpl, error) { + ptr := (*C.uchar)(unsafe.Pointer(&buf[0])) + size := C.size_t(len(buf)) + + // the idx var has C.FaissIndex within the struct which is nil as of now. + var idxBinary faissIndex + if c := C.faiss_read_index_binary_buf(ptr, + size, + C.int(ioflags), + &idxBinary.idxBinary); c != 0 { + return nil, getLastError() + } + + ptr = nil + + // after exiting the faiss_read_index_buf, the ref count to the memory allocated + // for the freshly created faiss::index becomes 1 (held by idx.idx of type C.FaissIndex) + // this is allocated on the C heap, so not available for golang's GC. hence needs + // to be cleaned up after the index is longer being used - to be done at zap layer. + return &IndexImpl{&idxBinary}, nil +} + const ( IOFlagMmap = C.FAISS_IO_FLAG_MMAP IOFlagReadOnly = C.FAISS_IO_FLAG_READ_ONLY diff --git a/index_ivf.go b/index_ivf.go index 38f023a..552e98d 100644 --- a/index_ivf.go +++ b/index_ivf.go @@ -5,6 +5,7 @@ package faiss #include #include #include +#include #include */ import "C" @@ -13,18 +14,34 @@ import ( ) func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { - + // Try to get either regular or binary IVF pointer ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) - if ivfPtr == nil { - return fmt.Errorf("index is not of ivf type") + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + + // If we have a regular IVF index + if ivfPtr != nil { + if c := C.faiss_IndexIVF_set_direct_map( + ivfPtr, + C.int(mapType), + ); c != 0 { + err = getLastError() + } + return err } - if c := C.faiss_IndexIVF_set_direct_map( - ivfPtr, - C.int(mapType), - ); c != 0 { - err = getLastError() + + // If we have a binary IVF index + if ivfPtrBinary != nil { + if c := C.faiss_IndexBinaryIVF_set_direct_map( + ivfPtrBinary, + C.int(mapType), + ); c != 0 { + err = getLastError() + } + return err } - return err + + // Get index type for better error message + return fmt.Errorf("index is not of ivf type 2") } func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { @@ -39,23 +56,38 @@ func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { return nil, fmt.Errorf("couldn't retrieve the sub index") } - return &IndexImpl{&faissIndex{subIdx}}, nil + return &IndexImpl{&faissIndex{idx: subIdx}}, nil } -// pass nprobe to be set as index time option for IVF indexes only. +// pass nprobe to be set as index time option for IVF/BIVF indexes only. // varying nprobe impacts recall but with an increase in latency. -func (idx *IndexImpl) SetNProbe(nprobe int32) { +func (idx *IndexImpl) SetNProbe(nprobe int32) error { ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) - if ivfPtr == nil { - return + if ivfPtr != nil { + C.faiss_IndexIVF_set_nprobe(ivfPtr, C.size_t(nprobe)) + return nil + } + + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + if ivfPtrBinary != nil { + C.faiss_IndexBinaryIVF_set_nprobe(ivfPtrBinary, C.size_t(nprobe)) + return nil } - C.faiss_IndexIVF_set_nprobe(ivfPtr, C.size_t(nprobe)) + + // Get index type for better error message + return fmt.Errorf("index is not of ivf type 3") } func (idx *IndexImpl) GetNProbe() int32 { ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) - if ivfPtr == nil { - return 0 + if ivfPtr != nil { + return int32(C.faiss_IndexIVF_nprobe(ivfPtr)) } - return int32(C.faiss_IndexIVF_nprobe(ivfPtr)) + + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + if ivfPtrBinary != nil { + return int32(C.faiss_IndexBinaryIVF_nprobe(ivfPtrBinary)) + } + + return 0 } From e577f27e6bda884af44f33468fa53980d58b9d23 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Fri, 9 May 2025 17:55:32 +0530 Subject: [PATCH 02/13] cleanup --- index.go | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/index.go b/index.go index f0a7999..4d057e5 100644 --- a/index.go +++ b/index.go @@ -317,23 +317,6 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, return distances, labels, nil } -func packBits(bits []uint8) []uint8 { - n := (len(bits) + 7) / 8 - result := make([]uint8, n) - for i := 0; i < len(bits); i++ { - // Determine the index in the result slice - byteIndex := i / 8 - // Determine the bit position in the byte - bitPosition := uint(7 - (i % 8)) - // If the bit is 1, set the corresponding bit in the uint8 value - if bits[i] == 1 { - result[byteIndex] |= (1 << bitPosition) - } - } - - return result -} - func (idx *faissIndex) AddWithIDs(x interface{}, xids []int64) error { floatVec, ok := x.([]float32) if ok { From 73e9e4b6596960bf3ae6e15b2dd321d29a2fc628 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Sun, 11 May 2025 09:34:58 +0530 Subject: [PATCH 03/13] hacky - search binary without ids --- index.go | 99 +++++++++++++++++++++++++++++++++++++++--------- search_params.go | 45 ++++++++++++++++++++++ 2 files changed, 127 insertions(+), 17 deletions(-) diff --git a/index.go b/index.go index 4d057e5..df20792 100644 --- a/index.go +++ b/index.go @@ -81,6 +81,9 @@ type Index interface { SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32, labels []int64, err error) + SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, + params json.RawMessage) (distances []int32, labels []int64, err error) + SearchBinary(x []uint8, k int64) (distances []int32, labels []int64, err error) @@ -390,11 +393,46 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p } defer searchParams.Delete() - distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) + d, labels, err := idx.searchWithParams(x, k, searchParams.sp) + distances = d.([]float32) return } +func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, + params json.RawMessage) (distances []int32, labels []int64, err error, +) { + if params == nil && len(exclude) == 0 { + return idx.SearchBinary(x, k) + } + + var selector *C.FaissIDSelector + if len(exclude) > 0 { + excludeSelector, err := NewIDSelectorNot(exclude) + if err != nil { + return nil, nil, err + } + selector = excludeSelector.Get() + defer excludeSelector.Delete() + } + + searchParams, err := NewSearchParams(idx, params, selector, nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + nq := (len(x) * 8) / idx.D() + + distances = make([]int32, int64(nq)*k) + labels = make([]int64, int64(nq)*k) + + d, labels, err := idx.searchWithParams(x, k, searchParams.sp) + distances = d.([]int32) + + return distances, labels, nil +} + func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32, labels []int64, err error, ) { @@ -455,7 +493,8 @@ func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, } defer searchParams.Delete() - distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) + d, labels, err := idx.searchWithParams(x, k, searchParams.sp) + distances = d.([]float32) return } @@ -551,23 +590,49 @@ func (idx *faissIndex) Close() { C.faiss_IndexBinary_free(idx.idxBinary) } -func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) ( - distances []float32, labels []int64, err error, +func (idx *faissIndex) searchWithParams(x interface{}, k int64, searchParams *C.FaissSearchParameters) ( + distances interface{}, labels []int64, err error, ) { - n := len(x) / idx.D() - distances = make([]float32, int64(n)*k) - labels = make([]int64, int64(n)*k) + floatVec, ok := x.([]float32) + if ok { + n := len(floatVec) / idx.D() + distancesFloat := make([]float32, int64(n)*k) + labels = make([]int64, int64(n)*k) - if c := C.faiss_Index_search_with_params( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - C.idx_t(k), - searchParams, - (*C.float)(&distances[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() + if c := C.faiss_Index_search_with_params( + idx.idx, + C.idx_t(n), + (*C.float)(&floatVec[0]), + C.idx_t(k), + searchParams, + (*C.float)(&distancesFloat[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + distances = distancesFloat + } else { + c, ok := x.([]uint8) + if ok { + n := (len(c) * 8) / idx.D() + distancesBinary := make([]int32, int64(n)*k) + labels = make([]int64, int64(n)*k) + + if c := C.faiss_IndexBinary_search_with_params( + idx.idxBinary, + C.idx_t(n), + (*C.uint8_t)(&c[0]), + C.idx_t(k), + searchParams, + (*C.int32_t)(&distancesBinary[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + distances = distancesBinary + } } return diff --git a/search_params.go b/search_params.go index 6086073..d194eca 100644 --- a/search_params.go +++ b/search_params.go @@ -2,6 +2,7 @@ package faiss /* #include +#include #include #include */ @@ -108,6 +109,50 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, rv.Delete() return nil, fmt.Errorf("failed to create faiss IVF search params") } + } else if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { + rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp) + if len(params) == 0 && sel == nil { + return rv, nil + } + var nlist, nprobe, nvecs, maxCodes int + nlist = int(C.faiss_IndexBinaryIVF_nlist(bivfIdx)) + nprobe = int(C.faiss_IndexBinaryIVF_nprobe(bivfIdx)) + nvecs = int(C.faiss_IndexBinaryIVF_ntotal(bivfIdx)) + if defaultParams != nil { + if defaultParams.Nlist > 0 { + nlist = defaultParams.Nlist + } + if defaultParams.Nprobe > 0 { + nprobe = defaultParams.Nprobe + } + } + var ivfParams searchParamsIVF + if len(params) > 0 { + if err := json.Unmarshal(params, &ivfParams); err != nil { + rv.Delete() + return nil, fmt.Errorf("failed to unmarshal IVF search params, "+ + "err:%v", err) + } + if err := ivfParams.Validate(); err != nil { + rv.Delete() + return nil, err + } + } + if ivfParams.NprobePct > 0 { + nprobe = max(int(float32(nlist)*(ivfParams.NprobePct/100)), 1) + } + if ivfParams.MaxCodesPct > 0 { + maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100)) + } // else, maxCodes will be set to the default value of 0, which means no limit + if c := C.faiss_SearchParametersIVF_new_with( + &rv.sp, + sel, + C.size_t(nprobe), + C.size_t(maxCodes), + ); c != 0 { + rv.Delete() + return nil, fmt.Errorf("failed to create faiss IVF search params") + } } return rv, nil } From 68a53a2f44068b858328557e52b45b84f146900e Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Mon, 12 May 2025 14:23:42 +0530 Subject: [PATCH 04/13] cleanup --- search_params.go | 105 +++++++++++++++++++---------------------------- 1 file changed, 42 insertions(+), 63 deletions(-) diff --git a/search_params.go b/search_params.go index d194eca..4fe700b 100644 --- a/search_params.go +++ b/search_params.go @@ -64,42 +64,55 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, if c := C.faiss_SearchParameters_new(&rv.sp, sel); c != 0 { return nil, fmt.Errorf("failed to create faiss search params") } + + if len(params) == 0 && sel == nil { + return rv, nil + } + + var nlist, nprobe, nvecs, maxCodes int + var ivfParams searchParamsIVF + + rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp) + // check if the index is IVF and set the search params if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx != nil { - rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp) - if len(params) == 0 && sel == nil { - return rv, nil - } - var nlist, nprobe, nvecs, maxCodes int nlist = int(C.faiss_IndexIVF_nlist(ivfIdx)) nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx)) nvecs = int(C.faiss_Index_ntotal(idx.cPtr())) - if defaultParams != nil { - if defaultParams.Nlist > 0 { - nlist = defaultParams.Nlist - } - if defaultParams.Nprobe > 0 { - nprobe = defaultParams.Nprobe - } + } else if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { + nlist = int(C.faiss_IndexBinaryIVF_nlist(bivfIdx)) + nprobe = int(C.faiss_IndexBinaryIVF_nprobe(bivfIdx)) + nvecs = int(C.faiss_IndexBinary_ntotal(idx.cPtrBinary())) + } + + if defaultParams != nil { + if defaultParams.Nlist > 0 { + nlist = defaultParams.Nlist } - var ivfParams searchParamsIVF - if len(params) > 0 { - if err := json.Unmarshal(params, &ivfParams); err != nil { - rv.Delete() - return nil, fmt.Errorf("failed to unmarshal IVF search params, "+ - "err:%v", err) - } - if err := ivfParams.Validate(); err != nil { - rv.Delete() - return nil, err - } + if defaultParams.Nprobe > 0 { + nprobe = defaultParams.Nprobe } - if ivfParams.NprobePct > 0 { - nprobe = max(int(float32(nlist)*(ivfParams.NprobePct/100)), 1) + } + + if len(params) > 0 { + if err := json.Unmarshal(params, &ivfParams); err != nil { + rv.Delete() + return nil, fmt.Errorf("failed to unmarshal IVF search params, "+ + "err:%v", err) } - if ivfParams.MaxCodesPct > 0 { - maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100)) - } // else, maxCodes will be set to the default value of 0, which means no limit + if err := ivfParams.Validate(); err != nil { + rv.Delete() + return nil, err + } + } + if ivfParams.NprobePct > 0 { + nprobe = max(int(float32(nlist)*(ivfParams.NprobePct/100)), 1) + } + if ivfParams.MaxCodesPct > 0 { + maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100)) + } // else, maxCodes will be set to the default value of 0, which means no limit + + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx != nil { if c := C.faiss_SearchParametersIVF_new_with( &rv.sp, sel, @@ -110,40 +123,6 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, return nil, fmt.Errorf("failed to create faiss IVF search params") } } else if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { - rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp) - if len(params) == 0 && sel == nil { - return rv, nil - } - var nlist, nprobe, nvecs, maxCodes int - nlist = int(C.faiss_IndexBinaryIVF_nlist(bivfIdx)) - nprobe = int(C.faiss_IndexBinaryIVF_nprobe(bivfIdx)) - nvecs = int(C.faiss_IndexBinaryIVF_ntotal(bivfIdx)) - if defaultParams != nil { - if defaultParams.Nlist > 0 { - nlist = defaultParams.Nlist - } - if defaultParams.Nprobe > 0 { - nprobe = defaultParams.Nprobe - } - } - var ivfParams searchParamsIVF - if len(params) > 0 { - if err := json.Unmarshal(params, &ivfParams); err != nil { - rv.Delete() - return nil, fmt.Errorf("failed to unmarshal IVF search params, "+ - "err:%v", err) - } - if err := ivfParams.Validate(); err != nil { - rv.Delete() - return nil, err - } - } - if ivfParams.NprobePct > 0 { - nprobe = max(int(float32(nlist)*(ivfParams.NprobePct/100)), 1) - } - if ivfParams.MaxCodesPct > 0 { - maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100)) - } // else, maxCodes will be set to the default value of 0, which means no limit if c := C.faiss_SearchParametersIVF_new_with( &rv.sp, sel, @@ -151,7 +130,7 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, C.size_t(maxCodes), ); c != 0 { rv.Delete() - return nil, fmt.Errorf("failed to create faiss IVF search params") + return nil, fmt.Errorf("failed to create faiss BIVF search params") } } return rv, nil From 4be47b7c485e2d694f2ecbd7c061a168de12ed6d Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Mon, 12 May 2025 15:39:59 +0530 Subject: [PATCH 05/13] better error messages --- index_ivf.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/index_ivf.go b/index_ivf.go index 552e98d..e063b6d 100644 --- a/index_ivf.go +++ b/index_ivf.go @@ -41,7 +41,7 @@ func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { } // Get index type for better error message - return fmt.Errorf("index is not of ivf type 2") + return fmt.Errorf("unable to set direct map") } func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { @@ -74,8 +74,7 @@ func (idx *IndexImpl) SetNProbe(nprobe int32) error { return nil } - // Get index type for better error message - return fmt.Errorf("index is not of ivf type 3") + return fmt.Errorf("unable to get nprobe") } func (idx *IndexImpl) GetNProbe() int32 { From a33ef3eea1b7430548c0f0e5a6cd21932e823a44 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Tue, 13 May 2025 12:33:14 +0530 Subject: [PATCH 06/13] generalised dist compute --- index.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/index.go b/index.go index df20792..7e6d07a 100644 --- a/index.go +++ b/index.go @@ -119,7 +119,7 @@ type Index interface { cPtrBinary() *C.FaissIndexBinary - IVFDistCompute(queryData []float32, ids []int64, k int, distances []float32) + DistCompute(queryData []float32, ids []int64, k int, distances []float32) error } type faissIndex struct { @@ -131,9 +131,13 @@ func (idx *faissIndex) cPtr() *C.FaissIndex { return idx.idx } -func (idx *faissIndex) IVFDistCompute(queryData []float32, ids []int64, k int, distances []float32) { - C.faiss_IndexIVF_dist_compute(idx.idx, (*C.float)(&queryData[0]), - (*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])) +func (idx *faissIndex) DistCompute(queryData []float32, ids []int64, k int, distances []float32) error { + if c := C.faiss_Index_dist_compute(idx.idx, (*C.float)(&queryData[0]), + (*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])); c != 0 { + return getLastError() + } + + return nil } func (idx *faissIndex) cPtrBinary() *C.FaissIndexBinary { From 42a99f0acf64e706aad1194f6b949cb645464104 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Sat, 17 May 2025 07:26:29 +0530 Subject: [PATCH 07/13] small fix --- index.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/index.go b/index.go index 7e6d07a..8b06c43 100644 --- a/index.go +++ b/index.go @@ -479,7 +479,7 @@ func (idx *faissIndex) SearchBinary(x []uint8, k int64) (distances []int32, labe err = getLastError() } - return distances, labels, nil + return distances, labels, err } func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, From 0d9a4f3301c01db578d67d306e400db2e9ff18c1 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Mon, 19 May 2025 12:19:13 +0530 Subject: [PATCH 08/13] Revert "hacky - search binary without ids" This reverts commit 73e9e4b6596960bf3ae6e15b2dd321d29a2fc628. --- index.go | 99 +++++++++--------------------------------------- search_params.go | 17 +-------- 2 files changed, 18 insertions(+), 98 deletions(-) diff --git a/index.go b/index.go index 8b06c43..a93c041 100644 --- a/index.go +++ b/index.go @@ -81,9 +81,6 @@ type Index interface { SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32, labels []int64, err error) - SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, - params json.RawMessage) (distances []int32, labels []int64, err error) - SearchBinary(x []uint8, k int64) (distances []int32, labels []int64, err error) @@ -397,46 +394,11 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p } defer searchParams.Delete() - d, labels, err := idx.searchWithParams(x, k, searchParams.sp) - distances = d.([]float32) + distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) return } -func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, - params json.RawMessage) (distances []int32, labels []int64, err error, -) { - if params == nil && len(exclude) == 0 { - return idx.SearchBinary(x, k) - } - - var selector *C.FaissIDSelector - if len(exclude) > 0 { - excludeSelector, err := NewIDSelectorNot(exclude) - if err != nil { - return nil, nil, err - } - selector = excludeSelector.Get() - defer excludeSelector.Delete() - } - - searchParams, err := NewSearchParams(idx, params, selector, nil) - if err != nil { - return nil, nil, err - } - defer searchParams.Delete() - - nq := (len(x) * 8) / idx.D() - - distances = make([]int32, int64(nq)*k) - labels = make([]int64, int64(nq)*k) - - d, labels, err := idx.searchWithParams(x, k, searchParams.sp) - distances = d.([]int32) - - return distances, labels, nil -} - func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32, labels []int64, err error, ) { @@ -497,8 +459,7 @@ func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, } defer searchParams.Delete() - d, labels, err := idx.searchWithParams(x, k, searchParams.sp) - distances = d.([]float32) + distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) return } @@ -594,49 +555,23 @@ func (idx *faissIndex) Close() { C.faiss_IndexBinary_free(idx.idxBinary) } -func (idx *faissIndex) searchWithParams(x interface{}, k int64, searchParams *C.FaissSearchParameters) ( - distances interface{}, labels []int64, err error, +func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) ( + distances []float32, labels []int64, err error, ) { - floatVec, ok := x.([]float32) - if ok { - n := len(floatVec) / idx.D() - distancesFloat := make([]float32, int64(n)*k) - labels = make([]int64, int64(n)*k) - - if c := C.faiss_Index_search_with_params( - idx.idx, - C.idx_t(n), - (*C.float)(&floatVec[0]), - C.idx_t(k), - searchParams, - (*C.float)(&distancesFloat[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() - } - - distances = distancesFloat - } else { - c, ok := x.([]uint8) - if ok { - n := (len(c) * 8) / idx.D() - distancesBinary := make([]int32, int64(n)*k) - labels = make([]int64, int64(n)*k) - - if c := C.faiss_IndexBinary_search_with_params( - idx.idxBinary, - C.idx_t(n), - (*C.uint8_t)(&c[0]), - C.idx_t(k), - searchParams, - (*C.int32_t)(&distancesBinary[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() - } + n := len(x) / idx.D() + distances = make([]float32, int64(n)*k) + labels = make([]int64, int64(n)*k) - distances = distancesBinary - } + if c := C.faiss_Index_search_with_params( + idx.idx, + C.idx_t(n), + (*C.float)(&x[0]), + C.idx_t(k), + searchParams, + (*C.float)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() } return diff --git a/search_params.go b/search_params.go index 4fe700b..06221e9 100644 --- a/search_params.go +++ b/search_params.go @@ -2,7 +2,6 @@ package faiss /* #include -#include #include #include */ @@ -79,11 +78,7 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, nlist = int(C.faiss_IndexIVF_nlist(ivfIdx)) nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx)) nvecs = int(C.faiss_Index_ntotal(idx.cPtr())) - } else if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { - nlist = int(C.faiss_IndexBinaryIVF_nlist(bivfIdx)) - nprobe = int(C.faiss_IndexBinaryIVF_nprobe(bivfIdx)) - nvecs = int(C.faiss_IndexBinary_ntotal(idx.cPtrBinary())) - } + } if defaultParams != nil { if defaultParams.Nlist > 0 { @@ -122,16 +117,6 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, rv.Delete() return nil, fmt.Errorf("failed to create faiss IVF search params") } - } else if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { - if c := C.faiss_SearchParametersIVF_new_with( - &rv.sp, - sel, - C.size_t(nprobe), - C.size_t(maxCodes), - ); c != 0 { - rv.Delete() - return nil, fmt.Errorf("failed to create faiss BIVF search params") - } } return rv, nil } From 5ad828de4e405c8be92003a2dce87b32d7e8d5b5 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Wed, 21 May 2025 12:15:02 +0530 Subject: [PATCH 09/13] nicer binary search without ids --- index.go | 38 +++++++++++++++++++++++++++++++------- search_params.go | 12 ++++++++++-- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/index.go b/index.go index a93c041..0d921a7 100644 --- a/index.go +++ b/index.go @@ -81,7 +81,7 @@ type Index interface { SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32, labels []int64, err error) - SearchBinary(x []uint8, k int64) (distances []int32, + SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, labels []int64, err error) // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs @@ -408,11 +408,18 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, distances = make([]int32, int64(nq)*k) labels = make([]int64, int64(nq)*k) - if c := C.faiss_IndexBinary_search( + searchParams, err := NewSearchParams(idx, params, nil, nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_IndexBinary_search_with_params( idx.idxBinary, C.idx_t(nq), (*C.uint8_t)(&x[0]), C.idx_t(k), + searchParams.sp, (*C.int32_t)(&distances[0]), (*C.idx_t)(&labels[0]), ); c != 0 { @@ -422,26 +429,43 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, return distances, labels, nil } -func (idx *faissIndex) SearchBinary(x []uint8, k int64) (distances []int32, labels []int64, err error, -) { +func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64,params json.RawMessage) (distances []int32, + labels []int64, err error) { d := idx.D() nq := (len(x) * 8) / d distances = make([]int32, int64(nq)*k) labels = make([]int64, int64(nq)*k) - if c := C.faiss_IndexBinary_search( + var selector *C.FaissIDSelector + if len(exclude) > 0 { + excludeSelector, err := NewIDSelectorNot(exclude) + if err != nil { + return nil, nil, err + } + selector = excludeSelector.Get() + defer excludeSelector.Delete() + } + + searchParams, err := NewSearchParams(idx, params, selector, nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_IndexBinary_search_with_params( idx.idxBinary, C.idx_t(nq), (*C.uint8_t)(&x[0]), C.idx_t(k), + searchParams.sp, (*C.int32_t)(&distances[0]), (*C.idx_t)(&labels[0]), - ); c != 0 { + ); c != 0 { err = getLastError() } - return distances, labels, err + return distances, labels, err } func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, diff --git a/search_params.go b/search_params.go index 06221e9..f210f7b 100644 --- a/search_params.go +++ b/search_params.go @@ -3,6 +3,7 @@ package faiss /* #include #include +#include #include */ import "C" @@ -78,7 +79,11 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, nlist = int(C.faiss_IndexIVF_nlist(ivfIdx)) nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx)) nvecs = int(C.faiss_Index_ntotal(idx.cPtr())) - } + } else if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { + nlist = int(C.faiss_IndexBinaryIVF_nlist(bivfIdx)) + nprobe = int(C.faiss_IndexBinaryIVF_nprobe(bivfIdx)) + nvecs = int(C.faiss_IndexBinary_ntotal(idx.cPtrBinary())) + } if defaultParams != nil { if defaultParams.Nlist > 0 { @@ -107,7 +112,10 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100)) } // else, maxCodes will be set to the default value of 0, which means no limit - if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx != nil { + ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()) + bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + + if ivfIdx != nil || bivfIdx != nil { if c := C.faiss_SearchParametersIVF_new_with( &rv.sp, sel, From 329d981a54d0ac9a82455761a9fc92c738a0bdb3 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Thu, 5 Jun 2025 18:35:06 +0530 Subject: [PATCH 10/13] clean up --- index.go | 50 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/index.go b/index.go index 0d921a7..57dfc68 100644 --- a/index.go +++ b/index.go @@ -78,7 +78,10 @@ type Index interface { SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32, labels []int64, err error) - SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32, + SearchBinary(x []uint8, k int64) (distances []int32, + labels []int64, err error) + + SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) (distances []int32, labels []int64, err error) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, @@ -399,7 +402,30 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p return } -func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, +func (idx *faissIndex) SearchBinary(x []uint8, k int64) (distances []int32, + labels []int64, err error, +) { + d := idx.D() + nq := (len(x) * 8) / d + + distances = make([]int32, int64(nq)*k) + labels = make([]int64, int64(nq)*k) + + if c := C.faiss_IndexBinary_search( + idx.idxBinary, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return distances, labels, nil +} + +func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) (distances []int32, labels []int64, err error, ) { d := idx.D() @@ -408,7 +434,17 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, distances = make([]int32, int64(nq)*k) labels = make([]int64, int64(nq)*k) - searchParams, err := NewSearchParams(idx, params, nil, nil) + var selector *C.FaissIDSelector + if len(include) > 0 { + includeSelector, err := NewIDSelectorBatch(include) + if err != nil { + return nil, nil, err + } + selector = includeSelector.Get() + defer includeSelector.Delete() + } + + searchParams, err := NewSearchParams(idx, params, selector, nil) if err != nil { return nil, nil, err } @@ -429,8 +465,12 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, return distances, labels, nil } -func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64,params json.RawMessage) (distances []int32, +func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, labels []int64, err error) { + if len(exclude) == 0 && params == nil { + return idx.SearchBinary(x, k) + } + d := idx.D() nq := (len(x) * 8) / d @@ -461,7 +501,7 @@ func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int6 searchParams.sp, (*C.int32_t)(&distances[0]), (*C.idx_t)(&labels[0]), - ); c != 0 { + ); c != 0 { err = getLastError() } From bfb5436b8013efbb1648b6da1c4ef84df99ca85b Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Mon, 9 Jun 2025 19:03:42 +0530 Subject: [PATCH 11/13] addressed reviews - 1 --- autotune.go | 18 +- index.go | 807 +++++++++++++++++++++++------------------------ index_flat.go | 14 +- index_io.go | 128 +++++--- index_ivf.go | 88 +++--- search_params.go | 57 ++-- 6 files changed, 588 insertions(+), 524 deletions(-) diff --git a/autotune.go b/autotune.go index 4b818d3..15e71ef 100644 --- a/autotune.go +++ b/autotune.go @@ -6,6 +6,7 @@ package faiss */ import "C" import ( + "fmt" "unsafe" ) @@ -30,11 +31,20 @@ func (p *ParameterSpace) SetIndexParameter(idx Index, name string, val float64) C.free(unsafe.Pointer(cname)) }() - c := C.faiss_ParameterSpace_set_index_parameter( - p.ps, idx.cPtr(), cname, C.double(val)) - if c != 0 { - return getLastError() + switch idx.(type) { + case FloatIndex: + idx := idx.(*IndexImpl) + c := C.faiss_ParameterSpace_set_index_parameter( + p.ps, idx.cPtrFloat(), cname, C.double(val)) + if c != 0 { + return getLastError() + } + case BinaryIndex: + return fmt.Errorf("binary indexes not supported for auto-tuning") + default: + return fmt.Errorf("unsupported index type") } + return nil } diff --git a/index.go b/index.go index 57dfc68..19b61d3 100644 --- a/index.go +++ b/index.go @@ -5,13 +5,12 @@ package faiss #include #include #include -#include #include -#include -#include +#include #include #include -#include +#include +#include */ import "C" import ( @@ -20,35 +19,66 @@ import ( "unsafe" ) -// Index is a Faiss index. -// -// Note that some index implementations do not support all methods. -// Check the Faiss wiki to see what operations an index supports. +// Index is the common interface for both binary and float vector indexes type Index interface { - // D returns the dimension of the indexed vectors. + // Core index operations D() int - // IsTrained returns true if the index has been trained or does not require - // training. - IsTrained() bool - // Ntotal returns the number of indexed vectors. Ntotal() int64 // MetricType returns the metric type of the index. MetricType() int - // Train trains the index on a representative set of vectors. - Train(x interface{}) error + Size() uint64 - // Add adds vectors to the index. - Add(x interface{}) error + // IVF-specific operations, common to both float and binary IVF indexes + IsIVFIndex() bool + SetNProbe(nprobe int32) + GetNProbe() int32 + SetDirectMap(directMapType int) error - // AddWithIDs is like Add, but stores xids instead of sequential IDs. - AddWithIDs(x interface{}, xids []int64) error + Close() +} - // Returns true if the index is an IVF index. - IsIVFIndex() bool +// BinaryIndex defines methods specific to binary FAISS indexes +type BinaryIndex interface { + Index + + cPtrBinary() *C.FaissIndexBinary + // Binary-specific operations + TrainBinary(vectors []uint8) error + AddBinary(vectors []uint8) error + AddBinaryWithIDs(vectors []uint8, ids []int64) error + SearchBinary(x []uint8, k int64) ([]int32, []int64, error) + SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) ([]int32, []int64, error) + SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, + labels []int64, err error) +} + +// FloatIndex defines methods specific to float-based FAISS indexes +type FloatIndex interface { + Index + + cPtrFloat() *C.FaissIndex + // Float-specific operations + // Train trains the index on a representative set of vectors. + Train(vectors []float32) error + Add(vectors []float32) error + // AddWithIDs is like Add, but stores xids instead of sequential IDs. + AddWithIDs(vectors []float32, xids []int64) error + // Search queries the index with the vectors in x. + // Returns the IDs of the k nearest neighbors for each query vector and the + // corresponding distances. + Search(x []float32, k int64) (distances []float32, labels []int64, err error) + // RangeSearch queries the index with the vectors in x. + // Returns all vectors with distance < radius. + RangeSearch(x []float32, radius float32) (*RangeSearchResult, error) + SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) ([]float32, []int64, error) + // SearchWithoutIDs is like Search, but excludes the vectors with IDs in exclude. + SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) ([]float32, []int64, error) + Reconstruct(key int64) (recons []float32, err error) + ReconstructBatch(ids []int64, vectors []float32) ([]float32, error) // Applicable only to IVF indexes: Returns a map where the keys // are cluster IDs and the values represent the count of input vectors that belong @@ -67,40 +97,14 @@ type Index interface { ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( []int64, []float32, error) - // Search queries the index with the vectors in x. - // Returns the IDs of the k nearest neighbors for each query vector and the - // corresponding distances. - Search(x []float32, k int64) (distances []float32, labels []int64, err error) - - SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) (distances []float32, - labels []int64, err error) - - SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32, - labels []int64, err error) - - SearchBinary(x []uint8, k int64) (distances []int32, - labels []int64, err error) - - SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) (distances []int32, - labels []int64, err error) - - SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, - labels []int64, err error) + DistCompute(queryData []float32, ids []int64, k int, distances []float32) error // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x, centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) - Reconstruct(key int64) ([]float32, error) - - ReconstructBatch(keys []int64, recons []float32) ([]float32, error) - - MergeFrom(other Index, add_id int64) error - - // RangeSearch queries the index with the vectors in x. - // Returns all vectors with distance < radius. - RangeSearch(x []float32, radius float32) (*RangeSearchResult, error) + MergeFrom(other IndexImpl, add_id int64) error // Reset removes all vectors from the index. Reset() error @@ -108,123 +112,243 @@ type Index interface { // RemoveIDs removes the vectors specified by sel from the index. // Returns the number of elements removed and error. RemoveIDs(sel *IDSelector) (int, error) +} - // Close frees the memory used by the index. - Close() +// IndexImpl represents a float vector index +type IndexImpl struct { + indexPtr *C.FaissIndex + d int + metric int +} - // consults the C++ side to get the size of the index - Size() uint64 +// BinaryIndexImpl represents a binary vector index +type BinaryIndexImpl struct { + indexPtr *C.FaissIndexBinary + d int + metric int +} - cPtr() *C.FaissIndex +// NewBinaryIndexImpl creates a new binary index implementation +func NewBinaryIndexImpl(d int, description string, metric int) (*BinaryIndexImpl, error) { + idx := &BinaryIndexImpl{ + d: d, + metric: metric, + } + var cDescription *C.char + if description != "" { + cDescription = C.CString(description) + defer C.free(unsafe.Pointer(cDescription)) + } - cPtrBinary() *C.FaissIndexBinary + var cIdx *C.FaissIndexBinary + if c := C.faiss_index_binary_factory(&cIdx, C.int(idx.d), cDescription); c != 0 { + return nil, getLastError() + } + idx.indexPtr = cIdx + return idx, nil +} - DistCompute(queryData []float32, ids []int64, k int, distances []float32) error +// Core index operations +func (idx *BinaryIndexImpl) Close() { + if idx.indexPtr != nil { + C.faiss_IndexBinary_free(idx.indexPtr) + idx.indexPtr = nil + } } -type faissIndex struct { - idx *C.FaissIndex - idxBinary *C.FaissIndexBinary +func (idx *BinaryIndexImpl) Size() uint64 { + return 0 } -func (idx *faissIndex) cPtr() *C.FaissIndex { - return idx.idx +func (idx *BinaryIndexImpl) cPtrBinary() *C.FaissIndexBinary { + return idx.indexPtr } -func (idx *faissIndex) DistCompute(queryData []float32, ids []int64, k int, distances []float32) error { - if c := C.faiss_Index_dist_compute(idx.idx, (*C.float)(&queryData[0]), - (*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])); c != 0 { - return getLastError() - } +func (idx *BinaryIndexImpl) D() int { + return idx.d +} - return nil +func (idx *BinaryIndexImpl) MetricType() int { + return idx.metric } -func (idx *faissIndex) cPtrBinary() *C.FaissIndexBinary { - return idx.idxBinary +func (idx *BinaryIndexImpl) Ntotal() int64 { + return int64(C.faiss_IndexBinary_ntotal(idx.indexPtr)) } -func (idx *faissIndex) Size() uint64 { - size := C.faiss_Index_size(idx.idx) - return uint64(size) +func (idx *BinaryIndexImpl) IsIVFIndex() bool { + return C.faiss_IndexBinaryIVF_cast(idx.indexPtr) != nil } -func (idx *faissIndex) D() int { - if idx.idx != nil { - return int(C.faiss_Index_d(idx.idx)) +// Binary-specific operations +func (idx *BinaryIndexImpl) TrainBinary(vectors []uint8) error { + n := (len(vectors) * 8) / idx.d + if c := C.faiss_IndexBinary_train(idx.indexPtr, C.idx_t(n), (*C.uint8_t)(&vectors[0])); c != 0 { + return getLastError() } - return int(C.faiss_IndexBinary_d(idx.idxBinary)) + return nil } -func (idx *faissIndex) IsTrained() bool { - return C.faiss_Index_is_trained(idx.idx) != 0 +func (idx *BinaryIndexImpl) AddBinary(vectors []uint8) error { + n := (len(vectors) * 8) / idx.d + if c := C.faiss_IndexBinary_add(idx.indexPtr, C.idx_t(n), (*C.uint8_t)(&vectors[0])); c != 0 { + return getLastError() + } + return nil } -func (idx *faissIndex) Ntotal() int64 { - if idx.idxBinary != nil { - return int64(C.faiss_IndexBinary_ntotal(idx.idxBinary)) +func (idx *BinaryIndexImpl) AddBinaryWithIDs(vectors []uint8, ids []int64) error { + n := (len(vectors) * 8) / idx.d + if c := C.faiss_IndexBinary_add_with_ids(idx.indexPtr, C.idx_t(n), (*C.uint8_t)(&vectors[0]), (*C.idx_t)(&ids[0])); c != 0 { + return getLastError() } - return int64(C.faiss_Index_ntotal(idx.idx)) + return nil } -func (idx *faissIndex) MetricType() int { - return int(C.faiss_Index_metric_type(idx.idx)) +func (idx *BinaryIndexImpl) SearchBinary(x []uint8, k int64) ([]int32, []int64, error) { + nq := (len(x) * 8) / idx.d + distances := make([]int32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) + + if c := C.faiss_IndexBinary_search( + idx.indexPtr, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + return nil, nil, getLastError() + } + return distances, labels, nil } -func (idx *faissIndex) Train(x interface{}) error { - floatVec, ok := x.([]float32) - if ok { - n := len(floatVec) / idx.D() - if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&floatVec[0])); c != 0 { - return getLastError() - } - } else { - c, ok := x.([]uint8) - if ok { - n := (len(c) * 8) / idx.D() - if c := C.faiss_IndexBinary_train(idx.idxBinary, C.idx_t(n), (*C.uint8_t)(&c[0])); c != 0 { - return getLastError() - } - } +func (idx *BinaryIndexImpl) SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) ([]int32, []int64, error) { + nq := (len(x) * 8) / idx.d + distances := make([]int32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) + + includeSelector, err := NewIDSelectorBatch(include) + if err != nil { + return nil, nil, err + } + defer includeSelector.Delete() + + searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_IndexBinary_search_with_params( + idx.indexPtr, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + searchParams.sp, + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + return nil, nil, getLastError() + } + return distances, labels, nil +} + +func (idx *BinaryIndexImpl) Train(vectors []uint8) error { + n := (len(vectors) * 8) / idx.d + if c := C.faiss_IndexBinary_train(idx.indexPtr, C.idx_t(n), (*C.uint8_t)(&vectors[0])); c != 0 { + return getLastError() } return nil } -func (idx *faissIndex) Add(x interface{}) error { - floatVec, ok := x.([]float32) - if ok { - n := len(floatVec) / idx.D() - if c := C.faiss_Index_add( - idx.idx, - C.idx_t(n), - (*C.float)(&floatVec[0]), - ); c != 0 { - return getLastError() - } - } else { - c, ok := x.([]uint8) - if ok { - n := (len(c) * 8) / idx.D() - if c := C.faiss_IndexBinary_add( - idx.idxBinary, - C.idx_t(n), - (*C.uint8_t)(&c[0]), - ); c != 0 { - return getLastError() - } +func (idx *BinaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, labels []int64, err error) { + if len(exclude) == 0 && params == nil { + return idx.SearchBinary(x, k) + } + + nq := (len(x) * 8) / idx.d + distances = make([]int32, int64(nq)*k) + labels = make([]int64, int64(nq)*k) + + var selector *C.FaissIDSelector + if len(exclude) > 0 { + excludeSelector, err := NewIDSelectorNot(exclude) + if err != nil { + return nil, nil, err } + selector = excludeSelector.Get() + defer excludeSelector.Delete() + } + + searchParams, err := NewSearchParams(idx, params, selector, nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_IndexBinary_search_with_params( + idx.indexPtr, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + searchParams.sp, + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return distances, labels, err +} + +// Factory functions +func IndexBinaryFactory(d int, description string, metric int) (BinaryIndex, error) { + return NewBinaryIndexImpl(d, description, metric) +} + +// Ensure BinaryIndexImpl implements BinaryIndex interface +var _ BinaryIndex = (*BinaryIndexImpl)(nil) + +func (idx *IndexImpl) searchWithParams(x []float32, k int64, params *C.FaissSearchParameters) (distances []float32, labels []int64, err error) { + n := len(x) / idx.D() + distances = make([]float32, int64(n)*k) + labels = make([]int64, int64(n)*k) + + if c := C.faiss_Index_search_with_params( + idx.indexPtr, + C.idx_t(n), + (*C.float)(&x[0]), + C.idx_t(k), + params, + (*C.float)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() } + return +} + +func (idx *IndexImpl) Size() uint64 { + return uint64(C.faiss_Index_size(idx.cPtrFloat())) +} + +func (idx *IndexImpl) Train(x []float32) error { + n := len(x) / idx.D() + if c := C.faiss_Index_train(idx.indexPtr, C.idx_t(n), (*C.float)(&x[0])); c != 0 { + return getLastError() + } return nil } -func (idx *faissIndex) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) { +func (idx *IndexImpl) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) { if !idx.IsIVFIndex() { return nil, fmt.Errorf("index is not an IVF index") } clusterIDs := make([]int64, len(vecIDs)) if c := C.faiss_get_lists_for_keys( - idx.idx, + idx.indexPtr, (*C.idx_t)(unsafe.Pointer(&vecIDs[0])), (C.size_t)(len(vecIDs)), (*C.idx_t)(unsafe.Pointer(&clusterIDs[0])), @@ -238,14 +362,16 @@ func (idx *faissIndex) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (ma return rv, nil } -func (idx *faissIndex) IsIVFIndex() bool { - if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx == nil { - return false +func (idx *IndexImpl) DistCompute(queryData []float32, ids []int64, k int, distances []float32) error { + if c := C.faiss_Index_dist_compute(idx.indexPtr, (*C.float)(&queryData[0]), + (*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])); c != 0 { + return getLastError() } - return true + + return nil } -func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( +func (idx *IndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( []int64, []float32, error) { // Selector to include only the centroids whose IDs are part of 'centroidIDs'. includeSelector, err := NewIDSelectorBatch(centroidIDs) @@ -267,7 +393,7 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent n := len(x) / idx.D() c := C.faiss_Search_closest_eligible_centroids( - idx.idx, + idx.indexPtr, (C.idx_t)(n), (*C.float)(&x[0]), (C.idx_t)(len(centroidIDs)), @@ -281,7 +407,7 @@ func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, cent return centroids, centroidDistances, nil } -func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, +func (idx *IndexImpl) SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x, centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) { @@ -308,7 +434,7 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, centroidDis = centroidDis[:effectiveNprobe] if c := C.faiss_IndexIVF_search_preassigned_with_params( - idx.idx, + idx.indexPtr, (C.idx_t)(n), (*C.float)(&x[0]), (C.idx_t)(k), @@ -324,158 +450,77 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, return distances, labels, nil } -func (idx *faissIndex) AddWithIDs(x interface{}, xids []int64) error { - floatVec, ok := x.([]float32) - if ok { - n := len(floatVec) / idx.D() - if c := C.faiss_Index_add_with_ids( - idx.idx, - C.idx_t(n), - (*C.float)(&floatVec[0]), - (*C.idx_t)(&xids[0]), - ); c != 0 { - return getLastError() - } - } else { - c, ok := x.([]uint8) - if ok { - n := (len(c) * 8) / idx.D() - if c := C.faiss_IndexBinary_add_with_ids( - idx.idxBinary, - C.idx_t(n), - (*C.uint8_t)(&c[0]), - (*C.idx_t)(&xids[0]), - ); c != 0 { - return getLastError() - } - } - } - - return nil -} - -func (idx *faissIndex) Search(x []float32, k int64) ( - distances []float32, labels []int64, err error, -) { - n := len(x) / idx.D() - distances = make([]float32, int64(n)*k) - labels = make([]int64, int64(n)*k) - if c := C.faiss_Index_search( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - C.idx_t(k), - (*C.float)(&distances[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() +func (idx *IndexImpl) IsIVFIndex() bool { + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx == nil { + return false } - - return + return true } -func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) ( - distances []float32, labels []int64, err error, -) { - if params == nil && len(exclude) == 0 { - return idx.Search(x, k) - } +// SearchWithIDs performs a search with ID filtering and search parameters +func (idx *IndexImpl) SearchWithIDs(queries []float32, k int64, include []int64, params json.RawMessage) ([]float32, []int64, error) { + nq := len(queries) / idx.d + distances := make([]float32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) - var selector *C.FaissIDSelector - if len(exclude) > 0 { - excludeSelector, err := NewIDSelectorNot(exclude) - if err != nil { - return nil, nil, err - } - selector = excludeSelector.Get() - defer excludeSelector.Delete() + includeSelector, err := NewIDSelectorBatch(include) + if err != nil { + return nil, nil, err } + defer includeSelector.Delete() - searchParams, err := NewSearchParams(idx, params, selector, nil) + searchParams, err := NewSearchParams(nil, params, includeSelector.Get(), nil) if err != nil { return nil, nil, err } defer searchParams.Delete() - distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) - - return -} - -func (idx *faissIndex) SearchBinary(x []uint8, k int64) (distances []int32, - labels []int64, err error, -) { - d := idx.D() - nq := (len(x) * 8) / d - - distances = make([]int32, int64(nq)*k) - labels = make([]int64, int64(nq)*k) - - if c := C.faiss_IndexBinary_search( - idx.idxBinary, + if c := C.faiss_Index_search_with_params( + idx.indexPtr, C.idx_t(nq), - (*C.uint8_t)(&x[0]), + (*C.float)(&queries[0]), C.idx_t(k), - (*C.int32_t)(&distances[0]), + searchParams.sp, + (*C.float)(&distances[0]), (*C.idx_t)(&labels[0]), ); c != 0 { - err = getLastError() + return nil, nil, getLastError() } - return distances, labels, nil } -func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, include []int64, - params json.RawMessage) (distances []int32, labels []int64, err error, -) { - d := idx.D() - nq := (len(x) * 8) / d - - distances = make([]int32, int64(nq)*k) - labels = make([]int64, int64(nq)*k) - - var selector *C.FaissIDSelector - if len(include) > 0 { - includeSelector, err := NewIDSelectorBatch(include) - if err != nil { - return nil, nil, err - } - selector = includeSelector.Get() - defer includeSelector.Delete() - } - - searchParams, err := NewSearchParams(idx, params, selector, nil) - if err != nil { - return nil, nil, err - } - defer searchParams.Delete() - - if c := C.faiss_IndexBinary_search_with_params( - idx.idxBinary, - C.idx_t(nq), - (*C.uint8_t)(&x[0]), +func (idx *IndexImpl) Search(x []float32, k int64) (distances []float32, labels []int64, err error) { + n := len(x) / idx.D() + distances = make([]float32, int64(n)*k) + labels = make([]int64, int64(n)*k) + if c := C.faiss_Index_search( + idx.indexPtr, + C.idx_t(n), + (*C.float)(&x[0]), C.idx_t(k), - searchParams.sp, - (*C.int32_t)(&distances[0]), + (*C.float)(&distances[0]), (*C.idx_t)(&labels[0]), ); c != 0 { err = getLastError() } - return distances, labels, nil + return distances, labels, err } -func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, - labels []int64, err error) { - if len(exclude) == 0 && params == nil { - return idx.SearchBinary(x, k) - } +func (idx *IndexImpl) Ntotal() int64 { + return int64(C.faiss_Index_ntotal(idx.indexPtr)) +} - d := idx.D() - nq := (len(x) * 8) / d +// SearchWithoutIDs performs a search without ID filtering +func (idx *IndexImpl) SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) ( + []float32, []int64, error) { + if params == nil && len(exclude) == 0 { + return idx.Search(x, k) + } - distances = make([]int32, int64(nq)*k) - labels = make([]int64, int64(nq)*k) + nq := len(x) / idx.d + distances := make([]float32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) var selector *C.FaissIDSelector if len(exclude) > 0 { @@ -493,93 +538,86 @@ func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int6 } defer searchParams.Delete() - if c := C.faiss_IndexBinary_search_with_params( - idx.idxBinary, - C.idx_t(nq), - (*C.uint8_t)(&x[0]), - C.idx_t(k), - searchParams.sp, - (*C.int32_t)(&distances[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() - } + distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) return distances, labels, err } -func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, - params json.RawMessage) (distances []float32, labels []int64, err error, -) { - includeSelector, err := NewIDSelectorBatch(include) - if err != nil { - return nil, nil, err - } - defer includeSelector.Delete() - - searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) - if err != nil { - return nil, nil, err - } - defer searchParams.Delete() +// ----------------------------------------------------------------------------- - distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) - return +// RangeSearchResult is the result of a range search. +type RangeSearchResult struct { + rsr *C.FaissRangeSearchResult } -func (idx *faissIndex) Reconstruct(key int64) (recons []float32, err error) { - rv := make([]float32, idx.D()) - if c := C.faiss_Index_reconstruct( - idx.idx, - C.idx_t(key), - (*C.float)(&rv[0]), - ); c != 0 { - err = getLastError() - } +// Nq returns the number of queries. +func (r *RangeSearchResult) Nq() int { + return int(C.faiss_RangeSearchResult_nq(r.rsr)) +} - return rv, err +// Lims returns a slice containing start and end indices for queries in the +// distances and labels slices returned by Labels. +func (r *RangeSearchResult) Lims() []int { + var lims *C.size_t + C.faiss_RangeSearchResult_lims(r.rsr, &lims) + length := r.Nq() + 1 + return (*[1 << 30]int)(unsafe.Pointer(lims))[:length:length] } -func (idx *faissIndex) ReconstructBatch(keys []int64, recons []float32) ([]float32, error) { - var err error - n := int64(len(keys)) - if c := C.faiss_Index_reconstruct_batch( - idx.idx, - C.idx_t(n), - (*C.idx_t)(&keys[0]), - (*C.float)(&recons[0]), - ); c != 0 { - err = getLastError() - } +// Labels returns the unsorted IDs and respective distances for each query. +// The result for query i is labels[lims[i]:lims[i+1]]. +func (r *RangeSearchResult) Labels() (labels []int64, distances []float32) { + lims := r.Lims() + length := lims[len(lims)-1] + var clabels *C.idx_t + var cdist *C.float + C.faiss_RangeSearchResult_labels(r.rsr, &clabels, &cdist) + labels = (*[1 << 30]int64)(unsafe.Pointer(clabels))[:length:length] + distances = (*[1 << 30]float32)(unsafe.Pointer(cdist))[:length:length] + return +} - return recons, err +// Delete frees the memory associated with r. +func (r *RangeSearchResult) Delete() { + C.faiss_RangeSearchResult_free(r.rsr) } -func (i *IndexImpl) MergeFrom(other Index, add_id int64) error { - if impl, ok := other.(*IndexImpl); ok { - return i.Index.MergeFrom(impl.Index, add_id) +// IndexFactory creates a new index using the factory function +func IndexFactory(d int, description string, metric int) (FloatIndex, error) { + var cDescription *C.char + if description != "" { + cDescription = C.CString(description) + defer C.free(unsafe.Pointer(cDescription)) } - return fmt.Errorf("merge not support") -} -func (idx *faissIndex) MergeFrom(other Index, add_id int64) (err error) { - otherIdx, ok := other.(*faissIndex) - if !ok { - return fmt.Errorf("merge api not supported") + var idx *C.FaissIndex + if c := C.faiss_index_factory(&idx, C.int(d), cDescription, C.FaissMetricType(metric)); c != 0 { + return nil, getLastError() } - if c := C.faiss_Index_merge_from( - idx.idx, - otherIdx.idx, - (C.idx_t)(add_id), - ); c != 0 { - err = getLastError() + return &IndexImpl{ + indexPtr: idx, + d: d, + metric: metric, + }, nil +} + +func (idx *IndexImpl) Close() { + if idx.indexPtr != nil { + C.faiss_Index_free(idx.indexPtr) + idx.indexPtr = nil } +} - return err +func (idx *IndexImpl) D() int { + return idx.d } -func (idx *faissIndex) RangeSearch(x []float32, radius float32) ( +func (idx *IndexImpl) MetricType() int { + return idx.metric +} + +func (idx *IndexImpl) RangeSearch(x []float32, radius float32) ( *RangeSearchResult, error, ) { n := len(x) / idx.D() @@ -588,7 +626,7 @@ func (idx *faissIndex) RangeSearch(x []float32, radius float32) ( return nil, getLastError() } if c := C.faiss_Index_range_search( - idx.idx, + idx.indexPtr, C.idx_t(n), (*C.float)(&x[0]), C.float(radius), @@ -599,114 +637,75 @@ func (idx *faissIndex) RangeSearch(x []float32, radius float32) ( return &RangeSearchResult{rsr}, nil } -func (idx *faissIndex) Reset() error { - if c := C.faiss_Index_reset(idx.idx); c != 0 { +func (idx *IndexImpl) Reset() error { + if c := C.faiss_Index_reset(idx.indexPtr); c != 0 { return getLastError() } return nil } -func (idx *faissIndex) RemoveIDs(sel *IDSelector) (int, error) { +func (idx *IndexImpl) RemoveIDs(sel *IDSelector) (int, error) { var nRemoved C.size_t - if c := C.faiss_Index_remove_ids(idx.idx, sel.sel, &nRemoved); c != 0 { + if c := C.faiss_Index_remove_ids(idx.indexPtr, sel.sel, &nRemoved); c != 0 { return 0, getLastError() } return int(nRemoved), nil } -func (idx *faissIndex) Close() { - C.faiss_Index_free(idx.idx) - C.faiss_IndexBinary_free(idx.idxBinary) -} - -func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) ( - distances []float32, labels []int64, err error, -) { - n := len(x) / idx.D() - distances = make([]float32, int64(n)*k) - labels = make([]int64, int64(n)*k) - - if c := C.faiss_Index_search_with_params( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - C.idx_t(k), - searchParams, - (*C.float)(&distances[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() +func (idx *IndexImpl) MergeFrom(other IndexImpl, add_id int64) error { + if c := C.faiss_Index_merge_from(idx.indexPtr, other.cPtrFloat(), C.idx_t(add_id)); c != 0 { + return getLastError() } - - return -} - -// ----------------------------------------------------------------------------- - -// RangeSearchResult is the result of a range search. -type RangeSearchResult struct { - rsr *C.FaissRangeSearchResult + return nil } -// Nq returns the number of queries. -func (r *RangeSearchResult) Nq() int { - return int(C.faiss_RangeSearchResult_nq(r.rsr)) +// Float-specific operations +func (idx *IndexImpl) Add(vectors []float32) error { + n := len(vectors) / idx.d + if c := C.faiss_Index_add(idx.indexPtr, C.idx_t(n), (*C.float)(&vectors[0])); c != 0 { + return getLastError() + } + return nil } -// Lims returns a slice containing start and end indices for queries in the -// distances and labels slices returned by Labels. -func (r *RangeSearchResult) Lims() []int { - var lims *C.size_t - C.faiss_RangeSearchResult_lims(r.rsr, &lims) - length := r.Nq() + 1 - return (*[1 << 30]int)(unsafe.Pointer(lims))[:length:length] +func (idx *IndexImpl) cPtrFloat() *C.FaissIndex { + return idx.indexPtr } -// Labels returns the unsorted IDs and respective distances for each query. -// The result for query i is labels[lims[i]:lims[i+1]]. -func (r *RangeSearchResult) Labels() (labels []int64, distances []float32) { - lims := r.Lims() - length := lims[len(lims)-1] - var clabels *C.idx_t - var cdist *C.float - C.faiss_RangeSearchResult_labels(r.rsr, &clabels, &cdist) - labels = (*[1 << 30]int64)(unsafe.Pointer(clabels))[:length:length] - distances = (*[1 << 30]float32)(unsafe.Pointer(cdist))[:length:length] - return +func (idx *IndexImpl) AddWithIDs(vectors []float32, xids []int64) error { + n := len(vectors) / idx.d + if c := C.faiss_Index_add_with_ids(idx.indexPtr, C.idx_t(n), (*C.float)(&vectors[0]), (*C.idx_t)(&xids[0])); c != 0 { + return getLastError() + } + return nil } -// Delete frees the memory associated with r. -func (r *RangeSearchResult) Delete() { - C.faiss_RangeSearchResult_free(r.rsr) -} +func (idx *IndexImpl) Reconstruct(key int64) (recons []float32, err error) { + rv := make([]float32, idx.D()) + if c := C.faiss_Index_reconstruct( + idx.indexPtr, + C.idx_t(key), + (*C.float)(&rv[0]), + ); c != 0 { + err = getLastError() + } -// IndexImpl is an abstract structure for an index. -type IndexImpl struct { - Index + return rv, err } -// IndexFactory builds a composite index. -// description is a comma-separated list of components. -func IndexFactory(d int, description string, metric int) (*IndexImpl, error) { - cdesc := C.CString(description) - defer C.free(unsafe.Pointer(cdesc)) - var idx faissIndex - c := C.faiss_index_factory(&idx.idx, C.int(d), cdesc, C.FaissMetricType(metric)) - if c != 0 { - return nil, getLastError() +func (idx *IndexImpl) ReconstructBatch(keys []int64, recons []float32) ([]float32, error) { + var err error + n := int64(len(keys)) + if c := C.faiss_Index_reconstruct_batch( + idx.indexPtr, + C.idx_t(n), + (*C.idx_t)(&keys[0]), + (*C.float)(&recons[0]), + ); c != 0 { + err = getLastError() } - return &IndexImpl{&idx}, nil -} -func IndexBinaryFactory(d int, description string, metric int) (*IndexImpl, error) { - cdesc := C.CString(description) - defer C.free(unsafe.Pointer(cdesc)) - var idx faissIndex - c := C.faiss_index_binary_factory(&idx.idxBinary, C.int(d), cdesc) - if c != 0 { - return nil, getLastError() - } - return &IndexImpl{&idx}, nil + return recons, err } func SetOMPThreads(n uint) { diff --git a/index_flat.go b/index_flat.go index a97d6f8..27bf0b5 100644 --- a/index_flat.go +++ b/index_flat.go @@ -10,20 +10,20 @@ import "unsafe" // IndexFlat is an index that stores the full vectors and performs exhaustive // search. type IndexFlat struct { - Index + *IndexImpl } // NewIndexFlat creates a new flat index. func NewIndexFlat(d int, metric int) (*IndexFlat, error) { - var idx faissIndex + var idx *C.FaissIndex if c := C.faiss_IndexFlat_new_with( - &idx.idx, + &idx, C.idx_t(d), C.FaissMetricType(metric), ); c != 0 { return nil, getLastError() } - return &IndexFlat{&idx}, nil + return &IndexFlat{&IndexImpl{indexPtr: idx, d: d, metric: metric}}, nil } // NewIndexFlatIP creates a new flat index with the inner product metric type. @@ -41,16 +41,16 @@ func NewIndexFlatL2(d int) (*IndexFlat, error) { func (idx *IndexFlat) Xb() []float32 { var size C.size_t var ptr *C.float - C.faiss_IndexFlat_xb(idx.cPtr(), &ptr, &size) + C.faiss_IndexFlat_xb(idx.cPtrFloat(), &ptr, &size) return (*[1 << 30]float32)(unsafe.Pointer(ptr))[:size:size] } // AsFlat casts idx to a flat index. // AsFlat panics if idx is not a flat index. func (idx *IndexImpl) AsFlat() *IndexFlat { - ptr := C.faiss_IndexFlat_cast(idx.cPtr()) + ptr := C.faiss_IndexFlat_cast(idx.cPtrFloat()) if ptr == nil { panic("index is not a flat index") } - return &IndexFlat{&faissIndex{idx: ptr}} + return &IndexFlat{&IndexImpl{indexPtr: ptr, d: idx.d, metric: idx.metric}} } diff --git a/index_io.go b/index_io.go index 88425c1..3ac50c3 100644 --- a/index_io.go +++ b/index_io.go @@ -5,29 +5,89 @@ package faiss #include #include #include +#include +#include */ import "C" import ( + "fmt" "unsafe" ) -// WriteIndex writes an index to a file. -func WriteIndex(idx Index, filename string) error { +const ( + IOFlagMmap = C.FAISS_IO_FLAG_MMAP + IOFlagReadOnly = C.FAISS_IO_FLAG_READ_ONLY + IOFlagReadMmap = C.FAISS_IO_FLAG_READ_MMAP | C.FAISS_IO_FLAG_ONDISK_IVF + IOFlagSkipPrefetch = C.FAISS_IO_FLAG_SKIP_PREFETCH +) + +// WriteIndex writes a float index to a file +func WriteIndex(idx FloatIndex, filename string) error { + impl, ok := idx.(*IndexImpl) + if !ok { + return fmt.Errorf("invalid index type for float index serialization") + } + + cfname := C.CString(filename) + defer C.free(unsafe.Pointer(cfname)) + if c := C.faiss_write_index_fname(impl.cPtrFloat(), cfname); c != 0 { + return getLastError() + } + return nil +} + +// WriteBinaryIndex writes a binary index to a file +func WriteBinaryIndex(idx BinaryIndex, filename string) error { + impl, ok := idx.(*BinaryIndexImpl) + if !ok { + return fmt.Errorf("invalid index type for binary index serialization") + } + cfname := C.CString(filename) defer C.free(unsafe.Pointer(cfname)) - if c := C.faiss_write_index_fname(idx.cPtr(), cfname); c != 0 { + if c := C.faiss_write_index_binary_fname(impl.cPtrBinary(), cfname); c != 0 { return getLastError() } return nil } -func WriteIndexIntoBuffer(idx Index) ([]byte, error) { +// ReadIndex reads a float index from a file +func ReadIndex(filename string, ioflags int) (FloatIndex, error) { + cfname := C.CString(filename) + defer C.free(unsafe.Pointer(cfname)) + var idx *C.FaissIndex + if c := C.faiss_read_index_fname(cfname, C.int(ioflags), &idx); c != 0 { + return nil, getLastError() + } + return &IndexImpl{ + indexPtr: idx, + d: int(C.faiss_Index_d(idx)), + metric: int(C.faiss_Index_metric_type(idx)), + }, nil +} + +// ReadBinaryIndex reads a binary index from a file +func ReadBinaryIndex(filename string, ioflags int) (BinaryIndex, error) { + cfname := C.CString(filename) + defer C.free(unsafe.Pointer(cfname)) + var idx *C.FaissIndexBinary + if c := C.faiss_read_index_binary_fname(cfname, C.int(ioflags), &idx); c != 0 { + return nil, getLastError() + } + return &BinaryIndexImpl{ + indexPtr: idx, + d: int(C.faiss_IndexBinary_d(idx)), + metric: int(C.faiss_IndexBinary_metric_type(idx)), + }, nil +} + +func WriteIndexIntoBuffer(idx FloatIndex) ([]byte, error) { // the values to be returned by the faiss APIs tempBuf := (*C.uchar)(nil) bufSize := C.size_t(0) if c := C.faiss_write_index_buf( - idx.cPtr(), + idx.cPtrFloat(), &bufSize, &tempBuf, ); c != 0 { @@ -79,7 +139,7 @@ func WriteIndexIntoBuffer(idx Index) ([]byte, error) { return rv, nil } -func WriteBinaryIndexIntoBuffer(idx Index) ([]byte, error) { +func WriteBinaryIndexIntoBuffer(idx BinaryIndex) ([]byte, error) { // the values to be returned by the faiss APIs tempBuf := (*C.uchar)(nil) bufSize := C.size_t(0) @@ -137,64 +197,48 @@ func WriteBinaryIndexIntoBuffer(idx Index) ([]byte, error) { return rv, nil } -func ReadIndexFromBuffer(buf []byte, ioflags int) (*IndexImpl, error) { +// ReadIndexFromBuffer deserializes a float index from a byte buffer +func ReadIndexFromBuffer(buf []byte, ioFlags int) (FloatIndex, error) { ptr := (*C.uchar)(unsafe.Pointer(&buf[0])) size := C.size_t(len(buf)) // the idx var has C.FaissIndex within the struct which is nil as of now. - var idx faissIndex + var idx *C.FaissIndex if c := C.faiss_read_index_buf(ptr, size, - C.int(ioflags), - &idx.idx); c != 0 { + C.int(ioFlags), + &idx); c != 0 { return nil, getLastError() } ptr = nil - // after exiting the faiss_read_index_buf, the ref count to the memory allocated - // for the freshly created faiss::index becomes 1 (held by idx.idx of type C.FaissIndex) - // this is allocated on the C heap, so not available for golang's GC. hence needs - // to be cleaned up after the index is longer being used - to be done at zap layer. - return &IndexImpl{&idx}, nil + return &IndexImpl{ + indexPtr: idx, + d: int(C.faiss_Index_d(idx)), + metric: int(C.faiss_Index_metric_type(idx)), + }, nil } -func ReadBinaryIndexFromBuffer(buf []byte, ioflags int) (*IndexImpl, error) { +// ReadBinaryIndexFromBuffer deserializes a binary index from a byte buffer +func ReadBinaryIndexFromBuffer(buf []byte, ioFlags int) (BinaryIndex, error) { ptr := (*C.uchar)(unsafe.Pointer(&buf[0])) size := C.size_t(len(buf)) // the idx var has C.FaissIndex within the struct which is nil as of now. - var idxBinary faissIndex + var idxBinary *C.FaissIndexBinary if c := C.faiss_read_index_binary_buf(ptr, size, - C.int(ioflags), - &idxBinary.idxBinary); c != 0 { + C.int(ioFlags), + &idxBinary); c != 0 { return nil, getLastError() } ptr = nil - // after exiting the faiss_read_index_buf, the ref count to the memory allocated - // for the freshly created faiss::index becomes 1 (held by idx.idx of type C.FaissIndex) - // this is allocated on the C heap, so not available for golang's GC. hence needs - // to be cleaned up after the index is longer being used - to be done at zap layer. - return &IndexImpl{&idxBinary}, nil -} - -const ( - IOFlagMmap = C.FAISS_IO_FLAG_MMAP - IOFlagReadOnly = C.FAISS_IO_FLAG_READ_ONLY - IOFlagReadMmap = C.FAISS_IO_FLAG_READ_MMAP | C.FAISS_IO_FLAG_ONDISK_IVF - IOFlagSkipPrefetch = C.FAISS_IO_FLAG_SKIP_PREFETCH -) - -// ReadIndex reads an index from a file. -func ReadIndex(filename string, ioflags int) (*IndexImpl, error) { - cfname := C.CString(filename) - defer C.free(unsafe.Pointer(cfname)) - var idx faissIndex - if c := C.faiss_read_index_fname(cfname, C.int(ioflags), &idx.idx); c != 0 { - return nil, getLastError() - } - return &IndexImpl{&idx}, nil + return &BinaryIndexImpl{ + indexPtr: idxBinary, + d: int(C.faiss_IndexBinary_d(idxBinary)), + metric: int(C.faiss_IndexBinary_metric_type(idxBinary)), + }, nil } diff --git a/index_ivf.go b/index_ivf.go index e063b6d..bc1fcfb 100644 --- a/index_ivf.go +++ b/index_ivf.go @@ -13,15 +13,33 @@ import ( "fmt" ) -func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { - // Try to get either regular or binary IVF pointer - ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) +// IndexIVF represents an IVF index +type IndexIVF struct { + *IndexImpl +} + +func (idx *IndexImpl) GetNProbe() int32 { + ivfPtr := C.faiss_IndexIVF_cast(idx.cPtrFloat()) + if ivfPtr == nil { + return 0 + } + return int32(C.faiss_IndexIVF_nprobe(ivfPtr)) +} + +func (idx *BinaryIndexImpl) GetNProbe() int32 { ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + if ivfPtrBinary == nil { + return 0 + } + return int32(C.faiss_IndexBinaryIVF_nprobe(ivfPtrBinary)) +} - // If we have a regular IVF index - if ivfPtr != nil { - if c := C.faiss_IndexIVF_set_direct_map( - ivfPtr, +func (idx *BinaryIndexImpl) SetDirectMap(mapType int) (err error) { + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) + // If we have a binary IVF index + if ivfPtrBinary != nil { + if c := C.faiss_IndexBinaryIVF_set_direct_map( + ivfPtrBinary, C.int(mapType), ); c != 0 { err = getLastError() @@ -29,10 +47,17 @@ func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { return err } - // If we have a binary IVF index - if ivfPtrBinary != nil { - if c := C.faiss_IndexBinaryIVF_set_direct_map( - ivfPtrBinary, + return fmt.Errorf("unable to set direct map") +} + +func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { + // Try to get either regular or binary IVF pointer + ivfPtr := C.faiss_IndexIVF_cast(idx.cPtrFloat()) + + // If we have a regular IVF index + if ivfPtr != nil { + if c := C.faiss_IndexIVF_set_direct_map( + ivfPtr, C.int(mapType), ); c != 0 { err = getLastError() @@ -45,8 +70,7 @@ func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { } func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { - - ptr := C.faiss_IndexIDMap2_cast(idx.cPtr()) + ptr := C.faiss_IndexIDMap2_cast(idx.indexPtr) if ptr == nil { return nil, fmt.Errorf("index is not a id map") } @@ -56,37 +80,23 @@ func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { return nil, fmt.Errorf("couldn't retrieve the sub index") } - return &IndexImpl{&faissIndex{idx: subIdx}}, nil + return &IndexImpl{indexPtr: subIdx}, nil } -// pass nprobe to be set as index time option for IVF/BIVF indexes only. -// varying nprobe impacts recall but with an increase in latency. -func (idx *IndexImpl) SetNProbe(nprobe int32) error { - ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) - if ivfPtr != nil { - C.faiss_IndexIVF_set_nprobe(ivfPtr, C.size_t(nprobe)) - return nil - } - +func (idx *BinaryIndexImpl) SetNProbe(nprobe int32) { ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) - if ivfPtrBinary != nil { - C.faiss_IndexBinaryIVF_set_nprobe(ivfPtrBinary, C.size_t(nprobe)) - return nil + if ivfPtrBinary == nil { + return } - - return fmt.Errorf("unable to get nprobe") + C.faiss_IndexBinaryIVF_set_nprobe(ivfPtrBinary, C.size_t(nprobe)) } -func (idx *IndexImpl) GetNProbe() int32 { - ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) - if ivfPtr != nil { - return int32(C.faiss_IndexIVF_nprobe(ivfPtr)) - } - - ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) - if ivfPtrBinary != nil { - return int32(C.faiss_IndexBinaryIVF_nprobe(ivfPtrBinary)) +// pass nprobe to be set as index time option for IVF/BIVF indexes only. +// varying nprobe impacts recall but with an increase in latency. +func (idx *IndexImpl) SetNProbe(nprobe int32) { + ivfPtr := C.faiss_IndexIVF_cast(idx.cPtrFloat()) + if ivfPtr == nil { + return } - - return 0 + C.faiss_IndexIVF_set_nprobe(ivfPtr, C.size_t(nprobe)) } diff --git a/search_params.go b/search_params.go index f210f7b..a8ccb07 100644 --- a/search_params.go +++ b/search_params.go @@ -12,16 +12,18 @@ import ( "fmt" ) +// SearchParams represents search parameters for both float and binary indexes type SearchParams struct { - sp *C.FaissSearchParameters + sp *C.FaissSearchParameters + idx Index } -// Delete frees the memory associated with s. -func (s *SearchParams) Delete() { - if s == nil || s.sp == nil { - return +// Delete frees the search parameters +func (sp *SearchParams) Delete() { + if sp.sp != nil { + C.faiss_SearchParameters_free(sp.sp) + sp.sp = nil } - C.faiss_SearchParameters_free(s.sp) } type searchParamsIVF struct { @@ -74,15 +76,19 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp) - // check if the index is IVF and set the search params - if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx != nil { - nlist = int(C.faiss_IndexIVF_nlist(ivfIdx)) - nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx)) - nvecs = int(C.faiss_Index_ntotal(idx.cPtr())) - } else if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { - nlist = int(C.faiss_IndexBinaryIVF_nlist(bivfIdx)) - nprobe = int(C.faiss_IndexBinaryIVF_nprobe(bivfIdx)) - nvecs = int(C.faiss_IndexBinary_ntotal(idx.cPtrBinary())) + switch idx.(type) { + case FloatIndex: + ivfIdx := idx.(*IndexImpl) + nlist = int(C.faiss_IndexIVF_nlist(ivfIdx.cPtrFloat())) + nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx.cPtrFloat())) + nvecs = int(C.faiss_Index_ntotal(ivfIdx.cPtrFloat())) + case BinaryIndex: + ivfIdx := idx.(*BinaryIndexImpl) + nlist = int(C.faiss_IndexBinaryIVF_nlist(ivfIdx.cPtrBinary())) + nprobe = int(C.faiss_IndexBinaryIVF_nprobe(ivfIdx.cPtrBinary())) + nvecs = int(C.faiss_IndexBinary_ntotal(ivfIdx.cPtrBinary())) + default: + return nil, fmt.Errorf("unsupported index type") } if defaultParams != nil { @@ -112,19 +118,14 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100)) } // else, maxCodes will be set to the default value of 0, which means no limit - ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()) - bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()) - - if ivfIdx != nil || bivfIdx != nil { - if c := C.faiss_SearchParametersIVF_new_with( - &rv.sp, - sel, - C.size_t(nprobe), - C.size_t(maxCodes), - ); c != 0 { - rv.Delete() - return nil, fmt.Errorf("failed to create faiss IVF search params") - } + if c := C.faiss_SearchParametersIVF_new_with( + &rv.sp, + sel, + C.size_t(nprobe), + C.size_t(maxCodes), + ); c != 0 { + rv.Delete() + return nil, fmt.Errorf("failed to create faiss IVF search params") } return rv, nil } From b3fff7eba2dae3191c1571058602681c6ce46861 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Thu, 12 Jun 2025 19:10:19 +0530 Subject: [PATCH 12/13] bivf pre-filtering utils --- index.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++- index_ivf.go | 1 + search_params.go | 9 ++++ 3 files changed, 119 insertions(+), 2 deletions(-) diff --git a/index.go b/index.go index 19b61d3..d2c6862 100644 --- a/index.go +++ b/index.go @@ -7,6 +7,7 @@ package faiss #include #include #include +#include #include #include #include @@ -54,6 +55,14 @@ type BinaryIndex interface { SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) ([]int32, []int64, error) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, labels []int64, err error) + + ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) + ObtainClustersWithDistancesFromIVFIndex(x []uint8, centroidIDs []int64) ( + []int64, []int32, error) + // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs + SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, + minEligibleCentroids int, k int64, x []uint8, centroidDis []int32, + params json.RawMessage) ([]int32, []int64, error) } // FloatIndex defines methods specific to float-based FAISS indexes @@ -156,6 +165,61 @@ func (idx *BinaryIndexImpl) Close() { } } +func (idx *BinaryIndexImpl) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) { + if !idx.IsIVFIndex() { + return nil, fmt.Errorf("index is not an IVF index") + } + clusterIDs := make([]int64, len(vecIDs)) + if c := C.faiss_get_lists_for_keys_binary( + idx.indexPtr, + (*C.idx_t)(unsafe.Pointer(&vecIDs[0])), + (C.size_t)(len(vecIDs)), + (*C.idx_t)(unsafe.Pointer(&clusterIDs[0])), + ); c != 0 { + return nil, getLastError() + } + rv := make(map[int64]int64, len(vecIDs)) + for _, v := range clusterIDs { + rv[v]++ + } + return rv, nil +} + +func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, centroidIDs []int64) ( + []int64, []int32, error) { + // Selector to include only the centroids whose IDs are part of 'centroidIDs'. + includeSelector, err := NewIDSelectorBatch(centroidIDs) + if err != nil { + return nil, nil, err + } + defer includeSelector.Delete() + + params, err := NewSearchParams(idx, json.RawMessage{}, includeSelector.Get(), nil) + if err != nil { + return nil, nil, err + } + defer params.Delete() + + // Populate these with the centroids and their distances. + centroidDistances := make([]int32, len(centroidIDs)) + + n := len(x) / idx.D() + + c := C.faiss_Search_closest_eligible_centroids_binary( + idx.indexPtr, + (C.idx_t)(n), + (*C.uint8_t)(&x[0]), + (C.idx_t)(len(centroidIDs)), + (*C.int32_t)(¢roidDistances[0]), + (*C.idx_t)(¢roidIDs[0]), + params.sp) + if c != 0 { + return nil, nil, getLastError() + } + + return centroidIDs, centroidDistances, nil +} + func (idx *BinaryIndexImpl) Size() uint64 { return 0 } @@ -263,7 +327,7 @@ func (idx *BinaryIndexImpl) Train(vectors []uint8) error { } func (idx *BinaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, labels []int64, err error) { - if len(exclude) == 0 && params == nil { + if len(exclude) == 0 && len(params) == 0 { return idx.SearchBinary(x, k) } @@ -302,6 +366,49 @@ func (idx *BinaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude [ return distances, labels, err } +func (idx *BinaryIndexImpl) SearchClustersFromIVFIndex(selector Selector, + eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x []uint8, + centroidDis []int32, params json.RawMessage) ([]int32, []int64, error) { + tempParams := &defaultSearchParamsIVF{ + Nlist: len(eligibleCentroidIDs), + // Have to override nprobe so that more clusters will be searched for this + // query, if required. + Nprobe: minEligibleCentroids, + } + + searchParams, err := NewSearchParams(idx, params, selector.Get(), tempParams) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + n := (len(x) * 8) / idx.D() + + distances := make([]int32, int64(n)*k) + labels := make([]int64, int64(n)*k) + + effectiveNprobe := getNProbeFromSearchParams(searchParams) + + eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe] + centroidDis = centroidDis[:effectiveNprobe] + + if c := C.faiss_IndexBinaryIVF_search_preassigned_with_params( + idx.indexPtr, + (C.idx_t)(n), + (*C.uint8_t)(&x[0]), + (C.idx_t)(k), + (*C.idx_t)(&eligibleCentroidIDs[0]), + (*C.int32_t)(¢roidDis[0]), + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + (C.int)(0), + searchParams.sp); c != 0 { + return nil, nil, getLastError() + } + + return distances, labels, nil +} + // Factory functions func IndexBinaryFactory(d int, description string, metric int) (BinaryIndex, error) { return NewBinaryIndexImpl(d, description, metric) @@ -469,7 +576,7 @@ func (idx *IndexImpl) SearchWithIDs(queries []float32, k int64, include []int64, } defer includeSelector.Delete() - searchParams, err := NewSearchParams(nil, params, includeSelector.Get(), nil) + searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) if err != nil { return nil, nil, err } diff --git a/index_ivf.go b/index_ivf.go index bc1fcfb..4ecea58 100644 --- a/index_ivf.go +++ b/index_ivf.go @@ -6,6 +6,7 @@ package faiss #include #include #include +#include #include */ import "C" diff --git a/search_params.go b/search_params.go index a8ccb07..2210a02 100644 --- a/search_params.go +++ b/search_params.go @@ -71,6 +71,15 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector, return rv, nil } + if !idx.IsIVFIndex() { + c := C.faiss_SearchParameters_new_with_selector(&rv.sp, sel) + if c != 0 { + rv.Delete() + return nil, fmt.Errorf("failed to create faiss search params") + } + return rv, nil + } + var nlist, nprobe, nvecs, maxCodes int var ivfParams searchParamsIVF From 5774cace5e19e23446384bbf06c19a214c6237c9 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Mon, 30 Jun 2025 00:06:02 +0530 Subject: [PATCH 13/13] utils to re-use centroids from ivf index --- index.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ search_params.go | 1 + 2 files changed, 60 insertions(+) diff --git a/index.go b/index.go index d2c6862..9191f3c 100644 --- a/index.go +++ b/index.go @@ -37,6 +37,7 @@ type Index interface { IsIVFIndex() bool SetNProbe(nprobe int32) GetNProbe() int32 + GetNlist() int SetDirectMap(directMapType int) error Close() @@ -63,6 +64,9 @@ type BinaryIndex interface { SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x []uint8, centroidDis []int32, params json.RawMessage) ([]int32, []int64, error) + + BinaryQuantizer() BinaryIndex + SetIsTrained(isTrained bool) } // FloatIndex defines methods specific to float-based FAISS indexes @@ -89,6 +93,8 @@ type FloatIndex interface { Reconstruct(key int64) (recons []float32, err error) ReconstructBatch(ids []int64, vectors []float32) ([]float32, error) + GetCentroids() ([]float32, error) + // Applicable only to IVF indexes: Returns a map where the keys // are cluster IDs and the values represent the count of input vectors that belong // to each cluster. @@ -121,6 +127,8 @@ type FloatIndex interface { // RemoveIDs removes the vectors specified by sel from the index. // Returns the number of elements removed and error. RemoveIDs(sel *IDSelector) (int, error) + + Quantizer() *C.FaissIndex } // IndexImpl represents a float vector index @@ -220,6 +228,50 @@ func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, c return centroidIDs, centroidDistances, nil } +func (idx *IndexImpl) GetNlist() int { + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil { + return int(C.faiss_IndexIVF_nlist(ivfIdx)) + } + return 0 +} + +func (idx *IndexImpl) GetCentroids() ([]float32, error) { + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil { + ivfCentroids := make([]float32, idx.D()*idx.GetNlist()) + C.faiss_IndexIVF_get_centroids(ivfIdx, (*C.float)(&ivfCentroids[0])) + return ivfCentroids, nil + } + return nil, fmt.Errorf("index is not an IVF index") +} + +func (idx *IndexImpl) Quantizer() *C.FaissIndex { + if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil { + return C.faiss_IndexIVF_quantizer(ivfIdx) + } + return nil +} + +func (idx *BinaryIndexImpl) SetIsTrained(isTrained bool) { + if isTrained { + C.faiss_IndexBinaryIVF_set_is_trained((*C.FaissIndexBinaryIVF)(idx.cPtrBinary()), + C.int(1)) + } else { + C.faiss_IndexBinaryIVF_set_is_trained((*C.FaissIndexBinaryIVF)(idx.cPtrBinary()), + C.int(0)) + } +} + +func (idx *BinaryIndexImpl) BinaryQuantizer() BinaryIndex { + if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil { + return &BinaryIndexImpl{ + indexPtr: C.faiss_IndexBinaryIVF_quantizer(bivfIdx), + d: idx.d, + metric: idx.metric, + } + } + return nil +} + func (idx *BinaryIndexImpl) Size() uint64 { return 0 } @@ -244,6 +296,13 @@ func (idx *BinaryIndexImpl) IsIVFIndex() bool { return C.faiss_IndexBinaryIVF_cast(idx.indexPtr) != nil } +func (idx *BinaryIndexImpl) GetNlist() int { + if ivfIdx := C.faiss_IndexBinaryIVF_cast(idx.indexPtr); ivfIdx != nil { + return int(C.faiss_IndexBinaryIVF_nlist(ivfIdx)) + } + return 0 +} + // Binary-specific operations func (idx *BinaryIndexImpl) TrainBinary(vectors []uint8) error { n := (len(vectors) * 8) / idx.d diff --git a/search_params.go b/search_params.go index 2210a02..bb79888 100644 --- a/search_params.go +++ b/search_params.go @@ -4,6 +4,7 @@ package faiss #include #include #include +#include #include */ import "C"