From 3149c8e7928981bfc7f7f4060fdc881dc440c863 Mon Sep 17 00:00:00 2001 From: achamma Date: Mon, 3 Jun 2024 02:11:46 +0200 Subject: [PATCH] Fix BBI multi_class --- hidimstat/BBI.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/hidimstat/BBI.py b/hidimstat/BBI.py index 3db9abc..a5fa9a5 100644 --- a/hidimstat/BBI.py +++ b/hidimstat/BBI.py @@ -833,7 +833,8 @@ def compute_importance(self, X=None, y=None): proc_col=p_col, index_i=ind_fold + 1, group_stacking=self.group_stacking, - sub_groups=[self.list_cols, self.sub_groups], + sub_groups=[self.list_cols, + self.sub_groups], list_seeds=list_seeds_imp, Perm=self.Perm, output_dim=output_dim, @@ -847,10 +848,13 @@ def compute_importance(self, X=None, y=None): score_imp_l.append(score_cur[0]) else: if self.prob_type in ("classification", "binary"): - score = roc_auc_score(y[ind_fold], self.org_pred[ind_fold]) + nonzero_cols = np.where(y[ind_fold].any(axis=0))[0] + score = roc_auc_score(y[ind_fold][:, nonzero_cols], + self.org_pred[ind_fold][:, nonzero_cols]) else: score = ( - mean_absolute_error(y[ind_fold], self.org_pred[ind_fold]), + mean_absolute_error(y[ind_fold], + self.org_pred[ind_fold]), r2_score(y[ind_fold], self.org_pred[ind_fold]), ) score_imp_l.append(score)