Skip to content

Commit eff034b

Browse files
committed
Formatting, unit tests, and ArrayGO updates
1 parent fda80ad commit eff034b

File tree

6 files changed

+147
-61
lines changed

6 files changed

+147
-61
lines changed

arraykit.c

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
// Bug in NumPy < 1.16 (https://github.com/numpy/numpy/pull/12131):
1010
# undef PyDataType_ISBOOL
11-
# define PyDataType_ISBOOL(obj) PyTypeNum_ISBOOL(((PyArray_Descr*)(obj))->type_num)
11+
# define PyDataType_ISBOOL(obj) \
12+
PyTypeNum_ISBOOL(((PyArray_Descr*)(obj))->type_num)
1213

1314
# define AK_CHECK_NUMPY_ARRAY(O) \
1415
if (!PyArray_Check(O)) { \
@@ -26,9 +27,8 @@
2627

2728
typedef struct {
2829
PyObject_VAR_HEAD
29-
PyArrayObject *array;
30+
PyObject *array;
3031
PyObject *list;
31-
PyArray_Descr* dtype;
3232
} ArrayGOObject;
3333

3434
static PyTypeObject ArrayGOType;
@@ -172,17 +172,17 @@ update_array_cache(ArrayGOObject *self)
172172
{
173173
if (self->list) {
174174
if (self->array) {
175-
PyObject *container = PyTuple_Pack(2, (PyObject *)self->array, self->list);
175+
PyObject *container = PyTuple_Pack(2, self->array, self->list);
176176
if (!container) {
177177
return -1;
178178
}
179-
Py_SETREF(self->array, (PyArrayObject *)PyArray_Concatenate(container, 0));
179+
Py_SETREF(self->array, PyArray_Concatenate(container, 0));
180180
Py_DECREF(container);
181181
}
182182
else {
183-
self->array = (PyArrayObject *) PyArray_FROM_OT(self->list, self->dtype->type_num);
183+
self->array = PyArray_FROM_OT(self->list, NPY_OBJECT);
184184
}
185-
PyArray_CLEARFLAGS(self->array, NPY_ARRAY_WRITEABLE);
185+
PyArray_CLEARFLAGS((PyArrayObject *)self->array, NPY_ARRAY_WRITEABLE);
186186
Py_CLEAR(self->list);
187187
}
188188
return 0;
@@ -194,41 +194,30 @@ update_array_cache(ArrayGOObject *self)
194194
static int
195195
ArrayGO_init(ArrayGOObject *self, PyObject *args, PyObject *kwargs)
196196
{
197-
198197
PyObject *temp;
198+
char* argnames[] = {"iterable", "own_iterable", NULL};
199199
PyObject *iterable;
200200
int own_iterable;
201-
int parsed;
202-
203-
char* argnames[] = {"iterable", "dtype", "own_iterable", NULL};
204-
205-
Py_CLEAR(self->dtype);
206-
parsed = PyArg_ParseTupleAndKeywords(
207-
args, kwargs, "O|$O&p:ArrayGO", argnames,
208-
&iterable, PyArray_DescrConverter, &self->dtype, &own_iterable
201+
int parsed = PyArg_ParseTupleAndKeywords(
202+
args, kwargs, "O|$p:ArrayGO", argnames, &iterable, &own_iterable
209203
);
210204
if (!parsed) {
211205
return -1;
212206
}
213-
if (!self->dtype) {
214-
self->dtype = PyArray_DescrFromType(NPY_OBJECT);
215-
}
216207
if (PyArray_Check(iterable)) {
217-
temp = (PyObject *) self->array;
208+
temp = self->array;
218209
if (own_iterable) {
219210
PyArray_CLEARFLAGS((PyArrayObject *) iterable, NPY_ARRAY_WRITEABLE);
220211
Py_INCREF(iterable);
221212
} else {
222213
iterable = (PyObject *)AK_ImmutableFilter((PyArrayObject *) iterable);
223214
}
224-
if (!PyArray_EquivTypes(PyArray_DESCR((PyArrayObject *) iterable), self->dtype)) {
225-
PyErr_Format(
226-
PyExc_TypeError, "bad dtype given to ArrayGO initializer (expected '%S', got '%S')",
227-
PyArray_DESCR((PyArrayObject *) iterable), self->dtype
228-
);
215+
if (!PyDataType_ISOBJECT(PyArray_DESCR((PyArrayObject *)iterable))) {
216+
PyErr_SetString(PyExc_NotImplementedError,
217+
"only object arrays are supported");
229218
return -1;
230219
}
231-
self->array = (PyArrayObject *) iterable;
220+
self->array = iterable;
232221
Py_XDECREF(temp);
233222

234223
} else {
@@ -291,9 +280,7 @@ ArrayGO_copy(ArrayGOObject *self, PyObject *Py_UNUSED(unused))
291280
ArrayGOObject *copy = PyObject_New(ArrayGOObject, &ArrayGOType);
292281
copy->array = self->array;
293282
copy->list = self->list ? PySequence_List(self->list) : NULL;
294-
copy->dtype = self->dtype;
295283
Py_XINCREF(copy->array);
296-
Py_INCREF(copy->dtype);
297284
return (PyObject *)copy;
298285
}
299286

@@ -303,7 +290,7 @@ ArrayGO_iter(ArrayGOObject *self)
303290
if (self->list && update_array_cache(self)) {
304291
return NULL;
305292
}
306-
return PyObject_GetIter((PyObject *)self->array);
293+
return PyObject_GetIter(self->array);
307294
}
308295

309296
static PyObject *
@@ -312,13 +299,13 @@ ArrayGO_mp_subscript(ArrayGOObject *self, PyObject *key)
312299
if (self->list && update_array_cache(self)) {
313300
return NULL;
314301
}
315-
return PyObject_GetItem((PyObject *)self->array, key);
302+
return PyObject_GetItem(self->array, key);
316303
}
317304

318305
static Py_ssize_t
319306
ArrayGO_mp_length(ArrayGOObject *self)
320307
{
321-
return ((self->array ? PyArray_SIZE(self->array) : 0)
308+
return ((self->array ? PyArray_SIZE((PyArrayObject *)self->array) : 0)
322309
+ (self->list ? PyList_Size(self->list) : 0));
323310
}
324311

@@ -329,13 +316,12 @@ ArrayGO_values_getter(ArrayGOObject *self, void* Py_UNUSED(closure))
329316
return NULL;
330317
}
331318
Py_INCREF(self->array);
332-
return (PyObject *)self->array;
319+
return self->array;
333320
}
334321

335322
static void
336323
ArrayGO_dealloc(ArrayGOObject *self)
337324
{
338-
Py_XDECREF(self->dtype);
339325
Py_XDECREF(self->array);
340326
Py_XDECREF(self->list);
341327
Py_TYPE(self)->tp_free((PyObject *)self);

arraykit.pyi

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,24 @@
1-
# pylint: disable = all
2-
31
import typing
42

53
import numpy # type: ignore
64

7-
85
_T = typing.TypeVar('_T')
96

10-
117
class ArrayGO:
128

139
values: numpy.array
14-
15-
def __init__(self, iterable: typing.Iterable[object], *, dtype: object = ..., own_iterable: bool = ...) -> None: ...
16-
10+
def __init__(
11+
self, iterable: typing.Iterable[object], *, own_iterable: bool = ...
12+
) -> None: ...
1713
def __iter__(self) -> typing.Iterator[typing.Any]: ...
18-
1914
def __getitem__(self, __key: object) -> typing.Any: ...
20-
2115
def __len__(self) -> int: ...
22-
2316
def append(self, __value: object) -> None: ...
24-
2517
def copy(self: _T) -> _T: ...
26-
2718
def extend(self, __values: typing.Iterable[object]) -> None: ...
2819

29-
3020
def immutable_filter(__array: numpy.array) -> numpy.array: ...
31-
3221
def mloc(__array: numpy.array) -> int: ...
33-
3422
def name_filter(__name: typing.Hashable) -> typing.Hashable: ...
35-
3623
def resolve_dtype(__d1: numpy.dtype, __d2: numpy.dtype) -> numpy.dtype: ...
37-
3824
def resolve_dtype_iter(__dtypes: typing.Iterable[numpy.dtype]) -> numpy.dtype: ...

setup.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def get_long_description() -> str:
1616
Packages: https://pypi.org/project/arraykit
1717
'''
1818

19+
1920
setup(
2021
name='arraykit',
2122
version=AK_VERSION,
@@ -28,16 +29,16 @@ def get_long_description() -> str:
2829
license='MIT',
2930
# See https://pypi.python.org/pypi?%3Aaction=list_classifiers
3031
classifiers=[
31-
'Development Status :: 3 - Alpha',
32-
'Intended Audience :: Developers',
33-
'Topic :: Software Development',
34-
'License :: OSI Approved :: MIT License',
35-
'Operating System :: MacOS :: MacOS X',
36-
'Operating System :: Microsoft :: Windows',
37-
'Operating System :: POSIX',
38-
'Programming Language :: Python :: 3.6',
39-
'Programming Language :: Python :: 3.7',
40-
'Programming Language :: Python :: 3.8',
32+
'Development Status :: 3 - Alpha',
33+
'Intended Audience :: Developers',
34+
'Topic :: Software Development',
35+
'License :: OSI Approved :: MIT License',
36+
'Operating System :: MacOS :: MacOS X',
37+
'Operating System :: Microsoft :: Windows',
38+
'Operating System :: POSIX',
39+
'Programming Language :: Python :: 3.6',
40+
'Programming Language :: Python :: 3.7',
41+
'Programming Language :: Python :: 3.8',
4142
],
4243
keywords='numpy array',
4344
packages=[],

tasks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import invoke
66

7-
#-------------------------------------------------------------------------------
7+
# -------------------------------------------------------------------------------
8+
89

910
@invoke.task
1011
def clean(context):

test_property.py

Whitespace-only changes.

test_unit.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import pytest # type: ignore
2+
import numpy as np # type: ignore
3+
4+
import arraykit as ak
5+
6+
7+
def test_array_init_a() -> None:
8+
with pytest.raises(NotImplementedError):
9+
ak.ArrayGO(np.array((3, 4, 5)))
10+
11+
12+
def test_array_append_a() -> None:
13+
ag1 = ak.ArrayGO(('a', 'b', 'c', 'd'))
14+
assert [*ag1] == ['a', 'b', 'c', 'd']
15+
assert ag1.values.tolist() == ['a', 'b', 'c', 'd']
16+
ag1.append('e')
17+
ag1.extend(('f', 'g'))
18+
assert [*ag1] == ['a', 'b', 'c', 'd', 'e', 'f', 'g']
19+
assert ag1.values.tolist() == ['a', 'b', 'c', 'd', 'e', 'f', 'g']
20+
21+
22+
def test_array_append_b() -> None:
23+
ag1 = ak.ArrayGO(np.array(('a', 'b', 'c', 'd'), object))
24+
assert [*ag1] == ['a', 'b', 'c', 'd']
25+
assert ag1.values.tolist() == ['a', 'b', 'c', 'd']
26+
ag1.append('e')
27+
ag1.extend(('f', 'g'))
28+
assert [*ag1] == ['a', 'b', 'c', 'd', 'e', 'f', 'g']
29+
assert ag1.values.tolist() == ['a', 'b', 'c', 'd', 'e', 'f', 'g']
30+
31+
32+
def test_array_getitem_a() -> None:
33+
a = np.array(('a', 'b', 'c', 'd'), object)
34+
a.flags.writeable = False
35+
ag1 = ak.ArrayGO(a)
36+
# Ensure no copy for immutable:
37+
assert ak.mloc(ag1.values) == ak.mloc(a)
38+
ag1.append('b')
39+
post = ag1[ag1.values == 'b']
40+
assert post.tolist() == ['b', 'b']
41+
assert ag1[[2, 1, 1, 1]].tolist() == ['c', 'b', 'b', 'b']
42+
43+
44+
def test_array_copy_a() -> None:
45+
ag1 = ak.ArrayGO(np.array(('a', 'b', 'c', 'd'), dtype=object))
46+
ag1.append('e')
47+
ag2 = ag1.copy()
48+
ag1.extend(('f', 'g'))
49+
assert ag1.values.tolist() == ['a', 'b', 'c', 'd', 'e', 'f', 'g']
50+
assert ag2.values.tolist() == ['a', 'b', 'c', 'd', 'e']
51+
52+
53+
def test_array_len_a() -> None:
54+
ag1 = ak.ArrayGO(np.array(('a', 'b', 'c', 'd'), object))
55+
ag1.append('e')
56+
assert len(ag1) == 5
57+
58+
59+
def test_resolve_dtype_a() -> None:
60+
a1 = np.array([1, 2, 3])
61+
a2 = np.array([False, True, False])
62+
a3 = np.array(['b', 'c', 'd'])
63+
a4 = np.array([2.3, 3.2])
64+
a5 = np.array(['test', 'test again'], dtype='S')
65+
a6 = np.array([2.3, 5.4], dtype='float32')
66+
assert ak.resolve_dtype(a1.dtype, a1.dtype) == a1.dtype
67+
assert ak.resolve_dtype(a1.dtype, a2.dtype) == np.object_
68+
assert ak.resolve_dtype(a2.dtype, a3.dtype) == np.object_
69+
assert ak.resolve_dtype(a2.dtype, a4.dtype) == np.object_
70+
assert ak.resolve_dtype(a3.dtype, a4.dtype) == np.object_
71+
assert ak.resolve_dtype(a3.dtype, a6.dtype) == np.object_
72+
assert ak.resolve_dtype(a1.dtype, a4.dtype) == np.float64
73+
assert ak.resolve_dtype(a1.dtype, a6.dtype) == np.float64
74+
assert ak.resolve_dtype(a4.dtype, a6.dtype) == np.float64
75+
76+
77+
def test_resolve_dtype_b() -> None:
78+
a1 = np.array('a').dtype
79+
a3 = np.array('aaa').dtype
80+
assert ak.resolve_dtype(a1, a3) == np.dtype(('U', 3))
81+
82+
83+
def test_resolve_dtype_c() -> None:
84+
a1 = np.array(['2019-01', '2019-02'], dtype=np.datetime64)
85+
a2 = np.array(['2019-01-01', '2019-02-01'], dtype=np.datetime64)
86+
a3 = np.array([0, 1], dtype='datetime64[ns]')
87+
a4 = np.array([0, 1])
88+
assert str(ak.resolve_dtype(a1.dtype, a2.dtype)) == 'datetime64[D]'
89+
assert ak.resolve_dtype(a1.dtype, a3.dtype) == np.dtype('<M8[ns]')
90+
assert ak.resolve_dtype(a1.dtype, a4.dtype) == np.dtype('O')
91+
92+
93+
def test_resolve_dtype_iter_a() -> None:
94+
a1 = np.array([1, 2, 3])
95+
a2 = np.array([False, True, False])
96+
a3 = np.array(['b', 'c', 'd'])
97+
a4 = np.array([2.3, 3.2])
98+
a5 = np.array(['test', 'test again'], dtype='S')
99+
a6 = np.array([2.3, 5.4], dtype='float32')
100+
assert ak.resolve_dtype_iter((a1.dtype, a1.dtype)) == a1.dtype
101+
assert ak.resolve_dtype_iter((a2.dtype, a2.dtype)) == a2.dtype
102+
# Boolean with mixed types:
103+
assert ak.resolve_dtype_iter((a2.dtype, a2.dtype, a3.dtype)) == np.object_
104+
assert ak.resolve_dtype_iter((a2.dtype, a2.dtype, a5.dtype)) == np.object_
105+
assert ak.resolve_dtype_iter((a2.dtype, a2.dtype, a6.dtype)) == np.object_
106+
# Numeric types go to float64:
107+
assert ak.resolve_dtype_iter((a1.dtype, a4.dtype, a6.dtype)) == np.float64
108+
# Add in bool or str, goes to object:
109+
assert ak.resolve_dtype_iter((a1.dtype, a4.dtype, a6.dtype, a2.dtype)) == np.object_
110+
assert ak.resolve_dtype_iter((a1.dtype, a4.dtype, a6.dtype, a5.dtype)) == np.object_
111+
# Mixed strings go to the largest:
112+
assert ak.resolve_dtype_iter((a3.dtype, a5.dtype)) == np.dtype('<U10')

0 commit comments

Comments
 (0)