Skip to content

Commit 12cbb22

Browse files
authored
Merge pull request #33 from chaburkland/2/dtype_from_element
2/dtype from element
2 parents bff4c46 + bab26e4 commit 12cbb22

File tree

6 files changed

+197
-2
lines changed

6 files changed

+197
-2
lines changed

performance/__main__.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import datetime
23
import timeit
34
import argparse
@@ -13,6 +14,7 @@
1314
from performance.reference.util import row_1d_filter as row_1d_filter_ref
1415
from performance.reference.util import resolve_dtype as resolve_dtype_ref
1516
from performance.reference.util import resolve_dtype_iter as resolve_dtype_iter_ref
17+
from performance.reference.util import dtype_from_element as dtype_from_element_ref
1618
from performance.reference.util import array_deepcopy as array_deepcopy_ref
1719
from performance.reference.util import isna_element as isna_element_ref
1820

@@ -27,6 +29,7 @@
2729
from arraykit import row_1d_filter as row_1d_filter_ak
2830
from arraykit import resolve_dtype as resolve_dtype_ak
2931
from arraykit import resolve_dtype_iter as resolve_dtype_iter_ak
32+
from arraykit import dtype_from_element as dtype_from_element_ak
3033
from arraykit import array_deepcopy as array_deepcopy_ak
3134
from arraykit import isna_element as isna_element_ak
3235

@@ -250,6 +253,45 @@ class ArrayGOPerfREF(ArrayGOPerf):
250253
entry = staticmethod(ArrayGOREF)
251254

252255

256+
#-------------------------------------------------------------------------------
257+
class DtypeFromElementPerf(Perf):
258+
NUMBER = 1000
259+
260+
def pre(self):
261+
NT = collections.namedtuple('NT', tuple('abc'))
262+
263+
self.values = [
264+
np.longlong(-1), np.int_(-1), np.intc(-1), np.short(-1), np.byte(-1),
265+
np.ubyte(1), np.ushort(1), np.uintc(1), np.uint(1), np.ulonglong(1),
266+
np.half(1.0), np.single(1.0), np.float_(1.0), np.longfloat(1.0),
267+
np.csingle(1.0j), np.complex_(1.0j), np.clongfloat(1.0j),
268+
np.bool_(0), np.str_('1'), np.unicode_('1'), np.void(1),
269+
np.object(), np.datetime64('NaT'), np.timedelta64('NaT'), np.nan,
270+
12, 12.0, True, None, float('NaN'), object(), (1, 2, 3),
271+
NT(1, 2, 3), datetime.date(2020, 12, 31), datetime.timedelta(14),
272+
]
273+
274+
# Datetime & Timedelta
275+
for precision in ['ns', 'us', 'ms', 's', 'm', 'h', 'D', 'M', 'Y']:
276+
for kind, ctor in (('m', np.timedelta64), ('M', np.datetime64)):
277+
self.values.append(ctor(12, precision))
278+
279+
for size in (1, 8, 16, 32, 64, 128, 256, 512):
280+
self.values.append(bytes(size))
281+
self.values.append('x' * size)
282+
283+
def main(self):
284+
for _ in range(40):
285+
for val in self.values:
286+
self.entry(val)
287+
288+
class DtypeFromElementPerfAK(DtypeFromElementPerf):
289+
entry = staticmethod(dtype_from_element_ak)
290+
291+
class DtypeFromElementPerfREF(DtypeFromElementPerf):
292+
entry = staticmethod(dtype_from_element_ref)
293+
294+
253295
#-------------------------------------------------------------------------------
254296
class IsNaElementPerf(Perf):
255297
NUMBER = 1000
@@ -336,7 +378,6 @@ def main():
336378

337379
records = [('cls', 'func', 'ak', 'ref', 'ref/ak')]
338380
for cls_perf in Perf.__subclasses__(): # only get one level
339-
print(cls_perf)
340381
cls_map = {}
341382
if match and cls_perf.__name__ not in match:
342383
continue

performance/reference/util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,20 @@ def isna_element(value: tp.Any) -> bool:
199199
return np.isnat(value) #type: ignore
200200

201201
return value is None
202+
203+
204+
def dtype_from_element(value: tp.Optional[tp.Hashable]) -> np.dtype:
205+
'''Given an arbitrary hashable to be treated as an element, return the appropriate dtype. This was created to avoid using np.array(value).dtype, which for a Tuple does not return object.
206+
'''
207+
if value is np.nan:
208+
# NOTE: this will not catch all NaN instances, but will catch any default NaNs in function signatures that reference the same NaN object found on the NP root namespace
209+
return DTYPE_FLOAT_DEFAULT
210+
if value is None:
211+
return DTYPE_OBJECT
212+
if isinstance(value, tuple):
213+
return DTYPE_OBJECT
214+
if hasattr(value, 'dtype'):
215+
return value.dtype #type: ignore
216+
# NOTE: calling array and getting dtype on np.nan is faster than combining isinstance, isnan calls
217+
return np.array(value).dtype
218+

src/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from ._arraykit import resolve_dtype as resolve_dtype
1616
from ._arraykit import resolve_dtype_iter as resolve_dtype_iter
1717
from ._arraykit import isna_element as isna_element
18+
from ._arraykit import dtype_from_element as dtype_from_element

src/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@ def array_deepcopy(__array: np.array, memo: tp.Dict[int, tp.Any]) -> np.ndarray:
3131
def resolve_dtype(__d1: np.dtype, __d2: np.dtype) -> np.dtype: ...
3232
def resolve_dtype_iter(__dtypes: tp.Iterable[np.dtype]) -> np.dtype: ...
3333
def isna_element(__value: tp.Any) -> bool: ...
34+
def dtype_from_element(__value: tp.Optional[tp.Hashable]) -> np.dtype: ...
35+

src/_arraykit.c

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,71 @@ resolve_dtype_iter(PyObject *Py_UNUSED(m), PyObject *arg)
365365
//------------------------------------------------------------------------------
366366
// general utility
367367

368+
static PyObject *
369+
dtype_from_element(PyObject *Py_UNUSED(m), PyObject *arg)
370+
{
371+
// -------------------------------------------------------------------------
372+
// 1. Handle fast, exact type checks first.
373+
374+
// None
375+
if (arg == Py_None) {
376+
return (PyObject*)PyArray_DescrFromType(NPY_OBJECT);
377+
}
378+
379+
// Float
380+
if (PyFloat_CheckExact(arg)) {
381+
return (PyObject*)PyArray_DescrFromType(NPY_DOUBLE);
382+
}
383+
384+
// Integers
385+
if (PyLong_CheckExact(arg)) {
386+
return (PyObject*)PyArray_DescrFromType(NPY_LONG);
387+
}
388+
389+
// Bool
390+
if (PyBool_Check(arg)) {
391+
return (PyObject*)PyArray_DescrFromType(NPY_BOOL);
392+
}
393+
394+
PyObject* dtype = NULL;
395+
396+
// String
397+
if (PyUnicode_CheckExact(arg)) {
398+
PyArray_Descr* descr = PyArray_DescrFromType(NPY_UNICODE);
399+
if (descr == NULL) {
400+
return NULL;
401+
}
402+
dtype = (PyObject*)PyArray_DescrFromObject(arg, descr);
403+
Py_DECREF(descr);
404+
return dtype;
405+
}
406+
407+
// Bytes
408+
if (PyBytes_CheckExact(arg)) {
409+
PyArray_Descr* descr = PyArray_DescrFromType(NPY_STRING);
410+
if (descr == NULL) {
411+
return NULL;
412+
}
413+
dtype = (PyObject*)PyArray_DescrFromObject(arg, descr);
414+
Py_DECREF(descr);
415+
return dtype;
416+
}
417+
418+
// -------------------------------------------------------------------------
419+
// 2. Construct dtype (slightly more complicated)
420+
421+
// Already known
422+
dtype = PyObject_GetAttrString(arg, "dtype");
423+
if (dtype) {
424+
return dtype;
425+
}
426+
PyErr_Clear();
427+
428+
// -------------------------------------------------------------------------
429+
// 3. Handles everything else.
430+
return (PyObject*)PyArray_DescrFromType(NPY_OBJECT);
431+
}
432+
368433
static PyObject *
369434
isna_element(PyObject *Py_UNUSED(m), PyObject *arg)
370435
{
@@ -706,6 +771,7 @@ static PyMethodDef arraykit_methods[] = {
706771
{"resolve_dtype", resolve_dtype, METH_VARARGS, NULL},
707772
{"resolve_dtype_iter", resolve_dtype_iter, METH_O, NULL},
708773
{"isna_element", isna_element, METH_O, NULL},
774+
{"dtype_from_element", dtype_from_element, METH_O, NULL},
709775
{NULL},
710776
};
711777

test/test_util.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import datetime
23
import unittest
34
import itertools
@@ -14,6 +15,7 @@
1415
from arraykit import immutable_filter
1516
from arraykit import array_deepcopy
1617
from arraykit import isna_element
18+
from arraykit import dtype_from_element
1719

1820
from performance.reference.util import mloc as mloc_ref
1921

@@ -299,7 +301,73 @@ def test_isna_element_false(self) -> None:
299301
self.assertFalse(isna_element(datetime.date(2020, 12, 31)))
300302
self.assertFalse(isna_element(False))
301303

304+
def test_dtype_from_element_core_dtypes(self) -> None:
305+
dtypes = [
306+
np.longlong,
307+
np.int_,
308+
np.intc,
309+
np.short,
310+
np.byte,
311+
np.ubyte,
312+
np.ushort,
313+
np.uintc,
314+
np.uint,
315+
np.ulonglong,
316+
np.half,
317+
np.single,
318+
np.float_,
319+
np.longfloat,
320+
np.csingle,
321+
np.complex_,
322+
np.clongfloat,
323+
np.bool_,
324+
]
325+
for dtype in dtypes:
326+
self.assertEqual(dtype, dtype_from_element(dtype()))
327+
328+
def test_dtype_from_element_str_and_misc_dtypes(self) -> None:
329+
dtype_obj_pairs = [
330+
(np.dtype('<U1'), np.str_('1')),
331+
(np.dtype('<U1'), np.unicode_('1')),
332+
(np.dtype('V1'), np.void(1)),
333+
(np.dtype('O'), np.object()),
334+
(np.dtype('<M8'), np.datetime64('NaT')),
335+
(np.dtype('<m8'), np.timedelta64('NaT')),
336+
(np.float_, np.nan),
337+
]
338+
for dtype, obj in dtype_obj_pairs:
339+
self.assertEqual(dtype, dtype_from_element(obj))
340+
341+
def test_dtype_from_element_obj_dtypes(self) -> None:
342+
NT = collections.namedtuple('NT', tuple('abc'))
343+
344+
dtype_obj_pairs = [
345+
(np.int_, 12),
346+
(np.float_, 12.0),
347+
(np.bool_, True),
348+
(np.dtype('O'), None),
349+
(np.float_, float('NaN')),
350+
(np.dtype('O'), object()),
351+
(np.dtype('O'), (1, 2, 3)),
352+
(np.dtype('O'), NT(1, 2, 3)),
353+
(np.dtype('O'), datetime.date(2020, 12, 31)),
354+
(np.dtype('O'), datetime.timedelta(14)),
355+
]
356+
for dtype, obj in dtype_obj_pairs:
357+
self.assertEqual(dtype, dtype_from_element(obj))
358+
359+
def test_dtype_from_element_time_dtypes(self) -> None:
360+
# Datetime & Timedelta
361+
for precision in ['ns', 'us', 'ms', 's', 'm', 'h', 'D', 'M', 'Y']:
362+
for kind, ctor in (('m', np.timedelta64), ('M', np.datetime64)):
363+
obj = ctor(12, precision)
364+
self.assertEqual(np.dtype(f'<{kind}8[{precision}]'), dtype_from_element(obj))
365+
366+
def test_dtype_from_element_str_and_bytes_dtypes(self) -> None:
367+
for size in (1, 8, 16, 32, 64, 128, 256, 512):
368+
self.assertEqual(np.dtype(f'|S{size}'), dtype_from_element(bytes(size)))
369+
self.assertEqual(np.dtype(f'<U{size}'), dtype_from_element('x' * size))
370+
302371

303372
if __name__ == '__main__':
304373
unittest.main()
305-

0 commit comments

Comments
 (0)