From b17e4552f0311ccb7bc1368d972653b88dbd0aa4 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 5 May 2023 19:01:45 +0530 Subject: [PATCH 1/3] Fix heap buffer overflow caused by prefetch. The prefetching goes past `size` since `datal` is initialized 1 index past `data`. I uncovered this issue on ASAN when SSE is enabled. Also spotted earlier here: https://github.com/nmslib/hnswlib/issues/107 --- hnswlib/hnswalg.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 7f34e62b..28666173 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -249,8 +249,8 @@ class HierarchicalNSW : public AlgorithmInterface { tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(datal + j)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j)), _MM_HINT_T0); #endif if (visited_array[candidate_id] == visited_array_tag) continue; visited_array[candidate_id] = visited_array_tag; From b5c2ebae31cd124e3a625f2de789a3496ebb2286 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sat, 13 May 2023 17:25:40 +0530 Subject: [PATCH 2/3] Fix prefetch overflow in other places. --- hnswlib/hnswalg.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 00474321..4f47060c 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -343,8 +343,8 @@ class HierarchicalNSW : public AlgorithmInterface { int candidate_id = *(data + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _mm_prefetch((char *) (visited_array + *(data + j)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); //////////// #endif if (!(visited_array[candidate_id] == visited_array_tag)) { @@ -1007,7 +1007,7 @@ class HierarchicalNSW : public AlgorithmInterface { #endif for (int i = 0; i < size; i++) { #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + i)), _MM_HINT_T0); #endif tableint cand = datal[i]; dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); From 573ab84a7f7645f98778cbb181ba762c5d2f19b5 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Mon, 15 May 2023 11:14:25 +0530 Subject: [PATCH 3/3] Use min to limit the index of next prefetch. --- hnswlib/hnswalg.h | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 4f47060c..2aff4d48 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -260,8 +260,9 @@ class HierarchicalNSW : public AlgorithmInterface { tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(datal + j)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + j)), _MM_HINT_T0); + size_t next_index = std::min(size - 1, j + 1); + _mm_prefetch((char *) (visited_array + *(datal + next_index)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + next_index)), _MM_HINT_T0); #endif if (visited_array[candidate_id] == visited_array_tag) continue; visited_array[candidate_id] = visited_array_tag; @@ -343,8 +344,9 @@ class HierarchicalNSW : public AlgorithmInterface { int candidate_id = *(data + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + j)), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + j)) * size_data_per_element_ + offsetData_, + size_t next_index = std::min(size, j + 1); + _mm_prefetch((char *) (visited_array + *(data + next_index)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); //////////// #endif if (!(visited_array[candidate_id] == visited_array_tag)) { @@ -1007,7 +1009,8 @@ class HierarchicalNSW : public AlgorithmInterface { #endif for (int i = 0; i < size; i++) { #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*(datal + i)), _MM_HINT_T0); + size_t next_index = std::min(size - 1, i + 1); + _mm_prefetch(getDataByInternalId(*(datal + next_index)), _MM_HINT_T0); #endif tableint cand = datal[i]; dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);