Skip to content

Commit 0285efb

Browse files
committed
Finish up
1 parent cf7a325 commit 0285efb

File tree

3 files changed

+48
-42
lines changed

3 files changed

+48
-42
lines changed

pandas/core/frame.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,7 @@
339339
* inner: use intersection of keys from both frames, similar to a SQL inner
340340
join; preserve the order of the left keys.
341341
* left_semi: Filter for rows in the left that have a match on the right;
342-
preserve the order of the left keys. Doesn't support `left_index`, `right_index`,
343-
`indicator` or `validate`.
342+
preserve the order of the left keys.
344343
345344
.. versionadded:: 3.0
346345
* cross: creates the cartesian product from both frames, preserves the order

pandas/core/reshape/merge.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,7 @@ def _validate_how(
10541054
"right",
10551055
"inner",
10561056
"outer",
1057+
"left_semi",
10571058
"left_anti",
10581059
"right_anti",
10591060
"cross",
@@ -1403,7 +1404,11 @@ def _get_join_info(
14031404
left_ax = self.left.index
14041405
right_ax = self.right.index
14051406

1406-
if self.left_index and self.right_index and self.how != "asof":
1407+
if (
1408+
self.left_index
1409+
and self.right_index
1410+
and self.how not in ("asof", "left_semi")
1411+
):
14071412
join_index, left_indexer, right_indexer = left_ax.join(
14081413
right_ax, how=self.how, return_indexers=True, sort=self.sort
14091414
)
@@ -1647,15 +1652,7 @@ def _get_merge_keys(
16471652
k = cast(Hashable, k)
16481653
left_keys.append(left._get_label_or_level_values(k))
16491654
join_names.append(k)
1650-
if isinstance(self.right.index, MultiIndex):
1651-
right_keys = [
1652-
lev._values.take(lev_codes)
1653-
for lev, lev_codes in zip(
1654-
self.right.index.levels, self.right.index.codes
1655-
)
1656-
]
1657-
else:
1658-
right_keys = [self.right.index._values]
1655+
right_keys = self._unpack_index_as_join_key(self.right.index)
16591656
elif _any(self.right_on):
16601657
for k in self.right_on:
16611658
k = extract_array(k, extract_numpy=True)
@@ -1669,18 +1666,23 @@ def _get_merge_keys(
16691666
k = cast(Hashable, k)
16701667
right_keys.append(right._get_label_or_level_values(k))
16711668
join_names.append(k)
1672-
if isinstance(self.left.index, MultiIndex):
1673-
left_keys = [
1674-
lev._values.take(lev_codes)
1675-
for lev, lev_codes in zip(
1676-
self.left.index.levels, self.left.index.codes
1677-
)
1678-
]
1679-
else:
1680-
left_keys = [self.left.index._values]
1669+
left_keys = self._unpack_index_as_join_key(self.left.index)
1670+
elif self.how == "left_semi":
1671+
left_keys = self._unpack_index_as_join_key(self.left.index)
1672+
right_keys = self._unpack_index_as_join_key(self.right.index)
16811673

16821674
return left_keys, right_keys, join_names, left_drop, right_drop
16831675

1676+
def _unpack_index_as_join_key(self, index: Index) -> list[ArrayLike]:
1677+
if isinstance(index, MultiIndex):
1678+
keys = [
1679+
lev._values.take(lev_codes)
1680+
for lev, lev_codes in zip(index.levels, index.codes)
1681+
]
1682+
else:
1683+
keys = [index._values]
1684+
return keys
1685+
16841686
@final
16851687
def _maybe_coerce_merge_keys(self) -> None:
16861688
# we have valid merges but we may have to further
@@ -2241,15 +2243,8 @@ def _convert_to_multiindex(index: Index) -> MultiIndex:
22412243

22422244
class _SemiMergeOperation(_MergeOperation):
22432245
def __init__(self, *args, **kwargs):
2244-
if kwargs.get("validate", None):
2245-
raise NotImplementedError("validate is not supported for semi-join.")
2246-
22472246
super().__init__(*args, **kwargs)
2248-
if self.left_index or self.right_index:
2249-
raise NotImplementedError(
2250-
"left_index or right_index are not supported for semi-join."
2251-
)
2252-
elif self.indicator:
2247+
if self.indicator:
22532248
raise NotImplementedError("indicator is not supported for semi-join.")
22542249
elif self.sort:
22552250
raise NotImplementedError(
@@ -2273,7 +2268,7 @@ def _reindex_and_concat(
22732268
left_indexer: npt.NDArray[np.intp] | None,
22742269
right_indexer: npt.NDArray[np.intp] | None,
22752270
) -> DataFrame:
2276-
left = self.left[:]
2271+
left = self.left
22772272

22782273
if left_indexer is not None and not is_range_indexer(left_indexer, len(left)):
22792274
lmgr = left._mgr.take(left_indexer, axis=1, verify=False)
@@ -2956,7 +2951,7 @@ def _factorize_keys(
29562951
lk_data, rk_data = lk, rk # type: ignore[assignment]
29572952
lk_mask, rk_mask = None, None
29582953

2959-
hash_join_available = how == "inner" and not sort
2954+
hash_join_available = how == "inner" and not sort and lk.dtype.kind in "iufbO"
29602955
if hash_join_available:
29612956
rlab = rizer.factorize(rk_data, mask=rk_mask)
29622957
if rizer.get_count() == len(rlab):

pandas/tests/reshape/merge/test_semi.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,21 @@
1010
"vals_left, vals_right, dtype",
1111
[
1212
([1, 2, 3], [1, 2], "int64"),
13+
([1.5, 2.5, 3.5], [1.5, 2.5], "float64"),
14+
([True, True, False], [True, True], "bool"),
1315
(["a", "b", "c"], ["a", "b"], "object"),
1416
pytest.param(
1517
["a", "b", "c"],
1618
["a", "b"],
1719
"string[pyarrow]",
1820
marks=td.skip_if_no("pyarrow"),
1921
),
22+
pytest.param(
23+
["a", "b", "c"],
24+
["a", "b"],
25+
"str",
26+
marks=td.skip_if_no("pyarrow"),
27+
),
2028
],
2129
)
2230
def test_left_semi(vals_left, vals_right, dtype):
@@ -28,6 +36,21 @@ def test_left_semi(vals_left, vals_right, dtype):
2836
result = left.merge(right, how="left_semi")
2937
tm.assert_frame_equal(result, expected)
3038

39+
result = left.set_index("a").merge(
40+
right.set_index("a"), how="left_semi", left_index=True, right_index=True
41+
)
42+
tm.assert_frame_equal(result, expected.set_index("a"))
43+
44+
result = left.set_index("a").merge(
45+
right, how="left_semi", left_index=True, right_on="a"
46+
)
47+
tm.assert_frame_equal(result, expected.set_index("a"))
48+
49+
result = left.merge(
50+
right.set_index("a"), how="left_semi", right_index=True, left_on="a"
51+
)
52+
tm.assert_frame_equal(result, expected)
53+
3154
right = pd.DataFrame({"d": vals_right, "c": 1})
3255
result = left.merge(right, how="left_semi", left_on="a", right_on="d")
3356
tm.assert_frame_equal(result, expected)
@@ -40,17 +63,6 @@ def test_left_semi(vals_left, vals_right, dtype):
4063
def test_left_semi_invalid():
4164
left = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]})
4265
right = pd.DataFrame({"a": [1, 2], "c": 1})
43-
44-
msg = "left_index or right_index are not supported for semi-join."
45-
with pytest.raises(NotImplementedError, match=msg):
46-
left.merge(right, how="left_semi", left_index=True, right_on="a")
47-
with pytest.raises(NotImplementedError, match=msg):
48-
left.merge(right, how="left_semi", right_index=True, left_on="a")
49-
50-
msg = "validate is not supported for semi-join."
51-
with pytest.raises(NotImplementedError, match=msg):
52-
left.merge(right, how="left_semi", validate="one_to_one")
53-
5466
msg = "indicator is not supported for semi-join."
5567
with pytest.raises(NotImplementedError, match=msg):
5668
left.merge(right, how="left_semi", indicator=True)

0 commit comments

Comments
 (0)