Skip to content

Commit

Permalink
add methods for constructing graph from outdegree table
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Dec 8, 2023
1 parent 2e9f808 commit c154a05
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 30 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ set_target_properties(FLAT_NAV_LIB PROPERTIES LINKER_LANGUAGE CXX)

if(BUILD_EXAMPLES)
message(STATUS "Building examples for Flatnav")
foreach(CONSTRUCT_EXEC construct_npy query_npy cereal_tests)
foreach(CONSTRUCT_EXEC construct_npy query_npy cereal_tests load_mtx)
add_executable(${CONSTRUCT_EXEC}
${PROJECT_SOURCE_DIR}/tools/${CONSTRUCT_EXEC}.cpp ${HEADERS})
add_dependencies(${CONSTRUCT_EXEC} FLAT_NAV_LIB)
Expand Down
121 changes: 92 additions & 29 deletions flatnav/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
#include <limits>
#include <memory>
#include <optional>
#include <quantization/ProductQuantization.h>
#include <queue>
#include <utility>
#include <vector>

using flatnav::quantization::ProductQuantizer;

namespace flatnav {

// dist_t: A distance function implementing DistanceInterface.
Expand All @@ -30,6 +27,10 @@ template <typename dist_t, typename label_t> class Index {
public:
typedef std::pair<float, label_t> dist_label_t;

// Constructor for serialization with cereal. Do not use outside of
// this class.
Index() = default;

/**
* @brief Construct a new Index object for approximate near neighbor search
*
Expand All @@ -53,12 +54,75 @@ template <typename dist_t, typename label_t> class Index {
_index_memory = new char[index_memory_size];
}

// Constructor for serialization with cereal. Do not use outside of
// this class.
Index() = default;
/**
* @brief Construct a new Index object using a pre-computed outdegree table.
*
* @param dist A distance metric for the specific index
* distance. Options include l2(euclidean) and inner product.
* @param outdegree_table A table of outdegrees for each node in the graph.
* Each vector in the table contains the IDs of the nodes to which it is
* connected.
*/
Index(std::shared_ptr<DistanceInterface<dist_t>> dist,
std::vector<std::vector<uint32_t>> &outdegree_table)
: _M(outdegree_table[0].size()), _max_node_count(outdegree_table.size()),
_cur_num_nodes(0), _distance(dist),
_visited_nodes(outdegree_table.size() + 1),
_outdegree_table(std::move(outdegree_table)) {

_data_size_bytes = _distance->dataSize();
_node_size_bytes =
_data_size_bytes + (sizeof(node_id_t) * _M) + sizeof(label_t);

size_t index_memory_size = _node_size_bytes * _max_node_count;
_index_memory = new char[index_memory_size];
}

~Index() { delete[] _index_memory; }

void buildGraphLinks() {
if (!_outdegree_table.has_value()) {
throw std::runtime_error("Cannot build graph links without outdegree "
"table. Please construct index with outdegree "
"table.");
}

for (node_id_t node = 0; node < _outdegree_table.value().size(); node++) {
node_id_t *links = getNodeLinks(node);
for (int i = 0; i < _M; i++) {
if (_outdegree_table.value()[node].size() < _M) {
links[i] = node;
}
else {
auto linkvalue = _outdegree_table.value()[node][i];
links[i] = linkvalue;
}
}
}
}

std::vector<std::vector<uint32_t>> getGraphOutdegreeTable() {
std::vector<std::vector<uint32_t>> outdegree_table(_cur_num_nodes);
for (node_id_t node = 0; node < _cur_num_nodes; node++) {
node_id_t *links = getNodeLinks(node);
for (int i = 0; i < _M; i++) {
if (links[i] != node) {
outdegree_table[node].push_back(links[i]);
}
}
}
return outdegree_table;
}

/**
* @brief Add a new vector to the index.
*
* @param data The vector to add.
* @param label The label (meta-data) of the vector.
* @param ef_construction ef parameter in the HNSW paper.
* @param num_initializations Parameter determining how to choose an entry
* point.
*/
void add(void *data, label_t &label, int ef_construction,
int num_initializations = 100) {
// initialization must happen before alloc due to a bug where
Expand Down Expand Up @@ -116,6 +180,27 @@ template <typename dist_t, typename label_t> class Index {
return results;
}

// TODO: Add optional argument here for quantized data vector.
void allocateNode(void *data, label_t &label, uint32_t &new_node_id) {
if (_cur_num_nodes >= _max_node_count) {
throw std::runtime_error("Maximum number of nodes reached. Consider "
"increasing the `max_node_count` parameter to "
"create a larger index.");
}
new_node_id = _cur_num_nodes;

_distance->transformData(/* destination = */ getNodeData(new_node_id),
/* src = */ data);
*(getNodeLabel(_cur_num_nodes)) = label;

node_id_t *links = getNodeLinks(_cur_num_nodes);
for (uint32_t i = 0; i < _M; i++) {
links[i] = _cur_num_nodes;
}

_cur_num_nodes++;
}

void reorderGOrder(const int window_size = 5) {
std::vector<std::vector<node_id_t>> outdegree_table(_cur_num_nodes);
for (node_id_t node = 0; node < _cur_num_nodes; node++) {
Expand Down Expand Up @@ -244,8 +329,7 @@ template <typename dist_t, typename label_t> class Index {
// Remembers which nodes we've visited, to avoid re-computing distances.
// Might be a caching problem in beamSearch - needs to be profiled.
VisitedSet _visited_nodes;

// std::unique_ptr<ProductQuantizer<dist_t>> _product_quantizer;
std::optional<std::vector<std::vector<uint32_t>>> _outdegree_table;

friend class cereal::access;

Expand Down Expand Up @@ -274,27 +358,6 @@ template <typename dist_t, typename label_t> class Index {
return reinterpret_cast<label_t *>(location);
}

// TODO: Add optional argument here for quantized data vector.
void allocateNode(void *data, label_t &label, node_id_t &new_node_id) {
if (_cur_num_nodes >= _max_node_count) {
throw std::runtime_error("Maximum number of nodes reached. Consider "
"increasing the `max_node_count` parameter to "
"create a larger index.");
}
new_node_id = _cur_num_nodes;

_distance->transformData(/* destination = */ getNodeData(new_node_id),
/* src = */ data);
*(getNodeLabel(_cur_num_nodes)) = label;

node_id_t *links = getNodeLinks(_cur_num_nodes);
for (uint32_t i = 0; i < _M; i++) {
links[i] = _cur_num_nodes;
}

_cur_num_nodes++;
}

inline void swapNodes(node_id_t a, node_id_t b, void *temp_data,
node_id_t *temp_links, label_t *temp_label) {

Expand Down
138 changes: 138 additions & 0 deletions tools/load_mtx.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@

#include "cnpy.h"
#include <flatnav/DistanceInterface.h>
#include <flatnav/Index.h>
#include <flatnav/distances/SquaredL2Distance.h>

#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>

struct Graph {
std::vector<std::vector<uint32_t>> adjacency_list;
int num_vertices;
};

// Function to load a graph from a Matrix Market file
Graph loadGraphFromMatrixMarket(const char *filename) {
std::ifstream input_file;
input_file.open(filename);

if (!input_file.is_open()) {
std::cerr << "Error opening file" << std::endl;
exit(1);
}

std::string line;
// Skip the header
while (std::getline(input_file, line)) {
if (line[0] != '%')
break;
}

std::istringstream iss(line);
int num_vertices, numEdges;
iss >> num_vertices >> num_vertices >> numEdges;

// Initialize graph
Graph graph;
graph.num_vertices = num_vertices;
graph.adjacency_list.resize(num_vertices);

int u, v;
while (input_file >> u >> v) {
// Adjust for 1-based indexing in Matrix Market format
u--;
v--;
graph.adjacency_list[u].push_back(v);
}

input_file.close();
return graph;
}

int main() {
// Replace with your filename
const char *ground_truth_file =
"/Users/blaisemunyampirwa/Desktop/flatnav-experimental/data/"
"sift-128-euclidean/sift-128-euclidean.gtruth.npy";
const char *train_file =
"/Users/blaisemunyampirwa/Desktop/flatnav-experimental/data/"
"sift-128-euclidean/sift-128-euclidean.train.npy";
const char *queries_file =
"/Users/blaisemunyampirwa/Desktop/flatnav-experimental/data/"
"sift-128-euclidean/sift-128-euclidean.test.npy";
const char *sift_mtx =
"/Users/blaisemunyampirwa/Desktop/flatnav-experimental/data/"
"sift-128-euclidean/sift.mtx";

Graph g = loadGraphFromMatrixMarket(sift_mtx);

cnpy::NpyArray trainfile = cnpy::npy_load(train_file);
cnpy::NpyArray queryfile = cnpy::npy_load(queries_file);
cnpy::NpyArray truthfile = cnpy::npy_load(ground_truth_file);
if ((queryfile.shape.size() != 2) || (truthfile.shape.size() != 2)) {
return -1;
}

float *data = trainfile.data<float>();
float *queries = queryfile.data<float>();
int *gtruth = truthfile.data<int>();

std::cout << "constructing the index" << std::endl;
auto distance = std::make_shared<flatnav::SquaredL2Distance>(128);
std::unique_ptr<flatnav::Index<flatnav::SquaredL2Distance, int>> index =
std::make_unique<flatnav::Index<flatnav::SquaredL2Distance, int>>(
distance, g.adjacency_list);

std::vector<int> ef_searches{100, 200};
int num_queries = queryfile.shape[0];
int num_gtruth = truthfile.shape[1];
int dim = 128;
int K = 100;

std::cout << "Adding vectors to the index" << std::endl;
for (int label = 0; label < 1000000; label++) {
float *element = data + (dim * label);
uint32_t node_id;
index->allocateNode(element, label, node_id);
}

std::cout << "Building graph links" << std::endl;
index->buildGraphLinks();

std::cout << "Querying" << std::endl;

for (const auto &ef_search : ef_searches) {
double mean_recall = 0;

auto start_q = std::chrono::high_resolution_clock::now();
for (int i = 0; i < num_queries; i++) {
float *q = queries + dim * i;
int *g = gtruth + num_gtruth * i;

std::vector<std::pair<float, int>> result =
index->search(q, K, ef_search);

double recall = 0;
for (int j = 0; j < K; j++) {
for (int l = 0; l < K; l++) {
if (result[j].second == g[l]) {
recall = recall + 1;
}
}
}
recall = recall / K;
mean_recall = mean_recall + recall;
}
auto stop_q = std::chrono::high_resolution_clock::now();
auto duration_q =
std::chrono::duration_cast<std::chrono::milliseconds>(stop_q - start_q);
std::cout << "[INFO] Mean Recall: " << mean_recall / num_queries
<< ", Duration: " << (float)(duration_q.count()) / num_queries
<< " for ef_search = " << ef_search << std::endl;
}

return 0;
}

0 comments on commit c154a05

Please sign in to comment.