Skip to content

Commit

Permalink
testing index
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Nov 20, 2023
1 parent 6004a2c commit 15a20c8
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 9 deletions.
33 changes: 25 additions & 8 deletions flatnav_python/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <pybind11/pybind11.h>
#include <string>
#include <utility>
#include <iostream>
#include <ostream>
#include <vector>

#include <flatnav/DistanceInterface.h>
Expand All @@ -19,20 +21,23 @@ using flatnav::SquaredL2Distance;
namespace py = pybind11;

template <typename dist_t, typename label_t> class PythonIndex {
const uint32_t NUM_LOG_STEPS = 1000;
private:
int _dim, label_id;
bool _verbose;
Index<dist_t, label_t> *_index;

public:
typedef std::pair<py::array_t<float>, py::array_t<label_t>>
DistancesLabelsPair;

explicit PythonIndex(std::unique_ptr<Index<dist_t, label_t>> index)
: _dim(index->dataDimension()), label_id(0), _index(index.get()) {}
: _dim(index->dataDimension()), label_id(0), _verbose(false),
_index(index.get()) {}

PythonIndex(std::shared_ptr<DistanceInterface<dist_t>> distance, int dim,
int dataset_size, int max_edges_per_node)
: _dim(dim), label_id(0),
int dataset_size, int max_edges_per_node, bool verbose = false)
: _dim(dim), label_id(0), _verbose(verbose),
_index(new Index<dist_t, label_t>(
/* dist = */ std::move(distance),
/* dataset_size = */ dataset_size,
Expand Down Expand Up @@ -62,13 +67,19 @@ template <typename dist_t, typename label_t> class PythonIndex {
throw std::invalid_argument("Data has incorrect dimensions.");
}

std::clog << "[num-vectors] = " << num_vectors << std::flush;
std::clog << "[data_dim] = " << data_dim << std::flush;
if (labels.is_none()) {
for (size_t vec_index = 0; vec_index < num_vectors; vec_index++) {
this->_index->add(/* data = */ (void *)data.data(vec_index),
/* label = */ label_id,
/* ef_construction = */ ef_construction);
if (_verbose && vec_index % NUM_LOG_STEPS == 0) {
std::clog << "." << std::flush;
}
label_id++;
}
std::clog << std::endl;
return;
}

Expand All @@ -84,7 +95,12 @@ template <typename dist_t, typename label_t> class PythonIndex {
this->_index->add(/* data = */ (void *)data.data(vec_index),
/* label = */ label_id,
/* ef_construction = */ ef_construction);

if (_verbose && vec_index % NUM_LOG_STEPS == 0) {
std::clog << "." << std::flush;
}
}
std::clog << std::endl;
}

DistancesLabelsPair
Expand Down Expand Up @@ -182,7 +198,7 @@ void bindIndexMethods(py::class_<IndexType> &index_class) {
},
py::arg("algorithm"),
"Perform graph re-ordering based on the given re-ordering strategy.")
.def_property_read_only(
.def_property_readonly(
"max_edges_per_node",
[](IndexType &index_type) {
return index_type.getIndex()->maxEdgesPerNode();
Expand All @@ -192,19 +208,20 @@ void bindIndexMethods(py::class_<IndexType> &index_class) {
}

py::object createIndex(const std::string &distance_type, int dim,
int dataset_size, int max_edges_per_node) {
int dataset_size, int max_edges_per_node,
bool verbose = false) {
auto dist_type = distance_type;
std::transform(dist_type.begin(), dist_type.end(), dist_type.begin(),
[](unsigned char c) { return std::tolower(c); });

if (dist_type == "l2") {
auto distance = std::make_shared<SquaredL2Distance>(/* dim = */ dim);
return py::cast(new L2FlatNavIndex(std::move(distance), dim, dataset_size,
max_edges_per_node));
max_edges_per_node, verbose));
} else if (dist_type == "angular") {
auto distance = std::make_shared<InnerProductDistance>(/* dim = */ dim);
return py::cast(new InnerProductFlatNavIndex(
std::move(distance), dim, dataset_size, max_edges_per_node));
std::move(distance), dim, dataset_size, max_edges_per_node, verbose));
}
throw std::invalid_argument("Invalid distance type: `" + dist_type +
"` during index construction. Valid options "
Expand All @@ -214,7 +231,7 @@ py::object createIndex(const std::string &distance_type, int dim,
void defineIndexSubmodule(py::module_ &index_submodule) {
index_submodule.def("index_factory", &createIndex, py::arg("distance_type"),
py::arg("dim"), py::arg("dataset_size"),
py::arg("max_edges_per_node"),
py::arg("max_edges_per_node"), py::arg("verbose") = false,
"Creates a FlatNav index given the corresponding "
"parameters. The `distance_type` argument determines the "
"kind of index created (either L2Index or IPIndex)");
Expand Down
2 changes: 2 additions & 0 deletions flatnav_python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ setuptools = "68.2.2"

[tool.poetry.dev-dependencies]
black = "^23.11.0"
pytest = "^7.4.3"
numpy = "^1.26.2"


[build-system]
Expand Down
2 changes: 1 addition & 1 deletion flatnav_python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
os.path.join(CURRENT_DIR, "..", "external", "cereal", "include"),
],
# Ignoring the `Wno-sign-compare` which warns you when you compare int with something like
# uint64_t.
# uint64_t.
extra_compile_args=["-Wno-sign-compare", "-fopenmp"],
)
]
Expand Down
61 changes: 61 additions & 0 deletions flatnav_python/test_index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,64 @@
import flatnav
from flatnav.index import index_factory
from flatnav.index import L2Index, IPIndex
from typing import Union
import pytest
import numpy as np
import time


def generate_random_data(dataset_length: int, dim: int) -> np.ndarray:
return np.random.rand(dataset_length, dim)


def create_index(
distance_type: str, dim: int, dataset_size: int, max_edges_per_node: int
) -> Union[L2Index, IPIndex]:
index = index_factory(
distance_type=distance_type,
dim=dim,
dataset_size=dataset_size,
max_edges_per_node=max_edges_per_node,
verbose=True
)
if not (
isinstance(index, flatnav.index.L2Index)
or isinstance(index, flatnav.index.IPIndex)
):
raise RuntimeError("Invalid index.")

return index


def test_flatnav_l2_index():
dataset_to_index = generate_random_data(dataset_length=60_000, dim=784)
queries = generate_random_data(dataset_length=10_000, dim=784)
index = create_index(
distance_type="l2",
dim=dataset_to_index.shape[1],
dataset_size=len(dataset_to_index),
max_edges_per_node=32,
)

assert hasattr(index, "max_edges_per_node")
assert index.max_edges_per_node == 32

start = time.time()
index.add(data=dataset_to_index, ef_construction=64)
end = time.time()

print(f"Indexing time = {end - start}")


start = time.time()
distances, node_ids = index.search(queries=queries, ef_search=64, K=100)
end = time.time()
print(f"Querying time = {end - start}")

assert distances.shape == node_ids.shape


"""
Indexing time = 693.3694415092468
Querying time = 48.112215518951416
"""

0 comments on commit 15a20c8

Please sign in to comment.