Skip to content

Commit

Permalink
Merge pull request #19 from achamma723/multi_class
Browse files Browse the repository at this point in the history
[Fix]: BBI in multi_classification setting (2)
  • Loading branch information
achamma723 authored Jun 3, 2024
2 parents 84de98a + 3149c8e commit c8282c1
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions hidimstat/BBI.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,8 @@ def compute_importance(self, X=None, y=None):
proc_col=p_col,
index_i=ind_fold + 1,
group_stacking=self.group_stacking,
sub_groups=[self.list_cols, self.sub_groups],
sub_groups=[self.list_cols,
self.sub_groups],
list_seeds=list_seeds_imp,
Perm=self.Perm,
output_dim=output_dim,
Expand All @@ -847,10 +848,13 @@ def compute_importance(self, X=None, y=None):
score_imp_l.append(score_cur[0])
else:
if self.prob_type in ("classification", "binary"):
score = roc_auc_score(y[ind_fold], self.org_pred[ind_fold])
nonzero_cols = np.where(y[ind_fold].any(axis=0))[0]
score = roc_auc_score(y[ind_fold][:, nonzero_cols],
self.org_pred[ind_fold][:, nonzero_cols])
else:
score = (
mean_absolute_error(y[ind_fold], self.org_pred[ind_fold]),
mean_absolute_error(y[ind_fold],
self.org_pred[ind_fold]),
r2_score(y[ind_fold], self.org_pred[ind_fold]),
)
score_imp_l.append(score)
Expand Down

0 comments on commit c8282c1

Please sign in to comment.