Skip to content

Commit

Permalink
fix memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
blaise-muhirwa committed Dec 17, 2023
1 parent ffb2f7d commit ca6a6b4
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 89 deletions.
26 changes: 10 additions & 16 deletions flatnav/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ template <typename dist_t, typename label_t> class Index {
return;
}

parallelFor(/* start = */ 0, /* end = */ total_num_nodes,
/* num_threads = */ _num_threads, /* fn = */
[&](uint32_t row_index) {
void *vector = (float *)data + (row_index * data_dimension);
label_t label = labels[row_index];
concurrentAdd(vector, label, ef_construction,
num_initializations);
});
flatnav::parallelExecutor(
/* start = */ 0, /* end = */ total_num_nodes,
/* num_threads = */ _num_threads, /* function = */
[&](uint32_t row_index) {
void *vector = (float *)data + (row_index * data_dimension);
label_t label = labels[row_index];
concurrentAdd(vector, label, ef_construction, num_initializations);
});
}

void concurrentAdd(void *data, label_t &label, int ef_construction,
Expand Down Expand Up @@ -467,9 +467,6 @@ template <typename dist_t, typename label_t> class Index {
}
}
}

// Release the lock(unnecessary since we are exiting the scope)
lock.unlock();
}

/**
Expand Down Expand Up @@ -600,10 +597,6 @@ template <typename dist_t, typename label_t> class Index {
}
neighbors.pop();
}

// Release the lock. I don't think this is necessary since are actually
// exiting the function scope, but just in case
lock.unlock();
}

/**
Expand Down Expand Up @@ -689,7 +682,8 @@ template <typename dist_t, typename label_t> class Index {
}
}

_visited_nodes_handlers->pushHandler(/* handler = */ visited_nodes);
_visited_nodes_handlers->pushHandler(
/* handler = */ visited_nodes);

delete[] temp_data;
delete[] temp_links;
Expand Down
22 changes: 17 additions & 5 deletions flatnav/util/ParallelConstructs.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,37 @@

namespace flatnav {

/**
* @brief template for executing a function in parallel using STL's threading
* library This is preferred in lieu of OpenMP only because it will not require
* having logic for installing OpenMP on the host system while installing the
* Python library.
*
* @tparam Function
* @param start_index
* @param end_index
* @param num_threads
* @param function
*/
template <typename Function>
void parallelFor(uint32_t start, uint32_t end, uint32_t num_threads,
Function fn) {
void parallelExecutor(uint32_t start_index, uint32_t end_index,
uint32_t num_threads, Function function) {
if (num_threads == 0) {
throw std::invalid_argument("Invalid number of threads");
}

// This needs to be an atomic because mutliple threads will be
// modifying it concurrently.
std::atomic<uint32_t> current(start);
std::atomic<uint32_t> current(start_index);
std::thread thread_objects[num_threads];

auto parallel_executor = [&] {
while (true) {
uint32_t current_vector_idx = current.fetch_add(1);
if (current_vector_idx >= end) {
if (current_vector_idx >= end_index) {
break;
}
fn(current_vector_idx);
function(current_vector_idx);
}
};

Expand Down
114 changes: 61 additions & 53 deletions flatnav/util/VisitedNodesHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

#include <flatnav/util/SIMDDistanceSpecializations.h>

#include <cereal/access.hpp>
#include <cereal/archives/binary.hpp>
#include <cereal/cereal.hpp>
#include <cereal/types/memory.hpp>
#include <cstring>
#include <iostream>
#include <memory>
Expand All @@ -21,23 +17,7 @@ class VisitedNodesHandler {
uint32_t *_table;
uint32_t _table_size;

friend class cereal::access;
template <typename Archive> void serialize(Archive &archive) {
archive(_mark, _table_size);

if (Archive::is_loading::value) {
// If we are loading, allocate memory for the table and delete
// previously allocated memory if any.
delete[] _table;
_table = new uint32_t[_table_size];
}

archive(cereal::binary_data(_table, _table_size * sizeof(uint32_t)));
}

public:
VisitedNodesHandler() = default;

VisitedNodesHandler(const uint32_t size) : _mark(0), _table_size(size) {
// initialize values to 0
_table = new uint32_t[_table_size]();
Expand Down Expand Up @@ -66,19 +46,18 @@ class VisitedNodesHandler {
~VisitedNodesHandler() { delete[] _table; }

// copy constructor
VisitedNodesHandler(const VisitedNodesHandler &other) {
_table_size = other._table_size;
_mark = other._mark;
VisitedNodesHandler(const VisitedNodesHandler &other)
: _table_size(other._table_size), _mark(other._mark) {

_table = new uint32_t[_table_size];
std::memcpy(_table, other._table, _table_size * sizeof(uint32_t));
}

// move constructor
VisitedNodesHandler(VisitedNodesHandler &&other) noexcept {
_table_size = other._table_size;
_mark = other._mark;
_table = other._table;
other._table = NULL;
VisitedNodesHandler(VisitedNodesHandler &&other) noexcept
: _table_size(other._table_size), _mark(other._mark),
_table(other._table) {
other._table = nullptr;
other._table_size = 0;
other._mark = 0;
}
Expand Down Expand Up @@ -107,61 +86,90 @@ class VisitedNodesHandler {
}
};

/**
*
* @brief Manages a pool of VisitedNodesHandler objects in a thread-safe manner.
*
* This class is designed to efficiently provide and manage a pool of
* VisitedNodesHandler instances for concurrent use in multi-threaded
* environments. It ensures that each handler can be used by only one thread at
* a time without the risk of concurrent access and modification.
*
* The class preallocates a specified number of VisitedNodesHandler objects to
* eliminate the overhead of dynamic allocation during runtime. It uses a mutex
* to synchronize access to the handler pool, ensuring that only one thread can
* modify the pool at any given time. This mechanism provides both thread safety
* and improved performance by reusing handler objects instead of continuously
* creating and destroying them.
*
* When a thread requires a VisitedNodesHandler, it can call
* `pollAvailableHandler()` to retrieve an available handler from the pool. If
* the pool is empty, the function will dynamically allocate a new handler to
* ensure that the requesting thread can proceed with its task. Once the thread
* has finished using the handler, it should return it to the pool by calling
* `pushHandler()`.
*
* @note The class assumes that all threads will properly return the handlers to
* the pool after use. Failing to return a handler will deplete the pool and
* lead to dynamic allocation, negating the performance benefits.
*
* Usage example:
* @code
* ThreadSafeVisitedNodesHandler handler_pool(10, 1000);
* VisitedNodesHandler* handler = handler_pool.pollAvailableHandler();
* // Use the handler in a thread...
* handler_pool.pushHandler(handler);
* @endcode
*
* @param initial_pool_size The number of handler objects to initially create
* and store in the pool.
* @param num_elements The size of each VisitedNodesHandler, which typically
* corresponds to the number of nodes or elements that each handler is expected
* to manage.
*/
class ThreadSafeVisitedNodesHandler {
std::vector<std::unique_ptr<VisitedNodesHandler>> _handler_pool;
std::vector<VisitedNodesHandler *> _handler_pool;
std::mutex _pool_guard;
uint32_t _num_elements;
uint32_t _total_handlers_in_use;

friend class cereal::access;

template <typename Archive> void serialize(Archive &archive) {
archive(_handler_pool, _num_elements, _total_handlers_in_use);
}

public:
ThreadSafeVisitedNodesHandler() = default;
ThreadSafeVisitedNodesHandler(uint32_t initial_pool_size,
uint32_t num_elements)
: _handler_pool(initial_pool_size), _num_elements(num_elements),
_total_handlers_in_use(1) {
: _handler_pool(initial_pool_size), _num_elements(num_elements) {
for (uint32_t handler_id = 0; handler_id < _handler_pool.size();
handler_id++) {
_handler_pool[handler_id] =
std::make_unique<VisitedNodesHandler>(/* size = */ _num_elements);
new VisitedNodesHandler(/* size = */ _num_elements);
}
}

VisitedNodesHandler *pollAvailableHandler() {
std::unique_lock<std::mutex> lock(_pool_guard);

if (!_handler_pool.empty()) {
// NOTE: release() call is required here to ensure that we don't free
// the handler's memory before using it since it's under a unique pointer.
auto *handler = _handler_pool.back().release();
auto *handler = _handler_pool.back();
_handler_pool.pop_back();
return handler;
} else {
// TODO: This is not great because it assumes the caller is responsible
// enough to return this handler to the pool. If the caller doesn't return
// the handler to the pool, we will have a memory leak. This can be
// resolved by std::unique_ptr but I prefer to use a raw pointer here.
auto *handler = new VisitedNodesHandler(/* size = */ _num_elements);
_total_handlers_in_use++;
return handler;
return new VisitedNodesHandler(/* size = */ _num_elements);
}
}

void pushHandler(VisitedNodesHandler *handler) {
std::unique_lock<std::mutex> lock(_pool_guard);

_handler_pool.push_back(std::make_unique<VisitedNodesHandler>(*handler));
_handler_pool.shrink_to_fit();
_handler_pool.push_back(handler);
}

inline uint32_t getPoolSize() { return _handler_pool.size(); }

~ThreadSafeVisitedNodesHandler() = default;
~ThreadSafeVisitedNodesHandler() {
while (!_handler_pool.empty()) {
auto *handler = _handler_pool.back();
_handler_pool.pop_back();
delete handler;
}
}
};

} // namespace flatnav
40 changes: 39 additions & 1 deletion flatnav_python/install_flatnav.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,45 @@

set -ex

function check_poetry_installed() {
if ! command -v poetry &> /dev/null; then
echo "Poetry not found. Installing it now..."

curl -sSL https://install.python-poetry.org | python3 -

# Check the shell and append to poetry to PATH
SHELL_NAME=$(basename "$SHELL")
# For newer poetry versions, this might be different.
# On ubuntu x86-64, for instance, I found this to be instead
# $HOME/.local/share/pypoetry/venv/bin
POETRY_PATH="$HOME/.poetry/bin"

if [[ "$SHELL_NAME" == "zsh" ]]; then
echo "Detected zsh shell."
echo "export PATH=\"$POETRY_PATH:\$PATH\"" >> $HOME/.zshrc
source $HOME/.zshrc

elif [[ "$SHELL_NAME" == "bash" ]]; then
echo "Detected bash shell."
echo "export PATH=\"$POETRY_PATH:\$PATH\"" >> $HOME/.bashrc
source $HOME/.bashrc

else
echo "Unsupported shell for poetry installation. $SHELL_NAME"
exit 1
fi
fi
}


# Make sure we are in this directory
cd "$(dirname "$0")"

# Install poetry if not yet installed
check_poetry_installed

poetry lock && poetry install --no-root

# Activate the poetry environment
POETRY_ENV=$(poetry env info --path)

Expand All @@ -19,4 +58,3 @@ echo "Installation of wheel completed"

#Testing the wheel
$POETRY_ENV/bin/python -c "import flatnav"

37 changes: 23 additions & 14 deletions flatnav_python/python_bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#include <algorithm>
#include <flatnav/DistanceInterface.h>
#include <flatnav/Index.h>
#include <flatnav/distances/InnerProductDistance.h>
#include <flatnav/distances/SquaredL2Distance.h>
#include <iostream>
#include <memory>
#include <ostream>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <string>
#include <thread>
#include <utility>
#include <vector>

#include <flatnav/DistanceInterface.h>
#include <flatnav/Index.h>
#include <flatnav/distances/InnerProductDistance.h>
#include <flatnav/distances/SquaredL2Distance.h>

using flatnav::DistanceInterface;
using flatnav::Index;
using flatnav::InnerProductDistance;
Expand Down Expand Up @@ -68,14 +68,20 @@ template <typename dist_t, typename label_t> class PyIndex {
if (data.ndim() != 2 || data_dim != _dim) {
throw std::invalid_argument("Data has incorrect dimensions.");
}

if (labels.is_none()) {
std::vector<label_t> vec_labels(num_vectors);
std::iota(vec_labels.begin(), vec_labels.end(), 0);

this->_index->addParallel(
/* data = */ (void *)data.data(0), /* labels = */ vec_labels,
/* ef_construction = */ ef_construction,
/* num_initializations = */ num_initializations);
{
// Release python GIL while threads are running
py::gil_scoped_release gil;
this->_index->addParallel(
/* data = */ (void *)data.data(0),
/* labels = */ vec_labels,
/* ef_construction = */ ef_construction,
/* num_initializations = */ num_initializations);
}
return;
}

Expand All @@ -85,11 +91,14 @@ template <typename dist_t, typename label_t> class PyIndex {
if (vec_labels.size() != num_vectors) {
throw std::invalid_argument("Incorrect numbe of labels.");
}

this->_index->addParallel(
/* data = */ (void *)data.data(0), /* labels = */ vec_labels,
/* ef_construction = */ ef_construction,
/* num_initializations = */ num_initializations);
{
// Relase python GIL while threads are running
py::gil_scoped_release gil;
this->_index->addParallel(
/* data = */ (void *)data.data(0), /* labels = */ vec_labels,
/* ef_construction = */ ef_construction,
/* num_initializations = */ num_initializations);
}
} catch (const py::cast_error &error) {
throw std::invalid_argument("Invalid labels provided.");
}
Expand Down

0 comments on commit ca6a6b4

Please sign in to comment.