Skip to content

Commit

Permalink
[wip]
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Dec 6, 2023
1 parent ef7d6aa commit adc7c8d
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 44 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ set(HEADERS
${PROJECT_SOURCE_DIR}/flatnav/util/GorderPriorityQueue.h
${PROJECT_SOURCE_DIR}/flatnav/util/Reordering.h
${PROJECT_SOURCE_DIR}/flatnav/util/SIMDDistanceSpecializations.h
${PROJECT_SOURCE_DIR}/flatnav/util/ParallelConstructs.h
${PROJECT_SOURCE_DIR}/flatnav/DistanceInterface.h
${PROJECT_SOURCE_DIR}/flatnav/Index.h
${PROJECT_SOURCE_DIR}/quantization/ProductQuantization.h
Expand Down
106 changes: 62 additions & 44 deletions flatnav/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cstring>
#include <flatnav/DistanceInterface.h>
#include <flatnav/util/ExplicitSet.h>
#include <flatnav/util/ParallelConstructs.h>
#include <flatnav/util/Reordering.h>
#include <flatnav/util/SIMDDistanceSpecializations.h>
#include <fstream>
Expand Down Expand Up @@ -79,7 +80,7 @@ template <typename dist_t, typename label_t> class Index {
// after benchmarking - it's slightly more cache-efficient than others.
size_t _node_size_bytes;
size_t _max_node_count; // Determines size of internal pre-allocated memory
size_t _cur_num_nodes;
std::atomic<size_t> _cur_num_nodes;
std::shared_ptr<DistanceInterface<dist_t>> _distance;
std::mutex _cur_num_nodes_global_lock;
std::condition_variable _cur_num_nodes_global_cv;
Expand All @@ -95,9 +96,8 @@ template <typename dist_t, typename label_t> class Index {
friend class cereal::access;

template <typename Archive> void serialize(Archive &archive) {
archive(_M, _data_size_bytes, _node_size_bytes, _max_node_count,
_cur_num_nodes, *_distance, _visited_nodes,
*_sharded_visited_nodes);
archive(_M, _data_size_bytes, _node_size_bytes, _max_node_count, *_distance,
_visited_nodes, *_sharded_visited_nodes);

// Serialize the allocated memory for the index & query.
archive(
Expand Down Expand Up @@ -182,29 +182,50 @@ template <typename dist_t, typename label_t> class Index {

void addParallel(void *data, std::vector<label_t> &labels,
int ef_construction, int num_initializations = 100) {

if (num_initializations <= 0) {
throw std::invalid_argument(
"num_initializations must be greater than 0.");
}
uint32_t thread_count =
_num_threads <= 0 ? std::thread::hardware_concurrency() : _num_threads;

std::vector<std::thread> thread_pool(thread_count);
uint32_t batch_size = labels.size() / thread_count;
uint32_t total_num_nodes = labels.size();
uint32_t data_dimension = _distance->dimension();

parallelFor(/* start = */ 0, /* end = */ total_num_nodes,
/* num_threads = */ thread_count, /* fn = */
[&](uint32_t row, uint32_t thread_id) {
void *vector = (float *)data + (row * data_dimension);
label_t label = labels[row];
concurrentAdd(vector, label, ef_construction,
num_initializations);
});
}

std::cout << "Starting parallel add"
<< "\n"
<< std::flush;
void concurrentAdd(void *data, label_t &label, int ef_construction,
int num_initializations = 100) {
// Lock the global counter to prevent multiple threads from
// trying to insert the same node.
// std::unique_lock<std::mutex> lock(_cur_num_nodes_global_lock);

for (uint32_t thread_id = 0; thread_id < thread_count; thread_id++) {
void *current_batch =
(float *)data + (thread_id * batch_size * _data_size_bytes);
uint32_t label_start = thread_id * batch_size;
thread_pool[thread_id] = std::thread(
&Index::addParallelBatch, this, current_batch, batch_size,
label_start, std::ref(labels), ef_construction, num_initializations);
if (_cur_num_nodes >= _max_node_count) {
throw std::runtime_error("Maximum number of nodes reached. Consider "
"increasing the `max_node_count` parameter to "
"create a larger index.");
}
auto entry_node = initializeSearch(data, num_initializations);
node_id_t new_node_id;
allocateNode(data, label, new_node_id);

for (uint32_t thread_id = 0; thread_id < thread_count; thread_id++) {
thread_pool[thread_id].join();
}
// if (new_node_id == 0) {
// return;
// }

// PriorityQueue neighbors = beamSearch(
// /* query = */ data, /* entry_node = */ entry_node,
// /* buffer_size = */ ef_construction);
// selectNeighbors(/* neighbors = */ neighbors);
// connectNeighbors(neighbors, new_node_id);
}

void addParallelBatch(void *batch, uint32_t batch_size, uint32_t label_start,
Expand Down Expand Up @@ -290,8 +311,8 @@ template <typename dist_t, typename label_t> class Index {
* @param num_initializations The number of random initializations to use.
*/
std::vector<dist_label_t> search(const void *query, const int K,
int ef_search,
int num_initializations = 100) {
int ef_search,
int num_initializations = 100) {
node_id_t entry_node = initializeSearch(query, num_initializations);
PriorityQueue neighbors =
concurrentBeamSearch(/* query = */ query,
Expand Down Expand Up @@ -363,10 +384,11 @@ template <typename dist_t, typename label_t> class Index {

// 1. Deserialize metadata
archive(index->_M, index->_data_size_bytes, index->_node_size_bytes,
index->_max_node_count, index->_cur_num_nodes, *dist,
index->_visited_nodes, *sharded_visited_nodes);
index->_max_node_count, *dist, index->_visited_nodes,
*sharded_visited_nodes);
index->_distance = dist;
index->_sharded_visited_nodes = sharded_visited_nodes;
index->_cur_num_nodes = 0;

// 3. Allocate memory using deserialized metadata
index->_index_memory =
Expand Down Expand Up @@ -442,24 +464,24 @@ template <typename dist_t, typename label_t> class Index {
* @param new_node_id The id of the new node.
*/
void allocateNode(void *data, label_t &label, node_id_t &new_node_id) {
if (_cur_num_nodes >= _max_node_count) {
throw std::runtime_error("Maximum number of nodes reached. Consider "
"increasing the `max_node_count` parameter to "
"create a larger index.");
}
new_node_id = _cur_num_nodes;

_distance->transformData(
/* destination = */ (void *)getNodeData(new_node_id),
/* src = */ data);
*(getNodeLabel(_cur_num_nodes)) = label;

node_id_t *links = getNodeLinks(_cur_num_nodes);
{
std::cout << "allocateNode: " << new_node_id << std::endl;
std::unique_lock<std::mutex> lock(_cur_num_nodes_global_lock);
new_node_id = _cur_num_nodes.fetch_add(1);
}
_distance->transformData(
/* destination = */ (void *)getNodeData(new_node_id),
/* src = */ data);
std::cout << "Setting label for node " << new_node_id << std::endl;
*(getNodeLabel(new_node_id)) = label;

node_id_t *links = getNodeLinks(new_node_id);
std::cout << "Inserting links now" << std::endl;
for (uint32_t i = 0; i < _M; i++) {
links[i] = _cur_num_nodes;
links[i] = new_node_id;
}

_cur_num_nodes++;
std::cout << "Finished inserting links" << std::endl;
}

inline void swapNodes(node_id_t a, node_id_t b, void *temp_data,
Expand Down Expand Up @@ -740,11 +762,7 @@ template <typename dist_t, typename label_t> class Index {
throw std::invalid_argument(
"num_initializations must be greater than 0.");
}

int step_size = _cur_num_nodes / num_initializations;
if (step_size <= 0) {
step_size = 1;
}
int step_size = _cur_num_nodes ? _cur_num_nodes / num_initializations : 1;

float min_dist = std::numeric_limits<float>::max();
node_id_t entry_node = 0;
Expand Down
61 changes: 61 additions & 0 deletions flatnav/util/ParallelConstructs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include <atomic>
#include <cstdint>
#include <exception>
#include <mutex>
#include <thread>
#include <vector>

namespace flatnav {

template <typename Function>
void parallelFor(uint32_t start, uint32_t end, uint32_t num_threads,
Function fn) {
if (num_threads <= 0) {
throw std::invalid_argument("Invalid number of threads");
}

if (num_threads == 1) {
for (uint32_t i = start; i < end; i++) {
fn(i, 0);
}
return;
}
std::vector<std::thread> threads;
std::atomic<uint32_t> current(start);

std::exception_ptr last_exception = nullptr;
std::mutex last_exception_mutex;

for (uint32_t thread_id = 0; thread_id < num_threads; thread_id++) {
threads.push_back(std::thread([&, thread_id] {
while (true) {
uint32_t current_value = current.fetch_add(1);
if (current_value >= end) {
break;
}

try {
fn(current_value, thread_id);
} catch (...) {
std::unique_lock<std::mutex> lock(last_exception_mutex);
last_exception = std::current_exception();

current = end;
break;
}
}
}));
}

for (auto &thread : threads) {
thread.join();
}

if (last_exception) {
std::rethrow_exception(last_exception);
}
}

} // namespace flatnav

0 comments on commit adc7c8d

Please sign in to comment.