Skip to content

Commit

Permalink
cleaned up the python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Jul 30, 2024
1 parent 0e4f180 commit 51d67dc
Show file tree
Hide file tree
Showing 29 changed files with 913 additions and 1,146 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ endif()
set(HEADERS
${PROJECT_SOURCE_DIR}/flatnav/distances/InnerProductDistance.h
${PROJECT_SOURCE_DIR}/flatnav/distances/SquaredL2Distance.h
${PROJECT_SOURCE_DIR}/flatnav/distances/L2DistanceDispatcher.h
${PROJECT_SOURCE_DIR}/flatnav/distances/IPDistanceDispatcher.h
${PROJECT_SOURCE_DIR}/flatnav/util/SquaredL2SimdExtensions.h
${PROJECT_SOURCE_DIR}/flatnav/util/InnerProductSimdExtensions.h
${PROJECT_SOURCE_DIR}/flatnav/util/VisitedSetPool.h
Expand Down
6 changes: 6 additions & 0 deletions docs/cpp_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ C++ API Documentation
:members:
:protected-members:
:undoc-members:

.. doxygenclass:: flatnav::util::DataType
:project: FlatNav
:members:
:protected-members:
:undoc-members:
4 changes: 2 additions & 2 deletions docs/flatnav_python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ FlatNav Index Module

This module provides interfaces to create and manipulate FlatNav index structures.

index_factory
create
-------------

.. autofunction:: flatnav.index.index_factory
.. autofunction:: flatnav.index.create

Index Classes
-------------
Expand Down
5 changes: 4 additions & 1 deletion experiments/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ sift-bench-flatnav:
--queries /root/data/sift-128-euclidean/sift-128-euclidean.test.npy \
--gtruth /root/data/sift-128-euclidean/sift-128-euclidean.gtruth.npy \
--index-type flatnav \
--data-type uint8 \
--data-type float32 \
--num-node-links 32 \
--ef-construction 30 40 50 100 200 300 \
--ef-search 100 200 300 500 1000 3000 \
Expand Down Expand Up @@ -621,6 +621,9 @@ s3-push:
cleanup:
rm -rf hnswlib-original

test-flatnav:
poetry run python test_flatnav.py

# If passed an invalid argument, print help message
%:
@echo "Invalid argument: $@"
Expand Down
10 changes: 5 additions & 5 deletions experiments/run-benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

def compute_metrics(
requested_metrics: List[str],
index: Union[flatnav.index.L2Index, flatnav.index.IPIndex, hnswlib.Index],
index: Union[hnswlib.Index, flatnav.index.IndexL2Float, flatnav.index.IndexIPFloat],
queries: np.ndarray,
ground_truth: np.ndarray,
ef_search: int,
Expand Down Expand Up @@ -169,7 +169,7 @@ def train_index(
use_hnsw_base_layer: bool = False,
hnsw_base_layer_filename: Optional[str] = None,
num_build_threads: int = 1,
) -> Union[flatnav.index.L2Index, flatnav.index.IPIndex, hnswlib.Index]:
) -> Union[flatnav.index.IndexL2Float, flatnav.index.IndexIPFloat, hnswlib.Index]:
"""
Creates and trains an index on the given dataset.
:param train_dataset: The dataset to train the index on.
Expand Down Expand Up @@ -220,7 +220,7 @@ def train_index(
if not os.path.exists(hnsw_base_layer_filename):
raise ValueError(f"Failed to create {hnsw_base_layer_filename=}")

index = flatnav.index.index_factory(
index = flatnav.index.create(
distance_type=distance_type,
index_data_type=FLATNAV_DATA_TYPES[data_type],
dim=dim,
Expand All @@ -239,7 +239,7 @@ def train_index(
os.remove(hnsw_base_layer_filename)

else:
index = flatnav.index.index_factory(
index = flatnav.index.create(
distance_type=distance_type,
index_data_type=FLATNAV_DATA_TYPES[data_type],
dim=dim,
Expand Down Expand Up @@ -542,7 +542,7 @@ def plot_all_metrics(

create_plot(
experiment_runs=experiment_runs,
raw=True,
raw=False,
x_scale="linear",
y_scale="linear",
x_axis_metric=x_metric,
Expand Down
10 changes: 1 addition & 9 deletions flatnav/distances/DistanceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,13 @@

#include <cereal/access.hpp>
#include <cstddef> // for size_t
#include <flatnav/util/Datatype.h>
#include <fstream> // for ifstream, ofstream
#include <functional>
#include <iostream>

namespace flatnav::distances {

using util::DataType;
typedef std::function<float(const void *, const void *, const size_t &)>
DistanceFunction;

typedef std::unique_ptr<DistanceFunction> DistanceFunctionPtr;

enum class METRIC_TYPE { EUCLIDEAN, INNER_PRODUCT };
enum class MetricType { L2, IP };

// We use the CRTP to implement static polymorphism on the distance. This is
// done to allow for metrics and distance functions that support arbitrary
Expand Down Expand Up @@ -43,7 +36,6 @@ template <typename T> class DistanceInterface {
// Prints the parameters of the distance function.
void getSummary() { static_cast<T *>(this)->getSummaryImpl(); }


// This transforms the data located at src into a form that is writeable
// to disk / storable in RAM. For distance functions that don't
// compress the input, this just passses through a copy from src to
Expand Down
103 changes: 103 additions & 0 deletions flatnav/distances/IPDistanceDispatcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#pragma once

#include <flatnav/util/Datatype.h>
#include <flatnav/util/InnerProductSimdExtensions.h>
#include <flatnav/util/Macros.h>

namespace flatnav::distances {

template <typename T>
static float defaultInnerProduct(const T *x, const T *y,
const size_t &dimension) {
float inner_product = 0;
for (size_t i = 0; i < dimension; i++) {
inner_product += x[i] * y[i];
}
return 1.0f - inner_product;
}

template <typename T> struct InnerProductImpl {
static float computeDistance(const T *x, const T *y,
const size_t &dimension) {
return defaultInnerProduct<T>(x, y, dimension);
}
};

template <> struct InnerProductImpl<float> {
static float computeDistance(const float *x, const float *y,
const size_t &dimension) {
#if defined(USE_AVX512)
if (platformSupportsAvx512()) {
if (dimension % 16 == 0) {
return util::computeIP_Avx512(x, y, dimension);
}
if (dimension % 4 == 0) {
#if defined(USE_AVX)
return util::computeIP_Avx_4aligned(x, y, dimension);
#else
return util::computeIP_Sse4Aligned(x, y, dimension);
#endif
} else if (dimension > 16) {
return util::computeIP_SseWithResidual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeIP_SseWithResidual_4(x, y, dimension);
}
}
#endif

#if defined(USE_AVX)
if (platformSupportsAvx()) {
if (dimension % 16 == 0) {
return util::computeIP_Avx(x, y, dimension);
}
if (dimension % 4 == 0) {
return util::computeIP_Avx_4aligned(x, y, dimension);
} else if (dimension > 16) {
return util::computeIP_SseWithResidual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeIP_SseWithResidual_4(x, y, dimension);
}
}
#endif

#if defined(USE_SSE)
if (dimension % 16 == 0) {
return util::computeIP_Sse(x, y, dimension);
}
if (dimension % 4 == 0) {
return util::computeIP_Sse_4aligned(x, y, dimension);
} else if (dimension > 16) {
return util::computeIP_SseWithResidual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeIP_SseWithResidual_4(x, y, dimension);
}

#endif
return defaultInnerProduct<float>(x, y, dimension);
}
};

// TODO: Include SIMD optimized implementations for int8_t.
template <> struct InnerProductImpl<int8_t> {
static float computeDistance(const int8_t *x, const int8_t *y,
const size_t &dimension) {
return defaultInnerProduct<int8_t>(x, y, dimension);
}
};

// TODO: Include SIMD optimized implementations for uint8_t.
template <> struct InnerProductImpl<uint8_t> {
static float computeDistance(const uint8_t *x, const uint8_t *y,
const size_t &dimension) {
return defaultInnerProduct<uint8_t>(x, y, dimension);
}
};

struct IPDistanceDispatcher {
template <typename T>
static float dispatch(const T *x, const T *y, const size_t &dimension) {
return InnerProductImpl<T>::computeDistance(x, y, dimension);
}
};

} // namespace flatnav::distances
Loading

0 comments on commit 51d67dc

Please sign in to comment.