Skip to content

Commit

Permalink
Merge pull request #16 from BlaiseMuhirwa/compiler-flags-for-avx
Browse files Browse the repository at this point in the history
Compiler flags for AVX
  • Loading branch information
BlaiseMuhirwa authored Nov 30, 2023
2 parents 8c0f184 + cce2574 commit 5fcc8e8
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ set(CMAKE_CXX_FLAGS
-w \
-ffast-math \
-funroll-loops \
-mavx \
-mavx512f \
-ftree-vectorize")

option(CMAKE_BUILD_TYPE "Build type" Release)
Expand Down
36 changes: 18 additions & 18 deletions flatnav/util/SIMDDistanceSpecializations.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@
#include <intrin.h>
#include <stdexcept>

void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
__int64 xgetbv(unsigned int x) { return _xgetbv(x); }

#else
#include <cpuid.h>
#include <stdint.h>
#include <x86intrin.h>

/**
* @brief Queries the CPU for various bits of information about its
* capabilities, including supported instruction sets and features. This is done
Expand All @@ -63,6 +53,16 @@ __int64 xgetbv(unsigned int x) { return _xgetbv(x); }
* @param ecx An additional parameter used by some CPUID function numbers to
* provide further information about what information to retrieve.
*/
void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
__int64 xgetbv(unsigned int x) { return _xgetbv(x); }

#else
#include <cpuid.h>
#include <stdint.h>
#include <x86intrin.h>

void cpuid(int32_t cpu_info[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpu_info[0], cpu_info[1], cpu_info[2], cpu_info[3]);
}
Expand Down Expand Up @@ -220,7 +220,7 @@ static float distanceImplInnerProductSIMD16ExtAVX512(const void *x,
float PORTABLE_ALIGN64 temp_res[16];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);
_m512 sum = _mm512_set1_ps(0.0f);
__m512 sum = _mm512_set1_ps(0.0f);

while (p_x != p_end_x) {
__m512 v1 = _mm512_loadu_ps(p_x);
Expand All @@ -243,7 +243,7 @@ static float distanceImplSquaredL2SIMD16ExtAVX512(const void *x, const void *y,
float *p_x = (float *)(x);
float *p_y = (float *)(y);

float PORTABLE_ALIGN64 tmp_res[16];
float PORTABLE_ALIGN64 temp_res[16];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);

Expand All @@ -259,7 +259,7 @@ static float distanceImplSquaredL2SIMD16ExtAVX512(const void *x, const void *y,
p_y += 16;
}

_mm512_store_ps(tmp_res, sum);
_mm512_store_ps(temp_res, sum);
return temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] + temp_res[4] +
temp_res[5] + temp_res[6] + temp_res[7] + temp_res[8] + temp_res[9] +
temp_res[10] + temp_res[11] + temp_res[12] + temp_res[13] +
Expand Down Expand Up @@ -338,17 +338,17 @@ static float distanceImplInnerProductSIMD16ExtAVX(const void *x, const void *y,
}

_mm256_store_ps(temp_res, sum);
float sum = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] +
temp_res[4] + temp_res[5] + temp_res[6] + temp_res[7];
return 1.0f - sum;
float total = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] +
temp_res[4] + temp_res[5] + temp_res[6] + temp_res[7];
return 1.0f - total;
}

static float distanceImplSquaredL2SIMD16ExtAVX(const void *x, const void *y,
const size_t &dimension) {
float *p_x = (float *)(x);
float *p_y = (float *)(y);

float PORTABLE_ALIGN32 tmp_res[8];
float PORTABLE_ALIGN32 temp_res[8];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);

Expand All @@ -371,7 +371,7 @@ static float distanceImplSquaredL2SIMD16ExtAVX(const void *x, const void *y,
p_y += 8;
}

_mm256_store_ps(tmp_res, sum);
_mm256_store_ps(temp_res, sum);

return temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] + temp_res[4] +
temp_res[5] + temp_res[6] + temp_res[7];
Expand Down
18 changes: 11 additions & 7 deletions flatnav_python/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ template <typename dist_t, typename label_t> class PyIndex {

void
add(const py::array_t<float, py::array::c_style | py::array::forcecast> &data,
int ef_construction, py::object labels = py::none()) {
int ef_construction, py::object labels = py::none(),
int num_initializations = 100) {
// py::array_t<float, py::array::c_style | py::array::forcecast> means that
// the functions expects either a Numpy array of floats or a castable type
// to that type. If the given type can't be casted, pybind11 will throw an
Expand All @@ -71,7 +72,8 @@ template <typename dist_t, typename label_t> class PyIndex {
for (size_t vec_index = 0; vec_index < num_vectors; vec_index++) {
this->_index->add(/* data = */ (void *)data.data(vec_index),
/* label = */ label_id,
/* ef_construction = */ ef_construction);
/* ef_construction = */ ef_construction,
/* num_initializations = */ 100);
if (_verbose && vec_index % NUM_LOG_STEPS == 0) {
std::clog << "." << std::flush;
}
Expand All @@ -92,7 +94,8 @@ template <typename dist_t, typename label_t> class PyIndex {
label_t label_id = *node_labels.data(vec_index);
this->_index->add(/* data = */ (void *)data.data(vec_index),
/* label = */ label_id,
/* ef_construction = */ ef_construction);
/* ef_construction = */ ef_construction,
/* num_initializations = */ 100);

if (_verbose && vec_index % NUM_LOG_STEPS == 0) {
std::clog << "." << std::flush;
Expand All @@ -104,7 +107,7 @@ template <typename dist_t, typename label_t> class PyIndex {
DistancesLabelsPair
search(const py::array_t<float, py::array::c_style | py::array::forcecast>
queries,
int K, int ef_search) {
int K, int ef_search, int num_initializations = 100) {
size_t num_queries = queries.shape(0);
size_t queries_dim = queries.shape(1);

Expand All @@ -118,7 +121,8 @@ template <typename dist_t, typename label_t> class PyIndex {
for (size_t query_index = 0; query_index < num_queries; query_index++) {
std::vector<std::pair<float, label_t>> top_k = this->_index->search(
/* query = */ (const void *)queries.data(query_index), /* K = */ K,
/* ef_search = */ ef_search);
/* ef_search = */ ef_search,
/* num_initializations = */ num_initializations);

for (size_t i = 0; i < top_k.size(); i++) {
distances[query_index * K + i] = top_k[i].first;
Expand Down Expand Up @@ -164,14 +168,14 @@ void bindIndexMethods(py::class_<IndexType> &index_class) {
.def_static("load", &IndexType::loadIndex, py::arg("filename"),
"Load a FlatNav index from a given file location")
.def("add", &IndexType::add, py::arg("data"), py::arg("ef_construction"),
py::arg("labels") = py::none(),
py::arg("labels") = py::none(), py::arg("num_initializations") = 100,
"Add vectors(data) to the index with the given `ef_construction` "
"parameter and optional labels. `ef_construction` determines how "
"many "
"vertices are visited while inserting every vector in the "
"underlying graph structure.")
.def("search", &IndexType::search, py::arg("queries"), py::arg("K"),
py::arg("ef_search"),
py::arg("ef_search"), py::arg("num_initializations") = 100,
"Return top `K` closest data points for every query in the "
"provided `queries`. The results are returned as a Tuple of "
"distances and label ID's. The `ef_search` parameter determines how "
Expand Down
4 changes: 3 additions & 1 deletion flatnav_python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
omp_flag = "-Xclang -fopenmp"
INCLUDE_DIRS.extend(["/opt/homebrew/opt/libomp/include"])
EXTRA_LINK_ARGS.extend(["-lomp", "-L/opt/homebrew/opt/libomp/lib"])
elif sys.platform() == "linux":
elif sys.platform == "linux":
omp_flag = "-fopenmp"
EXTRA_LINK_ARGS.extend(["-fopenmp"])

Expand All @@ -39,6 +39,8 @@
"-ffast-math", # Enable fast math optimizations
"-funroll-loops", # Unroll loops
"-ftree-vectorize", # Vectorize where possible
"-mavx", # Enable AVX instructions
"-mavx512f", # Enable AVX-512 instructions
],
extra_link_args=EXTRA_LINK_ARGS, # Link OpenMP when linking the extension
)
Expand Down

0 comments on commit 5fcc8e8

Please sign in to comment.