Skip to content

Commit bd4d643

Browse files
committed
Merge branch 'master' into 52/array-deepcopy-module-import
2 parents b51f1a5 + 33bb90b commit bd4d643

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

src/_arraykit.c

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ AK_ResolveDTypeIter(PyObject *dtypes)
120120
{
121121
PyObject *iterator = PyObject_GetIter(dtypes);
122122
if (iterator == NULL) {
123+
// No need to set exception here. GetIter already sets TypeError
123124
return NULL;
124125
}
125126
PyArray_Descr *resolved = NULL;
@@ -147,6 +148,10 @@ AK_ResolveDTypeIter(PyObject *dtypes)
147148
}
148149
}
149150
Py_DECREF(iterator);
151+
if (!resolved) {
152+
// this could happen if this function gets an empty tuple
153+
PyErr_SetString(PyExc_ValueError, "iterable passed to resolve dtypes is empty");
154+
}
150155
return resolved;
151156
}
152157

@@ -193,11 +198,12 @@ AK_ArrayDeepCopy(PyObject* m, PyArrayObject *array, PyObject *memo)
193198
array,
194199
dtype,
195200
NPY_ARRAY_ENSURECOPY);
196-
if (memo) {
197-
if (!array_new || PyDict_SetItem(memo, id, array_new)) {
198-
Py_XDECREF(array_new);
199-
goto error;
200-
}
201+
if (!array_new) {
202+
goto error;
203+
}
204+
if (memo && PyDict_SetItem(memo, id, array_new)) {
205+
Py_DECREF(array_new);
206+
goto error;
201207
}
202208
}
203209
// set immutable
@@ -771,7 +777,11 @@ static PyMethodDef arraykit_methods[] = {
771777
};
772778

773779
static struct PyModuleDef arraykit_module = {
774-
PyModuleDef_HEAD_INIT, "_arraykit", NULL, -1, arraykit_methods,
780+
PyModuleDef_HEAD_INIT,
781+
.m_name = "_arraykit",
782+
.m_doc = NULL,
783+
.m_size = -1,
784+
.m_methods = arraykit_methods,
775785
};
776786

777787
PyObject *

test/test_util.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import collections
23
import datetime
34
import unittest
@@ -117,6 +118,14 @@ def test_resolve_dtype_iter_a(self) -> None:
117118
self.assertEqual(resolve_dtype_iter((a3.dtype, a5.dtype)).kind, 'U')
118119
self.assertEqual(resolve_dtype_iter((a3.dtype, a5.dtype)).itemsize, 40)
119120

121+
with pytest.raises(TypeError):
122+
resolve_dtype_iter((a3.dtype, int))
123+
124+
self.assertEqual(resolve_dtype_iter((a1.dtype,)), a1.dtype)
125+
126+
with pytest.raises(ValueError):
127+
resolve_dtype_iter(())
128+
120129
#---------------------------------------------------------------------------
121130

122131
def test_shape_filter_a(self) -> None:
@@ -134,6 +143,11 @@ def test_shape_filter_a(self) -> None:
134143
with self.assertRaises(NotImplementedError):
135144
shape_filter(a1.reshape(1,2,5))
136145

146+
with self.assertRaises(NotImplementedError):
147+
# zero dimension
148+
shape_filter(np.array(1))
149+
150+
137151
#---------------------------------------------------------------------------
138152

139153
def test_column_2d_filter_a(self) -> None:

0 commit comments

Comments
 (0)