@@ -1054,6 +1054,7 @@ def _validate_how(
10541054 "right" ,
10551055 "inner" ,
10561056 "outer" ,
1057+ "left_semi" ,
10571058 "left_anti" ,
10581059 "right_anti" ,
10591060 "cross" ,
@@ -1403,7 +1404,11 @@ def _get_join_info(
14031404 left_ax = self .left .index
14041405 right_ax = self .right .index
14051406
1406- if self .left_index and self .right_index and self .how != "asof" :
1407+ if (
1408+ self .left_index
1409+ and self .right_index
1410+ and self .how not in ("asof" , "left_semi" )
1411+ ):
14071412 join_index , left_indexer , right_indexer = left_ax .join (
14081413 right_ax , how = self .how , return_indexers = True , sort = self .sort
14091414 )
@@ -1647,15 +1652,7 @@ def _get_merge_keys(
16471652 k = cast (Hashable , k )
16481653 left_keys .append (left ._get_label_or_level_values (k ))
16491654 join_names .append (k )
1650- if isinstance (self .right .index , MultiIndex ):
1651- right_keys = [
1652- lev ._values .take (lev_codes )
1653- for lev , lev_codes in zip (
1654- self .right .index .levels , self .right .index .codes
1655- )
1656- ]
1657- else :
1658- right_keys = [self .right .index ._values ]
1655+ right_keys = self ._unpack_index_as_join_key (self .right .index )
16591656 elif _any (self .right_on ):
16601657 for k in self .right_on :
16611658 k = extract_array (k , extract_numpy = True )
@@ -1669,18 +1666,23 @@ def _get_merge_keys(
16691666 k = cast (Hashable , k )
16701667 right_keys .append (right ._get_label_or_level_values (k ))
16711668 join_names .append (k )
1672- if isinstance (self .left .index , MultiIndex ):
1673- left_keys = [
1674- lev ._values .take (lev_codes )
1675- for lev , lev_codes in zip (
1676- self .left .index .levels , self .left .index .codes
1677- )
1678- ]
1679- else :
1680- left_keys = [self .left .index ._values ]
1669+ left_keys = self ._unpack_index_as_join_key (self .left .index )
1670+ elif self .how == "left_semi" :
1671+ left_keys = self ._unpack_index_as_join_key (self .left .index )
1672+ right_keys = self ._unpack_index_as_join_key (self .right .index )
16811673
16821674 return left_keys , right_keys , join_names , left_drop , right_drop
16831675
1676+ def _unpack_index_as_join_key (self , index : Index ) -> list [ArrayLike ]:
1677+ if isinstance (index , MultiIndex ):
1678+ keys = [
1679+ lev ._values .take (lev_codes )
1680+ for lev , lev_codes in zip (index .levels , index .codes )
1681+ ]
1682+ else :
1683+ keys = [index ._values ]
1684+ return keys
1685+
16841686 @final
16851687 def _maybe_coerce_merge_keys (self ) -> None :
16861688 # we have valid merges but we may have to further
@@ -2241,15 +2243,8 @@ def _convert_to_multiindex(index: Index) -> MultiIndex:
22412243
22422244class _SemiMergeOperation (_MergeOperation ):
22432245 def __init__ (self , * args , ** kwargs ):
2244- if kwargs .get ("validate" , None ):
2245- raise NotImplementedError ("validate is not supported for semi-join." )
2246-
22472246 super ().__init__ (* args , ** kwargs )
2248- if self .left_index or self .right_index :
2249- raise NotImplementedError (
2250- "left_index or right_index are not supported for semi-join."
2251- )
2252- elif self .indicator :
2247+ if self .indicator :
22532248 raise NotImplementedError ("indicator is not supported for semi-join." )
22542249 elif self .sort :
22552250 raise NotImplementedError (
@@ -2273,7 +2268,7 @@ def _reindex_and_concat(
22732268 left_indexer : npt .NDArray [np .intp ] | None ,
22742269 right_indexer : npt .NDArray [np .intp ] | None ,
22752270 ) -> DataFrame :
2276- left = self .left [:]
2271+ left = self .left
22772272
22782273 if left_indexer is not None and not is_range_indexer (left_indexer , len (left )):
22792274 lmgr = left ._mgr .take (left_indexer , axis = 1 , verify = False )
@@ -2956,7 +2951,7 @@ def _factorize_keys(
29562951 lk_data , rk_data = lk , rk # type: ignore[assignment]
29572952 lk_mask , rk_mask = None , None
29582953
2959- hash_join_available = how == "inner" and not sort
2954+ hash_join_available = how == "inner" and not sort and lk . dtype . kind in "iufbO"
29602955 if hash_join_available :
29612956 rlab = rizer .factorize (rk_data , mask = rk_mask )
29622957 if rizer .get_count () == len (rlab ):
0 commit comments