Skip to content

Commit

Permalink
finish adding simd extensions for inner product distance
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Nov 29, 2023
1 parent 31d4f1f commit 54aa687
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 777 deletions.
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,7 @@ endif()
# every header file
set(HEADERS
${PROJECT_SOURCE_DIR}/flatnav/distances/InnerProductDistance.h
${PROJECT_SOURCE_DIR}/flatnav/distances/InnerProductDistanceSpecializations.h
${PROJECT_SOURCE_DIR}/flatnav/distances/inner_products_from_hnswlib.h
${PROJECT_SOURCE_DIR}/flatnav/distances/SquaredL2Distance.h
${PROJECT_SOURCE_DIR}/flatnav/distances/SquaredL2DistanceSpecializations.h
${PROJECT_SOURCE_DIR}/flatnav/util/ExplicitSet.h
${PROJECT_SOURCE_DIR}/flatnav/util/GorderPriorityQueue.h
${PROJECT_SOURCE_DIR}/flatnav/util/Reordering.h
Expand Down
30 changes: 16 additions & 14 deletions flatnav/distances/InnerProductDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <iostream>
#include <limits>

namespace flatnav {

// This is the base distance function implementation for inner product distances
// on floating-point inputs.

Expand All @@ -22,12 +24,12 @@ class InnerProductDistance : public DistanceInterface<InnerProductDistance> {
public:
InnerProductDistance() = default;

explicit InnerProductDistance(size_t dim) _dimension(dim),
_data_size_bytes(dim * sizeof(float)),
_distance_computer(std::bind(&SquaredL2Distance::defaultDistanceImpl,
this, std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3)) {
explicit InnerProductDistance(size_t dim)
: _dimension(dim), _data_size_bytes(dim * sizeof(float)),
_distance_computer(std::bind(&InnerProductDistance::defaultDistanceImpl,
this, std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3)) {
setDistanceFunction();
}

Expand All @@ -52,7 +54,7 @@ class InnerProductDistance : public DistanceInterface<InnerProductDistance> {
if (Archive::is_loading::value) {
_data_size_bytes = _dimension * sizeof(float);
_distance_computer = std::bind(
&SquaredL2Distance::defaultDistanceImpl, this, std::placeholders::_1,
&InnerProductDistance::defaultDistanceImpl, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3);

setDistanceFunction();
Expand All @@ -79,26 +81,26 @@ class InnerProductDistance : public DistanceInterface<InnerProductDistance> {
_distance_computer = distanceImplInnerProductSIMD16ExtSSE;
#if defined(USE_AVX512)
if (platform_supports_avx512()) {
_distance_computer = distanceImplSquaredL2SIMD16ExtAVX512;
_distance_computer = distanceImplInnerProductSIMD16ExtAVX512;
} else if (platform_supports_avx()) {
_distance_computer = distanceImplSquaredL2SIMD16ExtAVX;
_distance_computer = distanceImplInnerProductSIMD16ExtAVX;
}
#elif defined(USE_AVX)
if (platform_supports_avx()) {
_distance_computer = distanceImplSquaredL2SIMD16ExtAVX;
_distance_computer = distanceImplInnerProductSIMD16ExtAVX;
}
#endif
if (!_dimension % 16 == 0) {
if (_dimension % 4 == 0) {
_distance_computer = distanceImplSquaredL2SIMD4Ext;
_distance_computer = distanceImplInnerProductSIMD4ExtSSE;
} else if (_dimension > 16) {
_distance_computer = distanceImplSquaredL2SIMD16ExtResiduals;
_distance_computer = distanceImplInnerProductSIMD16ExtResiduals;
} else if (_dimension > 4) {
_distance_computer = distanceImplSquaredL2SIMD4ExtResiduals;
_distance_computer = distanceImplInnerProductSIMD4ExtResiduals;
}
}

#endif
#endif // USE_AVX512 || USE_AVX || USE_SSE
#endif // NO_MANUAL_VECTORIZATION
}

Expand Down
88 changes: 0 additions & 88 deletions flatnav/distances/InnerProductDistanceSpecializations.h

This file was deleted.

Loading

0 comments on commit 54aa687

Please sign in to comment.