Skip to content

Commit dbcef01

Browse files
authored
Merge pull request #514 from lukaszsmolinski/develop
Fix incorrect results in bruteforce with filter
2 parents 5a8fd34 + 39bc6af commit dbcef01

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

hnswlib/bruteforce.h

+7-17
Original file line numberDiff line numberDiff line change
@@ -107,27 +107,17 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
107107
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
108108
assert(k <= cur_element_count);
109109
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
110-
if (cur_element_count == 0) return topResults;
111-
for (int i = 0; i < k; i++) {
110+
dist_t lastdist = std::numeric_limits<dist_t>::max();
111+
for (int i = 0; i < cur_element_count; i++) {
112112
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
113-
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
114-
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
115-
topResults.emplace(dist, label);
116-
}
117-
}
118-
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
119-
for (int i = k; i < cur_element_count; i++) {
120-
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
121-
if (dist <= lastdist) {
113+
if (dist <= lastdist || topResults.size() < k) {
122114
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
123115
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
124116
topResults.emplace(dist, label);
125-
}
126-
if (topResults.size() > k)
127-
topResults.pop();
128-
129-
if (!topResults.empty()) {
130-
lastdist = topResults.top().first;
117+
if (topResults.size() > k)
118+
topResults.pop();
119+
if (!topResults.empty())
120+
lastdist = topResults.top().first;
131121
}
132122
}
133123
}

python_bindings/bindings.cpp

+33-10
Original file line numberDiff line numberDiff line change
@@ -871,16 +871,39 @@ class BFIndex {
871871
CustomFilterFunctor idFilter(filter);
872872
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;
873873

874-
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
875-
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
876-
(void*)items.data(row), k, p_idFilter);
877-
for (int i = k - 1; i >= 0; i--) {
878-
auto& result_tuple = result.top();
879-
data_numpy_d[row * k + i] = result_tuple.first;
880-
data_numpy_l[row * k + i] = result_tuple.second;
881-
result.pop();
882-
}
883-
});
874+
if (!normalize) {
875+
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
876+
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
877+
(void*)items.data(row), k, p_idFilter);
878+
if (result.size() != k)
879+
throw std::runtime_error(
880+
"Cannot return the results in a contiguous 2D array. There are not enough elements.");
881+
for (int i = k - 1; i >= 0; i--) {
882+
auto& result_tuple = result.top();
883+
data_numpy_d[row * k + i] = result_tuple.first;
884+
data_numpy_l[row * k + i] = result_tuple.second;
885+
result.pop();
886+
}
887+
});
888+
} else {
889+
std::vector<float> norm_array(num_threads * features);
890+
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
891+
size_t start_idx = threadId * dim;
892+
normalize_vector((float*)items.data(row), norm_array.data() + start_idx);
893+
894+
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
895+
(void*)(norm_array.data() + start_idx), k, p_idFilter);
896+
if (result.size() != k)
897+
throw std::runtime_error(
898+
"Cannot return the results in a contiguous 2D array. There are not enough elements.");
899+
for (int i = k - 1; i >= 0; i--) {
900+
auto& result_tuple = result.top();
901+
data_numpy_d[row * k + i] = result_tuple.first;
902+
data_numpy_l[row * k + i] = result_tuple.second;
903+
result.pop();
904+
}
905+
});
906+
}
884907
}
885908

886909
py::capsule free_when_done_l(data_numpy_l, [](void *f) {

0 commit comments

Comments
 (0)