diff --git a/causalml/match.py b/causalml/match.py index b076a5f4..91cc64c0 100644 --- a/causalml/match.py +++ b/causalml/match.py @@ -88,8 +88,7 @@ class NearestNeighborMatch: Attributes: caliper (float): threshold to be considered as a match. replace (bool): whether to match with replacement or not - ratio (int): ratio of control / treatment to be matched. used only if - replace=True. + ratio (int): ratio of control / treatment to be matched. shuffle (bool): whether to shuffle the treatment group data before matching random_state (numpy.random.RandomState or int): RandomState or an int @@ -112,6 +111,7 @@ def __init__( Args: caliper (float): threshold to be considered as a match. replace (bool): whether to match with replacement or not + ratio (int): ratio of control / treatment to be matched. shuffle (bool): whether to shuffle the treatment group data before matching or not random_state (numpy.random.RandomState or int): RandomState or an @@ -200,11 +200,15 @@ def match(self, data, treatment_col, score_cols): control.loc[control.unmatched, score_col] - treatment.loc[t_idx, score_col] ) - c_idx_min = dist.idxmin() - if dist[c_idx_min] <= sdcal: - t_idx_matched.append(t_idx) - c_idx_matched.append(c_idx_min) - control.loc[c_idx_min, "unmatched"] = False + # Gets self.ratio lowest dists + c_np_idx_list = np.argpartition(dist, self.ratio)[: self.ratio] + c_idx_list = dist.index[c_np_idx_list] + for i, c_idx in enumerate(c_idx_list): + if dist[c_idx] <= sdcal: + if i == 0: + t_idx_matched.append(t_idx) + c_idx_matched.append(c_idx) + control.loc[c_idx, "unmatched"] = False return data.loc[ np.concatenate([np.array(t_idx_matched), np.array(c_idx_matched)]) diff --git a/tests/test_match.py b/tests/test_match.py index 9df30211..4933ee10 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -34,10 +34,18 @@ def _generate_data(): yield _generate_data +def test_nearest_neighbor_match_ratio_2(generate_unmatched_data): + df, features = generate_unmatched_data() + + psm = NearestNeighborMatch(replace=False, ratio=2, random_state=RANDOM_SEED) + matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL]) + assert sum(matched[TREATMENT_COL] == 0) == 2 * sum(matched[TREATMENT_COL] != 0) + + def test_nearest_neighbor_match_by_group(generate_unmatched_data): df, features = generate_unmatched_data() - psm = NearestNeighborMatch(replace=False, ratio=1.0, random_state=RANDOM_SEED) + psm = NearestNeighborMatch(replace=False, ratio=1, random_state=RANDOM_SEED) matched = psm.match_by_group( data=df,