Skip to content

Commit 30005a1

Browse files
committed
Fixed issue where python may deconstruct the diffeq while dense output still had reference and could call it
1 parent 907fad4 commit 30005a1

8 files changed

+433
-235
lines changed

CyRK/cy/cysolver.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ void CySolverBase::take_step()
317317
0, // Fake Q order just for consistent constructor call
318318
this,
319319
this->diffeq,
320+
this->cython_extension_class_instance,
320321
this->t_now_ptr,
321322
this->y_now_ptr,
322323
this->dy_now_ptr
@@ -526,6 +527,7 @@ CySolverDense* CySolverBase::p_dense_output_heap()
526527
0, // Fake Q order just for consistent constructor call
527528
this,
528529
this->diffeq,
530+
this->cython_extension_class_instance,
529531
this->t_now_ptr,
530532
this->y_now_ptr,
531533
this->dy_now_ptr

CyRK/cy/cysolver_api.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ cdef class WrapCySolverResult:
8888
if type(t) == np.ndarray:
8989
return self.call_vectorize(t)
9090
else:
91-
return self.call(t).reshape(self.cyresult_ptr.num_y, 1)
91+
return self.call(t).reshape(self.cyresult_ptr.num_dy, 1)
9292

9393

9494
# =====================================================================================================================

CyRK/cy/dense.cpp

+24-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ CySolverDense::CySolverDense(
1313
unsigned int Q_order,
1414
CySolverBase* cysolver_instance_ptr,
1515
std::function<void (CySolverBase *)> cysolver_diffeq_ptr,
16+
PyObject* cython_extension_class_instance,
1617
double* cysolver_t_now_ptr,
1718
double* cysolver_y_now_ptr,
1819
double* cysolver_dy_now_ptr
@@ -22,6 +23,7 @@ CySolverDense::CySolverDense(
2223
num_extra(num_extra),
2324
cysolver_instance_ptr(cysolver_instance_ptr),
2425
cysolver_diffeq_ptr(cysolver_diffeq_ptr),
26+
cython_extension_class_instance(cython_extension_class_instance),
2527
cysolver_t_now_ptr(cysolver_t_now_ptr),
2628
cysolver_y_now_ptr(cysolver_y_now_ptr),
2729
cysolver_dy_now_ptr(cysolver_dy_now_ptr),
@@ -33,9 +35,17 @@ CySolverDense::CySolverDense(
3335
std::memcpy(this->y_stored_ptr, y_in_ptr, sizeof(double) * this->num_y);
3436
// Calculate step
3537
this->step = this->t_now - this->t_old;
38+
39+
// Make a strong reference to the python class (if this dense output was built using the python hooks).
40+
if (cython_extension_class_instance)
41+
{
42+
// TODO: Do we need to decref this at some point? During CySolver's deconstruction?
43+
Py_XINCREF(this->cython_extension_class_instance);
44+
}
45+
3646
}
3747

38-
void CySolverDense::call(double t_interp, double* y_intepret)
48+
void CySolverDense::call(double t_interp, double* y_interp_ptr)
3949
{
4050
double step_factor = (t_interp - this->t_old) / this->step;
4151

@@ -65,7 +75,7 @@ void CySolverDense::call(double t_interp, double* y_intepret)
6575
// Finally multiply by step
6676
temp_double *= this->step;
6777

68-
y_intepret[y_i] = this->y_stored_ptr[y_i] + temp_double;
78+
y_interp_ptr[y_i] = this->y_stored_ptr[y_i] + temp_double;
6979
}
7080
break;
7181

@@ -90,7 +100,7 @@ void CySolverDense::call(double t_interp, double* y_intepret)
90100
// Finally multiply by step
91101
temp_double *= this->step;
92102

93-
y_intepret[y_i] = this->y_stored_ptr[y_i] + temp_double;
103+
y_interp_ptr[y_i] = this->y_stored_ptr[y_i] + temp_double;
94104
}
95105
break;
96106

@@ -127,13 +137,13 @@ void CySolverDense::call(double t_interp, double* y_intepret)
127137
temp_double += this->Q_ptr[Q_stride + 6];
128138
temp_double *= step_factor;
129139

130-
y_intepret[y_i] = this->y_stored_ptr[y_i] + temp_double;
140+
y_interp_ptr[y_i] = this->y_stored_ptr[y_i] + temp_double;
131141
}
132142
break;
133143

134144
[[unlikely]] default:
135145
// Don't know the model. Just return the input.
136-
std::memcpy(y_intepret, this->y_stored_ptr, sizeof(double) * this->num_y);
146+
std::memcpy(y_interp_ptr, this->y_stored_ptr, sizeof(double) * this->num_y);
137147
break;
138148
}
139149

@@ -151,31 +161,33 @@ void CySolverDense::call(double t_interp, double* y_intepret)
151161
// y array
152162
double y_tmp[Y_LIMIT];
153163
double* y_tmp_ptr = &y_tmp[0];
154-
memcpy(y_tmp_ptr, this->cysolver_y_now_ptr, this->num_y);
164+
memcpy(y_tmp_ptr, this->cysolver_y_now_ptr, sizeof(double) * this->num_y);
155165
// dy array
156166
double dy_tmp[DY_LIMIT];
157167
double* dy_tmp_ptr = &dy_tmp[0];
158-
memcpy(dy_tmp_ptr, this->cysolver_dy_now_ptr, num_dy);
168+
memcpy(dy_tmp_ptr, this->cysolver_dy_now_ptr, sizeof(double) * num_dy);
159169
// t
160170
double t_tmp = cysolver_t_now_ptr[0];
161171

162172
// Load new values into t and y
163-
memcpy(this->cysolver_y_now_ptr, y_intepret, this->num_y);
173+
memcpy(this->cysolver_y_now_ptr, y_interp_ptr, sizeof(double) * this->num_y);
164174
cysolver_t_now_ptr[0] = t_interp;
165175

166176
// Call diffeq to update dy_now pointer
177+
printf("DEBUG!! About to call diffeq from dense...\n");
167178
this->cysolver_diffeq_ptr(this->cysolver_instance_ptr);
179+
printf("DEBUG!! After diffeq call\n");
168180

169-
// Capture extra output and add to the y_intepret array
181+
// Capture extra output and add to the y_interp_ptr array
170182
// We already have y interpolated from above so start at num_y
171183
for (size_t i = this->num_y; i < num_dy; i++)
172184
{
173-
y_intepret[i] = this->cysolver_dy_now_ptr[i];
185+
y_interp_ptr[i] = this->cysolver_dy_now_ptr[i];
174186
}
175187

176188
// Reset CySolver state to what it was before
177189
cysolver_t_now_ptr[0] = t_tmp;
178-
memcpy(this->cysolver_y_now_ptr, y_tmp_ptr, num_dy);
179-
memcpy(this->cysolver_dy_now_ptr, dy_tmp_ptr, num_dy);
190+
memcpy(this->cysolver_y_now_ptr, y_tmp_ptr, sizeof(double) * num_y);
191+
memcpy(this->cysolver_dy_now_ptr, dy_tmp_ptr, sizeof(double) * num_dy);
180192
}
181193
}

CyRK/cy/dense.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <functional>
88
#include <cstring>
99

10+
#include "Python.h"
1011
#include "common.hpp"
1112

1213
// We need a pointer to the CySolverBase class. But that file includes this one. So we need to do a forward declaration
@@ -39,6 +40,7 @@ class CySolverDense
3940
double* cysolver_t_now_ptr = nullptr;
4041
double* cysolver_y_now_ptr = nullptr;
4142
double* cysolver_dy_now_ptr = nullptr;
43+
PyObject* cython_extension_class_instance = nullptr;
4244

4345
// Time step info
4446
double step = 0.0;
@@ -69,6 +71,7 @@ class CySolverDense
6971
unsigned int Q_order,
7072
CySolverBase* cysolver_instance_ptr,
7173
std::function<void (CySolverBase *)> cysolver_diffeq_ptr,
74+
PyObject* cython_extension_class_instance,
7275
double* cysolver_t_now_ptr,
7376
double* cysolver_y_now_ptr,
7477
double* cysolver_dy_now_ptr

CyRK/cy/pysolver_cyhook.pyx

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# distutils: language = c++
22
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True, initializedcheck=False
33

4+
from libc.stdio cimport printf
45

56
cdef public api void call_diffeq_from_cython(object py_instance, DiffeqMethod diffeq):
67
"""Callback function used by the C++ model.
@@ -9,4 +10,5 @@ cdef public api void call_diffeq_from_cython(object py_instance, DiffeqMethod di
910
"""
1011

1112
# Call the python diffeq.
13+
printf("DEBUG!!! \t\t Calling python diffeq from api'd func.\n")
1214
diffeq(py_instance)

CyRK/cy/rk.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,7 @@ CySolverDense* RKSolver::p_dense_output_heap()
10251025
this->len_Pcols,
10261026
this,
10271027
this->diffeq,
1028+
this->cython_extension_class_instance,
10281029
this->t_now_ptr,
10291030
this->y_now_ptr,
10301031
this->dy_now_ptr);

0 commit comments

Comments
 (0)