Skip to content

Commit

Permalink
Merge pull request numpy#28228 from ngoldbaum/float-string-nan-cast
Browse files Browse the repository at this point in the history
BUG: handle case when StringDType na_object is nan in float to string cast
  • Loading branch information
ngoldbaum authored Jan 27, 2025
2 parents 32a6b53 + d33cea1 commit cb8b623
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 1 deletion.
21 changes: 21 additions & 0 deletions numpy/_core/src/multiarray/stringdtype/casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1208,14 +1208,35 @@ float_to_string(
PyArray_StringDTypeObject *descr =
(PyArray_StringDTypeObject *)context->descriptors[1];
npy_string_allocator *allocator = NpyString_acquire_allocator(descr);
// borrowed reference
PyObject *na_object = descr->na_object;

while (N--) {
PyObject *scalar_val = PyArray_Scalar(in, float_descr, NULL);
if (descr->has_nan_na) {
// check for case when scalar_val is the na_object and store a null string
int na_cmp = na_eq_cmp(scalar_val, na_object);
if (na_cmp < 0) {
Py_DECREF(scalar_val);
goto fail;
}
if (na_cmp) {
Py_DECREF(scalar_val);
if (NpyString_pack_null(allocator, (npy_packed_static_string *)out) < 0) {
PyErr_SetString(PyExc_MemoryError,
"Failed to pack null string during float "
"to string cast");
goto fail;
}
goto next_step;
}
}
// steals reference to scalar_val
if (pyobj_to_string(scalar_val, out, allocator) == -1) {
goto fail;
}

next_step:
in += in_stride;
out += out_stride;
}
Expand Down
2 changes: 1 addition & 1 deletion numpy/_core/src/multiarray/stringdtype/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
return NULL;
}

static int
NPY_NO_EXPORT int
na_eq_cmp(PyObject *a, PyObject *b) {
if (a == b) {
// catches None and other singletons like Pandas.NA
Expand Down
3 changes: 3 additions & 0 deletions numpy/_core/src/multiarray/stringdtype/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ _eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona);
NPY_NO_EXPORT int
stringdtype_compatible_na(PyObject *na1, PyObject *na2, PyObject **out_na);

NPY_NO_EXPORT int
na_eq_cmp(PyObject *a, PyObject *b);

#ifdef __cplusplus
}
#endif
Expand Down
15 changes: 15 additions & 0 deletions numpy/_core/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,21 @@ def test_float_casts(typename):
assert_array_equal(eres, res)


def test_float_nan_cast_na_object():
# gh-28157
dt = np.dtypes.StringDType(na_object=np.nan)
arr1 = np.full((1,), fill_value=np.nan, dtype=dt)
arr2 = np.full_like(arr1, fill_value=np.nan)

assert arr1.item() is np.nan
assert arr2.item() is np.nan

inp = [1.2, 2.3, np.nan]
arr = np.array(inp).astype(dt)
assert arr[2] is np.nan
assert arr[0] == '1.2'


@pytest.mark.parametrize(
"typename",
[
Expand Down

0 comments on commit cb8b623

Please sign in to comment.