@@ -150,9 +150,9 @@ AK_ResolveDTypeIter(PyObject *dtypes)
150
150
return resolved ;
151
151
}
152
152
153
- // Perform a deepcopy on an array, using an optional memo dictionary, and specialized to depend on immutable arrays.
153
+ // Perform a deepcopy on an array, using an optional memo dictionary, and specialized to depend on immutable arrays. This depends on the module object to get the deepcopy method.
154
154
PyObject *
155
- AK_ArrayDeepCopy (PyArrayObject * array , PyObject * memo )
155
+ AK_ArrayDeepCopy (PyObject * m , PyArrayObject * array , PyObject * memo )
156
156
{
157
157
PyObject * id = PyLong_FromVoidPtr ((PyObject * )array );
158
158
if (!id ) {
@@ -176,12 +176,7 @@ AK_ArrayDeepCopy(PyArrayObject *array, PyObject *memo)
176
176
PyArray_Descr * dtype = PyArray_DESCR (array ); // borrowed ref
177
177
178
178
if (PyDataType_ISOBJECT (dtype )) {
179
- PyObject * copy = PyImport_ImportModule ("copy" );
180
- if (!copy ) {
181
- goto error ;
182
- }
183
- PyObject * deepcopy = PyObject_GetAttrString (copy , "deepcopy" );
184
- Py_DECREF (copy );
179
+ PyObject * deepcopy = PyObject_GetAttrString (m , "deepcopy" );
185
180
if (!deepcopy ) {
186
181
goto error ;
187
182
}
@@ -327,7 +322,7 @@ static char *array_deepcopy_kwarg_names[] = {
327
322
328
323
// Specialized array deepcopy that stores immutable arrays in an optional memo dict that can be provided with kwargs.
329
324
static PyObject *
330
- array_deepcopy (PyObject * Py_UNUSED ( m ) , PyObject * args , PyObject * kwargs )
325
+ array_deepcopy (PyObject * m , PyObject * args , PyObject * kwargs )
331
326
{
332
327
PyObject * array ;
333
328
PyObject * memo = NULL ;
@@ -338,7 +333,7 @@ array_deepcopy(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kwargs)
338
333
return NULL ;
339
334
}
340
335
AK_CHECK_NUMPY_ARRAY (array );
341
- return AK_ArrayDeepCopy ((PyArrayObject * )array , memo );
336
+ return AK_ArrayDeepCopy (m , (PyArrayObject * )array , memo );
342
337
}
343
338
344
339
//------------------------------------------------------------------------------
@@ -784,11 +779,26 @@ PyInit__arraykit(void)
784
779
{
785
780
import_array ();
786
781
PyObject * m = PyModule_Create (& arraykit_module );
782
+
783
+ PyObject * copy = PyImport_ImportModule ("copy" );
784
+ if (!copy ) {
785
+ Py_DECREF (m );
786
+ return NULL ;
787
+ }
788
+ PyObject * deepcopy = PyObject_GetAttrString (copy , "deepcopy" );
789
+ Py_DECREF (copy );
790
+ if (!deepcopy ) {
791
+ Py_DECREF (m );
792
+ return NULL ;
793
+ }
794
+
787
795
if (!m ||
788
796
PyModule_AddStringConstant (m , "__version__" , Py_STRINGIFY (AK_VERSION )) ||
789
797
PyType_Ready (& ArrayGOType ) ||
790
- PyModule_AddObject (m , "ArrayGO" , (PyObject * ) & ArrayGOType ))
798
+ PyModule_AddObject (m , "ArrayGO" , (PyObject * ) & ArrayGOType ) ||
799
+ PyModule_AddObject (m , "deepcopy" , deepcopy ))
791
800
{
801
+ Py_DECREF (deepcopy );
792
802
Py_XDECREF (m );
793
803
return NULL ;
794
804
}
0 commit comments