Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
- name: Run examples
run: |
cd build/
./examples
./examples/examples
shell: bash

- name: Test
Expand Down
14 changes: 1 addition & 13 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ if(WALNUTS_BUILD_STAN)
target_include_directories(BridgeStan::BridgeStan INTERFACE
"${CMAKE_BINARY_DIR}/_deps/bridgestan-src/src"
)

add_executable(examples_stan ${CMAKE_CURRENT_SOURCE_DIR}/examples/examples_stan.cpp)
target_link_libraries(examples_stan PRIVATE Eigen3::Eigen nuts::nuts BridgeStan::BridgeStan)
target_compile_options(examples_stan PRIVATE -Wall)
endif()

#############################
Expand All @@ -132,17 +128,9 @@ add_library(nuts::nuts ALIAS walnuts)
## Examples ##
##########################
if (WALNUTS_BUILD_EXAMPLES)
add_executable(examples ${CMAKE_CURRENT_SOURCE_DIR}/examples/examples.cpp)
target_link_libraries(examples PRIVATE nuts::nuts)
target_compile_options(examples PRIVATE -O3 -Wall)
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/examples")
endif()

##########################
## Extras ##
##########################
if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/extras/CMakeLists.txt")
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/extras")
endif()

##########################
## Tests ##
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ For example, to build and run the example:

```bash
cmake --build . --target examples
./examples
./examples/examples
```


Expand Down
24 changes: 24 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
add_executable(examples examples.cpp)
target_link_libraries(examples PRIVATE nuts::nuts)
target_compile_options(examples PRIVATE -O3 -Wall)



if(WALNUTS_BUILD_STAN)
add_executable(examples_stan examples_stan.cpp)
target_link_libraries(examples_stan PRIVATE Eigen3::Eigen nuts::nuts BridgeStan::BridgeStan)
target_compile_options(examples_stan PRIVATE -Wall)


FetchContent_Declare(
cli11
GIT_REPOSITORY https://github.com/CLIUtils/CLI11.git
GIT_TAG v2.5.0
)

FetchContent_MakeAvailable(cli11)

add_executable(stan_cli stan_cli.cpp)
target_link_libraries(stan_cli PRIVATE Eigen3::Eigen nuts::nuts BridgeStan::BridgeStan CLI11::CLI11)
target_compile_options(stan_cli PRIVATE -Wall)
endif()
75 changes: 39 additions & 36 deletions examples/load_stan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,23 @@ static char* dlerror() {

struct dlclose_deleter {
void operator()(void* handle) const {
if (handle) {
dlclose(handle);
}
// TODO: Crashes on some systems, see
// https://github.com/flatironinstitute/walnuts/pull/25#discussion_r2298576937
// if (handle) {
// dlclose(handle);
// }
}
};

using dynamic_library = std::unique_ptr<void, dlclose_deleter>;

inline dynamic_library dlopen_safe(const char* path) {
auto handle = dlopen(path, RTLD_NOW);
auto handle = dlopen(path, RTLD_NOW | RTLD_NODELETE);
if (!handle) {
throw std::runtime_error(std::string("Error loading library '") + path +
"': " + dlerror());
}
return std::unique_ptr<void, dlclose_deleter>(handle);
return dynamic_library(handle);
}

template <typename T>
Expand All @@ -64,42 +66,43 @@ inline T dlsym_cast_impl(dynamic_library& library, const char* name) {
#define dlsym_cast(library, func) \
dlsym_cast_impl<decltype(&func)>(library, #func)

template <typename T>
void no_op_deleter(T*) {}
using unique_bs_model = std::unique_ptr<bs_model, decltype(&bs_model_destruct)>;

inline unique_bs_model make_model(dynamic_library& library, const char* data,
unsigned int seed) {
auto model_construct = dlsym_cast(library, bs_model_construct);
auto model_destruct = dlsym_cast(library, bs_model_destruct);
char* err = nullptr;
auto model_ptr =
unique_bs_model(model_construct(data, seed, &err), model_destruct);
if (!model_ptr) {
if (err) {
std::string error_string(err);
dlsym_cast(library, bs_free_error_msg)(err);
throw std::runtime_error(error_string);
}
throw std::runtime_error("Failed to construct model");
}
return model_ptr;
}

class DynamicStanModel {
public:
DynamicStanModel(const char* model_path, const char* data, unsigned int seed)
: library_(dlopen_safe(model_path)),
model_ptr_(nullptr, no_op_deleter<bs_model>),
rng_ptr_(nullptr, no_op_deleter<bs_rng>) {
auto model_construct = dlsym_cast(library_, bs_model_construct);
auto model_destruct = dlsym_cast(library_, bs_model_destruct);
model_ptr_(make_model(library_, data, seed)),
free_error_msg_(dlsym_cast(library_, bs_free_error_msg)),
param_unc_num_(dlsym_cast(library_, bs_param_unc_num)),
param_num_(dlsym_cast(library_, bs_param_num)),
log_density_gradient_(dlsym_cast(library_, bs_log_density_gradient)),
param_constrain_(dlsym_cast(library_, bs_param_constrain)),
param_names_(dlsym_cast(library_, bs_param_names)),
rng_ptr_(nullptr, [](auto) {}) {
// temporary: we probably don't want to store the RNG in the model
// due to thread safety concerns
auto rng_construct = dlsym_cast(library_, bs_rng_construct);
auto rng_destruct = dlsym_cast(library_, bs_rng_destruct);

free_error_msg_ = dlsym_cast(library_, bs_free_error_msg);
param_unc_num_ = dlsym_cast(library_, bs_param_unc_num);
param_num_ = dlsym_cast(library_, bs_param_num);
log_density_gradient_ = dlsym_cast(library_, bs_log_density_gradient);
param_constrain_ = dlsym_cast(library_, bs_param_constrain);
param_names_ = dlsym_cast(library_, bs_param_names);

char* err = nullptr;
model_ptr_ = std::unique_ptr<bs_model, decltype(&bs_model_destruct)>(
model_construct(data, seed, &err), model_destruct);

if (!model_ptr_) {
if (err) {
std::string error_string(err);
free_error_msg_(err);
throw std::runtime_error(error_string);
}
throw std::runtime_error("Failed to construct model");
}

// temporary: we probably don't want to store the RNG in the model
// due to thread safety concerns
rng_ptr_ = std::unique_ptr<bs_rng, decltype(&bs_rng_destruct)>(
rng_construct(seed, &err), rng_destruct);
if (!rng_ptr_) {
Expand Down Expand Up @@ -177,15 +180,15 @@ class DynamicStanModel {
}

private:
std::unique_ptr<void, dlclose_deleter> library_;
std::unique_ptr<bs_model, decltype(&bs_model_destruct)> model_ptr_;
std::unique_ptr<bs_rng, decltype(&bs_rng_destruct)> rng_ptr_;
dynamic_library library_;
unique_bs_model model_ptr_;
decltype(&bs_free_error_msg) free_error_msg_;
decltype(&bs_param_unc_num) param_unc_num_;
decltype(&bs_param_num) param_num_;
decltype(&bs_log_density_gradient) log_density_gradient_;
decltype(&bs_param_constrain) param_constrain_;
decltype(&bs_param_names) param_names_;
std::unique_ptr<bs_rng, decltype(&bs_rng_destruct)> rng_ptr_;
};

#endif
Loading
Loading