Skip to content

Commit 8f87f3e

Browse files
authored
Merge pull request #92 from static-frame/91/dt64-seg-fault
segmentation fault on bad datetime64 input
2 parents 7b3dbca + b981b3a commit 8f87f3e

File tree

2 files changed

+38
-19
lines changed

2 files changed

+38
-19
lines changed

src/_arraykit.c

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,25 +1924,35 @@ AK_CPL_to_array_bytes(AK_CodePointLine* cpl, PyArray_Descr* dtype)
19241924
return array;
19251925
}
19261926

1927-
// If we cannot direclty convert bytes to values, create a bytes array and then use PyArray_CastToType to use numpy to interpet it as a new a array.
1927+
// If we cannot directly convert bytes to values in a pre-loadsed array, we can create a bytes or unicode array and then use PyArray_CastToType to use numpy to interpret it as a new a array and handle conversions. Note that we can use bytes for a smaller memory load if we are confident that the values are not unicode. This is a safe assumption for complext. Fore datetim64, we have to use Unicode to get errors on malformed inputs: using bytes causes a seg fault with these interfaces (the same is not observed with astyping a byte array in Python).
19281928
static inline PyObject*
1929-
AK_CPL_to_array_via_cast(AK_CodePointLine* cpl, PyArray_Descr* dtype)
1929+
AK_CPL_to_array_via_cast(AK_CodePointLine* cpl,
1930+
PyArray_Descr* dtype,
1931+
int type_inter)
19301932
{
1931-
PyArray_Descr *dtype_bytes = PyArray_DescrNewFromType(NPY_STRING);
1932-
if (dtype_bytes == NULL) {
1933+
PyArray_Descr* dtype_inter; // interchange array
1934+
PyObject* array_inter = NULL;
1935+
1936+
dtype_inter = PyArray_DescrNewFromType(type_inter);
1937+
if (dtype_inter == NULL) {
19331938
Py_DECREF(dtype);
19341939
return NULL;
19351940
}
1936-
PyObject* array_bytes = AK_CPL_to_array_bytes(cpl, dtype_bytes);
1937-
if (array_bytes == NULL) {
1938-
Py_DECREF(dtype);
1939-
// dtype_bytes stolen even if array creation failed
1941+
if (type_inter == NPY_STRING) {
1942+
array_inter = AK_CPL_to_array_bytes(cpl, dtype_inter);
1943+
}
1944+
else if (type_inter == NPY_UNICODE) {
1945+
array_inter = AK_CPL_to_array_unicode(cpl, dtype_inter);
1946+
}
1947+
1948+
if (array_inter == NULL) {
1949+
Py_DECREF(dtype); // dtype_inter ref already stolen
19401950
return NULL;
19411951
}
1942-
PyObject *array = PyArray_CastToType((PyArrayObject*)array_bytes, dtype, 0);
1943-
Py_DECREF(array_bytes);
1944-
if (array == NULL) {
1945-
// expected array to steal dtype reference
1952+
1953+
PyObject *array = PyArray_CastToType((PyArrayObject*)array_inter, dtype, 0);
1954+
Py_DECREF(array_inter);
1955+
if (array == NULL) { // dtype ref already stolen
19461956
return NULL;
19471957
}
19481958
PyArray_CLEARFLAGS((PyArrayObject *)array, NPY_ARRAY_WRITEABLE);
@@ -1983,10 +1993,11 @@ AK_CPL_ToArray(AK_CodePointLine* cpl, PyArray_Descr* dtype, char tsep, char decc
19831993
return AK_CPL_to_array_int(cpl, dtype, tsep);
19841994
}
19851995
else if (PyDataType_ISDATETIME(dtype)) {
1986-
return AK_CPL_to_array_via_cast(cpl, dtype);
1996+
return AK_CPL_to_array_via_cast(cpl, dtype, NPY_UNICODE);
19871997
}
19881998
else if (PyDataType_ISCOMPLEX(dtype)) {
1989-
return AK_CPL_to_array_via_cast(cpl, dtype); // no tsep, decc as using NumPy cast
1999+
// no tsep, decc as using NumPy cast
2000+
return AK_CPL_to_array_via_cast(cpl, dtype, NPY_STRING);
19902001
}
19912002

19922003
PyErr_Format(PyExc_NotImplementedError, "No handling for %R", dtype);

test/test_delimited_to_arrays.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -370,10 +370,14 @@ def test_iterable_str_to_array_1d_complex_4(self) -> None:
370370
self.assertEqual(a1.dtype, np.dtype(complex))
371371
self.assertEqual(a1.tolist(), [complex('-0+infj'), (0j)])
372372

373-
# NOTE: this causes a seg fault
374-
# def test_iterable_str_to_array_1d_d4(self) -> None:
375-
# with self.assertRaises(ValueError):
376-
# a1 = iterable_str_to_array_1d(['-2+1.2j', '1.5+-4.2j'], complex)
373+
def test_iterable_str_to_array_1d_complex_5(self) -> None:
374+
with self.assertRaises(ValueError):
375+
a1 = iterable_str_to_array_1d(['-2+1.2j', '1.5+-4.2j'], complex)
376+
377+
def test_iterable_str_to_array_1d_complex_6(self) -> None:
378+
# NOTE: malformed complex raise Exception as expected
379+
with self.assertRaises(ValueError):
380+
a1 = iterable_str_to_array_1d(['-2+1.2asdfj', '1.5wer4.2j'], complex)
377381

378382
#---------------------------------------------------------------------------
379383

@@ -389,12 +393,16 @@ def test_iterable_str_to_array_1d_dt64_2(self) -> None:
389393
self.assertFalse(a1.flags.writeable)
390394
self.assertEqual(a1.tolist(), [datetime.date(2020, 1, 1), datetime.date(2020, 2, 1)])
391395

392-
def test_iterable_str_to_array_1d_dt64_2(self) -> None:
396+
def test_iterable_str_to_array_1d_dt64_3(self) -> None:
393397
a1 = iterable_str_to_array_1d(['2020-01-01', '2020-02-01'], np.datetime64)
394398
self.assertEqual(a1.dtype, np.dtype('<M8[D]'))
395399
self.assertFalse(a1.flags.writeable)
396400
self.assertEqual(a1.tolist(), [datetime.date(2020, 1, 1), datetime.date(2020, 2, 1)])
397401

402+
def test_iterable_str_to_array_1d_dt64_4(self) -> None:
403+
with self.assertRaises(ValueError):
404+
_ = iterable_str_to_array_1d(['202.30', '202.20'], 'datetime64[D]')
405+
398406
#---------------------------------------------------------------------------
399407

400408
def test_iterable_str_to_array_1d_parse_1(self) -> None:

0 commit comments

Comments
 (0)