diff --git a/flatnav_python/bindings.cpp b/flatnav_python/bindings.cpp index 1220e02..0338fba 100644 --- a/flatnav_python/bindings.cpp +++ b/flatnav_python/bindings.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include #include @@ -19,8 +21,10 @@ using flatnav::SquaredL2Distance; namespace py = pybind11; template class PythonIndex { + const uint32_t NUM_LOG_STEPS = 1000; private: int _dim, label_id; + bool _verbose; Index *_index; public: @@ -28,11 +32,12 @@ template class PythonIndex { DistancesLabelsPair; explicit PythonIndex(std::unique_ptr> index) - : _dim(index->dataDimension()), label_id(0), _index(index.get()) {} + : _dim(index->dataDimension()), label_id(0), _verbose(false), + _index(index.get()) {} PythonIndex(std::shared_ptr> distance, int dim, - int dataset_size, int max_edges_per_node) - : _dim(dim), label_id(0), + int dataset_size, int max_edges_per_node, bool verbose = false) + : _dim(dim), label_id(0), _verbose(verbose), _index(new Index( /* dist = */ std::move(distance), /* dataset_size = */ dataset_size, @@ -62,13 +67,19 @@ template class PythonIndex { throw std::invalid_argument("Data has incorrect dimensions."); } + std::clog << "[num-vectors] = " << num_vectors << std::flush; + std::clog << "[data_dim] = " << data_dim << std::flush; if (labels.is_none()) { for (size_t vec_index = 0; vec_index < num_vectors; vec_index++) { this->_index->add(/* data = */ (void *)data.data(vec_index), /* label = */ label_id, /* ef_construction = */ ef_construction); + if (_verbose && vec_index % NUM_LOG_STEPS == 0) { + std::clog << "." << std::flush; + } label_id++; } + std::clog << std::endl; return; } @@ -84,7 +95,12 @@ template class PythonIndex { this->_index->add(/* data = */ (void *)data.data(vec_index), /* label = */ label_id, /* ef_construction = */ ef_construction); + + if (_verbose && vec_index % NUM_LOG_STEPS == 0) { + std::clog << "." << std::flush; + } } + std::clog << std::endl; } DistancesLabelsPair @@ -182,7 +198,7 @@ void bindIndexMethods(py::class_ &index_class) { }, py::arg("algorithm"), "Perform graph re-ordering based on the given re-ordering strategy.") - .def_property_read_only( + .def_property_readonly( "max_edges_per_node", [](IndexType &index_type) { return index_type.getIndex()->maxEdgesPerNode(); @@ -192,7 +208,8 @@ void bindIndexMethods(py::class_ &index_class) { } py::object createIndex(const std::string &distance_type, int dim, - int dataset_size, int max_edges_per_node) { + int dataset_size, int max_edges_per_node, + bool verbose = false) { auto dist_type = distance_type; std::transform(dist_type.begin(), dist_type.end(), dist_type.begin(), [](unsigned char c) { return std::tolower(c); }); @@ -200,11 +217,11 @@ py::object createIndex(const std::string &distance_type, int dim, if (dist_type == "l2") { auto distance = std::make_shared(/* dim = */ dim); return py::cast(new L2FlatNavIndex(std::move(distance), dim, dataset_size, - max_edges_per_node)); + max_edges_per_node, verbose)); } else if (dist_type == "angular") { auto distance = std::make_shared(/* dim = */ dim); return py::cast(new InnerProductFlatNavIndex( - std::move(distance), dim, dataset_size, max_edges_per_node)); + std::move(distance), dim, dataset_size, max_edges_per_node, verbose)); } throw std::invalid_argument("Invalid distance type: `" + dist_type + "` during index construction. Valid options " @@ -214,7 +231,7 @@ py::object createIndex(const std::string &distance_type, int dim, void defineIndexSubmodule(py::module_ &index_submodule) { index_submodule.def("index_factory", &createIndex, py::arg("distance_type"), py::arg("dim"), py::arg("dataset_size"), - py::arg("max_edges_per_node"), + py::arg("max_edges_per_node"), py::arg("verbose") = false, "Creates a FlatNav index given the corresponding " "parameters. The `distance_type` argument determines the " "kind of index created (either L2Index or IPIndex)"); diff --git a/flatnav_python/pyproject.toml b/flatnav_python/pyproject.toml index 5e8031b..e1cb0fb 100644 --- a/flatnav_python/pyproject.toml +++ b/flatnav_python/pyproject.toml @@ -19,6 +19,8 @@ setuptools = "68.2.2" [tool.poetry.dev-dependencies] black = "^23.11.0" +pytest = "^7.4.3" +numpy = "^1.26.2" [build-system] diff --git a/flatnav_python/setup.py b/flatnav_python/setup.py index bbc18c4..85bc91c 100644 --- a/flatnav_python/setup.py +++ b/flatnav_python/setup.py @@ -29,7 +29,7 @@ os.path.join(CURRENT_DIR, "..", "external", "cereal", "include"), ], # Ignoring the `Wno-sign-compare` which warns you when you compare int with something like - # uint64_t. + # uint64_t. extra_compile_args=["-Wno-sign-compare", "-fopenmp"], ) ] diff --git a/flatnav_python/test_index.py b/flatnav_python/test_index.py index b610eb4..76f0d4f 100644 --- a/flatnav_python/test_index.py +++ b/flatnav_python/test_index.py @@ -1,3 +1,64 @@ +import flatnav from flatnav.index import index_factory +from flatnav.index import L2Index, IPIndex +from typing import Union +import pytest +import numpy as np +import time +def generate_random_data(dataset_length: int, dim: int) -> np.ndarray: + return np.random.rand(dataset_length, dim) + + +def create_index( + distance_type: str, dim: int, dataset_size: int, max_edges_per_node: int +) -> Union[L2Index, IPIndex]: + index = index_factory( + distance_type=distance_type, + dim=dim, + dataset_size=dataset_size, + max_edges_per_node=max_edges_per_node, + verbose=True + ) + if not ( + isinstance(index, flatnav.index.L2Index) + or isinstance(index, flatnav.index.IPIndex) + ): + raise RuntimeError("Invalid index.") + + return index + + +def test_flatnav_l2_index(): + dataset_to_index = generate_random_data(dataset_length=60_000, dim=784) + queries = generate_random_data(dataset_length=10_000, dim=784) + index = create_index( + distance_type="l2", + dim=dataset_to_index.shape[1], + dataset_size=len(dataset_to_index), + max_edges_per_node=32, + ) + + assert hasattr(index, "max_edges_per_node") + assert index.max_edges_per_node == 32 + + start = time.time() + index.add(data=dataset_to_index, ef_construction=64) + end = time.time() + + print(f"Indexing time = {end - start}") + + + start = time.time() + distances, node_ids = index.search(queries=queries, ef_search=64, K=100) + end = time.time() + print(f"Querying time = {end - start}") + + assert distances.shape == node_ids.shape + + +""" +Indexing time = 693.3694415092468 +Querying time = 48.112215518951416 +""" \ No newline at end of file