Skip to content

Commit 2741e9b

Browse files
authored
Merge pull request #29 from InvestmentSystems/25/array-deepcopy
array_deepcopy
2 parents affc455 + 6afd548 commit 2741e9b

File tree

4 files changed

+170
-3
lines changed

4 files changed

+170
-3
lines changed

arraykit.c

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
}\
3535
} while (0)
3636

37-
// Placeholder of not implemented functions.
38-
# define AK_NOT_IMPLEMENTED\
37+
// Placeholder of not implemented pathways / debugging.
38+
# define AK_NOT_IMPLEMENTED(msg)\
3939
do {\
40-
PyErr_SetNone(PyExc_NotImplementedError);\
40+
PyErr_SetString(PyExc_NotImplementedError, msg);\
4141
return NULL;\
4242
} while (0)
4343

@@ -148,6 +148,65 @@ AK_ResolveDTypeIter(PyObject *dtypes)
148148
return resolved;
149149
}
150150

151+
// Numpy implementation: https://github.com/numpy/numpy/blob/a14c41264855e44ebd6187d7541b5b8d59bb32cb/numpy/core/src/multiarray/methods.c#L1557
152+
PyObject*
153+
AK_ArrayDeepCopy(PyArrayObject *array, PyObject *memo)
154+
{
155+
PyObject *id = PyLong_FromVoidPtr((PyObject*)array);
156+
if (!id) {
157+
return NULL;
158+
}
159+
PyObject *found = PyDict_GetItemWithError(memo, id);
160+
if (found) { // found will be NULL if not in dict
161+
Py_INCREF(found); // got a borrowed ref, increment first
162+
Py_DECREF(id);
163+
return found;
164+
}
165+
else if (PyErr_Occurred()) {
166+
goto error;
167+
}
168+
169+
// if dtype is object, call deepcopy with memo
170+
PyObject *array_new;
171+
PyArray_Descr *dtype = PyArray_DESCR(array); // borrowed ref
172+
173+
if (PyDataType_ISOBJECT(dtype)) {
174+
PyObject *copy = PyImport_ImportModule("copy");
175+
if (!copy) {
176+
goto error;
177+
}
178+
PyObject *deepcopy = PyObject_GetAttrString(copy, "deepcopy");
179+
Py_DECREF(copy);
180+
if (!deepcopy) {
181+
goto error;
182+
}
183+
array_new = PyObject_CallFunctionObjArgs(deepcopy, array, memo, NULL);
184+
Py_DECREF(deepcopy);
185+
if (!array_new) {
186+
goto error;
187+
}
188+
}
189+
else {
190+
Py_INCREF(dtype); // PyArray_FromArray steals a reference
191+
array_new = PyArray_FromArray(
192+
array,
193+
dtype,
194+
NPY_ARRAY_ENSURECOPY);
195+
if (!array_new || PyDict_SetItem(memo, id, array_new)) {
196+
Py_XDECREF(array_new);
197+
goto error;
198+
}
199+
}
200+
// set immutable
201+
PyArray_CLEARFLAGS((PyArrayObject *)array_new, NPY_ARRAY_WRITEABLE);
202+
Py_DECREF(id);
203+
return array_new;
204+
error:
205+
Py_DECREF(id);
206+
return NULL;
207+
}
208+
209+
151210
//------------------------------------------------------------------------------
152211
// AK module public methods
153212
//------------------------------------------------------------------------------
@@ -249,6 +308,26 @@ row_1d_filter(PyObject *Py_UNUSED(m), PyObject *a)
249308
return a;
250309
}
251310

311+
//------------------------------------------------------------------------------
312+
// array utility
313+
314+
// Specialized array deepcopy that stores immutable arrays in memo dict.
315+
static PyObject *
316+
array_deepcopy(PyObject *Py_UNUSED(m), PyObject *args)
317+
{
318+
PyObject *array, *memo;
319+
if (!PyArg_UnpackTuple(args, "array_deepcopy", 2, 2, &array, &memo)) {
320+
return NULL;
321+
}
322+
AK_CHECK_NUMPY_ARRAY(array);
323+
if (!PyDict_CheckExact(memo)) {
324+
PyErr_Format(PyExc_TypeError, "expected a dict (got %s)",
325+
Py_TYPE(memo)->tp_name);
326+
return NULL;
327+
}
328+
return AK_ArrayDeepCopy((PyArrayObject*)array, memo);
329+
}
330+
252331
//------------------------------------------------------------------------------
253332
// type resolution
254333

@@ -544,6 +623,7 @@ static PyMethodDef arraykit_methods[] = {
544623
{"column_2d_filter", column_2d_filter, METH_O, NULL},
545624
{"column_1d_filter", column_1d_filter, METH_O, NULL},
546625
{"row_1d_filter", row_1d_filter, METH_O, NULL},
626+
{"array_deepcopy", array_deepcopy, METH_VARARGS, NULL},
547627
{"resolve_dtype", resolve_dtype, METH_VARARGS, NULL},
548628
{"resolve_dtype_iter", resolve_dtype_iter, METH_O, NULL},
549629
{NULL},

arraykit.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ def shape_filter(__array: np.array) -> np.ndarray: ...
2525
def column_2d_filter(__array: np.array) -> np.ndarray: ...
2626
def column_1d_filter(__array: np.array) -> np.ndarray: ...
2727
def row_1d_filter(__array: np.array) -> np.ndarray: ...
28+
def array_deepcopy(__array: np.array, memo: tp.Dict[int, tp.Any]) -> np.ndarray: ...
2829
def resolve_dtype(__d1: np.dtype, __d2: np.dtype) -> np.dtype: ...
2930
def resolve_dtype_iter(__dtypes: tp.Iterable[np.dtype]) -> np.dtype: ...

performance/main.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from performance.reference.util import row_1d_filter as row_1d_filter_ref
1616
from performance.reference.util import resolve_dtype as resolve_dtype_ref
1717
from performance.reference.util import resolve_dtype_iter as resolve_dtype_iter_ref
18+
from performance.reference.util import array_deepcopy as array_deepcopy_ref
1819

1920
from performance.reference.array_go import ArrayGO as ArrayGOREF
2021

@@ -27,6 +28,7 @@
2728
from arraykit import row_1d_filter as row_1d_filter_ak
2829
from arraykit import resolve_dtype as resolve_dtype_ak
2930
from arraykit import resolve_dtype_iter as resolve_dtype_iter_ak
31+
from arraykit import array_deepcopy as array_deepcopy_ak
3032

3133
from arraykit import ArrayGO as ArrayGOAK
3234

@@ -200,6 +202,33 @@ class ResolveDTypeIterREF(ResolveDTypeIter):
200202
entry = staticmethod(resolve_dtype_iter_ref)
201203

202204

205+
#-------------------------------------------------------------------------------
206+
class ArrayDeepcopy(Perf):
207+
FUNCTIONS = ('memo_new', 'memo_shared')
208+
NUMBER = 500
209+
210+
def pre(self):
211+
self.array1 = np.arange(100_000)
212+
self.array2 = np.full(100_000, None)
213+
self.array2[0] = [np.nan] # add a mutable
214+
self.memo = {}
215+
216+
def memo_new(self):
217+
memo = {}
218+
self.entry(self.array1, memo)
219+
self.entry(self.array2, memo)
220+
221+
def memo_shared(self):
222+
self.entry(self.array1, self.memo)
223+
self.entry(self.array2, self.memo)
224+
225+
class ArrayDeepcopyAK(ArrayDeepcopy):
226+
entry = staticmethod(array_deepcopy_ak)
227+
228+
class ArrayDeepcopyREF(ArrayDeepcopy):
229+
entry = staticmethod(array_deepcopy_ref)
230+
231+
203232
#-------------------------------------------------------------------------------
204233
class ArrayGOPerf(Perf):
205234
NUMBER = 1000
@@ -240,6 +269,7 @@ def main():
240269

241270
records = [('cls', 'func', 'ak', 'ref', 'ref/ak')]
242271
for cls_perf in Perf.__subclasses__(): # only get one level
272+
print(cls_perf)
243273
cls_map = {}
244274
if match and cls_perf.__name__ not in match:
245275
continue

test/test_util.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from arraykit import row_1d_filter
1111
from arraykit import mloc
1212
from arraykit import immutable_filter
13+
from arraykit import array_deepcopy
1314

1415
from performance.reference.util import mloc as mloc_ref
1516

@@ -167,6 +168,61 @@ def test_row_1d_filter_a(self) -> None:
167168
with self.assertRaises(NotImplementedError):
168169
row_1d_filter(a1.reshape(1,2,5))
169170

171+
#---------------------------------------------------------------------------
172+
173+
def test_array_deepcopy_a1(self) -> None:
174+
a1 = np.arange(10)
175+
memo = {}
176+
a2 = array_deepcopy(a1, memo)
177+
178+
self.assertNotEqual(id(a1), id(a2))
179+
self.assertNotEqual(mloc(a1), mloc(a2))
180+
self.assertFalse(a2.flags.writeable)
181+
self.assertEqual(a1.dtype, a2.dtype)
182+
183+
def test_array_deepcopy_a2(self) -> None:
184+
a1 = np.arange(10)
185+
memo = {}
186+
a2 = array_deepcopy(a1, memo)
187+
188+
self.assertNotEqual(id(a1), id(a2))
189+
self.assertNotEqual(mloc(a1), mloc(a2))
190+
self.assertIn(id(a1), memo)
191+
self.assertEqual(memo[id(a1)].tolist(), a2.tolist())
192+
self.assertFalse(a2.flags.writeable)
193+
194+
195+
def test_array_deepcopy_b(self) -> None:
196+
a1 = np.arange(10)
197+
memo = {id(a1): a1}
198+
a2 = array_deepcopy(a1, memo)
199+
200+
self.assertEqual(mloc(a1), mloc(a2))
201+
202+
203+
def test_array_deepcopy_c1(self) -> None:
204+
mutable = [np.nan]
205+
memo = {}
206+
a1 = np.array((None, 'foo', True, mutable))
207+
a2 = array_deepcopy(a1, memo)
208+
209+
self.assertNotEqual(id(a1), id(a2))
210+
self.assertNotEqual(mloc(a1), mloc(a2))
211+
self.assertNotEqual(id(a1[3]), id(a2[3]))
212+
self.assertFalse(a2.flags.writeable)
213+
214+
def test_array_deepcopy_c2(self) -> None:
215+
memo = {}
216+
mutable = [np.nan]
217+
a1 = np.array((None, 'foo', True, mutable))
218+
a2 = array_deepcopy(a1, memo)
219+
self.assertNotEqual(id(a1), id(a2))
220+
self.assertNotEqual(mloc(a1), mloc(a2))
221+
self.assertNotEqual(id(a1[3]), id(a2[3]))
222+
self.assertFalse(a2.flags.writeable)
223+
self.assertIn(id(a1), memo)
224+
225+
170226
if __name__ == '__main__':
171227
unittest.main()
172228

0 commit comments

Comments
 (0)