Skip to content

Commit

Permalink
more progress on inner product distance
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Nov 28, 2023
1 parent 362a337 commit 31d4f1f
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 16 deletions.
74 changes: 60 additions & 14 deletions flatnav/distances/InnerProductDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
#include <cstddef> // for size_t
#include <cstring> // for memcpy
#include <flatnav/DistanceInterface.h>
#include <flatnav/util/SIMDIntrinsics.h>
#include <functional>
#include <iostream>
#include <limits>

// This is the base distance function implementation for inner product distances
// on floating-point inputs.

namespace flatnav {

class InnerProductDistance : public DistanceInterface<InnerProductDistance> {

friend class DistanceInterface<InnerProductDistance>;
Expand All @@ -22,36 +22,40 @@ class InnerProductDistance : public DistanceInterface<InnerProductDistance> {
public:
InnerProductDistance() = default;

explicit InnerProductDistance(size_t dim) {
_dimension = dim;
_data_size_bytes = dim * sizeof(float);
explicit InnerProductDistance(size_t dim) _dimension(dim),
_data_size_bytes(dim * sizeof(float)),
_distance_computer(std::bind(&SquaredL2Distance::defaultDistanceImpl,
this, std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3)) {
setDistanceFunction();
}

float distanceImpl(const void *x, const void *y,
bool asymmetric = false) const {
(void)asymmetric;
// Default implementation of inner product distance, in case we cannot
// support the SIMD specializations for special input _dimension sizes.
float *p_x = (float *)x;
float *p_y = (float *)y;
float result = 0;
for (size_t i = 0; i < _dimension; i++) {
result += p_x[i] * p_y[i];
}
return 1.0 - result;
_distance_computer(x, y, _dimension);
}

private:
size_t _dimension;
size_t _data_size_bytes;
std::function<float(const void *, const void *, const size_t &)>
_distance_computer;

friend class cereal::access;

template <typename Archive> void serialize(Archive &ar) {
ar(_dimension);

// If loading, we need to set the data size bytes
if (Archive::is_loading::value) {
_data_size_bytes = _dimension * sizeof(float);
_distance_computer = std::bind(
&SquaredL2Distance::defaultDistanceImpl, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3);

setDistanceFunction();
}
}

Expand All @@ -68,6 +72,48 @@ class InnerProductDistance : public DistanceInterface<InnerProductDistance> {
std::cout << "-----------------------------" << std::endl;
std::cout << "Dimension: " << _dimension << std::endl;
}

void setDistanceFunction() {
#ifndef NO_MANUAL_VECTORIZATION
#if defined(USE_AVX512) || defined(USE_AVX) || defined(USE_SSE)
_distance_computer = distanceImplInnerProductSIMD16ExtSSE;
#if defined(USE_AVX512)
if (platform_supports_avx512()) {
_distance_computer = distanceImplSquaredL2SIMD16ExtAVX512;
} else if (platform_supports_avx()) {
_distance_computer = distanceImplSquaredL2SIMD16ExtAVX;
}
#elif defined(USE_AVX)
if (platform_supports_avx()) {
_distance_computer = distanceImplSquaredL2SIMD16ExtAVX;
}
#endif
if (!_dimension % 16 == 0) {
if (_dimension % 4 == 0) {
_distance_computer = distanceImplSquaredL2SIMD4Ext;
} else if (_dimension > 16) {
_distance_computer = distanceImplSquaredL2SIMD16ExtResiduals;
} else if (_dimension > 4) {
_distance_computer = distanceImplSquaredL2SIMD4ExtResiduals;
}
}

#endif
#endif // NO_MANUAL_VECTORIZATION
}

float defaultDistanceImpl(const void *x, const void *y,
const size_t &dimension) const {
// Default implementation of inner product distance, in case we cannot
// support the SIMD specializations for special input _dimension sizes.
float *p_x = (float *)x;
float *p_y = (float *)y;
float result = 0;
for (size_t i = 0; i < dimension; i++) {
result += p_x[i] * p_y[i];
}
return 1.0 - result;
}
};

} // namespace flatnav
1 change: 0 additions & 1 deletion flatnav/distances/SquaredL2Distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class SquaredL2Distance : public DistanceInterface<SquaredL2Distance> {
private:
size_t _dimension;
size_t _data_size_bytes;
// float (*_distance_computer)(const void *, const void *, size_t &) const;
std::function<float(const void *, const void *, const size_t &)>
_distance_computer;

Expand Down
234 changes: 233 additions & 1 deletion flatnav/util/SIMDIntrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,32 @@ bool platform_supports_avx512() {
#endif

#if defined(USE_AVX512)
static float distanceImplInnerProductSIMD16ExtAVX512(const void *x,
const void *y,
const size_t &dimension) {
float *p_x = (float *)(x);
float *p_y = (float *)(y);

float PORTABLE_ALIGN64 temp_res[16];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);
_m512 sum = _mm512_set1_ps(0.0f);

while (p_x != p_end_x) {
__m512 v1 = _mm512_loadu_ps(p_x);
__m512 v2 = _mm512_loadu_ps(p_y);
sum = _mm512_add_ps(sum, _mm512_mul_ps(v1, v2));
p_x += 16;
p_y += 16;
}

_mm512_store_ps(temp_res, sum);
float sum = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] +
temp_res[4] + temp_res[5] + temp_res[6] + temp_res[7] +
temp_res[8] + temp_res[9] + temp_res[10] + temp_res[11] +
temp_res[12] + temp_res[13] + temp_res[14] + temp_res[15];
return 1.0f - sum;
}

static float distanceImplSquaredL2SIMD16ExtAVX512(const void *x, const void *y,
const size_t &dimension) {
Expand Down Expand Up @@ -242,6 +268,82 @@ static float distanceImplSquaredL2SIMD16ExtAVX512(const void *x, const void *y,
#endif

#if defined(USE_AVX)
static float distanceImplInnerProductSIMD4ExtAVX(const void *x, const void *y,
const size_t &dimension) {
float *p_x = (float *)(x);
float *p_y = (float *)(y);
float PORTABLE_ALIGN32 temp_res[8];

size_t dimension_1_16 = dimension >> 4;
size_t dimension_1_4 = dimension >> 2;
const float *p_end_x1 = p_x + (dimension_1_16 << 4);
const float *p_end_x2 = p_x + (dimension_1_4 << 2);

__m256 sum256 = _mm256_set1_ps(0.0f);

while (p_x != p_end_x1) {
__m256 v1 = _mm256_loadu_ps(p_x);
__m256 v2 = _mm256_loadu_ps(p_y);
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
p_x += 8;
p_y += 8;

v1 = _mm256_loadu_ps(p_x);
v2 = _mm256_loadu_ps(p_y);
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
p_x += 8;
p_y += 8;
}

__m128 v1, v2;
__m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0),
_mm256_extractf128_ps(sum256, 1));

while (p_x != p_end_x2) {
v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;
}

_mm_store_ps(temp_res, sum_prod);
float sum = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3];

return 1.0f - sum;
}


static float distanceImplInnerProductSIMD16ExtAVX(const void *x, const void *y,
const size_t &dimension) {
float *p_x = (float *)(x);
float *p_y = (float *)(y);

float PORTABLE_ALIGN32 temp_res[8];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);
__m256 sum = _mm256_set1_ps(0.0f);

while (p_x != p_end_x) {
__m256 v1 = _mm256_loadu_ps(p_x);
__m256 v2 = _mm256_loadu_ps(p_y);
sum = _mm256_add_ps(sum, _mm256_mul_ps(v1, v2));
p_x += 8;
p_y += 8;

v1 = _mm256_loadu_ps(p_x);
v2 = _mm256_loadu_ps(p_y);
sum = _mm256_add_ps(sum, _mm256_mul_ps(v1, v2));
p_x += 8;
p_y += 8;
}

_mm256_store_ps(temp_res, sum);
float sum = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] +
temp_res[4] + temp_res[5] + temp_res[6] + temp_res[7];
return 1.0f - sum;
}

static float distanceImplSquaredL2SIMD16ExtAVX(const void *x, const void *y,
const size_t &dimension) {
float *p_x = (float *)(x);
Expand Down Expand Up @@ -277,7 +379,106 @@ static float distanceImplSquaredL2SIMD16ExtAVX(const void *x, const void *y,
}
#endif



#if defined(USE_SSE)

static float distanceImplInnerProductSIMD16ExtSSE(const void *x, const void *y,
const size_t &dimension) {
float *p_x = (float *)(x);
float *p_y = (float *)(y);

float PORTABLE_ALIGN32 temp_res[8];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);
__m128 sum = _mm_set1_ps(0.0f);
__m128 v1, v2;

while (p_x != p_end_x) {
v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum = _mm_add_ps(sum, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;

v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum = _mm_add_ps(sum, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;

v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum = _mm_add_ps(sum, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;

v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum = _mm_add_ps(sum, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;
}

_mm_store_ps(temp_res, sum);
float sum = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3];
return 1.0f - sum;
}

static float distanceImplInnerProductSIMD4ExtSSE(const void *x, const void *y, const size_t& dimension) {
float *p_x = (float *)(x);
float *p_y = (float *)(y);
float PORTABLE_ALIGN32 temp_res[8];
size_t dimension_1_4 = dimension >> 2;
size_t dimension_1_16 = dimension >> 4;

const float* p_end_x1 = p_x + (dimension_1_16 << 4);
const float* p_end_x2 = p_x + (dimension_1_4 << 2);

__m128 sum_prod = _mm_set1_ps(0.0f);
__m128 v1, v2;

while (p_x != p_end_x1) {
v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;

v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;

v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;

v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;
}

while (p_x != p_end_x2) {
v1 = _mm_loadu_ps(p_x);
v2 = _mm_loadu_ps(p_y);
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
p_x += 4;
p_y += 4;
}

_mm_store_ps(temp_res, sum_prod);
float sum = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3];
return 1.0f - sum;
}



static float distanceImplSquaredL2SIMD16ExtSSE(const void *x, const void *y,
const size_t &dimension) {
float *p_x = (float *)(x);
Expand Down Expand Up @@ -375,10 +576,41 @@ static float distanceImplSquaredL2SIMD4ExtResiduals(const void *x,

#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)

static float distanceImplInnerProductSIMD16ExtResiduals(const void *x, const void *y, const size_t& dimension) {
size_t dimension16 = dimension >> 4 << 4;
float res = distanceImplInnerProductSIMD16ExtSSE(x, y, dimension16);
size_t residual = dimension - dimension16;

float *p_x = (float *)(x) + dimension16;
float *p_y = (float *)(y) + dimension16;
float sum_res = 0;
for (size_t i = 0; i < residual; i++) {
sum_res += *p_x * *p_y;
p_x++;
p_y++;
}
return 1.0f - (res + sum_res);
}

static float distanceImplInnerProductSIMD4ExtResiduals(const void *x, const void *y, const size_t& dimension) {
size_t dimension4 = dimension >> 2 << 2;
float res = distanceImplInnerProductSIMD4ExtSSE(x, y, dimension4);
size_t residual = dimension - dimension4;

float *p_x = (float *)(x) + dimension4;
float *p_y = (float *)(y) + dimension4;
float sum_res = 0;
for (size_t i = 0; i < residual; i++) {
sum_res += *p_x * *p_y;
p_x++;
p_y++;
}
return 1.0f - (res + sum_res);
}

static float distanceImplSquaredL2SIMD16ExtResiduals(const void *x,
const void *y,
const size_t &dimension) {
// The purpose of this is to ensure that dimension is always a multiple of 16.
size_t dimension16 = dimension >> 4 << 4;
float res = distanceImplSquaredL2SIMD16ExtSSE(x, y, dimension16);
size_t residual = dimension - dimension16;
Expand Down

0 comments on commit 31d4f1f

Please sign in to comment.