Skip to content

Commit c8a6e57

Browse files
committed
Merge branch 'master' into 97/dta-cleanup
2 parents 0fbc8a1 + 382d15c commit c8a6e57

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

src/_arraykit.c

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ AK_ResolveDTypeIter(PyObject *dtypes)
162162
return resolved;
163163
}
164164

165-
// Perform a deepcopy on an array, using an optional memo dictionary, and specialized to depend on immutable arrays. This depends on the module object to get the deepcopy method.
165+
// Perform a deepcopy on an array, using an optional memo dictionary, and specialized to depend on immutable arrays. This depends on the module object to get the deepcopy method. The `memo` object can be None.
166166
PyObject*
167167
AK_ArrayDeepCopy(PyObject* m, PyArrayObject *array, PyObject *memo)
168168
{
@@ -186,6 +186,7 @@ AK_ArrayDeepCopy(PyObject* m, PyArrayObject *array, PyObject *memo)
186186
PyArray_Descr *dtype = PyArray_DESCR(array); // borrowed ref
187187

188188
if (PyDataType_ISOBJECT(dtype)) {
189+
// we store the deepcopy function on this module for faster lookup here
189190
PyObject *deepcopy = PyObject_GetAttrString(m, "deepcopy");
190191
if (!deepcopy) {
191192
goto error;
@@ -3310,11 +3311,20 @@ array_deepcopy(PyObject *m, PyObject *args, PyObject *kwargs)
33103311
PyObject *array;
33113312
PyObject *memo = NULL;
33123313
if (!PyArg_ParseTupleAndKeywords(args, kwargs,
3313-
"O|O!:array_deepcopy", array_deepcopy_kwarg_names,
3314+
"O|O:array_deepcopy", array_deepcopy_kwarg_names,
33143315
&array,
3315-
&PyDict_Type, &memo)) {
3316+
&memo)) {
33163317
return NULL;
33173318
}
3319+
if ((memo == NULL) || (memo == Py_None)) {
3320+
memo = NULL;
3321+
}
3322+
else {
3323+
if (!PyDict_Check(memo)) {
3324+
PyErr_SetString(PyExc_TypeError, "memo must be a dict or None");
3325+
return NULL;
3326+
}
3327+
}
33183328
AK_CHECK_NUMPY_ARRAY(array);
33193329
return AK_ArrayDeepCopy(m, (PyArrayObject*)array, memo);
33203330
}

test/test_util.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,25 @@ def test_array_deepcopy_d(self) -> None:
268268
def test_array_deepcopy_e(self) -> None:
269269
a1 = np.array((3, 4, 5))
270270
with self.assertRaises(TypeError):
271-
# memo argument must be a dictionary
272-
a2 = array_deepcopy(a1, memo=None)
271+
a2 = array_deepcopy(a1, memo='')
273272

274273
def test_array_deepcopy_f(self) -> None:
275274
a1 = np.array((3, 4, 5))
276275
a2 = array_deepcopy(a1)
277276
self.assertNotEqual(id(a1), id(a2))
278277

278+
def test_array_deepcopy_g(self) -> None:
279+
a1 = np.arange(10)
280+
a2 = array_deepcopy(a1, None)
281+
self.assertNotEqual(mloc(a1), mloc(a2))
282+
283+
def test_array_deepcopy_h(self) -> None:
284+
a1 = np.arange(10)
285+
with self.assertRaises(TypeError):
286+
a2 = array_deepcopy(a1, ())
287+
288+
#---------------------------------------------------------------------------
289+
279290
def test_isna_element_a(self) -> None:
280291
class FloatSubclass(float): pass
281292
class ComplexSubclass(complex): pass

0 commit comments

Comments
 (0)