Skip to content

Commit

Permalink
add compiler flags for avx
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Nov 30, 2023
1 parent 8c0f184 commit 5b9f452
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 95 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ set(CMAKE_CXX_FLAGS
-w \
-ffast-math \
-funroll-loops \
-mavx \
-mavx512f \
-ftree-vectorize")

option(CMAKE_BUILD_TYPE "Build type" Release)
Expand Down
102 changes: 8 additions & 94 deletions flatnav/util/SIMDDistanceSpecializations.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,59 +37,10 @@ __int64 xgetbv(unsigned int x) { return _xgetbv(x); }
#include <stdint.h>
#include <x86intrin.h>

/**
* @brief Queries the CPU for various bits of information about its
* capabilities, including supported instruction sets and features. This is done
* using the CPUID instruction, which is a processor supplementary instruction
* (PSI) for the x86 architecture allowing software to discover details of the
* processor.
*
* @param cpu_info An array of four 32-bit integers that will be filled with the
* CPU information. The specific information returned in cpu_info depends on the
* value of the `eax` and `ecx` registers.
* - cpu_info[0] (EAX): The function result value after the CPUID
* instruction.
* - cpu_info[1] (EBX): Additional information returned by the CPUID
* instruction.
* - cpu_info[2] (ECX): Additional information returned by the CPUID
* instruction.
* - cpu_info[3] (EDX): Additional information returned by the CPUID
* instruction.
*
* @param eax Specifies what information to retrieve. Different values
* of EAX will return different information in the cpu_info array, such
* as processor type, family, model, stepping, and feature flags.
*
* @param ecx An additional parameter used by some CPUID function numbers to
* provide further information about what information to retrieve.
*/
void cpuid(int32_t cpu_info[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpu_info[0], cpu_info[1], cpu_info[2], cpu_info[3]);
}

/**
* @brief Retrieves the value of an extended control register (XCR).
* This is particularly useful for checking the status of advanced CPU features.
*
* @param index The index of the XCR to query. For example, 0 for XCR0, which
* contains flags for x87 state, SSE state, and AVX state.
*
* @return A 64-bit value with the state of the specified XCR. The lower 32 bits
* are from the EAX register, and the higher 32 bits from the EDX register after
* the instruction executes.
*
* Inline assembly breakdown:
* - __volatile__ tells the compiler not to optimize this assembly block as its
* side effects are important.
* - "xgetbv": The assembly instruction to execute.
* - "=a"(eax), "=d"(edx): Output operands; after executing 'xgetbv', store EAX
* in 'eax', and EDX in 'edx'.
* - "c"(index): Input operand; provides the 'index' parameter to the ECX
* register before executing 'xgetbv'.
*
* The result is constructed by shifting 'edx' left by 32 bits and combining it
* with 'eax' using bitwise OR.
*/
uint64_t xgetbv(unsigned int index) {
uint32_t eax, edx;
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
Expand All @@ -114,22 +65,6 @@ uint64_t xgetbv(unsigned int index) {
// Adapted from https://github.com/Mysticial/FeatureDetector
#define _XCR_XFEATURE_ENABLED_MASK 0

/**
* Checks if the system's CPU and OS support AVX (Advanced Vector Extensions).
*
* - First, it uses the cpuid function to check if the CPU supports AVX
* instructions by examining the presence of the AVX bit in the CPU feature
* flags (ECX register, bit 28).
*
* - It then checks if the OS saves and restores AVX registers on context
* switches by checking the OS XSAVE feature flag (ECX register, bit 27) and
* confirming the OS has enabled AVX state saving with the xgetbv function. This
* function checks if the XMM and YMM registers (bits 1 and 2) are enabled in
* the XCR feature mask.
*
* - The function returns true if both hardware and OS-level AVX support are
* detected and enabled.
*/
bool platform_supports_avx() {
int cpu_info[4];

Expand Down Expand Up @@ -157,27 +92,6 @@ bool platform_supports_avx() {
return HW_AVX && avxSupported;
}

/**
* Checks if the system's CPU and OS support AVX-512 (Advanced Vector Extensions
* 512).
*
* - Initially, it verifies AVX capability since AVX-512 is an extension of AVX.
*
* - It uses the cpuid function to check for AVX-512 Foundation support by
* querying the presence of the AVX-512F feature flag (EBX register, bit 16) for
* the CPU.
*
* - Ensures the OS supports context switch saving for AVX-512 registers by
* checking the OS XSAVE feature flag (ECX register, bit 27) and that AVX is
* supported (bit 28).
*
* - Checks the OS has enabled AVX-512 state saving with the xgetbv function,
* looking for specific bits in the XCR feature mask that correspond to AVX-512
* registers.
*
* - Returns true if both hardware and OS-level AVX-512 support are present and
* enabled.
*/
bool platform_supports_avx512() {
if (!platform_supports_avx()) {
return false;
Expand Down Expand Up @@ -220,7 +134,7 @@ static float distanceImplInnerProductSIMD16ExtAVX512(const void *x,
float PORTABLE_ALIGN64 temp_res[16];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);
_m512 sum = _mm512_set1_ps(0.0f);
__m512 sum = _mm512_set1_ps(0.0f);

while (p_x != p_end_x) {
__m512 v1 = _mm512_loadu_ps(p_x);
Expand All @@ -243,7 +157,7 @@ static float distanceImplSquaredL2SIMD16ExtAVX512(const void *x, const void *y,
float *p_x = (float *)(x);
float *p_y = (float *)(y);

float PORTABLE_ALIGN64 tmp_res[16];
float PORTABLE_ALIGN64 temp_res[16];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);

Expand All @@ -259,7 +173,7 @@ static float distanceImplSquaredL2SIMD16ExtAVX512(const void *x, const void *y,
p_y += 16;
}

_mm512_store_ps(tmp_res, sum);
_mm512_store_ps(temp_res, sum);
return temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] + temp_res[4] +
temp_res[5] + temp_res[6] + temp_res[7] + temp_res[8] + temp_res[9] +
temp_res[10] + temp_res[11] + temp_res[12] + temp_res[13] +
Expand Down Expand Up @@ -338,17 +252,17 @@ static float distanceImplInnerProductSIMD16ExtAVX(const void *x, const void *y,
}

_mm256_store_ps(temp_res, sum);
float sum = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] +
temp_res[4] + temp_res[5] + temp_res[6] + temp_res[7];
return 1.0f - sum;
float total = temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] +
temp_res[4] + temp_res[5] + temp_res[6] + temp_res[7];
return 1.0f - total;
}

static float distanceImplSquaredL2SIMD16ExtAVX(const void *x, const void *y,
const size_t &dimension) {
float *p_x = (float *)(x);
float *p_y = (float *)(y);

float PORTABLE_ALIGN32 tmp_res[8];
float PORTABLE_ALIGN32 temp_res[8];
size_t dimension_1_16 = dimension >> 4;
const float *p_end_x = p_x + (dimension_1_16 << 4);

Expand All @@ -371,7 +285,7 @@ static float distanceImplSquaredL2SIMD16ExtAVX(const void *x, const void *y,
p_y += 8;
}

_mm256_store_ps(tmp_res, sum);
_mm256_store_ps(temp_res, sum);

return temp_res[0] + temp_res[1] + temp_res[2] + temp_res[3] + temp_res[4] +
temp_res[5] + temp_res[6] + temp_res[7];
Expand Down
4 changes: 3 additions & 1 deletion flatnav_python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
omp_flag = "-Xclang -fopenmp"
INCLUDE_DIRS.extend(["/opt/homebrew/opt/libomp/include"])
EXTRA_LINK_ARGS.extend(["-lomp", "-L/opt/homebrew/opt/libomp/lib"])
elif sys.platform() == "linux":
elif sys.platform == "linux":
omp_flag = "-fopenmp"
EXTRA_LINK_ARGS.extend(["-fopenmp"])

Expand All @@ -39,6 +39,8 @@
"-ffast-math", # Enable fast math optimizations
"-funroll-loops", # Unroll loops
"-ftree-vectorize", # Vectorize where possible
"-mavx", # Enable AVX instructions
"-mavx512f", # Enable AVX-512 instructions
],
extra_link_args=EXTRA_LINK_ARGS, # Link OpenMP when linking the extension
)
Expand Down

0 comments on commit 5b9f452

Please sign in to comment.