Skip to content

Commit 4f4b108

Browse files
authored
BUG: fix error when unstacking with sort=False when indexes contains NA (#62334)
1 parent 0c87e2d commit 4f4b108

File tree

3 files changed

+66
-8
lines changed

3 files changed

+66
-8
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,7 @@ Reshaping
10831083
- Bug in :meth:`DataFrame.join` when a :class:`DataFrame` with a :class:`MultiIndex` would raise an ``AssertionError`` when :attr:`MultiIndex.names` contained ``None``. (:issue:`58721`)
10841084
- Bug in :meth:`DataFrame.merge` where merging on a column containing only ``NaN`` values resulted in an out-of-bounds array access (:issue:`59421`)
10851085
- Bug in :meth:`DataFrame.unstack` producing incorrect results when ``sort=False`` (:issue:`54987`, :issue:`55516`)
1086+
- Bug in :meth:`DataFrame.unstack` raising an error with indexes containing ``NaN`` with ``sort=False`` (:issue:`61221`)
10861087
- Bug in :meth:`DataFrame.merge` when merging two :class:`DataFrame` on ``intc`` or ``uintc`` types on Windows (:issue:`60091`, :issue:`58713`)
10871088
- Bug in :meth:`DataFrame.pivot_table` incorrectly subaggregating results when called without an ``index`` argument (:issue:`58722`)
10881089
- Bug in :meth:`DataFrame.pivot_table` incorrectly ignoring the ``values`` argument when also supplied to the ``index`` or ``columns`` parameters (:issue:`57876`, :issue:`61292`)

pandas/core/reshape/reshape.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,11 @@ def __init__(
128128

129129
self.level = self.index._get_level_number(level)
130130

131-
# when index includes `nan`, need to lift levels/strides by 1
132-
self.lift = 1 if -1 in self.index.codes[self.level] else 0
131+
# `nan` values have code `-1`, when sorting, we lift to assign them
132+
# at index 0
133+
self.has_nan = -1 in self.index.codes[self.level]
134+
should_lift = self.has_nan and self.sort
135+
self.lift = 1 if should_lift else 0
133136

134137
# Note: the "pop" below alters these in-place.
135138
self.new_index_levels = list(self.index.levels)
@@ -138,8 +141,16 @@ def __init__(
138141
self.removed_name = self.new_index_names.pop(self.level)
139142
self.removed_level = self.new_index_levels.pop(self.level)
140143
self.removed_level_full = index.levels[self.level]
144+
self.unique_nan_index: int = -1
141145
if not self.sort:
142-
unique_codes = unique(self.index.codes[self.level])
146+
unique_codes: np.ndarray = unique(self.index.codes[self.level])
147+
if self.has_nan:
148+
# drop nan codes, because they are not represented in level
149+
nan_mask = unique_codes == -1
150+
151+
unique_codes = unique_codes[~nan_mask]
152+
self.unique_nan_index = np.flatnonzero(nan_mask)[0]
153+
143154
self.removed_level = self.removed_level.take(unique_codes)
144155
self.removed_level_full = self.removed_level_full.take(unique_codes)
145156

@@ -210,7 +221,7 @@ def _make_selectors(self) -> None:
210221
ngroups = len(obs_ids)
211222

212223
comp_index = ensure_platform_int(comp_index)
213-
stride = self.index.levshape[self.level] + self.lift
224+
stride = self.index.levshape[self.level] + self.has_nan
214225
self.full_shape = ngroups, stride
215226

216227
selector = self.sorted_labels[-1] + stride * comp_index + self.lift
@@ -362,13 +373,13 @@ def get_new_values(self, values, fill_value=None):
362373

363374
def get_new_columns(self, value_columns: Index | None):
364375
if value_columns is None:
365-
if self.lift == 0:
376+
if not self.has_nan:
366377
return self.removed_level._rename(name=self.removed_name)
367378

368379
lev = self.removed_level.insert(0, item=self.removed_level._na_value)
369380
return lev.rename(self.removed_name)
370381

371-
stride = len(self.removed_level) + self.lift
382+
stride = len(self.removed_level) + self.has_nan
372383
width = len(value_columns)
373384
propagator = np.repeat(np.arange(width), stride)
374385

@@ -401,12 +412,21 @@ def _repeater(self) -> np.ndarray:
401412
if len(self.removed_level_full) != len(self.removed_level):
402413
# In this case, we remap the new codes to the original level:
403414
repeater = self.removed_level_full.get_indexer(self.removed_level)
404-
if self.lift:
415+
if self.has_nan:
416+
# insert nan index at first position
405417
repeater = np.insert(repeater, 0, -1)
406418
else:
407419
# Otherwise, we just use each level item exactly once:
408-
stride = len(self.removed_level) + self.lift
420+
stride = len(self.removed_level) + self.has_nan
409421
repeater = np.arange(stride) - self.lift
422+
if self.has_nan and not self.sort:
423+
assert self.unique_nan_index > -1, (
424+
"`unique_nan_index` not properly initialized"
425+
)
426+
# assign -1 where should be nan according to the unique values.
427+
repeater[self.unique_nan_index] = -1
428+
# compensate for the removed index level
429+
repeater[self.unique_nan_index + 1 :] -= 1
410430

411431
return repeater
412432

pandas/tests/frame/test_stack_unstack.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,43 @@ def test_unstack_sort_false(frame_or_series, dtype):
13861386
tm.assert_frame_equal(result, expected)
13871387

13881388

1389+
@pytest.mark.parametrize(
1390+
"levels2, expected_columns",
1391+
[
1392+
(
1393+
[None, 1, 2, 3],
1394+
[("value", np.nan), ("value", 1), ("value", 2), ("value", 3)],
1395+
),
1396+
(
1397+
[1, None, 2, 3],
1398+
[("value", 1), ("value", np.nan), ("value", 2), ("value", 3)],
1399+
),
1400+
(
1401+
[1, 2, None, 3],
1402+
[("value", 1), ("value", 2), ("value", np.nan), ("value", 3)],
1403+
),
1404+
(
1405+
[1, 2, 3, None],
1406+
[("value", 1), ("value", 2), ("value", 3), ("value", np.nan)],
1407+
),
1408+
],
1409+
ids=["nan=first", "nan=second", "nan=third", "nan=last"],
1410+
)
1411+
def test_unstack_sort_false_nan(levels2, expected_columns):
1412+
# GH#61221
1413+
levels1 = ["b", "a"]
1414+
index = MultiIndex.from_product([levels1, levels2], names=["level1", "level2"])
1415+
df = DataFrame({"value": [0, 1, 2, 3, 4, 5, 6, 7]}, index=index)
1416+
result = df.unstack(level="level2", sort=False)
1417+
expected_data = [[0, 4], [1, 5], [2, 6], [3, 7]]
1418+
expected = DataFrame(
1419+
dict(zip(expected_columns, expected_data)),
1420+
index=Index(["b", "a"], name="level1"),
1421+
columns=MultiIndex.from_tuples(expected_columns, names=[None, "level2"]),
1422+
)
1423+
tm.assert_frame_equal(result, expected)
1424+
1425+
13891426
def test_unstack_fill_frame_object():
13901427
# GH12815 Test unstacking with object.
13911428
data = Series(["a", "b", "c", "a"], dtype="object")

0 commit comments

Comments
 (0)