@@ -871,16 +871,39 @@ class BFIndex {
871
871
CustomFilterFunctor idFilter (filter);
872
872
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr ;
873
873
874
- ParallelFor (0 , rows, num_threads, [&](size_t row, size_t threadId) {
875
- std::priority_queue<std::pair<dist_t , hnswlib::labeltype >> result = alg->searchKnn (
876
- (void *)items.data (row), k, p_idFilter);
877
- for (int i = k - 1 ; i >= 0 ; i--) {
878
- auto & result_tuple = result.top ();
879
- data_numpy_d[row * k + i] = result_tuple.first ;
880
- data_numpy_l[row * k + i] = result_tuple.second ;
881
- result.pop ();
882
- }
883
- });
874
+ if (!normalize) {
875
+ ParallelFor (0 , rows, num_threads, [&](size_t row, size_t threadId) {
876
+ std::priority_queue<std::pair<dist_t , hnswlib::labeltype >> result = alg->searchKnn (
877
+ (void *)items.data (row), k, p_idFilter);
878
+ if (result.size () != k)
879
+ throw std::runtime_error (
880
+ " Cannot return the results in a contiguous 2D array. There are not enough elements." );
881
+ for (int i = k - 1 ; i >= 0 ; i--) {
882
+ auto & result_tuple = result.top ();
883
+ data_numpy_d[row * k + i] = result_tuple.first ;
884
+ data_numpy_l[row * k + i] = result_tuple.second ;
885
+ result.pop ();
886
+ }
887
+ });
888
+ } else {
889
+ std::vector<float > norm_array (num_threads * features);
890
+ ParallelFor (0 , rows, num_threads, [&](size_t row, size_t threadId) {
891
+ size_t start_idx = threadId * dim;
892
+ normalize_vector ((float *)items.data (row), norm_array.data () + start_idx);
893
+
894
+ std::priority_queue<std::pair<dist_t , hnswlib::labeltype >> result = alg->searchKnn (
895
+ (void *)(norm_array.data () + start_idx), k, p_idFilter);
896
+ if (result.size () != k)
897
+ throw std::runtime_error (
898
+ " Cannot return the results in a contiguous 2D array. There are not enough elements." );
899
+ for (int i = k - 1 ; i >= 0 ; i--) {
900
+ auto & result_tuple = result.top ();
901
+ data_numpy_d[row * k + i] = result_tuple.first ;
902
+ data_numpy_l[row * k + i] = result_tuple.second ;
903
+ result.pop ();
904
+ }
905
+ });
906
+ }
884
907
}
885
908
886
909
py::capsule free_when_done_l (data_numpy_l, [](void *f) {
0 commit comments