diff --git a/lightgbmlss/model.py b/lightgbmlss/model.py index 3693ee7..f966a87 100644 --- a/lightgbmlss/model.py +++ b/lightgbmlss/model.py @@ -570,15 +570,8 @@ def set_valid_margin(self, valid_sets : list List of tuples containing the train and evaluation set. """ - valid_sets1 = valid_sets[0] - init_score_val1 = (np.ones(shape=(valid_sets1.get_label().shape[0], 1))) * start_values - valid_sets1.set_init_score(init_score_val1.ravel(order="F")) - - valid_sets2 = valid_sets[1] - init_score_val2 = (np.ones(shape=(valid_sets2.get_label().shape[0], 1))) * start_values - valid_sets2.set_init_score(init_score_val2.ravel(order="F")) - - valid_sets = [valid_sets1, valid_sets2] + for valid_set in valid_sets: + self.set_init_score(valid_set) return valid_sets