Skip to content

Commit

Permalink
Merge pull request #18 from achamma723/multi_class
Browse files Browse the repository at this point in the history
[FIX]: BBI in Multi-Class settings
  • Loading branch information
achamma723 authored Jun 2, 2024
2 parents 40ac401 + 9e20546 commit 84de98a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
19 changes: 13 additions & 6 deletions hidimstat/BBI.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
self.random_state = random_state
self.X_test = [None] * max(self.k_fold, 1)
self.y_test = [None] * max(self.k_fold, 1)
self.y_train = [None] * max(self.k_fold, 1)
self.org_pred = [None] * max(self.k_fold, 1)
self.pred_scores = [None] * max(self.k_fold, 1)
self.X_nominal = [None] * max(self.k_fold, 1)
Expand Down Expand Up @@ -449,6 +450,7 @@ def fit(self, X, y=None):

self.X_test[ind_fold] = X_test.copy()
self.y_test[ind_fold] = y_test.copy()
self.y_train[ind_fold] = y_train.copy()

# Find the list of optimal sub-models to be used in the
# following steps (Default estimator)
Expand All @@ -459,6 +461,7 @@ def fit(self, X, y=None):
self.list_estimators[ind_fold] = copy(self.estimator)

else:
self.y_train = y.copy()
if not self.apply_ridge:
if self.coffeine_transformer is not None:
X = self.coffeine_transformers[0].fit_transform(
Expand Down Expand Up @@ -549,7 +552,12 @@ def func(x):

self.estimator.fit(X_train_scaled, y_train_curr)

list_loss.append(self.loss(y_valid_curr, func(X_valid_scaled)))
if self.prob_type == "classification":
list_loss.append(self.loss(y_valid_curr,
func(X_valid_scaled)[:, np.unique(y_valid_curr)]))
else:
list_loss.append(self.loss(y_valid_curr,
func(X_valid_scaled)))

ind_min = np.argmin(list_loss)
best_hyper = list_hyper[ind_min]
Expand Down Expand Up @@ -761,11 +769,10 @@ def compute_importance(self, X=None, y=None):
)[y_col]
else:
if self.prob_type in ("classification", "binary"):
y[ind_fold] = (
OneHotEncoder(handle_unknown="ignore")
.fit_transform(y[ind_fold].reshape(-1, 1))
.toarray()
)
one_hot = (OneHotEncoder(handle_unknown="ignore")
.fit(self.y_train[ind_fold].reshape(-1, 1)))
y[ind_fold] = (one_hot.transform(y[ind_fold]
.reshape(-1, 1)).toarray())
if self.com_imp:
if not self.conditional:
self.pred_scores[ind_fold], score_cur = list(
Expand Down
7 changes: 5 additions & 2 deletions hidimstat/compute_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ def joblib_compute_conditional(
)

if prob_type in ("classification", "binary"):
score = roc_auc_score(y_test, org_pred)
nonzero_cols = np.where(y_test.any(axis=0))[0]
score = roc_auc_score(y_test[:, nonzero_cols], org_pred[:, nonzero_cols])
else:
score = (
mean_absolute_error(y_test, org_pred),
Expand Down Expand Up @@ -710,7 +711,9 @@ def joblib_compute_permutation(

res = (y_test - pred_i) ** 2 - (y_test - org_pred) ** 2
else:
score = roc_auc_score(y_test, org_pred)
nonzero_cols = np.where(y_test.any(axis=0))[0]
score = roc_auc_score(y_test[:, nonzero_cols],
org_pred[:, nonzero_cols])
if type_predictor == "DNN":
pred_i = estimator.predict_proba(current_X_test_list, scale=False)
else:
Expand Down

0 comments on commit 84de98a

Please sign in to comment.