Skip to content

Commit b6d50de

Browse files
committed
[DL] check register_alpaqa_problem_version() before loading
1 parent ba7fdb7 commit b6d50de

File tree

9 files changed

+71
-29
lines changed

9 files changed

+71
-29
lines changed

doxygen/pages/Problem-formulations.md

+4
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ register_alpaqa_problem(alpaqa_register_arg_t user_data) noexcept try {
275275
} catch (...) {
276276
return {.exception = new alpaqa_exception_ptr_t{std::current_exception()}};
277277
}
278+
279+
/// Used by @ref alpaqa::dl::DLProblem to ensure binary compatibility.
280+
extern "C" alpaqa_dl_abi_version_t
281+
register_alpaqa_problem_version() { return ALPAQA_DL_ABI_VERSION; }
278282
```
279283

280284
A full example can be found in @ref problems/sparse-logistic-regression.cpp.

examples/C++/DLProblem/CMakeLists.txt

+7-15
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,16 @@
1-
add_library(problem-c MODULE problem.c)
2-
set_target_properties(problem-c PROPERTIES PREFIX "" DEBUG_POSTFIX "")
3-
4-
# DLL import/export
5-
include(GenerateExportHeader)
6-
generate_export_header(problem-c
7-
EXPORT_FILE_NAME problem-c-export.h)
8-
set_target_properties(problem-c PROPERTIES
9-
C_VISIBILITY_PRESET "hidden"
10-
VISIBILITY_INLINES_HIDDEN true)
1+
# Build the loadable problem
2+
alpaqa_add_dl_problem_module(problem-c FILES problem.c)
113
target_compile_features(problem-c PRIVATE c_std_11)
12-
target_include_directories(problem-c PRIVATE
13-
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>)
14-
target_link_libraries(problem-c PRIVATE alpaqa::dl-api alpaqa::warnings)
4+
target_link_libraries(problem-c PRIVATE alpaqa::warnings)
155

6+
# Build the driver program that loads the problem
167
add_executable(dl-problem-example main.cpp)
178
target_link_libraries(dl-problem-example
189
PRIVATE alpaqa::alpaqa alpaqa::warnings alpaqa::dl-loader)
19-
target_compile_definitions(dl-problem-example PRIVATE
20-
DLPROBLEM_DLL=\"$<TARGET_FILE_NAME:problem-c>\")
10+
target_compile_definitions(dl-problem-example
11+
PRIVATE DLPROBLEM_DLL=\"$<TARGET_FILE_NAME:problem-c>\")
2112
add_dependencies(dl-problem-example problem-c)
2213

14+
# Stand-alone executable to test the matrix functions used by problem-c
2315
add_executable(test-matmul "test-matmul.c")
2416
target_link_libraries(test-matmul PRIVATE alpaqa::dl-api alpaqa::warnings)

examples/C++/DLProblem/problem.c

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <alpaqa/dl/dl-problem.h>
2-
#include <problem-c-export.h>
2+
#include <problem-c/export.h>
33

44
#include <stdlib.h>
55
#include <string.h>
@@ -114,3 +114,7 @@ register_alpaqa_problem(alpaqa_register_arg_t user_data) {
114114
result.functions = &problem->functions;
115115
return result;
116116
}
117+
118+
PROBLEM_C_EXPORT alpaqa_dl_abi_version_t register_alpaqa_problem_version(void) {
119+
return ALPAQA_DL_ABI_VERSION;
120+
}

examples/examples.dox

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@
187187
* problem, and a pointer to the provided functions (the `funcs` member from
188188
* earlier).
189189
*
190-
* The entry point also accepts a pointer to type-erased user data. You can
190+
* The entry point also accepts an argument with type-erased user data. You can
191191
* use this to pass any additional data to the problem constructor, such as
192192
* problem parameters. When using the `alpaqa-driver` program, the type of this
193193
* user data is always `std::any *`. Inside of the `std::any`, there is a span

examples/problems/sparse-logistic-regression.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,11 @@ register_alpaqa_problem(alpaqa_register_arg_t user_data_v) noexcept try {
260260
} catch (...) {
261261
return {.exception = new alpaqa_exception_ptr_t{std::current_exception()}};
262262
}
263+
264+
/// Returns the alpaqa DL ABI version. This version is verified for
265+
/// compatibility by the @ref alpaqa::dl::DLProblem constructor before
266+
/// registering the problem.
267+
extern "C" SPARSE_LOGISTIC_REGRESSION_EXPORT alpaqa_dl_abi_version_t
268+
register_alpaqa_problem_version() {
269+
return ALPAQA_DL_ABI_VERSION;
270+
}

src/cmake/dl-problem.cmake

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ function(alpaqa_configure_dl_problem_visibility target)
99
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
1010
set(VERSION_SCRIPT "${CMAKE_CURRENT_BINARY_DIR}/${target}-export.lds")
1111
file(WRITE ${VERSION_SCRIPT}
12-
"{ global: ${ALPAQA_CONFIG_VIS_FUNCTION_NAME}; local: *; };")
12+
"{ local: *;"
13+
" global: ${ALPAQA_CONFIG_VIS_FUNCTION_NAME};"
14+
" global: ${ALPAQA_CONFIG_VIS_FUNCTION_NAME}_version;"
15+
" global: _ZTI*;"
16+
" global: _ZTS*; };")
1317
target_link_options(${target} PRIVATE
1418
"LINKER:--version-script=${VERSION_SCRIPT}"
1519
"LINKER:--exclude-libs,ALL")

src/interop/dl-api/include/alpaqa/dl/dl-problem.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ extern "C" {
3131
typedef double alpaqa_real_t;
3232
typedef ptrdiff_t alpaqa_length_t;
3333
typedef alpaqa_length_t alpaqa_index_t;
34+
typedef uint64_t alpaqa_dl_abi_version_t;
3435

3536
/// User-provided argument that is passed to the problem registration functions.
3637
ALPAQA_BEGIN_STRUCT(alpaqa_register_arg_t) {
@@ -418,7 +419,7 @@ typedef struct alpaqa_exception_ptr_s alpaqa_exception_ptr_t;
418419
ALPAQA_BEGIN_STRUCT(alpaqa_problem_register_t) {
419420
/// To check whether the loaded problem is compatible with the version of
420421
/// the solver.
421-
uint64_t abi_version ALPAQA_DEFAULT(ALPAQA_DL_ABI_VERSION);
422+
alpaqa_dl_abi_version_t abi_version ALPAQA_DEFAULT(ALPAQA_DL_ABI_VERSION);
422423
/// Owning pointer.
423424
void *instance ALPAQA_DEFAULT(nullptr);
424425
/// Non-owning pointer, lifetime at least as long as @ref instance.
@@ -594,7 +595,7 @@ ALPAQA_END_STRUCT(alpaqa_control_problem_functions_t);
594595
ALPAQA_BEGIN_STRUCT(alpaqa_control_problem_register_t) {
595596
/// To check whether the loaded problem is compatible with the version of
596597
/// the solver.
597-
uint64_t abi_version ALPAQA_DEFAULT(ALPAQA_DL_ABI_VERSION);
598+
alpaqa_dl_abi_version_t abi_version ALPAQA_DEFAULT(ALPAQA_DL_ABI_VERSION);
598599
/// Owning pointer.
599600
void *instance ALPAQA_DEFAULT(nullptr);
600601
/// Non-owning pointer, lifetime at least as long as @ref instance.

src/interop/dl/include/alpaqa/dl/dl-problem.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717

1818
namespace alpaqa::dl {
1919

20+
struct DL_LOADER_EXPORT invalid_abi_error : std::runtime_error {
21+
using std::runtime_error::runtime_error;
22+
};
23+
24+
struct DL_LOADER_EXPORT function_load_error : std::runtime_error {
25+
using std::runtime_error::runtime_error;
26+
};
27+
2028
class ExtraFuncs {
2129
public:
2230
/// Unique type for calling an extra function that is a member function.

src/interop/dl/src/dl-problem.cpp

+30-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <algorithm>
77
#include <cassert>
88
#include <charconv>
9+
#include <iostream>
910
#include <list>
1011
#include <memory>
1112
#include <mutex>
@@ -35,7 +36,7 @@ void check_abi_version(uint64_t abi_version) {
3536
if (abi_version != ALPAQA_DL_ABI_VERSION) {
3637
auto prob_version = format_abi_version(abi_version);
3738
auto alpaqa_version = format_abi_version(ALPAQA_DL_ABI_VERSION);
38-
throw std::runtime_error(
39+
throw invalid_abi_error(
3940
"alpaqa::dl::DLProblem::DLProblem: "
4041
"Incompatible problem definition (problem ABI version 0x" +
4142
prob_version + ", this version of alpaqa supports 0x" +
@@ -62,8 +63,8 @@ std::shared_ptr<void> load_lib(const std::filesystem::path &so_filename) {
6263
assert(!so_filename.empty());
6364
void *h = LoadLibraryW(so_filename.c_str());
6465
if (!h)
65-
throw std::runtime_error("Unable to load \"" + so_filename.string() +
66-
"\": " + get_last_error_msg().get());
66+
throw function_load_error("Unable to load \"" + so_filename.string() +
67+
"\": " + get_last_error_msg().get());
6768
#if ALPAQA_NO_DLCLOSE
6869
return std::shared_ptr<void>{h, +[](void *) {}};
6970
#else
@@ -77,8 +78,8 @@ F *load_func(void *handle, const std::string &name) {
7778
assert(handle);
7879
auto *h = GetProcAddress(static_cast<HMODULE>(handle), name.c_str());
7980
if (!h)
80-
throw std::runtime_error("Unable to load function '" + name +
81-
"': " + get_last_error_msg().get());
81+
throw function_load_error("Unable to load function '" + name +
82+
"': " + get_last_error_msg().get());
8283
// We can only hope that the user got the signature right ...
8384
return reinterpret_cast<F *>(h);
8485
}
@@ -88,7 +89,7 @@ std::shared_ptr<void> load_lib(const std::filesystem::path &so_filename) {
8889
::dlerror();
8990
void *h = ::dlopen(so_filename.c_str(), RTLD_LOCAL | RTLD_NOW);
9091
if (auto *err = ::dlerror())
91-
throw std::runtime_error(err);
92+
throw function_load_error(err);
9293
#if ALPAQA_NO_DLCLOSE
9394
return std::shared_ptr<void>{h, +[](void *) {}};
9495
#else
@@ -102,8 +103,8 @@ F *load_func(void *handle, const std::string &name) {
102103
::dlerror();
103104
auto *h = ::dlsym(handle, name.c_str());
104105
if (auto *err = ::dlerror())
105-
throw std::runtime_error("Unable to load function '" + name +
106-
"': " + err);
106+
throw function_load_error("Unable to load function '" + name +
107+
"': " + err);
107108
// We can only hope that the user got the signature right ...
108109
return reinterpret_cast<F *>(h);
109110
}
@@ -205,7 +206,17 @@ DLProblem::DLProblem(const std::filesystem::path &so_filename,
205206
: BoxConstrProblem{0, 0} {
206207
if (so_filename.empty())
207208
throw std::invalid_argument("Invalid problem filename");
208-
handle = load_lib(so_filename);
209+
handle = load_lib(so_filename);
210+
try {
211+
auto *version_func = load_func<alpaqa_dl_abi_version_t(void)>(
212+
handle.get(), function_name + "_version");
213+
check_abi_version(version_func());
214+
} catch (const function_load_error &) {
215+
std::cerr << "Warning: problem " << so_filename
216+
<< " does not provide a function to query the ABI version, "
217+
"alpaqa_dl_abi_version_t "
218+
<< function_name << "_version(void)\n";
219+
}
209220
auto *register_func = load_func<problem_register_t(alpaqa_register_arg_t)>(
210221
handle.get(), function_name);
211222
auto r = register_func(user_param);
@@ -342,6 +353,16 @@ DLControlProblem::DLControlProblem(const std::filesystem::path &so_filename,
342353
if (so_filename.empty())
343354
throw std::invalid_argument("Invalid problem filename");
344355
handle = load_lib(so_filename);
356+
try {
357+
auto *version_func = load_func<alpaqa_dl_abi_version_t(void)>(
358+
handle.get(), function_name + "_version");
359+
check_abi_version(version_func());
360+
} catch (const function_load_error &) {
361+
std::cerr << "Warning: problem " << so_filename
362+
<< " does not provide a function to query the ABI version, "
363+
"alpaqa_dl_abi_version_t "
364+
<< function_name << "_version(void)\n";
365+
}
345366
auto *register_func =
346367
load_func<control_problem_register_t(alpaqa_register_arg_t)>(
347368
handle.get(), function_name);

0 commit comments

Comments
 (0)