Skip to content

Commit

Permalink
black formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
spohngellert-o committed Jul 19, 2024
1 parent b05117e commit b049dea
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
6 changes: 3 additions & 3 deletions causalml/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class NearestNeighborMatch:
Attributes:
caliper (float): threshold to be considered as a match.
replace (bool): whether to match with replacement or not
ratio (int): ratio of control / treatment to be matched.
ratio (int): ratio of control / treatment to be matched.
shuffle (bool): whether to shuffle the treatment group data before
matching
random_state (numpy.random.RandomState or int): RandomState or an int
Expand All @@ -111,7 +111,7 @@ def __init__(
Args:
caliper (float): threshold to be considered as a match.
replace (bool): whether to match with replacement or not
ratio (int): ratio of control / treatment to be matched.
ratio (int): ratio of control / treatment to be matched.
shuffle (bool): whether to shuffle the treatment group data before
matching or not
random_state (numpy.random.RandomState or int): RandomState or an
Expand Down Expand Up @@ -201,7 +201,7 @@ def match(self, data, treatment_col, score_cols):
- treatment.loc[t_idx, score_col]
)
# Gets self.ratio lowest dists
c_np_idx_list = np.argpartition(dist, self.ratio)[:self.ratio]
c_np_idx_list = np.argpartition(dist, self.ratio)[: self.ratio]
c_idx_list = dist.index[c_np_idx_list]
for i, c_idx in enumerate(c_idx_list):
if dist[c_idx] <= sdcal:
Expand Down
7 changes: 2 additions & 5 deletions tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,10 @@ def test_nearest_neighbor_match_ratio_2(generate_unmatched_data):
df, features = generate_unmatched_data()

psm = NearestNeighborMatch(replace=False, ratio=2, random_state=RANDOM_SEED)
matched = psm.match(
data=df,
treatment_col=TREATMENT_COL,
score_cols=[SCORE_COL]
)
matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL])
assert sum(matched[TREATMENT_COL] == 0) == 2 * sum(matched[TREATMENT_COL] != 0)


def test_nearest_neighbor_match_by_group(generate_unmatched_data):
df, features = generate_unmatched_data()

Expand Down

0 comments on commit b049dea

Please sign in to comment.