Skip to content

Commit

Permalink
Not displaying interaction plot by default;limit plotly version >6.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-vignal committed Feb 11, 2025
1 parent daedb63 commit 96fa8f3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 36 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "shapash"
version = "2.7.6"
version = "2.7.7"
authors = [
{name = "Yann Golhen"},
{name = "Sebastien Bidault"},
Expand All @@ -29,7 +29,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"plotly>=5.0.0",
"plotly>=5.0.0,<6.0.0",
"matplotlib>=3.2.0",
"numpy>1.18.0,<2",
"pandas>=2.1.0",
Expand Down
5 changes: 5 additions & 0 deletions shapash/explainer/smart_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,7 @@ def generate_report(
notebook_path=None,
kernel_name=None,
max_points=200,
display_interaction_plot=False,
nb_top_interactions=5,
):
"""
Expand Down Expand Up @@ -1251,6 +1252,9 @@ def generate_report(
by default.
max_points : int, optional
number of maximum points in the contribution plot
display_interaction_plot: bool, optional
Whether to display the interaction plot. This can be computationally expensive,
so it is set to False by default to optimize performance.
nb_top_interactions : int
Number of top interactions to display.
Examples
Expand Down Expand Up @@ -1305,6 +1309,7 @@ def generate_report(
title_description=title_description,
metrics=metrics,
max_points=max_points,
display_interaction_plot=display_interaction_plot,
nb_top_interactions=nb_top_interactions,
),
notebook_path=notebook_path,
Expand Down
74 changes: 40 additions & 34 deletions shapash/report/project_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def __init__(
else:
self.max_points = 200

if "display_interaction_plot" in self.config.keys():
self.display_interaction_plot = config["display_interaction_plot"]
else:
self.display_interaction_plot = False

if "nb_top_interactions" in self.config.keys():
self.nb_top_interactions = config["nb_top_interactions"]
else:
Expand Down Expand Up @@ -427,46 +432,47 @@ def display_model_explainability(self):
)

# Interaction Plot
explain_contrib_data_interaction = list()
list_ind, _ = self.explainer.plot._select_indices_interactions_plot(
selection=None, max_points=self.max_points
)
interaction_values = self.explainer.get_interaction_values(selection=list_ind)
sorted_top_features_indices = compute_sorted_variables_interactions_list_indices(interaction_values)
indices_to_plot = sorted_top_features_indices[: self.nb_top_interactions]
if self.display_interaction_plot:
explain_contrib_data_interaction = list()
list_ind, _ = self.explainer.plot._select_indices_interactions_plot(
selection=None, max_points=self.max_points
)
interaction_values = self.explainer.get_interaction_values(selection=list_ind)
sorted_top_features_indices = compute_sorted_variables_interactions_list_indices(interaction_values)
indices_to_plot = sorted_top_features_indices[: self.nb_top_interactions]

for i, ids in enumerate(indices_to_plot):
id0, id1 = ids
for i, ids in enumerate(indices_to_plot):
id0, id1 = ids

fig_one_interaction = self.explainer.plot.interactions_plot(
col1=self.explainer.columns_dict[id0],
col2=self.explainer.columns_dict[id1],
max_points=self.max_points,
)
fig_one_interaction = self.explainer.plot.interactions_plot(
col1=self.explainer.columns_dict[id0],
col2=self.explainer.columns_dict[id1],
max_points=self.max_points,
)

explain_contrib_data_interaction.append(
explain_contrib_data_interaction.append(
{
"feature_index": i,
"name": self.explainer.columns_dict[id0] + " / " + self.explainer.columns_dict[id1],
"description": self.explainer.features_dict[self.explainer.columns_dict[id0]]
+ " / "
+ self.explainer.features_dict[self.explainer.columns_dict[id1]],
"plot": plotly.io.to_html(fig_one_interaction, include_plotlyjs=False, full_html=False),
}
)

# Aggregating the data
explain_data.append(
{
"feature_index": i,
"name": self.explainer.columns_dict[id0] + " / " + self.explainer.columns_dict[id1],
"description": self.explainer.features_dict[self.explainer.columns_dict[id0]]
+ " / "
+ self.explainer.features_dict[self.explainer.columns_dict[id1]],
"plot": plotly.io.to_html(fig_one_interaction, include_plotlyjs=False, full_html=False),
"index": index_label,
"name": label_value,
"feature_importance_plot": plotly.io.to_html(
fig_features_importance, include_plotlyjs=False, full_html=False
),
"features": explain_contrib_data,
"features_interaction": explain_contrib_data_interaction,
}
)

# Aggregating the data
explain_data.append(
{
"index": index_label,
"name": label_value,
"feature_importance_plot": plotly.io.to_html(
fig_features_importance, include_plotlyjs=False, full_html=False
),
"features": explain_contrib_data,
"features_interaction": explain_contrib_data_interaction,
}
)
print_html(explainability_template.render(labels=explain_data))
print_md("---")

Expand Down

0 comments on commit 96fa8f3

Please sign in to comment.