Skip to content

Commit e65d848

Browse files
committed
Improve CasADi Function interface
1 parent cf070cd commit e65d848

File tree

3 files changed

+33
-22
lines changed

3 files changed

+33
-22
lines changed

src/interop/casadi/include/alpaqa/casadi/casadi-external-function.hpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,26 @@ namespace casadi {
2424
/// Designed to match (part of) the `casadi::Function` API.
2525
class CASADI_LOADER_EXPORT Function {
2626
public:
27+
struct Functions {
28+
fname_incref::signature_t *incref = nullptr;
29+
fname_decref::signature_t *decref = nullptr;
30+
fname_n_in::signature_t *n_in = nullptr;
31+
fname_n_out::signature_t *n_out = nullptr;
32+
fname_name_in::signature_t *name_in = nullptr;
33+
fname_name_out::signature_t *name_out = nullptr;
34+
fname_sparsity_in::signature_t *sparsity_in = nullptr;
35+
fname_sparsity_out::signature_t *sparsity_out = nullptr;
36+
fname_alloc_mem::signature_t *alloc_mem = nullptr;
37+
fname_init_mem::signature_t *init_mem = nullptr;
38+
fname_free_mem::signature_t *free_mem = nullptr;
39+
fname_work::signature_t *work = nullptr;
40+
fname::signature_t *call = nullptr;
41+
};
42+
43+
public:
44+
Function();
2745
Function(std::shared_ptr<void> so_handle, const std::string &func_name);
46+
Function(const Functions &functions);
2847
Function(const Function &);
2948
Function(Function &&) noexcept;
3049
~Function();
@@ -80,29 +99,15 @@ class CASADI_LOADER_EXPORT Function {
8099

81100
private:
82101
std::shared_ptr<void> so_handle;
83-
struct Functions {
84-
fname_incref::signature_t *incref = nullptr;
85-
fname_decref::signature_t *decref = nullptr;
86-
fname_n_in::signature_t *n_in = nullptr;
87-
fname_n_out::signature_t *n_out = nullptr;
88-
fname_name_in::signature_t *name_in = nullptr;
89-
fname_name_out::signature_t *name_out = nullptr;
90-
fname_sparsity_in::signature_t *sparsity_in = nullptr;
91-
fname_sparsity_out::signature_t *sparsity_out = nullptr;
92-
fname_alloc_mem::signature_t *alloc_mem = nullptr;
93-
fname_init_mem::signature_t *init_mem = nullptr;
94-
fname_free_mem::signature_t *free_mem = nullptr;
95-
fname_work::signature_t *work = nullptr;
96-
fname::signature_t *call = nullptr;
97-
} functions;
102+
Functions functions;
98103
struct Work {
99104
std::vector<const casadi_real *> arg;
100105
std::vector<casadi_real *> res;
101106
std::vector<casadi_int> iw;
102107
std::vector<casadi_real> w;
103108
};
104109
std::optional<Work> work;
105-
void *mem = nullptr;
110+
int mem = 0;
106111
};
107112

108113
inline std::pair<casadi_int, casadi_int> Function::Sparsity::size() const {

src/interop/casadi/include/alpaqa/casadi/casadi-functions.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ using fname_name_in = ExternalFunction<"_name_in", const char *(casadi_int
1919
using fname_name_out = ExternalFunction<"_name_out", const char *(casadi_int ind)>;
2020
using fname_sparsity_in = ExternalFunction<"_sparsity_in", const casadi_int *(casadi_int ind)>;
2121
using fname_sparsity_out = ExternalFunction<"_sparsity_out", const casadi_int *(casadi_int ind)>;
22-
using fname_alloc_mem = ExternalFunction<"_alloc_mem", void *(void)>;
23-
using fname_init_mem = ExternalFunction<"_init_mem", int(void *mem)>;
24-
using fname_free_mem = ExternalFunction<"_free_mem", int(void *mem)>;
22+
using fname_alloc_mem = ExternalFunction<"_alloc_mem", int(void)>;
23+
using fname_init_mem = ExternalFunction<"_init_mem", int(int mem)>;
24+
using fname_free_mem = ExternalFunction<"_free_mem", void(int mem)>;
2525
using fname_work = ExternalFunction<"_work", int(casadi_int *sz_arg, casadi_int *sz_res, casadi_int *sz_iw, casadi_int *sz_w)>;
26-
using fname = ExternalFunction<"", int(const casadi_real **arg, casadi_real **res, casadi_int *iw, casadi_real *w, void *mem)>;
26+
using fname = ExternalFunction<"", int(const casadi_real **arg, casadi_real **res, casadi_int *iw, casadi_real *w, int mem)>;
2727
// clang-format on
2828

2929
template <Name Nm, class Sgn>

src/interop/casadi/src/casadi-external-function.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,27 @@ void Function::init_work() {
4242
w.w.resize(static_cast<size_t>(sz_w));
4343
}
4444

45+
Function::Function() = default;
4546
Function::Function(std::shared_ptr<void> so_handle,
4647
const std::string &func_name)
4748
: so_handle{std::move(so_handle)} {
4849
load(this->so_handle.get(), func_name);
4950
}
51+
static char no_handle;
52+
Function::Function(const Functions &functions)
53+
: so_handle{&no_handle, [](void *) {}}, functions{functions} {
54+
functions.incref();
55+
}
5056
Function::Function(const Function &o)
5157
: so_handle{o.so_handle}, functions{o.functions} {
5258
functions.incref();
5359
}
5460
Function::Function(Function &&o) noexcept
5561
: so_handle{std::move(o.so_handle)}, functions{o.functions},
56-
work{std::move(o.work)}, mem{std::exchange(o.mem, nullptr)} {}
62+
work{std::move(o.work)}, mem{std::exchange(o.mem, 0)} {}
5763
Function::~Function() {
5864
if (so_handle) {
59-
if (mem)
65+
if (work)
6066
functions.free_mem(mem);
6167
functions.decref();
6268
}

0 commit comments

Comments
 (0)