Skip to content

Commit

Permalink
Plotting improvements - correlation and permutation importance
Browse files Browse the repository at this point in the history
  • Loading branch information
adityasevak123ga committed Dec 1, 2023
1 parent d0e5054 commit 3def600
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions uncoverml/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,11 +680,13 @@ def plot_feature_correlation_matrix(config: Config, x_all):
corr_df = pd.DataFrame(x_all)
corr_df.columns = features
sns.heatmap(corr_df.corr(),
vmin=-1, vmax=1, annot=True,
vmin=-1, vmax=1, annot=False,
square=True, linewidths=.5, cbar_kws={"shrink": .5},
cmap='BrBG',
cmap='BrBG', xticklabels=True, yticklabels=True
)
fig.suptitle('Feature Correlations')
plt.xticks(fontsize=5)
plt.yticks(fontsize=5)
fig.tight_layout()
save_path = Path(config.output_dir).joinpath(config.name + "_feature_correlation.png") \
.as_posix()
Expand Down Expand Up @@ -772,20 +774,24 @@ def plot_permutation_feature_importance(model, x_all, targets_all, conf: Config,
score)).as_posix()
df_picv.to_csv(csv, index=False)

x = np.arange(len(df_picv.index))
width = 0.35
fig, ax = plt.subplots()
ax.barh(x - width / 2, df_picv['weight'].values, width, label='Weight')
ax.barh(x + width / 2, df_picv['std'].values, width, label='Std')
ax.set_ylabel('Covariate')
ax.set_title('Permutation Feature Importance Weight and Std')
ax.set_xticks(x)
num_cov = np.arange(len(feature_names))
ax.set_yticks(num_cov)
ax.set_yticklabels(feature_names)
ax.set_xlabel('Score')
ax.legend()
# x = np.arange(len(df_picv.index))
# width = 0.35
# fig, ax = plt.subplots()
# ax.barh(x - width / 2, df_picv['weight'].values, width, label='Weight')
# ax.barh(x + width / 2, df_picv['std'].values, width, label='Std')
# ax.set_ylabel('Covariate')
# ax.set_title('Permutation Feature Importance Weight and Std')
# ax.set_xticks(x)
# num_cov = np.arange(len(feature_names))
# ax.set_yticks(num_cov)
# ax.set_yticklabels(feature_names)
# ax.set_xlabel('Score')
# ax.legend()

fig, ax = plt.subplots()
sns.barplot(data=df_picv, x='weight', y='feature', orient='h')
fig.suptitle('Permutation Importance')
plt.yticks(fontsize=5)
fig.tight_layout()
save_path = Path(conf.output_dir).joinpath(conf.name + "_feature_importance_bars_{}.png".format(score))\
.as_posix()
Expand Down

0 comments on commit 3def600

Please sign in to comment.