Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-vignal committed Oct 10, 2024
1 parent 3afcb93 commit 3b74aae
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
12 changes: 8 additions & 4 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ def contribution_plot(
--------
>>> xpl.plot.contribution_plot(0)
"""

if self._explainer._case == "classification":
label_num, _, label_value = self._explainer.check_label_name(label)

Expand Down Expand Up @@ -505,8 +504,13 @@ def contribution_plot(
else:
feature_values = self._explainer.x_init.loc[list_ind, col_name]

if self.explainer.x_init[col_name].dtype == "bool":
feature_values = feature_values.astype(int)
if isinstance(col_name, list):
for el in col_name:
if feature_values[el].dtype == "bool":
feature_values[el] = feature_values[el].astype(int)
else:
if feature_values.dtype == "bool":
feature_values = feature_values.astype(int)

if col_is_group:
feature_values = project_feature_values_1d(
Expand Down Expand Up @@ -1130,7 +1134,7 @@ def interactions_plot(
interaction_values = interaction_values * 2

# add break line to X label if necessary
max_len_by_row = max([round(50 / self.explainer.features_desc[feature_values1.columns.values[0]]), 8])
max_len_by_row = max([round(50 / self._explainer.features_desc[feature_values1.columns.values[0]]), 8])
args = (max_len_by_row, 120)
feature_values_str = feature_values1.iloc[:, 0].apply(add_line_break, args=args)
feature_values1 = pd.DataFrame({feature_values1.columns[0]: feature_values_str})
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/explainer/test_smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,11 +1860,11 @@ def test_interactions_plot_3(self):

output = smart_explainer.plot.interactions_plot(col2, col1, violin_maxf=0)

assert np.array_equal(output.data[0].x, [34.0])
assert np.array_equal(output.data[0].x, ["34.0"])
assert np.array_equal(output.data[0].y, [-1.4])
assert output.data[0].name == "PhD"

assert np.array_equal(output.data[1].x, [27.0])
assert np.array_equal(output.data[1].x, ["27.0"])
assert np.array_equal(output.data[1].y, [-0.2])
assert output.data[1].name == "Master"

Expand Down Expand Up @@ -1893,7 +1893,7 @@ def test_interactions_plot_4(self):

output = smart_explainer.plot.interactions_plot(col1, col2, violin_maxf=0)

assert np.array_equal(output.data[0].x, [520, 12800])
assert np.array_equal(output.data[0].x, ["520.0", "12800.0"])
assert np.array_equal(output.data[0].y, [-1.4, -0.2])
assert np.array_equal(output.data[0].marker.color, [34.0, 27.0])

Expand Down

0 comments on commit 3b74aae

Please sign in to comment.