diff --git a/CMakeLists.txt b/CMakeLists.txt index e3a8f63..3157628 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,8 @@ endif() set(HEADERS ${PROJECT_SOURCE_DIR}/flatnav/distances/InnerProductDistance.h ${PROJECT_SOURCE_DIR}/flatnav/distances/SquaredL2Distance.h + ${PROJECT_SOURCE_DIR}/flatnav/distances/L2DistanceDispatcher.h + ${PROJECT_SOURCE_DIR}/flatnav/distances/IPDistanceDispatcher.h ${PROJECT_SOURCE_DIR}/flatnav/util/SquaredL2SimdExtensions.h ${PROJECT_SOURCE_DIR}/flatnav/util/InnerProductSimdExtensions.h ${PROJECT_SOURCE_DIR}/flatnav/util/VisitedSetPool.h diff --git a/docs/cpp_api.rst b/docs/cpp_api.rst index 936f006..3503913 100644 --- a/docs/cpp_api.rst +++ b/docs/cpp_api.rst @@ -38,3 +38,9 @@ C++ API Documentation :members: :protected-members: :undoc-members: + +.. doxygenclass:: flatnav::util::DataType + :project: FlatNav + :members: + :protected-members: + :undoc-members: diff --git a/docs/flatnav_python.rst b/docs/flatnav_python.rst index ef6e893..950257e 100644 --- a/docs/flatnav_python.rst +++ b/docs/flatnav_python.rst @@ -3,10 +3,10 @@ FlatNav Index Module This module provides interfaces to create and manipulate FlatNav index structures. -index_factory +create ------------- -.. autofunction:: flatnav.index.index_factory +.. autofunction:: flatnav.index.create Index Classes ------------- diff --git a/experiments/Makefile b/experiments/Makefile index 8c08a5c..cc96e42 100644 --- a/experiments/Makefile +++ b/experiments/Makefile @@ -413,7 +413,7 @@ sift-bench-flatnav: --queries /root/data/sift-128-euclidean/sift-128-euclidean.test.npy \ --gtruth /root/data/sift-128-euclidean/sift-128-euclidean.gtruth.npy \ --index-type flatnav \ - --data-type uint8 \ + --data-type float32 \ --num-node-links 32 \ --ef-construction 30 40 50 100 200 300 \ --ef-search 100 200 300 500 1000 3000 \ @@ -621,6 +621,9 @@ s3-push: cleanup: rm -rf hnswlib-original +test-flatnav: + poetry run python test_flatnav.py + # If passed an invalid argument, print help message %: @echo "Invalid argument: $@" diff --git a/experiments/run-benchmark.py b/experiments/run-benchmark.py index 581144a..2681464 100644 --- a/experiments/run-benchmark.py +++ b/experiments/run-benchmark.py @@ -37,7 +37,7 @@ def compute_metrics( requested_metrics: List[str], - index: Union[flatnav.index.L2Index, flatnav.index.IPIndex, hnswlib.Index], + index: Union[hnswlib.Index, flatnav.index.IndexL2Float, flatnav.index.IndexIPFloat], queries: np.ndarray, ground_truth: np.ndarray, ef_search: int, @@ -169,7 +169,7 @@ def train_index( use_hnsw_base_layer: bool = False, hnsw_base_layer_filename: Optional[str] = None, num_build_threads: int = 1, -) -> Union[flatnav.index.L2Index, flatnav.index.IPIndex, hnswlib.Index]: +) -> Union[flatnav.index.IndexL2Float, flatnav.index.IndexIPFloat, hnswlib.Index]: """ Creates and trains an index on the given dataset. :param train_dataset: The dataset to train the index on. @@ -220,7 +220,7 @@ def train_index( if not os.path.exists(hnsw_base_layer_filename): raise ValueError(f"Failed to create {hnsw_base_layer_filename=}") - index = flatnav.index.index_factory( + index = flatnav.index.create( distance_type=distance_type, index_data_type=FLATNAV_DATA_TYPES[data_type], dim=dim, @@ -239,7 +239,7 @@ def train_index( os.remove(hnsw_base_layer_filename) else: - index = flatnav.index.index_factory( + index = flatnav.index.create( distance_type=distance_type, index_data_type=FLATNAV_DATA_TYPES[data_type], dim=dim, @@ -542,7 +542,7 @@ def plot_all_metrics( create_plot( experiment_runs=experiment_runs, - raw=True, + raw=False, x_scale="linear", y_scale="linear", x_axis_metric=x_metric, diff --git a/flatnav/distances/DistanceInterface.h b/flatnav/distances/DistanceInterface.h index 44e6813..3a58d61 100644 --- a/flatnav/distances/DistanceInterface.h +++ b/flatnav/distances/DistanceInterface.h @@ -2,20 +2,13 @@ #include #include // for size_t -#include #include // for ifstream, ofstream #include #include namespace flatnav::distances { -using util::DataType; -typedef std::function - DistanceFunction; - -typedef std::unique_ptr DistanceFunctionPtr; - -enum class METRIC_TYPE { EUCLIDEAN, INNER_PRODUCT }; +enum class MetricType { L2, IP }; // We use the CRTP to implement static polymorphism on the distance. This is // done to allow for metrics and distance functions that support arbitrary @@ -43,7 +36,6 @@ template class DistanceInterface { // Prints the parameters of the distance function. void getSummary() { static_cast(this)->getSummaryImpl(); } - // This transforms the data located at src into a form that is writeable // to disk / storable in RAM. For distance functions that don't // compress the input, this just passses through a copy from src to diff --git a/flatnav/distances/IPDistanceDispatcher.h b/flatnav/distances/IPDistanceDispatcher.h new file mode 100644 index 0000000..1e5149c --- /dev/null +++ b/flatnav/distances/IPDistanceDispatcher.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include + +namespace flatnav::distances { + +template +static float defaultInnerProduct(const T *x, const T *y, + const size_t &dimension) { + float inner_product = 0; + for (size_t i = 0; i < dimension; i++) { + inner_product += x[i] * y[i]; + } + return 1.0f - inner_product; +} + +template struct InnerProductImpl { + static float computeDistance(const T *x, const T *y, + const size_t &dimension) { + return defaultInnerProduct(x, y, dimension); + } +}; + +template <> struct InnerProductImpl { + static float computeDistance(const float *x, const float *y, + const size_t &dimension) { +#if defined(USE_AVX512) + if (platformSupportsAvx512()) { + if (dimension % 16 == 0) { + return util::computeIP_Avx512(x, y, dimension); + } + if (dimension % 4 == 0) { +#if defined(USE_AVX) + return util::computeIP_Avx_4aligned(x, y, dimension); +#else + return util::computeIP_Sse4Aligned(x, y, dimension); +#endif + } else if (dimension > 16) { + return util::computeIP_SseWithResidual_16(x, y, dimension); + } else if (dimension > 4) { + return util::computeIP_SseWithResidual_4(x, y, dimension); + } + } +#endif + +#if defined(USE_AVX) + if (platformSupportsAvx()) { + if (dimension % 16 == 0) { + return util::computeIP_Avx(x, y, dimension); + } + if (dimension % 4 == 0) { + return util::computeIP_Avx_4aligned(x, y, dimension); + } else if (dimension > 16) { + return util::computeIP_SseWithResidual_16(x, y, dimension); + } else if (dimension > 4) { + return util::computeIP_SseWithResidual_4(x, y, dimension); + } + } +#endif + +#if defined(USE_SSE) + if (dimension % 16 == 0) { + return util::computeIP_Sse(x, y, dimension); + } + if (dimension % 4 == 0) { + return util::computeIP_Sse_4aligned(x, y, dimension); + } else if (dimension > 16) { + return util::computeIP_SseWithResidual_16(x, y, dimension); + } else if (dimension > 4) { + return util::computeIP_SseWithResidual_4(x, y, dimension); + } + +#endif + return defaultInnerProduct(x, y, dimension); + } +}; + +// TODO: Include SIMD optimized implementations for int8_t. +template <> struct InnerProductImpl { + static float computeDistance(const int8_t *x, const int8_t *y, + const size_t &dimension) { + return defaultInnerProduct(x, y, dimension); + } +}; + +// TODO: Include SIMD optimized implementations for uint8_t. +template <> struct InnerProductImpl { + static float computeDistance(const uint8_t *x, const uint8_t *y, + const size_t &dimension) { + return defaultInnerProduct(x, y, dimension); + } +}; + +struct IPDistanceDispatcher { + template + static float dispatch(const T *x, const T *y, const size_t &dimension) { + return InnerProductImpl::computeDistance(x, y, dimension); + } +}; + +} // namespace flatnav::distances \ No newline at end of file diff --git a/flatnav/distances/InnerProductDistance.h b/flatnav/distances/InnerProductDistance.h index ac59195..79c2f1c 100644 --- a/flatnav/distances/InnerProductDistance.h +++ b/flatnav/distances/InnerProductDistance.h @@ -5,6 +5,7 @@ #include // for size_t #include // for memcpy #include +#include #include #include #include @@ -17,129 +18,11 @@ namespace flatnav::distances { // on floating-point inputs. using util::DataType; +using util::type_for_data_type; -template -struct OptimalInnerProductSimdSelector { - - static void - adjustForNonOptimalDimensions(DistanceFunctionPtr &distance_function, - const size_t &dimension) { -#if defined(USE_SSE) || defined(USE_AVX) - - if (dimension % 16 != 0) { - if (dimension % 4 == 0) { -#if defined(USE_AVX) - distance_function = std::make_unique( - std::bind(&util::computeIP_Avx_4aligned, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); -#else - distance_function = std::make_unique( - std::bind(&util::computeIP_Sse_4aligned, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - -#endif // USE_AVX - } else if (dimension > 16) { - distance_function = std::make_unique(std::bind( - &util::computeIP_SseWithResidual_16, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - - } else if (dimension > 4) { - distance_function = std::make_unique( - std::bind(&util::computeIP_SseWithResidual_4, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - } - } -#endif // USE_SSE || USE_AVX - } - - static void selectInt8(DistanceFunctionPtr &distance_function, - const size_t &dimension) { - (void)dimension; - (void)distance_function; - throw std::runtime_error("Not implemented"); - } - - static void selectUint8(DistanceFunctionPtr &distance_function, - const size_t &dimension) { - (void)dimension; - (void)distance_function; - throw std::runtime_error("Not implemented"); - } - - static void selectFloat32(DistanceFunctionPtr &distance_function, - const size_t &dimension) { -#if defined(USE_SSE) - distance_function = std::make_unique( - std::bind(&util::computeIP_Sse, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - -#endif // USE_SSE - -#if defined(USE_AVX512) - if (platformSupportsAvx512) { - distance_function = std::make_unique( - std::bind(&util::computeIP_Avx512, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - adjustForNonOptimalDimensions(distance_function, dimension); - return; - } - -#endif // USE_AVX512 - -#if defined(USE_AVX) - if (platformSupportsAvx) { - distance_function = std::make_unique( - std::bind(&util::computeIP_Avx, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - adjustForNonOptimalDimensions(distance_function, dimension); - return; - } - -#endif // USE_AVX - - adjustForNonOptimalDimensions(distance_function, dimension); - } - - /** - * @brief Select the optimal distance function based on the input dimension - * @param dimension The dimension of the input data - * - * @note There are different SIMD functions for float32, int8_t and uint8_t. - * This is why we are templating this class on the data type. - */ - static void select(DistanceFunctionPtr &distance_function, - const size_t &dimension) { - switch (data_type) { - case DataType::float32: - selectFloat32(distance_function, dimension); - break; - case DataType::uint8: - selectUint8(distance_function, dimension); - break; - case DataType::int8: - selectInt8(distance_function, dimension); - break; - default: - throw std::runtime_error("Unsupported data type"); - } - } -}; - -struct DefaultInnerProduct { - template - static constexpr float compute(const void *x, const void *y, - const size_t &dimension) { - T *p_x = const_cast(static_cast(x)); - T *p_y = const_cast(static_cast(y)); - float result = 0; - for (size_t i = 0; i < dimension; i++) { - result += p_x[i] * p_y[i]; - } - return 1.0 - result; - } -}; - -class InnerProductDistance : public DistanceInterface { +template +class InnerProductDistance + : public DistanceInterface> { friend class DistanceInterface; // Enum for compile-time constant @@ -147,75 +30,30 @@ class InnerProductDistance : public DistanceInterface { public: InnerProductDistance() = default; - InnerProductDistance(size_t dim, DataType data_type = DataType::float32) - : _data_type(data_type), _dimension(dim), - _data_size_bytes(dim * flatnav::util::size(data_type)), - _distance_computer(nullptr) {} - - template - static std::unique_ptr create(size_t dim) { - if (data_type == DataType::undefined) { - throw std::runtime_error("Undefined data type"); - } - - return std::make_unique(dim, data_type); - } - - float distanceImpl(const void *x, const void *y, - bool asymmetric = false) const { - (void)asymmetric; - return (*_distance_computer)(x, y, _dimension); - } + InnerProductDistance(size_t dim) + : _dimension(dim), + _data_size_bytes(dim * flatnav::util::size(data_type)) {} - inline constexpr DataType dataTypeImpl() const { return _data_type; } - - /** - * @brief Dispatcher for templating setDistanceFunction the right data type - */ - void setDistanceFunctionWithType() { -#ifndef NO_SIMD_VECTORIZATION - switch (_data_type) { - case DataType::float32: - setDistanceFunction(); - break; - case DataType::uint8: - setDistanceFunction(); - break; - - case DataType::int8: - setDistanceFunction(); - break; - - default: - throw std::runtime_error("Unsupported data type"); - } -#endif // NO_SIMD_VECTORIZATION + static std::unique_ptr> create(size_t dim) { + return std::make_unique>(dim); } - template void setDistanceFunction() { - _distance_computer = std::make_unique(std::bind( - &InnerProductDistance::defaultDistanceImpl, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3)); - - OptimalInnerProductSimdSelector::select(_distance_computer, - _dimension); + constexpr float distanceImpl(const void *x, const void *y, + [[maybe_unused]] bool asymmetric = false) const { + return IPDistanceDispatcher::dispatch( + static_cast::type *>(x), + static_cast::type *>(y), + _dimension); } private: - DataType _data_type; size_t _dimension; size_t _data_size_bytes; - DistanceFunctionPtr _distance_computer; friend class cereal::access; template void serialize(Archive &ar) { - ar(_data_type, _dimension, _data_size_bytes); - - // If loading, we need to set the data size bytes - if (Archive::is_loading::value) { - setDistanceFunctionWithType(); - } + ar(_dimension, _data_size_bytes); } inline size_t getDimension() const { return _dimension; } @@ -233,23 +71,6 @@ class InnerProductDistance : public DistanceInterface { << std::flush; std::cout << "Dimension: " << _dimension << "\n" << std::flush; } - - float defaultDistanceImpl(const void *x, const void *y, - const size_t &dimension) const { - // Default implementation of inner product distance, in case we cannot - // support the SIMD specializations for special input _dimension sizes. - - if (_data_type == DataType::float32) { - return DefaultInnerProduct::compute(x, y, dimension); - } - else if (_data_type == DataType::uint8) { - return DefaultInnerProduct::compute(x, y, dimension); - } - else if (_data_type == DataType::int8) { - return DefaultInnerProduct::compute(x, y, dimension); - } - throw std::runtime_error("Unsupported data type"); - } }; } // namespace flatnav::distances \ No newline at end of file diff --git a/flatnav/distances/L2DistanceDispatcher.h b/flatnav/distances/L2DistanceDispatcher.h index 1c02b04..7bd48b5 100644 --- a/flatnav/distances/L2DistanceDispatcher.h +++ b/flatnav/distances/L2DistanceDispatcher.h @@ -41,18 +41,45 @@ template <> struct SquaredL2Impl { const size_t &dimension) { #if defined(USE_AVX512) if (platformSupportsAvx512()) { - return util::computeL2_Avx512(x, y, dimension); + if (dimension % 16 == 0) { + return util::computeL2_Avx512(x, y, dimension); + } + if (dimension % 4 == 0) { + return util::computeL2_Sse4Aligned(x, y, dimension); + } else if (dimension > 16) { + return util::computeL2_SseWithResidual_16(x, y, dimension); + } else if (dimension > 4) { + return util::computeL2_SseWithResidual_4(x, y, dimension); + } } #endif #if defined(USE_AVX) if (platformSupportsAvx()) { - return util::computeL2_Avx2(x, y, dimension); + if (dimension % 16 == 0) { + return util::computeL2_Avx2(x, y, dimension); + } + if (dimension % 4 == 0) { + return util::computeL2_Sse4Aligned(x, y, dimension); + } else if (dimension > 16) { + return util::computeL2_SseWithResidual_16(x, y, dimension); + } else if (dimension > 4) { + return util::computeL2_SseWithResidual_4(x, y, dimension); + } } #endif #if defined(USE_SSE) - return util::computeL2_Sse(x, y, dimension); + if (dimension % 16 == 0) { + return util::computeL2_Sse(x, y, dimension); + } + if (dimension % 4 == 0) { + return util::computeL2_Sse4Aligned(x, y, dimension); + } else if (dimension > 16) { + return util::computeL2_SseWithResidual_16(x, y, dimension); + } else if (dimension > 4) { + return util::computeL2_SseWithResidual_4(x, y, dimension); + } #else return defaultSquaredL2(x, y, dimension); #endif @@ -62,11 +89,11 @@ template <> struct SquaredL2Impl { template <> struct SquaredL2Impl { static float computeDistance(const int8_t *x, const int8_t *y, const size_t &dimension) { -#if defined(USE_AVX512BW) && defined(USE_AVX512VNNI) - if (platformSupportsAvx512()) { - return flatnav::util::computeL2_Avx512_int8(x, y, dimension); - } -#endif +// #if defined(USE_AVX512BW) && defined(USE_AVX512VNNI) +// if (platformSupportsAvx512()) { +// return flatnav::util::computeL2_Avx512_int8(x, y, dimension); +// } +// #endif #if defined(USE_SSE) return flatnav::util::computeL2_Sse_int8(x, y, dimension); #endif @@ -79,7 +106,9 @@ template <> struct SquaredL2Impl { const size_t &dimension) { #if defined(USE_AVX512) if (platformSupportsAvx512()) { - return util::computeL2_Avx512_Uint8(x, y, dimension); + if (dimension % 64 == 0) { + return util::computeL2_Avx512_Uint8(x, y, dimension); + } } #endif diff --git a/flatnav/distances/SquaredL2Distance.h b/flatnav/distances/SquaredL2Distance.h index dc61c1d..7e6c67c 100644 --- a/flatnav/distances/SquaredL2Distance.h +++ b/flatnav/distances/SquaredL2Distance.h @@ -21,128 +21,6 @@ namespace flatnav::distances { using util::DataType; using util::type_for_data_type; -template struct OptimalL2SimdSelector { - - static void - adjustForNonOptimalDimensions(DistanceFunctionPtr &distance_function, - const size_t &dimension) { -#if defined(USE_SSE) - if (dimension % 16 != 0) { - if (dimension % 4 == 0) { - distance_function = std::make_unique( - std::bind(&util::computeL2_Sse4Aligned, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - - } else if (dimension > 16) { - distance_function = std::make_unique(std::bind( - &util::computeL2_SseWithResidual_16, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - - } else if (dimension > 4) { - distance_function = std::make_unique( - std::bind(&util::computeL2_SseWithResidual_4, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - } - } -#endif // USE_SSE - } - - static void selectInt8(DistanceFunctionPtr &distance_function, - const size_t &dimension) { -#if defined(USE_AVX512) - if (platformSupportsAvx512()) { - distance_function = std::make_unique( - std::bind(&util::computeL2_Sse_int8, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - } -#endif // USE_AVX512 - } - - static void selectUint8(DistanceFunctionPtr &distance_function, - const size_t &dimension) { -#if defined(USE_AVX512) - distance_function = std::make_unique( - std::bind(&util::computeL2_Avx512_Uint8, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); -#endif // USE_AVX512 - } - - static void selectFloat32(DistanceFunctionPtr &distance_function, - const size_t &dimension) { -#if defined(USE_SSE) - distance_function = std::make_unique( - std::bind(&util::computeL2_Sse, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); -#endif // USE_SSE - -#if defined(USE_AVX512) - if (platformSupportsAvx512()) { - distance_function = std::make_unique( - std::bind(&util::computeL2_Avx512, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - - adjustForNonOptimalDimensions(distance_function, dimension); - return; - } -#endif // USE_AVX512 - -#if defined(USE_AVX) - if (platformSupportsAvx()) { - - distance_function = std::make_unique( - std::bind(&util::computeL2_Avx2, std::placeholders::_1, - std::placeholders::_2, std::cref(dimension))); - adjustForNonOptimalDimensions(distance_function, dimension); - return; - } -#endif // USE_AVX - - adjustForNonOptimalDimensions(distance_function, dimension); - } - - /** - * @brief Select the optimal distance function based on the input dimension - * @param dimension The dimension of the input data - * - * @note There are different SIMD functions for float32, int8_t and uint8_t. - * This is why we are templating this class on the data type. - */ - static void select(DistanceFunctionPtr &distance_function, - const size_t &dimension) { - switch (data_type) { - case DataType::float32: - selectFloat32(distance_function, dimension); - break; - case DataType::int8: - selectInt8(distance_function, dimension); - break; - case DataType::uint8: - selectUint8(distance_function, dimension); - break; - default: - throw std::runtime_error("Unsupported data type"); - } - } -}; - -struct DefaultSquaredL2 { - template - static constexpr float compute(const void *x, const void *y, - const size_t &dimension) { - T *p_x = const_cast(static_cast(x)); - T *p_y = const_cast(static_cast(y)); - float squared_distance = 0; - for (size_t i = 0; i < dimension; i++) { - float difference = *p_x - *p_y; - p_x++; - p_y++; - squared_distance += difference * difference; - } - return squared_distance; - } -}; - - /** * @brief The SquaredL2Distance class is designed to balance compile-time and * runtime dispatching for efficient distance computation. @@ -199,8 +77,7 @@ class SquaredL2Distance inline constexpr size_t getDimension() const { return _dimension; } constexpr float distanceImpl(const void *x, const void *y, - bool asymmetric = false) const { - (void)asymmetric; + [[maybe_unused]] bool asymmetric = false) const { return L2DistanceDispatcher::dispatch( static_cast::type *>(x), static_cast::type *>(y), @@ -230,18 +107,6 @@ class SquaredL2Distance << std::flush; std::cout << "Dimension: " << _dimension << "\n" << std::flush; } - - float defaultDistanceImpl(const void *x, const void *y, - const size_t &dimension) const { - if (data_type == DataType::float32) { - return DefaultSquaredL2::compute(x, y, dimension); - } else if (data_type == DataType::int8) { - return DefaultSquaredL2::compute(x, y, dimension); - } else if (data_type == DataType::uint8) { - return DefaultSquaredL2::compute(x, y, dimension); - } - throw std::runtime_error("Unsupported data type"); - } }; } // namespace flatnav::distances diff --git a/flatnav/index/Index.h b/flatnav/index/Index.h index 7c5330e..97d4593 100644 --- a/flatnav/index/Index.h +++ b/flatnav/index/Index.h @@ -22,10 +22,9 @@ #include #include -using flatnav::util::VisitedSetPool; -using flatnav::util::VisitedSet; using flatnav::distances::DistanceInterface; - +using flatnav::util::VisitedSet; +using flatnav::util::VisitedSetPool; namespace flatnav { @@ -445,7 +444,8 @@ template class Index { /* initial_pool_size = */ 1, /* num_elements = */ index->_max_node_count); index->_distance = std::move(dist); - index->_num_threads = std::max((uint32_t)1, (uint32_t)std::thread::hardware_concurrency() / 2); + index->_num_threads = std::max( + (uint32_t)1, (uint32_t)std::thread::hardware_concurrency() / 2); index->_node_links_mutexes = std::vector(index->_max_node_count); @@ -509,10 +509,6 @@ template class Index { inline size_t currentNumNodes() const { return _cur_num_nodes; } inline size_t dataDimension() const { return _distance->dimension(); } - inline constexpr util::DataType dataType() const { - return _distance->dataType(); - } - inline uint64_t distanceComputations() const { return _distance_computations.load(); } @@ -592,6 +588,11 @@ template class Index { auto *visited_set = _visited_set_pool->pollAvailableSet(); visited_set->clear(); + // Prefetch the data for entry node before computing its distance. +#ifdef USE_SSE + _mm_prefetch(getNodeData(entry_node), _MM_HINT_T0); +#endif + float dist = _distance->distance(/* x = */ query, /* y = */ getNodeData(entry_node), /* asymmetric = */ true); @@ -602,15 +603,28 @@ template class Index { visited_set->insert(entry_node); while (!candidates.empty()) { - dist_node_t d_node = candidates.top(); + auto [distance, node] = candidates.top(); - if ((-d_node.first) > max_dist && neighbors.size() >= buffer_size) { + if (-distance > max_dist && neighbors.size() >= buffer_size) { break; } candidates.pop(); + // Prefetching the next candidate node data and visited set marker + // before processing it. Note that this might not be useful if the current + // iteration finds a neighbor that is closer than the current max + // distance. In that case we would have prefetched data that is not used + // immediately, but I think the cost of prefetching is low enough that + // it's probably worth it. +#ifdef USE_SSE + if (!candidates.empty()) { + _mm_prefetch(getNodeData(candidates.top().second), _MM_HINT_T0); + visited_set->prefetch(candidates.top().second); + } +#endif + processCandidateNode( - /* query = */ query, /* node = */ d_node.second, + /* query = */ query, /* node = */ node, /* max_dist = */ max_dist, /* buffer_size = */ buffer_size, /* visited_set = */ visited_set, /* neighbors = */ neighbors, /* candidates = */ candidates); diff --git a/flatnav/tests/CMakeLists.txt b/flatnav/tests/CMakeLists.txt index 156d07b..e6676de 100644 --- a/flatnav/tests/CMakeLists.txt +++ b/flatnav/tests/CMakeLists.txt @@ -3,7 +3,7 @@ enable_testing() include(GoogleTest) # Add test executables here -set(FLAT_NAV_LIB_TESTS test_distances)# test_simd_int8_instructions test_serialization) +set(FLAT_NAV_LIB_TESTS test_distances test_serialization) foreach(TEST IN LISTS FLAT_NAV_LIB_TESTS) add_executable(${TEST} ${TEST}.cpp) diff --git a/flatnav/tests/test_distances.cpp b/flatnav/tests/test_distances.cpp index 29e2bd4..d7a4f00 100644 --- a/flatnav/tests/test_distances.cpp +++ b/flatnav/tests/test_distances.cpp @@ -37,7 +37,8 @@ class DistanceTest : public ::testing::Test { TEST_F(DistanceTest, TestAvx512L2Distance) { #if defined(USE_AVX512) float result = flatnav::util::computeL2_Avx512(x, y, dimensions); - float expected = flatnav::distances::DefaultSquaredL2::compute(x, y, dimensions); + float expected = + flatnav::distances::defaultSquaredL2(x, y, dimensions); ASSERT_NEAR(result, expected, epsilon); #endif @@ -60,7 +61,7 @@ TEST_F(DistanceTest, TestAvx512L2DistanceUint8) { uint8_t *y = y_matrix + i * dimensions; float result = flatnav::util::computeL2_Avx512_Uint8(x, y, dimensions); float expected = - flatnav::distances::DefaultSquaredL2::compute(x, y, dimensions); + flatnav::distances::defaultSquaredL2(x, y, dimensions); ASSERT_NEAR(result, expected, epsilon); } @@ -75,7 +76,8 @@ TEST_F(DistanceTest, TestAvxL2Distance) { #if defined(USE_AVX) float result = flatnav::util::computeL2_Avx2(x, y, dimensions); - float expected = flatnav::distances::DefaultSquaredL2::compute(x, y, dimensions); + float expected = + flatnav::distances::defaultSquaredL2(x, y, dimensions); ASSERT_NEAR(result, expected, epsilon); @@ -104,24 +106,25 @@ TEST(TestSingleIntrinsic, TestReduceAddSse) { TEST_F(DistanceTest, TestSseL2Distance) { #if defined(USE_SSE) float result = flatnav::util::computeL2_Sse(x, y, dimensions); - float expected = flatnav::distances::DefaultSquaredL2::compute(x, y, dimensions); + float expected = + flatnav::distances::defaultSquaredL2(x, y, dimensions); ASSERT_NEAR(result, expected, epsilon); // try with dimensions not divisible by 16 // this will just take the first 100 elements in the arrays result = flatnav::util::computeL2_Sse4Aligned(x, y, 100); - expected = flatnav::distances::DefaultSquaredL2::compute(x, y, 100); + expected = flatnav::distances::defaultSquaredL2(x, y, 100); ASSERT_NEAR(result, expected, epsilon); // try with dimensions not divisible by 4 result = flatnav::util::computeL2_SseWithResidual_16(x, y, 37); - expected = flatnav::distances::DefaultSquaredL2::compute(x, y, 37); + expected = flatnav::distances::defaultSquaredL2(x, y, 37); ASSERT_NEAR(result, expected, epsilon); // try with dimensions not divisible by 4 and less than 16 result = flatnav::util::computeL2_SseWithResidual_4(x, y, 7); - expected = flatnav::distances::DefaultSquaredL2::compute(x, y, 7); + expected = flatnav::distances::defaultSquaredL2(x, y, 7); ASSERT_NEAR(result, expected, epsilon); #endif @@ -131,8 +134,7 @@ TEST_F(DistanceTest, TestSseL2Distance) { TEST_F(DistanceTest, TestAvx512InnerProductDistance) { #if defined(USE_AVX512) float result = flatnav::util::computeIP_Avx512(x, y, dimensions); - float expected = - flatnav::distances::DefaultInnerProduct::compute(x, y, dimensions); + float expected = flatnav::distances::defaultInnerProduct(x, y, dimensions); ASSERT_NEAR(result, expected, epsilon); #endif @@ -142,8 +144,7 @@ TEST_F(DistanceTest, TestAvx512InnerProductDistance) { TEST_F(DistanceTest, TestAvxInnerProductDistance) { #if defined(USE_AVX) float result = flatnav::util::computeIP_Avx(x, y, dimensions); - float expected = - flatnav::distances::DefaultInnerProduct::compute(x, y, dimensions); + float expected = flatnav::distances::defaultInnerProduct(x, y, dimensions); ASSERT_NEAR(result, expected, epsilon); #endif @@ -153,50 +154,32 @@ TEST_F(DistanceTest, TestAvxInnerProductDistance) { TEST_F(DistanceTest, TestSseInnerProductDistance) { #if defined(USE_SSE) float result = flatnav::util::computeIP_Sse(x, y, dimensions); - float expected = - flatnav::distances::DefaultInnerProduct::compute(x, y, dimensions); + float expected = flatnav::distances::defaultInnerProduct(x, y, dimensions); ASSERT_NEAR(result, expected, epsilon); // try with dimensions not divisible by 16 // this will just take the first 100 elements in the arrays result = flatnav::util::computeIP_Sse_4aligned(x, y, 100); - expected = flatnav::distances::DefaultInnerProduct::compute(x, y, 100); + expected = flatnav::distances::defaultInnerProduct(x, y, 100); ASSERT_NEAR(result, expected, epsilon); #if defined(USE_AVX) result = flatnav::util::computeIP_Avx_4aligned(x, y, 100); - expected = flatnav::distances::DefaultInnerProduct::compute(x, y, 100); + expected = flatnav::distances::defaultInnerProduct(x, y, 100); ASSERT_NEAR(result, expected, epsilon); #endif // try with dimensions not divisible by 4 result = flatnav::util::computeIP_SseWithResidual_16(x, y, 37); - expected = flatnav::distances::DefaultInnerProduct::compute(x, y, 37); + expected = flatnav::distances::defaultInnerProduct(x, y, 37); ASSERT_NEAR(result, expected, epsilon); // try with dimensions not divisible by 4 and less than 16 result = flatnav::util::computeIP_SseWithResidual_4(x, y, 7); - expected = flatnav::distances::DefaultInnerProduct::compute(x, y, 7); + expected = flatnav::distances::defaultInnerProduct(x, y, 7); ASSERT_NEAR(result, expected, epsilon); #endif } -TEST(TestSquaredL2Distance, TestSimple) { - float x[4] = {1.0f, 2.0f, 3.0f, 4.0f}; - float y[4] = {5.0f, 6.0f, 7.0f, 8.0f}; - - #undef USE_SSE - #undef USE_AVX - #undef USE_AVX512 - - auto distance = flatnav::distances::SquaredL2Distance::create(4); - float result = distance->distanceImpl(x, y); - float expected = 32.0f; - ASSERT_NEAR(result, expected, 1e-6); - -} - - - } // namespace flatnav::testing \ No newline at end of file diff --git a/flatnav/tests/test_serialization.cpp b/flatnav/tests/test_serialization.cpp index 3bbc79d..3002ba3 100644 --- a/flatnav/tests/test_serialization.cpp +++ b/flatnav/tests/test_serialization.cpp @@ -2,13 +2,13 @@ #include #include // for remove #include -#include #include #include +#include #include -using flatnav::distances::DistanceInterface; using flatnav::Index; +using flatnav::distances::DistanceInterface; using flatnav::distances::InnerProductDistance; using flatnav::distances::SquaredL2Distance; diff --git a/flatnav/tests/test_simd_int8_instructions.cpp b/flatnav/tests/test_simd_int8_instructions.cpp deleted file mode 100644 index 1d9cea0..0000000 --- a/flatnav/tests/test_simd_int8_instructions.cpp +++ /dev/null @@ -1,69 +0,0 @@ -#include "gtest/gtest.h" -#include -#include // for remove -#include -#include -#include -#include -#include -#include - -using flatnav::distances::DistanceInterface; -using flatnav::Index; -using flatnav::distances::SquaredL2Distance; -using flatnav::util::DataType; - -namespace flatnav::testing { - -static const uint32_t INDEXED_VECTORS = 2; -static const uint32_t VEC_DIM = 100; - -void printVector(void *vector, uint32_t dim) { - for (uint32_t i = 0; i < dim; i++) { - printf("%d ", ((int8_t *)vector)[i]); - } - printf("\n"); -} - -std::vector generateTestVectors(uint32_t num_vectors, uint32_t dim) { - std::vector vectors(num_vectors * dim); - for (uint32_t i = 0; i < num_vectors * dim; i++) { - vectors[i] = (int8_t)(rand() % 256); - } - return vectors; -} - -TEST(SIMD_INT8_TESTS, TestSquaredL2Distance) { - // This test checks that the computed distance with the int8_t simd - // instructions is the same as the float simd instructions. - auto vectors = generateTestVectors(INDEXED_VECTORS, VEC_DIM); - - // Make a copy that casts each value to float. - std::vector float_vectors(vectors.begin(), vectors.end()); - - auto int8_distance = SquaredL2Distance::create(VEC_DIM); - int8_distance->setDistanceFunction(); - - auto float_distance = SquaredL2Distance::create(VEC_DIM); - float_distance->setDistanceFunction(); - - for (uint32_t i = 0; i < INDEXED_VECTORS - 1; i++) { - for (uint32_t j = i + 1; j < INDEXED_VECTORS; j++) { - int8_t *first_int8_vector = vectors.data() + (VEC_DIM * i); - int8_t *second_int8_vector = vectors.data() + (VEC_DIM * j); - - printVector(first_int8_vector, VEC_DIM); - printVector(second_int8_vector, VEC_DIM); - - float *first_float_vector = float_vectors.data() + (VEC_DIM * i); - float *second_float_vector = float_vectors.data() + (VEC_DIM * j); - - ASSERT_FLOAT_EQ( - int8_distance->distanceImpl(first_int8_vector, second_int8_vector), - float_distance->distanceImpl(first_float_vector, - second_float_vector)); - } - } -} - -} // namespace flatnav::testing diff --git a/flatnav/util/Datatype.h b/flatnav/util/Datatype.h index 2a9ecfd..c7dff7d 100644 --- a/flatnav/util/Datatype.h +++ b/flatnav/util/Datatype.h @@ -6,27 +6,27 @@ namespace flatnav::util { /** * @brief Enum class for data types - * Currently, only float32 is supported for index building. + * We currently support indexes of type float32, uint8 and int8. */ enum class DataType { - uint8, - uint16, - uint32, - uint64, - int8, - int16, - int32, - int64, - float16, - float32, - float64, - undefined + uint8, /** Unsigned 8-bit integer */ + uint16, /** Unsigned 16-bit integer */ + uint32, /** Unsigned 32-bit integer */ + uint64, /** Unsigned 64-bit integer */ + int8, /** Signed 8-bit integer */ + int16, /** Signed 16-bit integer */ + int32, /** Signed 32-bit integer */ + int64, /** Signed 64-bit integer */ + float16, /** 16-bit floating-point number */ + float32, /** 32-bit floating-point number */ + float64, /** 64-bit floating-point number */ + undefined /** Undefined data type */ }; /** * @brief Get a string representation of the data type */ -inline constexpr const char* name(DataType data_type) { +inline constexpr const char *name(DataType data_type) { switch (data_type) { case DataType::uint8: return "uint8"; @@ -130,6 +130,51 @@ template <> struct type_for_data_type { using type = uint8_t; }; +/** + * @brief Template metaprogramming to allow compile-time distance dispatching + * for each data type + * This is useful for iterating over each data type in a compile-time loop. + * One place where this is used is in python bindings to generate the Index + * class for each one of the supported data types. Here is a simple example of + * how to use this: + * @code + * struct Callable { + * template void operator()() { + * std::cout << "Data type: " << name(data_type) << std::endl; + * } + * }; + * for_each_data_type::apply(Callable()); + * // If you have multiple data types, you can pass them as template arguments + * like this: for_each_data_type::apply(Callable()); + * @endcode + * @tparam F A callable object + * @tparam data_types The data types to iterate over + */ +template struct for_each_data_type; + +/** + * @brief Template specialization for for_each_data_type when there are data + * types to iterate over + * @tparam F A callable object + * @tparam data_type The current data type + * @tparam rest The remaining data types + */ +template +struct for_each_data_type { + static void apply(F &&f) { + f.template operator()(); + for_each_data_type::apply(std::forward(f)); + } +}; +/** + * @brief Template specialization for for_each_data_type when there are no data + * types to iterate over + * @tparam F A callable object + */ +template struct for_each_data_type { + static void apply(F &&) {} +}; } // namespace flatnav::util \ No newline at end of file diff --git a/flatnav/util/Macros.h b/flatnav/util/Macros.h index b9ad49f..1d58ad0 100644 --- a/flatnav/util/Macros.h +++ b/flatnav/util/Macros.h @@ -20,15 +20,11 @@ #ifdef __AVX512BW__ #define USE_AVX512BW -#else -#error "AVX512BW not supported by the compiler" #endif // __AVX512BW__ -// #ifdef __AVX512VNNI__ -// #define USE_AVX512VNNI -// #else -// #error "AVX512VNNI not supported by the compiler" -// #endif // __AVX512VNNI__ +#ifdef __AVX512VNNI__ +#define USE_AVX512VNNI +#endif // __AVX512VNNI__ #define USE_AVX512 #endif // __AVX512F__ @@ -94,73 +90,81 @@ uint64_t xgetbv(unsigned int index) { #define _XCR_XFEATURE_ENABLED_MASK 0 // Cache for AVX and AVX512 support -std::atomic avxSupportCache{false}; -std::atomic avx512SupportCache{false}; -std::atomic avxInitialized{false}; -std::atomic avx512Initialized{false}; +std::atomic avx_support_cache{false}; +std::atomic avx_512_support_cache{false}; +std::atomic avx_initialized{false}; +std::atomic avx_512_initialized{false}; +/** + * @brief Initializes the platform support for AVX and AVX512 instructions. + * This function checks if the CPU and operating system support AVX and AVX512 + * instructions and caches the result for future use. + * + * @note This function should be called before using any AVX or AVX512 + * instructions. + */ void initializePlatformSupport() { - if (!avxInitialized.load(std::memory_order_acquire)) { - bool avxSupport = false; + if (!avx_initialized.load(std::memory_order_acquire)) { + bool avx_support = false; int cpu_info[4]; cpuid(cpu_info, 0, 0); int n_ids = cpu_info[0]; if (n_ids >= 1) { cpuid(cpu_info, 1, 0); - bool osUsesXSAVE_XRSTORE = (cpu_info[2] & (1 << 27)) != 0; - bool cpuAVXSuport = (cpu_info[2] & (1 << 28)) != 0; - if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { - uint64_t xcrFeatureMask = xgetbv(0); - avxSupport = (xcrFeatureMask & 0x6) == 0x6; + bool os_uses_xsave_xrstore = (cpu_info[2] & (1 << 27)) != 0; + bool cpu_avx_support = (cpu_info[2] & (1 << 28)) != 0; + if (os_uses_xsave_xrstore && cpu_avx_support) { + uint64_t xcr_feature_mask = xgetbv(0); + avx_support = (xcr_feature_mask & 0x6) == 0x6; } } - avxSupportCache.store(avxSupport, std::memory_order_release); - avxInitialized.store(true, std::memory_order_release); + avx_support_cache.store(avx_support, std::memory_order_release); + avx_initialized.store(true, std::memory_order_release); } - if (!avx512Initialized.load(std::memory_order_acquire)) { + if (!avx_512_initialized.load(std::memory_order_acquire)) { bool avx512Support = false; - if (avxSupportCache.load(std::memory_order_acquire)) { + if (avx_support_cache.load(std::memory_order_acquire)) { int cpu_info[4]; cpuid(cpu_info, 0, 0); int n_ids = cpu_info[0]; if (n_ids >= 0x00000007) { cpuid(cpu_info, 0x00000007, 0); - bool HW_AVX512F = (cpu_info[1] & ((int)1 << 16)) != 0; + bool hw_avx512f = (cpu_info[1] & ((int)1 << 16)) != 0; - if (HW_AVX512F) { + if (hw_avx512f) { cpuid(cpu_info, 1, 0); - bool osUsesXSAVE_XRSTORE = (cpu_info[2] & (1 << 27)) != 0; - bool cpuAVXSuport = (cpu_info[2] & (1 << 28)) != 0; + bool os_uses_xsave_xrstore = (cpu_info[2] & (1 << 27)) != 0; + bool cpu_avx_support = (cpu_info[2] & (1 << 28)) != 0; - if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { - uint64_t xcrFeatureMask = xgetbv(0); - avx512Support = (xcrFeatureMask & 0xe6) == 0xe6; + if (os_uses_xsave_xrstore && cpu_avx_support) { + uint64_t xcr_feature_mask = xgetbv(0); + avx512Support = (xcr_feature_mask & 0xe6) == 0xe6; } } } } - avx512SupportCache.store(avx512Support, std::memory_order_release); - avx512Initialized.store(true, std::memory_order_release); + avx_512_support_cache.store(avx512Support, std::memory_order_release); + avx_512_initialized.store(true, std::memory_order_release); } } bool platformSupportsAvx() { - if (!avxInitialized.load(std::memory_order_acquire)) { + if (!avx_initialized.load(std::memory_order_acquire)) { initializePlatformSupport(); } - return avxSupportCache.load(std::memory_order_acquire); + return avx_support_cache.load(std::memory_order_acquire); } bool platformSupportsAvx512() { - if (!avx512Initialized.load(std::memory_order_acquire)) { + if (!avx_512_initialized.load(std::memory_order_acquire)) { initializePlatformSupport(); } - return avx512SupportCache.load(std::memory_order_acquire); + return avx_512_support_cache.load(std::memory_order_acquire); } #endif \ No newline at end of file diff --git a/flatnav/util/SimdSelector.h b/flatnav/util/SimdSelector.h deleted file mode 100644 index 8b13789..0000000 --- a/flatnav/util/SimdSelector.h +++ /dev/null @@ -1 +0,0 @@ - diff --git a/flatnav/util/SimdUtils.h b/flatnav/util/SimdUtils.h index ef8c74b..f8f913b 100644 --- a/flatnav/util/SimdUtils.h +++ b/flatnav/util/SimdUtils.h @@ -14,32 +14,6 @@ namespace flatnav::util { -template struct MaskRepr {}; -template <> struct MaskRepr<2> { using type = uint8_t; }; -template <> struct MaskRepr<4> { using type = uint8_t; }; -template <> struct MaskRepr<8> { using type = uint8_t; }; -template <> struct MaskRepr<16> { using type = uint16_t; }; -template <> struct MaskRepr<32> { using type = uint32_t; }; -template <> struct MaskRepr<64> { using type = uint64_t; }; - -template struct MaskIntrinsic {}; -template <> struct MaskIntrinsic { using mask_type = __mmask8; }; -template <> struct MaskIntrinsic { using mask_type = __mmask16; }; -template <> struct MaskIntrinsic { using mask_type = __mmask32; }; -template <> struct MaskIntrinsic { using mask_type = __mmask64; }; - -// Given a length `N`, obtain an appropriate integer type used as a mask for `N` -// lanes in an AVX vector operation. -template using mask_repr_t = typename MaskRepr::type; - -// Given an unsigned integer type, return the corresponding mask type -template -using mask_intrinsic_t = typename MaskIntrinsic::mask_type; - -// Given a length `N`, obtain an appropriate mask intrinsic type. -template -using mask_intrinsic_from_length = mask_intrinsic_t>; - // clang-format off /** * @file SimdUtils.h diff --git a/flatnav/util/SquaredL2SimdExtensions.h b/flatnav/util/SquaredL2SimdExtensions.h index b084887..ca5760a 100644 --- a/flatnav/util/SquaredL2SimdExtensions.h +++ b/flatnav/util/SquaredL2SimdExtensions.h @@ -4,10 +4,6 @@ namespace flatnav::util { -// Explicitly expresses that narrowing is either acceptable or known impossible. -template constexpr T narrow_cast(U &&u) noexcept { - return static_cast(std::forward(u)); -} #if defined(USE_AVX512) static float computeL2_Avx512(const void *x, const void *y, @@ -32,6 +28,9 @@ static float computeL2_Avx512(const void *x, const void *y, return sum.reduce_add(); } +/** + * @todo Make this support dimensions that are not multiples of 64 + */ static float computeL2_Avx512_Uint8(const void *x, const void *y, const size_t &dimension) { const uint8_t *pointer_x = static_cast(x); @@ -81,82 +80,6 @@ static float computeL2_Avx512_Uint8(const void *x, const void *y, return static_cast(total_sum); } -#if defined(USE_AVX512BW) && defined(USE_AVX512VNNI) - -// template static inline __mmask32 create_mask(const size_t &length) -// { -// __mmask32 mask = 0; -// for (size_t i = 0; i < N; ++i) { -// mask |= (i < length) ? (1UL << i) : 0; -// } -// return mask; -// } - -constexpr __mmask32 create_mask(size_t remaining) { - // If remaining is 32 or more, we want to load everything, so the mask is all - // 1s. If remaining is less, shift a 1 up to the remaining bit, subtracting - // one to get a mask with that many 1s. - // return remaining >= 32 ? static_cast<__mmask32>(-1) : (1UL << remaining) - - // 1; - return (1UL << remaining) - 1; -} - -template -constexpr mask_intrinsic_from_length create_mask(size_t dimension) { - using MaskType = mask_repr_t; - constexpr MaskType one{0x1}; - MaskType shift = dimension % VecLength; - MaskType mask_raw = - shift == 0 ? std::numeric_limits::max() : (one << shift) - one; - return mask_raw; -} - -template -constexpr mask_intrinsic_from_length no_mask() { - return std::numeric_limits>::max(); -} - -static constexpr size_t div_round_up(size_t x, size_t y) { - return (x / y) + static_cast((x % y) != 0); -} - -template static constexpr bool islast(size_t N, size_t i) { - size_t last_iter = Step * (div_round_up(N, Step) - 1); - return i == last_iter; -} - -static float compute(const int8_t *a, const int8_t *b, const size_t &length) { - auto sum = _mm512_setzero_epi32(); - size_t j = 0; - - auto mask = create_mask<32>(length); - auto all = no_mask<32>(); - - for (; j < length; j += 32) { - auto temp_a = - _mm256_maskz_loadu_epi8(islast<32>(length, j) ? mask : all, a + j); - auto va = _mm512_cvtepi8_epi16(temp_a); - - auto temp_b = - _mm256_maskz_loadu_epi8(islast<32>(length, j) ? mask : all, b + j); - auto vb = _mm512_cvtepi8_epi16(temp_b); - - auto diff = _mm512_sub_epi16(va, vb); - sum = _mm512_dpwssd_epi32(sum, diff, diff); - } - return narrow_cast(_mm512_reduce_add_epi32(sum)); -} - -static float computeL2_Avx512_int8(const void *x, const void *y, - const size_t &dimension) { - int8_t *pointer_x = static_cast(const_cast(x)); - int8_t *pointer_y = static_cast(const_cast(y)); - - return flatnav::util::compute(pointer_x, pointer_y, dimension); -} - -#endif // USE_AVX512BW && USE_AVX512VNNI - #endif // USE_AVX512 #if defined(USE_AVX) diff --git a/flatnav_python/docs.h b/flatnav_python/docs.h new file mode 100644 index 0000000..edec900 --- /dev/null +++ b/flatnav_python/docs.h @@ -0,0 +1,142 @@ +#pragma once + +// One sad thing about this is that the docstrings are likely to become stale +// as the code evolves. Nonetheless, it's good to have them in one place. + + +static const char *ADD_DOCSTRING = R"pbdoc( +Add vectors(data) to the index with the given `ef_construction` parameter and optional labels. +`ef_construction` determines how many vertices are visited while inserting every vector in +the underlying graph structure. +Args: + data (np.ndarray): The data to add to the index. + ef_construction (int): The number of vertices to visit while inserting every vector in the graph. + num_initializations (int, optional): The number of initializations to perform. Defaults to 100. + labels (Optional[np.ndarray], optional): The labels for the data. Defaults to None. +Returns: + None +)pbdoc"; + +static const char *ALLOCATE_NODES_DOCSTRING = R"pbdoc( +Allocate nodes in the underlying graph structure for the given data. Unlike the add method, +this method does not construct the edge connectivity. It only allocates memory for each node +in the graph. When using this method, you should invoke `build_graph_links` explicity. +```NOTE```: In most cases you should not need to use this method. +Args: + data (np.ndarray): The data to add to the index. +Returns: + None +)pbdoc"; + +static const char *SEARCH_SINGLE_DOCSTRING = R"pbdoc( +Return top `K` closest data points for the given `query`. The results are returned as a Tuple of +distances and label ID's. The `ef_search` parameter determines how many neighbors are visited +while finding the closest neighbors for the query. + +Args: + query (np.ndarray): The query vector. + K (int): The number of neighbors to return. + ef_search (int): The number of neighbors to visit while finding the closest neighbors for the query. + num_initializations (int, optional): The number of initializations to perform. Defaults to 100. +Returns: + Tuple[np.ndarray, np.ndarray]: The distances and label ID's of the closest neighbors. +)pbdoc"; + +static const char *SEARCH_DOCSTRING = R"pbdoc( +This is a batched version of the `search_single` method. +Return top `K` closest data points for every query in the provided `queries`. The results are returned as a Tuple of +distances and label ID's. The `ef_search` parameter determines how many neighbors are visited while finding the closest neighbors +for every query. + +Args: + queries (np.ndarray): The query vectors. + K (int): The number of neighbors to return. + ef_search (int): The number of neighbors to visit while finding the closest neighbors for every query. + num_initializations (int, optional): The number of initializations to perform. Defaults to 100. +Returns: + Tuple[np.ndarray, np.ndarray]: The distances and label ID's of the closest neighbors. +)pbdoc"; + +static const char *GET_GRAPH_OUTDEGREE_TABLE_DOCSTRING = R"pbdoc( +Returns the outdegree table (adjacency list) representation of the underlying graph. +Returns: + List[List[int]]: The outdegree table. +)pbdoc"; + +static const char *BUILD_GRAPH_LINKS_DOCSTRING = R"pbdoc( +Construct the edge connectivity of the underlying graph. This method should be invoked after +allocating nodes using the `allocate_nodes` method. +Args: + mtx_filename (str): The filename of the matrix file. + +Returns: + None +)pbdoc"; + +static const char *REORDER_DOCSTRING = R"pbdoc( +Perform graph re-ordering based on the given sequence of re-ordering strategies. +Supported re-ordering strategies include `gorder` and `rcm`. +Reference: + 1. Graph Reordering for Cache-Efficient Near Neighbor Search: https://arxiv.org/pdf/2104.03221 +Args: + strategies (List[str]): The sequence of re-ordering strategies. +Returns: + None +)pbdoc"; + +static const char *SET_NUM_THREADS_DOCSTRING = R"pbdoc( +Set the number of threads to use for constructing the graph and/or performing KNN search. +Args: + num_threads (int): The number of threads to use. +Returns: + None +)pbdoc"; + +static const char *NUM_THREADS_DOCSTRING = R"pbdoc( +Returns the number of threads used for constructing the graph and/or performing KNN search. +Returns: + int: The number of threads. +)pbdoc"; + +static const char *MAX_EDGES_PER_NODE_DOCSTRING = R"pbdoc( +Maximum number of edges(links) per node in the underlying NSW graph data structure. +Returns: + int: The maximum number of edges per node. +)pbdoc"; + +static const char *SAVE_DOCSTRING = R"pbdoc( +Save a FlatNav index at the given file location. +Args: + filename (str): The file location to save the index. +Returns: + None +)pbdoc"; + +static const char *LOAD_DOCSTRING = R"pbdoc( +Load a FlatNav index from a given file location. +Args: + filename (str): The file location to load the index from. +Returns: + Union[L2Inde, IPIndex]: The loaded index. +)pbdoc"; + +static const char *GET_QUERY_DISTANCE_COMPUTATIONS_DOCSTRING = R"pbdoc( +Returns the number of distance computations performed during the last search operation. +This method also resets the distance computations counter. +Returns: + int: The number of distance computations. +)pbdoc"; + +static const char *CONSTRUCTOR_DOCSTRING = R"pbdoc( +Constructs a an in-memory index with the parameters. +Args: + distance_type (str): The type of distance metric to use ('l2' for Euclidean, 'angular' for inner product). + dim (int): The number of dimensions in the dataset. + dataset_size (int): The number of vectors in the dataset. + max_edges_per_node (int): The maximum number of edges per node in the graph. + verbose (bool, optional): Enables verbose output. Defaults to False. + collect_stats (bool, optional): Collects performance statistics. Defaults to False. + +Returns: + Union[L2Index, IPIndex]: The constructed index. +)pbdoc"; diff --git a/flatnav_python/python_bindings.cpp b/flatnav_python/python_bindings.cpp index 4ac0479..01c4c85 100644 --- a/flatnav_python/python_bindings.cpp +++ b/flatnav_python/python_bindings.cpp @@ -1,9 +1,10 @@ +#include "docs.h" #include #include #include -#include #include #include +#include #include #include #include @@ -17,108 +18,54 @@ #include #include -using flatnav::distances::DistanceInterface; using flatnav::Index; +using flatnav::distances::DistanceInterface; using flatnav::distances::InnerProductDistance; using flatnav::distances::SquaredL2Distance; using flatnav::util::DataType; +using flatnav::util::for_each_data_type; namespace py = pybind11; +template +auto cast_and_call(DataType data_type, const py::array &array, Func &&function, + Args &&... args) { + switch (data_type) { + case DataType::float32: + return function( + array.cast< + py::array_t>(), + std::forward(args)...); + case DataType::int8: + return function( + array.cast< + py::array_t>(), + std::forward(args)...); + case DataType::uint8: + return function( + array.cast< + py::array_t>(), + std::forward(args)...); + default: + throw std::invalid_argument("Unsupported data type."); + } +} + template class PyIndex : public std::enable_shared_from_this> { - const uint32_t NUM_LOG_STEPS = 10000; -private: int _dim; label_t _label_id; bool _verbose; Index *_index; - std::string _data_type = "float32"; - -public: - typedef std::pair, py::array_t> - DistancesLabelsPair; - - explicit PyIndex(std::unique_ptr> index) - : _dim(index->dataDimension()), _label_id(0), _verbose(false), - _index(index.release()) { - - if (_verbose) { - _index->getIndexSummary(); - } - } - - PyIndex(std::unique_ptr> &&distance, - int dataset_size, int max_edges_per_node, bool verbose = false, - bool collect_stats = false, const std::string &data_type = "float32") - : _dim(distance->dimension()), _label_id(0), _verbose(verbose), - _index(new Index( - /* dist = */ std::move(distance), - /* dataset_size = */ dataset_size, - /* max_edges_per_node = */ max_edges_per_node, - /* collect_stats = */ collect_stats)), - _data_type(data_type) { - - if (_verbose) { - uint64_t total_index_memory = _index->getTotalIndexMemory(); - uint64_t visited_set_allocated_memory = - _index->visitedSetPoolAllocatedMemory(); - uint64_t mutexes_allocated_memory = _index->mutexesAllocatedMemory(); - - auto total_memory = total_index_memory + visited_set_allocated_memory + - mutexes_allocated_memory; - - std::cout << "Total allocated index memory: " << (float)(total_memory / 1e9) - << " GB \n" - << std::flush; - std::cout << "[WARN]: More memory might be allocated due to visited sets " - "in multi-threaded environments.\n" - << std::flush; - _index->getIndexSummary(); - } - } - - Index *getIndex() { return _index; } - - ~PyIndex() { delete _index; } - - uint64_t getQueryDistanceComputations() const { - auto distance_computations = _index->distanceComputations(); - _index->resetStats(); - return distance_computations; - } - - static std::shared_ptr> - loadIndex(const std::string &filename) { - auto index = Index::loadIndex(/* filename = */ filename); - return std::make_shared>(std::move(index)); - } - - std::shared_ptr> allocateNodes( - const py::array_t - &data) { - auto num_vectors = data.shape(0); - auto data_dim = data.shape(1); - if (data.ndim() != 2 || data_dim != _dim) { - throw std::invalid_argument("Data has incorrect dimensions."); - } - for (size_t vec_index = 0; vec_index < num_vectors; vec_index++) { - uint32_t new_node_id; - - this->_index->allocateNode(/* data = */ (void *)data.data(vec_index), - /* label = */ _label_id, - /* new_node_id = */ new_node_id); - _label_id++; - } - return this->shared_from_this(); - } + DataType _data_type; + // Internal add method that handles templated dispatch template - void add(const py::array_t &data, - int ef_construction, int num_initializations = 100, - py::object labels = py::none()) { + void addImpl(const py::array_t &data, + int ef_construction, int num_initializations = 100, + py::object labels = py::none()) { // py::array_t means that // the functions expects either a Numpy array of floats or a castable type // to that type. If the given type can't be casted, pybind11 will throw an @@ -175,10 +122,10 @@ class PyIndex : public std::enable_shared_from_this> { } template - DistancesLabelsPair - searchSingle(const py::array_t &query, - int K, int ef_search, int num_initializations = 100) { + DistancesLabelsPair searchSingleImpl( + const py::array_t + &query, + int K, int ef_search, int num_initializations = 100) { if (query.ndim() != 1 || query.shape(0) != _dim) { throw std::invalid_argument("Query has incorrect dimensions."); } @@ -220,9 +167,9 @@ class PyIndex : public std::enable_shared_from_this> { template DistancesLabelsPair - search(const py::array_t - &queries, - int K, int ef_search, int num_initializations = 100) { + searchImpl(const py::array_t &queries, + int K, int ef_search, int num_initializations = 100) { size_t num_queries = queries.shape(0); size_t queries_dim = queries.shape(1); @@ -291,322 +238,186 @@ class PyIndex : public std::enable_shared_from_this> { return {dists, labels}; } -}; -using L2FlatNavIndex = PyIndex; -using InnerProductFlatNavIndex = PyIndex; +public: + typedef std::pair, py::array_t> + DistancesLabelsPair; -/** - * Dispatches the given function based on the data type of the index. - */ -template -auto dispatch(IndexType &index_type, DataType data_type, Function function, - Args &&...args) { - switch (data_type) { - case DataType::float32: - return function.template operator()(index_type, - std::forward(args)...); - case DataType::int8: - return function.template operator()(index_type, - std::forward(args)...); - case DataType::uint8: - return function.template operator()(index_type, - std::forward(args)...); - default: - throw std::runtime_error("Unsupported data type"); + explicit PyIndex(std::unique_ptr> index) + : _dim(index->dataDimension()), _label_id(0), _verbose(false), + _index(index.release()) { + + if (_verbose) { + _index->getIndexSummary(); + } } -} -template struct SearchSingle { - /** - * @note: One less nice thing here is that we might still incur cost due to - * a copy of the vector if data_type is not float32. A to-do item is to - * figure out how to avoid this. - */ - template - auto operator()(IndexType &index_type, - const py::array_t &query, - int K, int ef_search, int num_initializations) - -> decltype(auto) { - return index_type.template searchSingle(query, K, ef_search, - num_initializations); + PyIndex(std::unique_ptr> &&distance, + DataType data_type, int dataset_size, int max_edges_per_node, + bool verbose = false, bool collect_stats = false) + : _dim(distance->dimension()), _label_id(0), _verbose(verbose), + _index(new Index( + /* dist = */ std::move(distance), + /* dataset_size = */ dataset_size, + /* max_edges_per_node = */ max_edges_per_node, + /* collect_stats = */ collect_stats)) { + + _data_type = data_type; + + if (_verbose) { + uint64_t total_index_memory = _index->getTotalIndexMemory(); + uint64_t visited_set_allocated_memory = + _index->visitedSetPoolAllocatedMemory(); + uint64_t mutexes_allocated_memory = _index->mutexesAllocatedMemory(); + + auto total_memory = total_index_memory + visited_set_allocated_memory + + mutexes_allocated_memory; + + std::cout << "Total allocated index memory: " + << (float)(total_memory / 1e9) << " GB \n" + << std::flush; + std::cout << "[WARN]: More memory might be allocated due to visited sets " + "in multi-threaded environments.\n" + << std::flush; + _index->getIndexSummary(); + } } -}; -template struct BatchedSearch { - template - auto - operator()(IndexType &index_type, - const py::array_t &queries, - int K, int ef_search, int num_initializations) -> decltype(auto) { - return index_type.template search(queries, K, ef_search, - num_initializations); + Index *getIndex() { return _index; } + + ~PyIndex() { delete _index; } + + uint64_t getQueryDistanceComputations() const { + auto distance_computations = _index->distanceComputations(); + _index->resetStats(); + return distance_computations; } -}; -template struct Add { - template - void operator()(IndexType &index_type, - const py::array_t &data, - int ef_construction, int num_initializations, - py::object labels) { - index_type.template add(data, ef_construction, - num_initializations, labels); + void buildGraphLinks(const std::string &mtx_filename) { + _index->buildGraphLinks(/* mtx_filename = */ mtx_filename); + } + + std::vector> getGraphOutdegreeTable() { + return _index->getGraphOutdegreeTable(); + } + + void reorder(const std::vector &strategies) { + // validate the given strategies + for (auto &strategy : strategies) { + auto alg = strategy; + std::transform(alg.begin(), alg.end(), alg.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (alg != "gorder" && alg != "rcm") { + throw std::invalid_argument( + "`" + strategy + + "` is not a supported graph re-ordering strategy."); + } + } + _index->doGraphReordering(strategies); + } + + void setNumThreads(uint32_t num_threads) { + _index->setNumThreads(num_threads); + } + + uint32_t getNumThreads() { return _index->getNumThreads(); } + + void save(const std::string &filename) { + _index->saveIndex(/* filename = */ filename); + } + + static std::shared_ptr> + loadIndex(const std::string &filename) { + auto index = Index::loadIndex(/* filename = */filename); + return std::make_shared>(std::move(index)); + } + + std::shared_ptr> allocateNodes( + const py::array_t &data) { + auto num_vectors = data.shape(0); + auto data_dim = data.shape(1); + if (data.ndim() != 2 || data_dim != _dim) { + throw std::invalid_argument("Data has incorrect dimensions."); + } + for (size_t vec_index = 0; vec_index < num_vectors; vec_index++) { + uint32_t new_node_id; + + this->_index->allocateNode(/* data = */ (void *)data.data(vec_index), + /* label = */ _label_id, + /* new_node_id = */ new_node_id); + _label_id++; + } + return this->shared_from_this(); + } + + void add(const py::array &data, int ef_construction, int num_initializations, + py::object labels = py::none()) { + cast_and_call( + _data_type, data, + [this](auto &&casted_data, int ef, int num_init, py::object lbls) { + this->addImpl(std::forward(casted_data), ef, + num_init, lbls); + }, + ef_construction, num_initializations, labels); + } + + DistancesLabelsPair search(const py::array &queries, int K, int ef_search, + int num_initializations) { + return cast_and_call( + _data_type, queries, + [this](auto &&casted_queries, int k, int ef, int num_init) { + return this->searchImpl( + std::forward(casted_queries), k, ef, + num_init); + }, + K, ef_search, num_initializations); + } + + DistancesLabelsPair searchSingle(const py::array &query, int K, int ef_search, + int num_initializations) { + return cast_and_call( + _data_type, query, + [this](auto &&casted_query, int k, int ef, int num_init) { + return this->searchSingleImpl( + std::forward(casted_query), k, ef, + num_init); + }, + K, ef_search, num_initializations); } }; -static const char *ADD_DOCSTRING = R"pbdoc( -Add vectors(data) to the index with the given `ef_construction` parameter and optional labels. -`ef_construction` determines how many vertices are visited while inserting every vector in -the underlying graph structure. -Args: - data (np.ndarray): The data to add to the index. - ef_construction (int): The number of vertices to visit while inserting every vector in the graph. - num_initializations (int, optional): The number of initializations to perform. Defaults to 100. - labels (Optional[np.ndarray], optional): The labels for the data. Defaults to None. -Returns: - None -)pbdoc"; - -static const char *ALLOCATE_NODES_DOCSTRING = R"pbdoc( -Allocate nodes in the underlying graph structure for the given data. Unlike the add method, -this method does not construct the edge connectivity. It only allocates memory for each node -in the graph. When using this method, you should invoke `build_graph_links` explicity. -```NOTE```: In most cases you should not need to use this method. -Args: - data (np.ndarray): The data to add to the index. -Returns: - None -)pbdoc"; - -static const char *SEARCH_SINGLE_DOCSTRING = R"pbdoc( -Return top `K` closest data points for the given `query`. The results are returned as a Tuple of -distances and label ID's. The `ef_search` parameter determines how many neighbors are visited -while finding the closest neighbors for the query. - -Args: - query (np.ndarray): The query vector. - K (int): The number of neighbors to return. - ef_search (int): The number of neighbors to visit while finding the closest neighbors for the query. - num_initializations (int, optional): The number of initializations to perform. Defaults to 100. -Returns: - Tuple[np.ndarray, np.ndarray]: The distances and label ID's of the closest neighbors. -)pbdoc"; - -static const char *SEARCH_DOCSTRING = R"pbdoc( -This is a batched version of the `search_single` method. -Return top `K` closest data points for every query in the provided `queries`. The results are returned as a Tuple of -distances and label ID's. The `ef_search` parameter determines how many neighbors are visited while finding the closest neighbors -for every query. - -Args: - queries (np.ndarray): The query vectors. - K (int): The number of neighbors to return. - ef_search (int): The number of neighbors to visit while finding the closest neighbors for every query. - num_initializations (int, optional): The number of initializations to perform. Defaults to 100. -Returns: - Tuple[np.ndarray, np.ndarray]: The distances and label ID's of the closest neighbors. -)pbdoc"; - -static const char *GET_GRAPH_OUTDEGREE_TABLE_DOCSTRING = R"pbdoc( -Returns the outdegree table (adjacency list) representation of the underlying graph. -Returns: - List[List[int]]: The outdegree table. -)pbdoc"; - -static const char *BUILD_GRAPH_LINKS_DOCSTRING = R"pbdoc( -Construct the edge connectivity of the underlying graph. This method should be invoked after -allocating nodes using the `allocate_nodes` method. -Args: - mtx_filename (str): The filename of the matrix file. - -Returns: - None -)pbdoc"; - -static const char *REORDER_DOCSTRING = R"pbdoc( -Perform graph re-ordering based on the given sequence of re-ordering strategies. -Supported re-ordering strategies include `gorder` and `rcm`. -Reference: - 1. Graph Reordering for Cache-Efficient Near Neighbor Search: https://arxiv.org/pdf/2104.03221 -Args: - strategies (List[str]): The sequence of re-ordering strategies. -Returns: - None -)pbdoc"; - -static const char *SET_NUM_THREADS_DOCSTRING = R"pbdoc( -Set the number of threads to use for constructing the graph and/or performing KNN search. -Args: - num_threads (int): The number of threads to use. -Returns: - None -)pbdoc"; - -static const char *NUM_THREADS_DOCSTRING = R"pbdoc( -Returns the number of threads used for constructing the graph and/or performing KNN search. -Returns: - int: The number of threads. -)pbdoc"; - -static const char *MAX_EDGES_PER_NODE_DOCSTRING = R"pbdoc( -Maximum number of edges(links) per node in the underlying NSW graph data structure. -Returns: - int: The maximum number of edges per node. -)pbdoc"; - -static const char *SAVE_DOCSTRING = R"pbdoc( -Save a FlatNav index at the given file location. -Args: - filename (str): The file location to save the index. -Returns: - None -)pbdoc"; - -static const char *LOAD_DOCSTRING = R"pbdoc( -Load a FlatNav index from a given file location. -Args: - filename (str): The file location to load the index from. -Returns: - Union[L2Inde, IPIndex]: The loaded index. -)pbdoc"; - -static const char *GET_QUERY_DISTANCE_COMPUTATIONS_DOCSTRING = R"pbdoc( -Returns the number of distance computations performed during the last search operation. -This method also resets the distance computations counter. -Returns: - int: The number of distance computations. -)pbdoc"; - -static const char *CONSTRUCTOR_DOCSTRING = R"pbdoc( -Constructs a an in-memory index with the parameters. -Args: - distance_type (str): The type of distance metric to use ('l2' for Euclidean, 'angular' for inner product). - dim (int): The number of dimensions in the dataset. - dataset_size (int): The number of vectors in the dataset. - max_edges_per_node (int): The maximum number of edges per node in the graph. - verbose (bool, optional): Enables verbose output. Defaults to False. - collect_stats (bool, optional): Collects performance statistics. Defaults to False. - -Returns: - Union[L2Index, IPIndex]: The constructed index. -)pbdoc"; - -template -void bindIndexMethods( - py::class_> &index_class) { - index_class - .def( - "save", - [](IndexType &index_type, const std::string &filename) { - auto index = index_type.getIndex(); - index->saveIndex(/* filename = */ filename); - }, - py::arg("filename"), SAVE_DOCSTRING) - .def_static("load", &IndexType::loadIndex, py::arg("filename"), - LOAD_DOCSTRING) - .def( - "add", - [](IndexType &index_type, - const py::array_t - &data, - int ef_construction, int num_initializations = 100, - py::object labels = py::none()) { - DataType data_type = index_type.getIndex()->dataType(); - dispatch(index_type, data_type, Add{}, data, - ef_construction, num_initializations, labels); - }, - py::arg("data"), py::arg("ef_construction"), - py::arg("num_initializations") = 100, py::arg("labels") = py::none(), - ADD_DOCSTRING) - .def("allocate_nodes", &IndexType::allocateNodes, py::arg("data"), - ALLOCATE_NODES_DOCSTRING) - .def( - "search_single", - [](IndexType &index_type, - const py::array_t - &query, - int K, int ef_search, int num_initializations = 100) { - DataType data_type = index_type.getIndex()->dataType(); - return dispatch(index_type, data_type, SearchSingle{}, - query, K, ef_search, num_initializations); - }, - py::arg("query"), py::arg("K"), py::arg("ef_search"), - py::arg("num_initializations") = 100, - SEARCH_SINGLE_DOCSTRING) - .def("get_query_distance_computations", - &IndexType::getQueryDistanceComputations, - GET_QUERY_DISTANCE_COMPUTATIONS_DOCSTRING) - .def( - "search", - [](IndexType &index_type, - const py::array_t - &queries, - int K, int ef_search, int num_initializations = 100) { - DataType data_type = index_type.getIndex()->dataType(); - return dispatch(index_type, data_type, BatchedSearch{}, - queries, K, ef_search, num_initializations); - }, - py::arg("queries"), py::arg("K"), py::arg("ef_search"), - py::arg("num_initializations") = 100, - SEARCH_DOCSTRING) - .def( - "get_graph_outdegree_table", - [](IndexType &index_type) -> std::vector> { - auto index = index_type.getIndex(); - return index->getGraphOutdegreeTable(); - }, - GET_GRAPH_OUTDEGREE_TABLE_DOCSTRING) - .def( - "build_graph_links", - [](IndexType &index_type, const std::string &mtx_filename) { - auto index = index_type.getIndex(); - index->buildGraphLinks(/* mtx_filename = */ mtx_filename); - }, - py::arg("mtx_filename"), BUILD_GRAPH_LINKS_DOCSTRING) - .def( - "reorder", - [](IndexType &index_type, - const std::vector &strategies) { - auto index = index_type.getIndex(); - // validate the given strategies - for (auto &strategy : strategies) { - auto alg = strategy; - std::transform(alg.begin(), alg.end(), alg.begin(), - [](unsigned char c) { return std::tolower(c); }); - if (alg != "gorder" && alg != "rcm") { - throw std::invalid_argument( - "`" + strategy + - "` is not a supported graph re-ordering strategy."); - } - } - index->doGraphReordering(strategies); - }, - py::arg("strategies"), REORDER_DOCSTRING) - .def( - "set_num_threads", - [](IndexType &index_type, uint32_t num_threads) { - auto *index = index_type.getIndex(); - index->setNumThreads(num_threads); - }, - py::arg("num_threads"), SET_NUM_THREADS_DOCSTRING) - .def_property_readonly( - "num_threads", - [](IndexType &index_type) { - auto *index = index_type.getIndex(); - return index->getNumThreads(); - }, - NUM_THREADS_DOCSTRING) - .def_property_readonly( - "max_edges_per_node", - [](IndexType &index_type) { - return index_type.getIndex()->maxEdgesPerNode(); - }, - MAX_EDGES_PER_NODE_DOCSTRING); -} +template struct IndexSpecialization; + +template <> struct IndexSpecialization> { + using type = PyIndex, int>; + static constexpr char *name = "IndexL2Float"; +}; + +template <> struct IndexSpecialization> { + using type = PyIndex, int>; + static constexpr char *name = "IndexL2Uint8"; +}; + +template <> struct IndexSpecialization> { + using type = PyIndex, int>; + static constexpr char *name = "IndexL2Int8"; +}; + +template <> +struct IndexSpecialization> { + using type = PyIndex, int>; + static constexpr char *name = "IndexIPFloat"; +}; + +template <> struct IndexSpecialization> { + using type = PyIndex, int>; + static constexpr char *name = "IndexIPUint8"; +}; + +template <> struct IndexSpecialization> { + using type = PyIndex, int>; + static constexpr char *name = "IndexIPInt8"; +}; void validateDistanceType(const std::string &distance_type) { auto dist_type = distance_type; @@ -620,95 +431,142 @@ void validateDistanceType(const std::string &distance_type) { } } -std::unique_ptr createL2Distance(DataType data_type, - int dim) { - switch (data_type) { - case DataType::float32: - return SquaredL2Distance::create(dim); - case DataType::int8: - return SquaredL2Distance::create(dim); - case DataType::uint8: - return SquaredL2Distance::create(dim); - default: - throw std::runtime_error("Unsupported data type"); - } -} - -std::unique_ptr -createInnerProductDistance(DataType data_type, int dim) { - switch (data_type) { - case DataType::float32: - return InnerProductDistance::create(dim); - case DataType::int8: - return InnerProductDistance::create(dim); - case DataType::uint8: - return InnerProductDistance::create(dim); - default: - throw std::runtime_error("Unsupported data type"); - } -} - -template +template py::object createIndex(const std::string &distance_type, - DataType index_data_type, int dim, - Args &&...args) { + int dim, Args &&... args) { validateDistanceType(distance_type); if (distance_type == "l2") { - auto distance = createL2Distance(index_data_type, dim); - distance->setDistanceFunctionWithType(); - return py::cast(std::make_shared( - std::move(distance), std::forward(args)...)); + auto distance = SquaredL2Distance::create(dim); + return std::make_shared, int>>( + std::move(distance), data_type, std::forward(args)...); } - auto distance = createInnerProductDistance(index_data_type, dim); - distance->setDistanceFunctionWithType(); - return py::cast(std::make_shared( - std::move(distance), std::forward(args)...)); + auto distance = InnerProductDistance::create(dim); + return std::make_shared, int>>( + std::move(distance), data_type, std::forward(args)...); } +template +void bindSpecialization(py::module_ &index_submodule) { + using IndexType = typename IndexSpecialization::type; + auto index_class = py::class_>( + index_submodule, IndexSpecialization::name); + + index_class + .def( + "add", + [](IndexType &index, const py::array &data, int ef_construction, + int num_initializations = 100, py::object labels = py::none()) { + index.add(data, ef_construction, num_initializations, labels); + }, + py::arg("data"), py::arg("ef_construction"), + py::arg("num_initializations") = 100, py::arg("labels") = py::none(), + ADD_DOCSTRING) + .def( + "allocate_nodes", + [](IndexType &index, + const py::array_t + &data) { return index.allocateNodes(data); }, + py::arg("data"), ALLOCATE_NODES_DOCSTRING) + .def( + "search_single", + [](IndexType &index, const py::array &query, int K, int ef_search, + int num_initializations = 100) { + return index.searchSingle(query, K, ef_search, num_initializations); + }, + py::arg("query"), py::arg("K"), py::arg("ef_search"), + py::arg("num_initializations") = 100, SEARCH_SINGLE_DOCSTRING) + .def( + "search", + [](IndexType &index, const py::array &queries, int K, int ef_search, + int num_initializations = 100) { + return index.search(queries, K, ef_search, num_initializations); + }, + py::arg("queries"), py::arg("K"), py::arg("ef_search"), + py::arg("num_initializations") = 100, SEARCH_DOCSTRING) + .def("get_query_distance_computations", + &IndexType::getQueryDistanceComputations, + GET_QUERY_DISTANCE_COMPUTATIONS_DOCSTRING) + .def("save", &IndexType::save, py::arg("filename"), SAVE_DOCSTRING) + .def("build_graph_links", &IndexType::buildGraphLinks, + py::arg("mtx_filename"), BUILD_GRAPH_LINKS_DOCSTRING) + .def("get_graph_outdegree_table", &IndexType::getGraphOutdegreeTable, + GET_GRAPH_OUTDEGREE_TABLE_DOCSTRING) + .def("reorder", &IndexType::reorder, py::arg("strategies"), + REORDER_DOCSTRING) + .def("set_num_threads", &IndexType::setNumThreads, py::arg("num_threads"), + SET_NUM_THREADS_DOCSTRING) + .def_static("load_index", &IndexType::loadIndex, py::arg("filename"), + LOAD_INDEX_DOCSTRING); + .def_property_readonly("num_threads", &IndexType::getNumThreads, + NUM_THREADS_DOCSTRING); +} + + void defineIndexSubmodule(py::module_ &index_submodule) { + bindSpecialization, int>( + index_submodule); + bindSpecialization, int>(index_submodule); + bindSpecialization, int>(index_submodule); + bindSpecialization, int>( + index_submodule); + bindSpecialization, int>( + index_submodule); + bindSpecialization, int>( + index_submodule); + index_submodule.def( - "index_factory", + "create", [](const std::string &distance_type, int dim, int dataset_size, - int max_edges_per_node, DataType index_data_type, - bool verbose = false, bool collect_stats = false) { - return createIndex(distance_type, index_data_type, dim, dataset_size, - max_edges_per_node, verbose, collect_stats); + int max_edges_per_node, DataType index_data_type, bool verbose = false, + bool collect_stats = false) { + switch (index_data_type) { + case DataType::float32: + return createIndex( + distance_type, dim, dataset_size, max_edges_per_node, verbose, + collect_stats); + case DataType::int8: + return createIndex(distance_type, dim, dataset_size, + max_edges_per_node, verbose, + collect_stats); + case DataType::uint8: + return createIndex(distance_type, dim, dataset_size, + max_edges_per_node, verbose, + collect_stats); + default: + throw std::runtime_error("Unsupported data type"); + } }, py::arg("distance_type"), py::arg("dim"), py::arg("dataset_size"), - py::arg("max_edges_per_node"), py::arg("index_data_type") = DataType::float32, + py::arg("max_edges_per_node"), + py::arg("index_data_type") = DataType::float32, py::arg("verbose") = false, py::arg("collect_stats") = false, CONSTRUCTOR_DOCSTRING); - - py::class_> l2_index_class( - index_submodule, "L2Index"); - bindIndexMethods(l2_index_class); - - py::class_> - ip_index_class(index_submodule, "IPIndex"); - bindIndexMethods(ip_index_class); } -void defineDataTypeSubmodule(py::module_ &data_type_submodule) { - // More enums are available, but these are the only ones that we support - // for index construction. - py::enum_(data_type_submodule, "DataType") +void defineDatatypeEnums(py::module_ &module) { + // More enums are available, but these are the only ones that we support + // for index construction. + py::enum_(module, "DataType") .value(flatnav::util::name(DataType::float32), DataType::float32) .value(flatnav::util::name(DataType::int8), DataType::int8) .value(flatnav::util::name(DataType::uint8), DataType::uint8) .export_values(); } +void defineDistanceEnums(py::module_ &module) { + py::enum_(module, "MetricType") + .value("L2", flatnav::distances::MetricType::L2) + .value("IP", flatnav::distances::MetricType::IP) + .export_values(); +} PYBIND11_MODULE(flatnav, module) { auto data_type_submodule = module.def_submodule("data_type"); - defineDataTypeSubmodule(data_type_submodule); + defineDatatypeEnums(data_type_submodule); auto index_submodule = module.def_submodule("index"); defineIndexSubmodule(index_submodule); - - - + defineDistanceEnums(module); } \ No newline at end of file diff --git a/flatnav_python/unit_tests/test_utils.py b/flatnav_python/unit_tests/test_utils.py index d4a8028..2902659 100644 --- a/flatnav_python/unit_tests/test_utils.py +++ b/flatnav_python/unit_tests/test_utils.py @@ -6,13 +6,13 @@ import os import time import flatnav -from flatnav.index import L2Index, IPIndex, index_factory +from flatnav.index import L2Index, IPIndex, create def create_index( distance_type: str, dim: int, dataset_size: int, max_edges_per_node: int ) -> Union[L2Index, IPIndex]: - index = index_factory( + index = create( distance_type=distance_type, dim=dim, dataset_size=dataset_size, diff --git a/quantization/ProductQuantization.h b/quantization/ProductQuantization.h index 06390c6..f5d6bcc 100644 --- a/quantization/ProductQuantization.h +++ b/quantization/ProductQuantization.h @@ -30,7 +30,7 @@ namespace flatnav::quantization { -using flatnav::METRIC_TYPE; +using flatnav::MetricType; using flatnav::quantization::CentroidsGenerator; template struct PQCodeManager { @@ -84,7 +84,8 @@ template struct PQCodeManager { * */ -class ProductQuantizer : public flatnav::distances::DistanceInterface { +class ProductQuantizer + : public flatnav::distances::DistanceInterface { friend class flatnav::distances::DistanceInterface; // Represents the block size used in ProductQuantizer::computePQCodes @@ -105,7 +106,7 @@ class ProductQuantizer : public flatnav::distances::DistanceInterface(_subvector_dim); - } else if (_metric_type == METRIC_TYPE::INNER_PRODUCT) { + } else if (_metric_type == MetricType::IP) { _distance = InnerProductDistance::create(_subvector_dim); } else { @@ -542,7 +543,7 @@ class ProductQuantizer : public flatnav::distances::DistanceInterface(_subvector_dim); - } else if (_metric_type == METRIC_TYPE::INNER_PRODUCT) { + } else if (_metric_type == MetricType::IP) { _distance = InnerProductDistance::create(_subvector_dim); } else { diff --git a/tools/cereal_tests.cpp b/tools/cereal_tests.cpp index 9e7a0bc..4f6cceb 100644 --- a/tools/cereal_tests.cpp +++ b/tools/cereal_tests.cpp @@ -1,13 +1,13 @@ #include "cnpy.h" #include #include -#include #include #include +#include #include -using flatnav::distances::DistanceInterface; using flatnav::Index; +using flatnav::distances::DistanceInterface; using flatnav::distances::InnerProductDistance; using flatnav::distances::SquaredL2Distance; using flatnav::util::DataType; @@ -61,9 +61,9 @@ int main(int argc, char **argv) { int N = 60000; float *data = datafile.data(); auto l2_distance = SquaredL2Distance::create(dim); - serializeIndex>(data, std::move(l2_distance), N, M, dim, - ef_construction, - std::string("l2_flatnav.bin")); + serializeIndex>( + data, std::move(l2_distance), N, M, dim, ef_construction, + std::string("l2_flatnav.bin")); // auto inner_product_distance = // std::make_unique>(dim); diff --git a/tools/construct_npy.cpp b/tools/construct_npy.cpp index a65ac8b..8ce0773 100644 --- a/tools/construct_npy.cpp +++ b/tools/construct_npy.cpp @@ -30,8 +30,7 @@ using flatnav::util::DataType; template void buildIndex(float *data, std::unique_ptr> distance, int N, - int M, int dim, int ef_construction, - int build_num_threads, + int M, int dim, int ef_construction, int build_num_threads, const std::string &save_file) { auto index = new Index( @@ -60,11 +59,9 @@ void buildIndex(float *data, delete index; } -void run(float *data, flatnav::distances::METRIC_TYPE metric_type, int N, int M, - int dim, int ef_construction, - int build_num_threads, - const std::string &save_file, - bool quantize = false) { +void run(float *data, flatnav::distances::MetricType metric_type, int N, int M, + int dim, int ef_construction, int build_num_threads, + const std::string &save_file, bool quantize = false) { if (quantize) { // Parameters M and nbits should be adjusted accordingly. @@ -84,16 +81,17 @@ void run(float *data, flatnav::distances::METRIC_TYPE metric_type, int N, int M, // ef_construction, save_file); } else { - if (metric_type == flatnav::distances::METRIC_TYPE::EUCLIDEAN) { + if (metric_type == flatnav::distances::MetricType::L2) { auto distance = SquaredL2Distance::create(dim); buildIndex>( - data, std::move(distance), N, M, dim, ef_construction, build_num_threads, save_file); - - } else if (metric_type == flatnav::distances::METRIC_TYPE::INNER_PRODUCT) { - // auto distance = InnerProductDistance::create(dim); - // distance->setDistanceFunction(); - // buildIndex(data, std::move(distance), N, M, dim, - // ef_construction, save_file); + data, std::move(distance), N, M, dim, ef_construction, + build_num_threads, save_file); + + } else if (metric_type == flatnav::distances::MetricType::IP) { + auto distance = InnerProductDistance::create(dim); + buildIndex>( + data, std::move(distance), N, M, dim, ef_construction, + build_num_threads, save_file); } } } @@ -134,14 +132,14 @@ int main(int argc, char **argv) { std::clog << "Loading " << dim << "-dimensional dataset with N = " << N << std::endl; float *data = datafile.data(); - flatnav::distances::METRIC_TYPE metric_type = - metric_id == 0 ? flatnav::distances::METRIC_TYPE::EUCLIDEAN - : flatnav::distances::METRIC_TYPE::INNER_PRODUCT; + flatnav::distances::MetricType metric_type = + metric_id == 0 ? flatnav::distances::MetricType::L2 + : flatnav::distances::MetricType::IP; run(/* data = */ data, /* metric_type = */ metric_type, /* N = */ N, /* M = */ M, /* dim = */ dim, - /* ef_construction = */ ef_construction, + /* ef_construction = */ ef_construction, /* build_num_threads = */ std::stoi(argv[6]), /* save_file = */ argv[7], /* quantize = */ quantize); diff --git a/tools/flatnav_pq.cpp b/tools/flatnav_pq.cpp index 7e0349b..a9da320 100644 --- a/tools/flatnav_pq.cpp +++ b/tools/flatnav_pq.cpp @@ -2,9 +2,9 @@ #include #include #include -#include #include #include +#include #include #include #include @@ -18,9 +18,10 @@ using flatnav::distances::InnerProductDistance; using flatnav::distances::SquaredL2Distance; template -void run(float *data, - std::unique_ptr> &&distance, int N, - int M, int dim, int ef_construction, const std::string &save_file) { +void run( + float *data, + std::unique_ptr> &&distance, + int N, int M, int dim, int ef_construction, const std::string &save_file) { auto index = new Index( /* dist = */ std::move(distance), /* dataset_size = */ N, /* max_edges = */ M); diff --git a/tools/query_npy.cpp b/tools/query_npy.cpp index e38005d..fe2d59d 100644 --- a/tools/query_npy.cpp +++ b/tools/query_npy.cpp @@ -1,8 +1,8 @@ #include #include -#include #include #include +#include #include #include #include @@ -144,23 +144,25 @@ int main(int argc, char **argv) { // /* num_gtruth = */ n_gt, /* dim = */ dim, // /* reorder = */ reorder); } else if (space_ID == 0) { - run>(/* queries = */ queries, - /* gtruth = */ gtruth, - /* index_filename = */ indexfilename, - /* ef_searches = */ ef_searches, /* K = */ k, - /* num_queries = */ num_queries, - /* num_gtruth = */ n_gt, /* dim = */ dim, - /* reorder = */ reorder); + run>( + /* queries = */ queries, + /* gtruth = */ gtruth, + /* index_filename = */ indexfilename, + /* ef_searches = */ ef_searches, /* K = */ k, + /* num_queries = */ num_queries, + /* num_gtruth = */ n_gt, /* dim = */ dim, + /* reorder = */ reorder); } else if (space_ID == 1) { - run(/* queries = */ queries, /* gtruth = */ - gtruth, - /* index_filename = */ indexfilename, - /* ef_searches = */ ef_searches, - /* K = */ k, - /* num_queries = */ num_queries, - /* num_gtruth = */ n_gt, /* dim = */ dim, - /* reorder = */ reorder); + run>( + /* queries = */ queries, /* gtruth = */ + gtruth, + /* index_filename = */ indexfilename, + /* ef_searches = */ ef_searches, + /* K = */ k, + /* num_queries = */ num_queries, + /* num_gtruth = */ n_gt, /* dim = */ dim, + /* reorder = */ reorder); } else { throw std::invalid_argument("Invalid space ID. Valid IDs are 0 and 1."); diff --git a/tools/run_query.sh b/tools/run_query.sh old mode 100644 new mode 100755 index e69de29..c36ebb3 --- a/tools/run_query.sh +++ b/tools/run_query.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# Here is how to use the script. +# This script expects a single argument, which is the name of the benchmark +# dataset you want to query with. +# For example, to query with the 'sift-128-euclidean' dataset, you would run: +# ./tools/run_query.sh sift-128-euclidean +# The script will assume that you have already built the index with the same +# dataset. If you haven't, you should run the 'run_build.sh' script first. +# The index should be under data/$1/$1.index + + + +# Reference doc from query_npy.cpp +# std::clog << "Usage: " << std::endl; +# std::clog << "query " +# " " +# << std::endl; +# std::clog << "\t : .npy files (float, float, int) " +# "from ann-benchmarks" +# << std::endl; +# std::clog << "\t : int number of links" << std::endl; +# std::clog << "\t : int " << std::endl; +# std::clog << "\t : int,int,int,int...,int " << std::endl; +# std::clog << "\t : number of neighbors " << std::endl; +# std::clog << "\t : 0 for no reordering, 1 for reordering" +# << std::endl; +# std::clog << "\t : 0 for no quantization, 1 for quantization" +# << std::endl; + + +# Make sure we're at the top level directory. +cd "$(dirname "$0")/.." + +# Make sure the user provided a dataset name. +if [ -z "$1" ] + then + echo "No dataset name provided. Please provide a dataset name." + exit 1 +fi + +INDEX="data/$1/$1.index" +if [ ! -f "$INDEX" ] + then + echo "Index $INDEX does not exist. Please build the index first." + exit 1 +fi + +# Now query the index +echo "Querying with dataset $1" +DATASET_NAME=$1 +QUERY_DATASET_PATH="data/$DATASET_NAME/$DATASET_NAME.test.npy" +GROUNDTRUTH_DATASET_PATH="data/$DATASET_NAME/$DATASET_NAME.gtruth.npy" + +# SPACE will be 0 if dataset name contains 'euclidean' otherwise 1 +if [[ $DATASET_NAME == *"euclidean"* ]] + then + SPACE=0 + else + SPACE=1 +fi + +./build/query_npy \ + $SPACE \ + $INDEX \ + $QUERY_DATASET_PATH \ + $GROUNDTRUTH_DATASET_PATH \ + 100,200,300 \ + 100 \ + 0 \ + 0 \ No newline at end of file