Skip to content

Commit

Permalink
add a method to parallelize search
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Dec 18, 2023
1 parent 7746c23 commit 34c34c9
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 10 deletions.
4 changes: 2 additions & 2 deletions flatnav/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ template <typename dist_t, typename label_t> class Index {
return;
}

flatnav::parallelExecutor(
/* start = */ 0, /* end = */ total_num_nodes,
flatnav::executeInParallel(
/* start_index = */ 0, /* end_index = */ total_num_nodes,
/* num_threads = */ _num_threads, /* function = */
[&](uint32_t row_index) {
void *vector = (float *)data + (row_index * data_dimension);
Expand Down
4 changes: 4 additions & 0 deletions flatnav/distances/InnerProductDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ class InnerProductDistance : public DistanceInterface<InnerProductDistance> {
#endif
if (!(_dimension % 16 == 0)) {
if (_dimension % 4 == 0) {
#if defined(USE_AVX)
_distance_computer = distanceImplInnerProductSIMD4ExtAVX;
#else
_distance_computer = distanceImplInnerProductSIMD4ExtSSE;
#endif
} else if (_dimension > 16) {
_distance_computer = distanceImplInnerProductSIMD16ExtResiduals;
} else if (_dimension > 4) {
Expand Down
4 changes: 2 additions & 2 deletions flatnav/util/ParallelConstructs.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace flatnav {
* @param function
*/
template <typename Function>
void parallelExecutor(uint32_t start_index, uint32_t end_index,
uint32_t num_threads, Function function) {
void executeInParallel(uint32_t start_index, uint32_t end_index,
uint32_t num_threads, Function function) {
if (num_threads == 0) {
throw std::invalid_argument("Invalid number of threads");
}
Expand Down
2 changes: 2 additions & 0 deletions flatnav/util/SIMDDistanceSpecializations.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ static float distanceImplSquaredL2SIMD16ExtAVX512(const void *x, const void *y,
#if defined(USE_AVX)
static float distanceImplInnerProductSIMD4ExtAVX(const void *x, const void *y,
const size_t &dimension) {

std::cout << "[info] invoking distanceImplInnerProductSIMD4ExtAVX" << std::flush;
float *p_x = (float *)(x);
float *p_y = (float *)(y);
float PORTABLE_ALIGN32 temp_res[8];
Expand Down
67 changes: 61 additions & 6 deletions flatnav_python/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <flatnav/Index.h>
#include <flatnav/distances/InnerProductDistance.h>
#include <flatnav/distances/SquaredL2Distance.h>
#include <flatnav/util/ParallelConstructs.h>
#include <iostream>
#include <memory>
#include <ostream>
Expand Down Expand Up @@ -104,17 +105,70 @@ template <typename dist_t, typename label_t> class PyIndex {
}
}

DistancesLabelsPair
search(const py::array_t<float, py::array::c_style | py::array::forcecast>
queries,
int K, int ef_search, int num_initializations = 100) {
DistancesLabelsPair searchParallel(
const py::array_t<float, py::array::c_style | py::array::forcecast>
queries,
int K, int ef_search, int num_initializations = 100) {

size_t num_queries = queries.shape(0);
size_t queries_dim = queries.shape(1);

if (queries.ndim() != 2 || queries_dim != _dim) {
throw std::invalid_argument("Queries have incorrect dimensions.");
}

auto num_threads = _index->getNumThreads();

// No need to spawn any threads if we are in a single-threaded environment
if (num_threads == 1) {
return search(/* queries = */ queries, /* K = */ K,
/* ef_search = */ ef_search,
/* num_initializations = */ num_initializations);
}

label_t *results = new label_t[num_queries * K];
float *distances = new float[num_queries * K];

flatnav::executeInParallel(
/* start_index = */ 0, /* end_index = */ num_queries,
/* num_threads = */ num_threads,
/* function = */ [&](uint32_t row_index) {
auto *query = (const void *)queries.data(row_index);
std::vector<std::pair<float, label_t>> top_k = this->_index->search(
/* query = */ query, /* K = */ K, /* ef_search = */ ef_search,
/* num_initializations = */ num_initializations);

for (uint32_t result_id = 0; result_id < K; result_id++) {
distances[(row_index * K) + result_id] = top_k[result_id].first;
results[(row_index * K) + result_id] = top_k[result_id].second;
}
});
// Allows to transfer ownership to Python
py::capsule free_results_when_done(
results, [](void *ptr) { delete (label_t *)ptr; });
py::capsule free_distances_when_done(
distances, [](void *ptr) { delete (float *)ptr; });

py::array_t<label_t> labels =
py::array_t<label_t>({num_queries, (size_t)K}, // shape of the array
{K * sizeof(label_t), sizeof(label_t)}, // strides
results, // data pointer
free_results_when_done // capsule
);

py::array_t<float> dists = py::array_t<float>(
{num_queries, (size_t)K}, {K * sizeof(float), sizeof(float)}, distances,
free_distances_when_done);

return {dists, labels};
}

DistancesLabelsPair
search(const py::array_t<float, py::array::c_style | py::array::forcecast>
queries,
int K, int ef_search, int num_initializations = 100) {
size_t num_queries = queries.shape(0);
size_t queries_dim = queries.shape(1);
label_t *results = new label_t[num_queries * K];
float *distances = new float[num_queries * K];

Expand Down Expand Up @@ -174,8 +228,9 @@ void bindIndexMethods(py::class_<IndexType> &index_class) {
"many "
"vertices are visited while inserting every vector in the "
"underlying graph structure.")
.def("search", &IndexType::search, py::arg("queries"), py::arg("K"),
py::arg("ef_search"), py::arg("num_initializations") = 100,
.def("search", &IndexType::searchParallel, py::arg("queries"),
py::arg("K"), py::arg("ef_search"),
py::arg("num_initializations") = 100,
"Return top `K` closest data points for every query in the "
"provided `queries`. The results are returned as a Tuple of "
"distances and label ID's. The `ef_search` parameter determines how "
Expand Down

0 comments on commit 34c34c9

Please sign in to comment.