Skip to content

Commit 37d869e

Browse files
authored
Merge pull request #42 from InvestmentSystems/39/array-deepcopy-kwarg
`array_deepcopy` interface to accept kwargs, memo dict optional
2 parents 80b0ea9 + b96c14e commit 37d869e

File tree

2 files changed

+62
-28
lines changed

2 files changed

+62
-28
lines changed

src/_arraykit.c

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,25 @@ AK_ResolveDTypeIter(PyObject *dtypes)
148148
return resolved;
149149
}
150150

151-
// Numpy implementation: https://github.com/numpy/numpy/blob/a14c41264855e44ebd6187d7541b5b8d59bb32cb/numpy/core/src/multiarray/methods.c#L1557
151+
// Perform a deepcopy on an array, using an optional memo dictionary, and specialized to depend on immutable arrays. Related Numpy implementation: https://github.com/numpy/numpy/blob/a14c41264855e44ebd6187d7541b5b8d59bb32cb/numpy/core/src/multiarray/methods.c#L1557
152152
PyObject*
153153
AK_ArrayDeepCopy(PyArrayObject *array, PyObject *memo)
154154
{
155155
PyObject *id = PyLong_FromVoidPtr((PyObject*)array);
156156
if (!id) {
157157
return NULL;
158158
}
159-
PyObject *found = PyDict_GetItemWithError(memo, id);
160-
if (found) { // found will be NULL if not in dict
161-
Py_INCREF(found); // got a borrowed ref, increment first
162-
Py_DECREF(id);
163-
return found;
164-
}
165-
else if (PyErr_Occurred()) {
166-
goto error;
159+
160+
if (memo) {
161+
PyObject *found = PyDict_GetItemWithError(memo, id);
162+
if (found) { // found will be NULL if not in dict
163+
Py_INCREF(found); // got a borrowed ref, increment first
164+
Py_DECREF(id);
165+
return found;
166+
}
167+
else if (PyErr_Occurred()) {
168+
goto error;
169+
}
167170
}
168171

169172
// if dtype is object, call deepcopy with memo
@@ -187,14 +190,17 @@ AK_ArrayDeepCopy(PyArrayObject *array, PyObject *memo)
187190
}
188191
}
189192
else {
193+
// if not a n object dtype, we will force a copy (even if this is an immutable array) so as to not hold on to any references
190194
Py_INCREF(dtype); // PyArray_FromArray steals a reference
191195
array_new = PyArray_FromArray(
192196
array,
193197
dtype,
194198
NPY_ARRAY_ENSURECOPY);
195-
if (!array_new || PyDict_SetItem(memo, id, array_new)) {
196-
Py_XDECREF(array_new);
197-
goto error;
199+
if (memo) {
200+
if (!array_new || PyDict_SetItem(memo, id, array_new)) {
201+
Py_XDECREF(array_new);
202+
goto error;
203+
}
198204
}
199205
}
200206
// set immutable
@@ -311,20 +317,25 @@ row_1d_filter(PyObject *Py_UNUSED(m), PyObject *a)
311317
//------------------------------------------------------------------------------
312318
// array utility
313319

314-
// Specialized array deepcopy that stores immutable arrays in memo dict.
320+
static char *array_deepcopy_kwarg_names[] = {
321+
"array",
322+
"memo",
323+
NULL
324+
};
325+
326+
// Specialized array deepcopy that stores immutable arrays in an optional memo dict that can be provided with kwargs.
315327
static PyObject *
316-
array_deepcopy(PyObject *Py_UNUSED(m), PyObject *args)
328+
array_deepcopy(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kwargs)
317329
{
318-
PyObject *array, *memo;
319-
if (!PyArg_UnpackTuple(args, "array_deepcopy", 2, 2, &array, &memo)) {
330+
PyObject *array;
331+
PyObject *memo = NULL;
332+
if (!PyArg_ParseTupleAndKeywords(args, kwargs,
333+
"O|O!:array_deepcopy", array_deepcopy_kwarg_names,
334+
&array,
335+
&PyDict_Type, &memo)) {
320336
return NULL;
321337
}
322338
AK_CHECK_NUMPY_ARRAY(array);
323-
if (!PyDict_CheckExact(memo)) {
324-
PyErr_Format(PyExc_TypeError, "expected a dict (got %s)",
325-
Py_TYPE(memo)->tp_name);
326-
return NULL;
327-
}
328339
return AK_ArrayDeepCopy((PyArrayObject*)array, memo);
329340
}
330341

@@ -623,7 +634,10 @@ static PyMethodDef arraykit_methods[] = {
623634
{"column_2d_filter", column_2d_filter, METH_O, NULL},
624635
{"column_1d_filter", column_1d_filter, METH_O, NULL},
625636
{"row_1d_filter", row_1d_filter, METH_O, NULL},
626-
{"array_deepcopy", array_deepcopy, METH_VARARGS, NULL},
637+
{"array_deepcopy",
638+
(PyCFunction)array_deepcopy,
639+
METH_VARARGS | METH_KEYWORDS,
640+
NULL},
627641
{"resolve_dtype", resolve_dtype, METH_VARARGS, NULL},
628642
{"resolve_dtype_iter", resolve_dtype_iter, METH_O, NULL},
629643
{NULL},

test/test_util.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_array_deepcopy_a1(self) -> None:
177177
memo = {}
178178
a2 = array_deepcopy(a1, memo)
179179

180-
self.assertNotEqual(id(a1), id(a2))
180+
self.assertIsNot(a1, a2)
181181
self.assertNotEqual(mloc(a1), mloc(a2))
182182
self.assertFalse(a2.flags.writeable)
183183
self.assertEqual(a1.dtype, a2.dtype)
@@ -187,7 +187,7 @@ def test_array_deepcopy_a2(self) -> None:
187187
memo = {}
188188
a2 = array_deepcopy(a1, memo)
189189

190-
self.assertNotEqual(id(a1), id(a2))
190+
self.assertIsNot(a1, a2)
191191
self.assertNotEqual(mloc(a1), mloc(a2))
192192
self.assertIn(id(a1), memo)
193193
self.assertEqual(memo[id(a1)].tolist(), a2.tolist())
@@ -208,23 +208,43 @@ def test_array_deepcopy_c1(self) -> None:
208208
a1 = np.array((None, 'foo', True, mutable))
209209
a2 = array_deepcopy(a1, memo)
210210

211-
self.assertNotEqual(id(a1), id(a2))
211+
self.assertIsNot(a1, a2)
212212
self.assertNotEqual(mloc(a1), mloc(a2))
213-
self.assertNotEqual(id(a1[3]), id(a2[3]))
213+
self.assertIsNot(a1[3], a2[3])
214214
self.assertFalse(a2.flags.writeable)
215215

216216
def test_array_deepcopy_c2(self) -> None:
217217
memo = {}
218218
mutable = [np.nan]
219219
a1 = np.array((None, 'foo', True, mutable))
220220
a2 = array_deepcopy(a1, memo)
221-
self.assertNotEqual(id(a1), id(a2))
221+
self.assertIsNot(a1, a2)
222222
self.assertNotEqual(mloc(a1), mloc(a2))
223-
self.assertNotEqual(id(a1[3]), id(a2[3]))
223+
self.assertIsNot(a1[3], a2[3])
224224
self.assertFalse(a2.flags.writeable)
225225
self.assertIn(id(a1), memo)
226226

227227

228+
def test_array_deepcopy_d(self) -> None:
229+
memo = {}
230+
mutable = [3, 4, 5]
231+
a1 = np.array((None, 'foo', True, mutable))
232+
a2 = array_deepcopy(a1, memo=memo)
233+
self.assertIsNot(a1, a2)
234+
self.assertTrue(id(mutable) in memo)
235+
236+
def test_array_deepcopy_e(self) -> None:
237+
a1 = np.array((3, 4, 5))
238+
with self.assertRaises(TypeError):
239+
# memo argument must be a dictionary
240+
a2 = array_deepcopy(a1, memo=None)
241+
242+
def test_array_deepcopy_f(self) -> None:
243+
a1 = np.array((3, 4, 5))
244+
a2 = array_deepcopy(a1)
245+
self.assertNotEqual(id(a1), id(a2))
246+
247+
228248
if __name__ == '__main__':
229249
unittest.main()
230250

0 commit comments

Comments
 (0)