From b598bff92a889c120d629cd4c3d2c8872f7a40c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Bostr=C3=B6m?= Date: Thu, 27 Jun 2024 16:57:06 +0200 Subject: [PATCH] 0.7.0 --- src/crepes/__init__.py | 11 +- src/crepes/base.py | 280 +++++++++++++++++++++-------------------- src/crepes/extras.py | 187 ++++++++++++++++++++++++++- 3 files changed, 341 insertions(+), 137 deletions(-) diff --git a/src/crepes/__init__.py b/src/crepes/__init__.py index e2e04b6..e672bcf 100644 --- a/src/crepes/__init__.py +++ b/src/crepes/__init__.py @@ -1,2 +1,11 @@ -from crepes.base import WrapRegressor, WrapClassifier, ConformalRegressor, ConformalPredictiveSystem, ConformalClassifier, ConformalPredictor, __version__ +from crepes.base import ( + WrapRegressor, + WrapClassifier, + ConformalRegressor, + ConformalPredictiveSystem, + ConformalClassifier, + ConformalPredictor, + __version__ +) + diff --git a/src/crepes/base.py b/src/crepes/base.py index 8c0be67..69f3f27 100644 --- a/src/crepes/base.py +++ b/src/crepes/base.py @@ -14,14 +14,14 @@ """ -__version__ = "0.6.2" +__version__ = "0.7.0" import numpy as np import pandas as pd import time import warnings -from crepes.extras import hinge +from crepes.extras import hinge, MondrianCategorizer warnings.simplefilter("always", UserWarning) @@ -886,6 +886,8 @@ def predict(self, y_hat, sigmas=None, bins=None, if no_result_columns > 0: result = np.zeros((len(y_hat),no_result_columns)) if y is not None: + if isinstance(y, pd.Series): + y = y.values no_prec_result_cols += 1 gammas = np.random.rand(len(y_hat)) if isinstance(y, (int, float, np.integer, np.floating)): @@ -915,7 +917,7 @@ def predict(self, y_hat, sigmas=None, bins=None, (y_hat[i]+bin_alphas[b])