Skip to content

Commit

Permalink
Add load and store index to the bindings, update test recall
Browse files Browse the repository at this point in the history
  • Loading branch information
alonre24 committed Jul 25, 2021
1 parent d4c881d commit 079c71e
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 79 deletions.
71 changes: 3 additions & 68 deletions examples/searchKnnCloserFirst_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,69 +9,12 @@

#include <vector>
#include <iostream>
#include <thread>

namespace
{

using idx_t = hnswlib::labeltype;

template<class Function>
inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
if (numThreads <= 0) {
numThreads = std::thread::hardware_concurrency();
}

if (numThreads == 1) {
for (size_t id = start; id < end; id++) {
fn(id, 0);
}
} else {
std::vector<std::thread> threads;
std::atomic<size_t> current(start);

// keep track of exceptions in threads
// https://stackoverflow.com/a/32428427/1713196
std::exception_ptr lastException = nullptr;
std::mutex lastExceptMutex;

for (size_t threadId = 0; threadId < numThreads; ++threadId) {
threads.push_back(std::thread([&, threadId] {
while (true) {
size_t id = current.fetch_add(1);

if ((id >= end)) {
break;
}

try {
fn(id, threadId);
} catch (...) {
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
lastException = std::current_exception();
/*
* This will work even when current is the largest value that
* size_t can fit, because fetch_add returns the previous value
* before the increment (what will result in overflow
* and produce 0 instead of current + 1).
*/
current = end;
break;
}
}
}));
}
for (auto &thread : threads) {
thread.join();
}
if (lastException) {
std::rethrow_exception(lastException);
}
}


}

void test() {
int d = 4;
idx_t n = 100;
Expand All @@ -97,18 +40,10 @@ void test() {
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);

// for (size_t i = 0; i < n; ++i) {
// alg_brute->addPoint(data.data() + d * i, i);
// alg_hnsw->addPoint(data.data() + d * i, i);
// }

ParallelFor(0, n, 4, [&](size_t i, size_t threadId) {
alg_hnsw->addPoint(data.data() + d * i, i);
});

ParallelFor(0, n, 4, [&](size_t i, size_t threadId) {
for (size_t i = 0; i < n; ++i) {
alg_brute->addPoint(data.data() + d * i, i);
});
alg_hnsw->addPoint(data.data() + d * i, i);
}

// test searchKnnCloserFirst of BruteforceSearch
for (size_t j = 0; j < nq; ++j) {
Expand Down
3 changes: 2 additions & 1 deletion hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ namespace hnswlib {
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);




};

void removePoint(labeltype cur_external) {
Expand Down Expand Up @@ -97,7 +99,6 @@ namespace hnswlib {
dist_t lastdist = topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);

if (dist <= lastdist) {
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
Expand Down
Empty file added python_bindings/__init__.py
Empty file.
33 changes: 27 additions & 6 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,21 +743,39 @@ class BFIndex {
throw std::runtime_error("wrong dimensionality of the labels");
}
{
int start = 0;
py::gil_scoped_release l;

std::vector<float> norm_array(dim);
for (size_t i = start; i < rows; i++) {
alg->addPoint((void *) items.data(i), (size_t) i);
for (size_t row = 0; row < rows; row++) {
size_t id = ids.size() ? ids.at(row) : cur_l + row;
if (!normalize) {
alg->addPoint((void *) items.data(row), (size_t) id);
} else {
float normalized_vector[dim];
normalize_vector((float *)items.data(row), normalized_vector);
alg->addPoint((void *) normalized_vector, (size_t) id);
}
}
cur_l+=rows;
}
}

void deletedVector(size_t label) {
void deleteVector(size_t label) {
alg->removePoint(label);
}

void saveIndex(const std::string &path_to_index) {
alg->saveIndex(path_to_index);
}

void loadIndex(const std::string &path_to_index, size_t max_elements) {
if (alg) {
std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated.";
delete alg;
}
alg = new hnswlib::BruteforceSearch<dist_t>(space, path_to_index);
cur_l = alg->cur_element_count;
index_inited = true;
}

py::object knnQuery_return_numpy(py::object input, size_t k = 1) {

py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
Expand Down Expand Up @@ -885,6 +903,9 @@ PYBIND11_PLUGIN(hnswlib) {
.def("init_index", &BFIndex<float>::init_new_index, py::arg("max_elements"))
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1)
.def("add_items", &BFIndex<float>::addItems, py::arg("data"), py::arg("ids") = py::none())
.def("delete_vector", &BFIndex<float>::deleteVector, py::arg("label"))
.def("save_index", &BFIndex<float>::saveIndex, py::arg("path_to_index"))
.def("load_index", &BFIndex<float>::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0)
.def("__repr__", [](const BFIndex<float> &a) {
return "<hnswlib.BFIndex(space='" + a.space_name + "', dim="+std::to_string(a.dim)+")>";
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hnswlib
import numpy as np

dim = 128
dim = 32
num_elements = 100000
k = 10
nun_queries = 10
Expand All @@ -24,12 +24,12 @@
# M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M)
# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction

hnsw_index.init_index(max_elements=num_elements, ef_construction=10, M=6)
hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16)
bf_index.init_index(max_elements=num_elements)

# Controlling the recall for hnsw by setting ef:
# higher ef leads to better accuracy, but slower search
hnsw_index.set_ef(10)
hnsw_index.set_ef(200)

# Set number of threads used during batch search/construction in hnsw
# By default using all available cores
Expand All @@ -42,7 +42,7 @@
print("Indices built")

# Generating query data
query_data = np.float32(np.random.random((10, dim)))
query_data = np.float32(np.random.random((nun_queries, dim)))

# Query the elements and measure recall:
labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k)
Expand All @@ -58,3 +58,31 @@
break

print("recall is :", float(correct)/(k*nun_queries))

# test serializing the brute force index
index_path = 'bf_index.bin'
print("Saving index to '%s'" % index_path)
bf_index.save_index(index_path)
del bf_index

# Re-initiating, loading the index
bf_index = hnswlib.BFIndex(space='l2', dim=dim)

print("\nLoading index from '%s'\n" % index_path)
bf_index.load_index(index_path)

# Query the brute force index again to verify that we get the same results
labels_bf, distances_bf = bf_index.knn_query(query_data, k)

# Measure recall
correct = 0
for i in range(nun_queries):
for label in labels_hnsw[i]:
for correct_label in labels_bf[i]:
if label == correct_label:
correct += 1
break

print("recall after reloading is :", float(correct)/(k*nun_queries))


0 comments on commit 079c71e

Please sign in to comment.