Skip to content

Commit

Permalink
fix feature warning column transform in a dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-vignal committed Oct 8, 2024
1 parent e972df3 commit 1c1a2bb
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
12 changes: 4 additions & 8 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,14 +1127,10 @@ def interactions_plot(
interaction_values = interaction_values * 2

# add break line to X label if necessary
max_len_by_row = max([round(50 / self._explainer.features_desc[feature_values1.columns.values[0]]), 8])
feature_values1.iloc[:, 0] = feature_values1.iloc[:, 0].apply(
add_line_break,
args=(
max_len_by_row,
120,
),
)
max_len_by_row = max([round(50 / self.explainer.features_desc[feature_values1.columns.values[0]]), 8])
args = (max_len_by_row, 120)
feature_values_str = feature_values1.iloc[:, 0].apply(add_line_break, args=args)
feature_values1 = pd.DataFrame({feature_values1.columns[0]: feature_values_str})

# selecting the best plot : Scatter, Violin?
if col_value_count1 > violin_maxf:
Expand Down
6 changes: 4 additions & 2 deletions shapash/plots/plot_contribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def plot_scatter(

# add break line to X label if necessary
args = (max_len_by_row, 120)
feature_values.iloc[:, 0] = feature_values.iloc[:, 0].apply(add_line_break, args=args)
feature_values_str = feature_values.iloc[:, 0].apply(add_line_break, args=args)
feature_values = pd.DataFrame({column_name: feature_values_str})

if pred is not None:
hv_text = [f"Id: {x}<br />Predict: {y}" for x, y in zip(feature_values.index, pred.values.flatten())]
Expand Down Expand Up @@ -270,7 +271,8 @@ def plot_violin(

# add break line to X label if necessary
args = (max_len_by_row, 120)
feature_values.iloc[:, 0] = feature_values.iloc[:, 0].apply(add_line_break, args=args)
feature_values_str = feature_values.iloc[:, 0].apply(add_line_break, args=args)
feature_values = pd.DataFrame({column_name: feature_values_str})

contributions = contributions.loc[feature_values.index]
if pred is not None:
Expand Down
2 changes: 1 addition & 1 deletion shapash/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def add_line_break(text, nbchar, maxlen=150):
new_string = "".join(sum(zip(input_word, final_sep + [""]), ())[:-1]) + last_char
return new_string
else:
return text
return str(text)


def truncate_str(text, maxlen=40):
Expand Down

0 comments on commit 1c1a2bb

Please sign in to comment.