@@ -13,6 +13,7 @@ CySolverDense::CySolverDense(
13
13
unsigned int Q_order,
14
14
CySolverBase* cysolver_instance_ptr,
15
15
std::function<void (CySolverBase *)> cysolver_diffeq_ptr,
16
+ PyObject* cython_extension_class_instance,
16
17
double* cysolver_t_now_ptr,
17
18
double* cysolver_y_now_ptr,
18
19
double* cysolver_dy_now_ptr
@@ -22,6 +23,7 @@ CySolverDense::CySolverDense(
22
23
num_extra(num_extra),
23
24
cysolver_instance_ptr(cysolver_instance_ptr),
24
25
cysolver_diffeq_ptr(cysolver_diffeq_ptr),
26
+ cython_extension_class_instance(cython_extension_class_instance),
25
27
cysolver_t_now_ptr(cysolver_t_now_ptr),
26
28
cysolver_y_now_ptr(cysolver_y_now_ptr),
27
29
cysolver_dy_now_ptr(cysolver_dy_now_ptr),
@@ -33,9 +35,17 @@ CySolverDense::CySolverDense(
33
35
std::memcpy (this ->y_stored_ptr , y_in_ptr, sizeof (double ) * this ->num_y );
34
36
// Calculate step
35
37
this ->step = this ->t_now - this ->t_old ;
38
+
39
+ // Make a strong reference to the python class (if this dense output was built using the python hooks).
40
+ if (cython_extension_class_instance)
41
+ {
42
+ // TODO: Do we need to decref this at some point? During CySolver's deconstruction?
43
+ Py_XINCREF (this ->cython_extension_class_instance );
44
+ }
45
+
36
46
}
37
47
38
- void CySolverDense::call (double t_interp, double * y_intepret )
48
+ void CySolverDense::call (double t_interp, double * y_interp_ptr )
39
49
{
40
50
double step_factor = (t_interp - this ->t_old ) / this ->step ;
41
51
@@ -65,7 +75,7 @@ void CySolverDense::call(double t_interp, double* y_intepret)
65
75
// Finally multiply by step
66
76
temp_double *= this ->step ;
67
77
68
- y_intepret [y_i] = this ->y_stored_ptr [y_i] + temp_double;
78
+ y_interp_ptr [y_i] = this ->y_stored_ptr [y_i] + temp_double;
69
79
}
70
80
break ;
71
81
@@ -90,7 +100,7 @@ void CySolverDense::call(double t_interp, double* y_intepret)
90
100
// Finally multiply by step
91
101
temp_double *= this ->step ;
92
102
93
- y_intepret [y_i] = this ->y_stored_ptr [y_i] + temp_double;
103
+ y_interp_ptr [y_i] = this ->y_stored_ptr [y_i] + temp_double;
94
104
}
95
105
break ;
96
106
@@ -127,13 +137,13 @@ void CySolverDense::call(double t_interp, double* y_intepret)
127
137
temp_double += this ->Q_ptr [Q_stride + 6 ];
128
138
temp_double *= step_factor;
129
139
130
- y_intepret [y_i] = this ->y_stored_ptr [y_i] + temp_double;
140
+ y_interp_ptr [y_i] = this ->y_stored_ptr [y_i] + temp_double;
131
141
}
132
142
break ;
133
143
134
144
[[unlikely]] default :
135
145
// Don't know the model. Just return the input.
136
- std::memcpy (y_intepret , this ->y_stored_ptr , sizeof (double ) * this ->num_y );
146
+ std::memcpy (y_interp_ptr , this ->y_stored_ptr , sizeof (double ) * this ->num_y );
137
147
break ;
138
148
}
139
149
@@ -151,31 +161,33 @@ void CySolverDense::call(double t_interp, double* y_intepret)
151
161
// y array
152
162
double y_tmp[Y_LIMIT];
153
163
double * y_tmp_ptr = &y_tmp[0 ];
154
- memcpy (y_tmp_ptr, this ->cysolver_y_now_ptr , this ->num_y );
164
+ memcpy (y_tmp_ptr, this ->cysolver_y_now_ptr , sizeof ( double ) * this ->num_y );
155
165
// dy array
156
166
double dy_tmp[DY_LIMIT];
157
167
double * dy_tmp_ptr = &dy_tmp[0 ];
158
- memcpy (dy_tmp_ptr, this ->cysolver_dy_now_ptr , num_dy);
168
+ memcpy (dy_tmp_ptr, this ->cysolver_dy_now_ptr , sizeof ( double ) * num_dy);
159
169
// t
160
170
double t_tmp = cysolver_t_now_ptr[0 ];
161
171
162
172
// Load new values into t and y
163
- memcpy (this ->cysolver_y_now_ptr , y_intepret, this ->num_y );
173
+ memcpy (this ->cysolver_y_now_ptr , y_interp_ptr, sizeof ( double ) * this ->num_y );
164
174
cysolver_t_now_ptr[0 ] = t_interp;
165
175
166
176
// Call diffeq to update dy_now pointer
177
+ printf (" DEBUG!! About to call diffeq from dense...\n " );
167
178
this ->cysolver_diffeq_ptr (this ->cysolver_instance_ptr );
179
+ printf (" DEBUG!! After diffeq call\n " );
168
180
169
- // Capture extra output and add to the y_intepret array
181
+ // Capture extra output and add to the y_interp_ptr array
170
182
// We already have y interpolated from above so start at num_y
171
183
for (size_t i = this ->num_y ; i < num_dy; i++)
172
184
{
173
- y_intepret [i] = this ->cysolver_dy_now_ptr [i];
185
+ y_interp_ptr [i] = this ->cysolver_dy_now_ptr [i];
174
186
}
175
187
176
188
// Reset CySolver state to what it was before
177
189
cysolver_t_now_ptr[0 ] = t_tmp;
178
- memcpy (this ->cysolver_y_now_ptr , y_tmp_ptr, num_dy );
179
- memcpy (this ->cysolver_dy_now_ptr , dy_tmp_ptr, num_dy);
190
+ memcpy (this ->cysolver_y_now_ptr , y_tmp_ptr, sizeof ( double ) * num_y );
191
+ memcpy (this ->cysolver_dy_now_ptr , dy_tmp_ptr, sizeof ( double ) * num_dy);
180
192
}
181
193
}
0 commit comments