Skip to content

Commit 8837191

Browse files
committed
add predict_proba functionality
1 parent e8b4acf commit 8837191

File tree

5 files changed

+64
-3
lines changed

5 files changed

+64
-3
lines changed

ml_garden/core/data_container.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
import sys
9-
from typing import Any, Optional, Union
9+
from typing import Any, Dict, Optional, Union
1010

1111
import dill as pickle
1212
import pandas as pd
@@ -675,6 +675,32 @@ def predictions(self, value: pd.Series):
675675
"""
676676
self["predictions"] = value
677677

678+
@property
679+
def predict_proba(self) -> pd.DataFrame:
680+
"""
681+
Get the prediction probabilities from the DataContainer.
682+
683+
Returns
684+
-------
685+
pd.DataFrame
686+
The prediction probabilities stored in the DataContainer.
687+
For binary and multiclass classification, returns a DataFrame with a column for each class.
688+
"""
689+
return self["predict_proba"]
690+
691+
@predict_proba.setter
692+
def predict_proba(self, value: pd.DataFrame):
693+
"""
694+
Set the prediction probabilities in the DataContainer.
695+
696+
Parameters
697+
----------
698+
value : pd.DataFrame
699+
The prediction probabilities to be stored in the DataContainer.
700+
Should be a DataFrame with a column for each class.
701+
"""
702+
self["predict_proba"] = value
703+
678704
@property
679705
def explainer(self) -> BaseExplainer:
680706
"""

ml_garden/core/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,19 @@ def fit(
2424
@abstractmethod
2525
def predict(self, X: pd.DataFrame) -> pd.Series:
2626
"""Abstract method for making predictions."""
27+
28+
def predict_proba(self, X: pd.DataFrame) -> pd.DataFrame:
29+
"""
30+
Predict class probabilities with the trained model.
31+
32+
Parameters
33+
----------
34+
X : pd.DataFrame
35+
Features to make probability predictions on.
36+
37+
Returns
38+
-------
39+
pd.DataFrame
40+
Predicted class probabilities for the input features.
41+
"""
42+
pass

ml_garden/core/steps/fit_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import optuna
7+
import pandas as pd
78
from sklearn.metrics import (
89
accuracy_score,
910
f1_score,
@@ -300,4 +301,14 @@ def predict(self, data: DataContainer) -> DataContainer:
300301
self.logger.info(f"Predicting with {self.model_class.__name__} model")
301302
data.flow[data.prediction_column] = data.model.predict(data.X_prediction)
302303
data.predictions = data.flow[data.prediction_column]
304+
305+
# If the task is classification, also get the prediction probabilities
306+
if data.task == Task.CLASSIFICATION:
307+
proba_df = data.model.predict_proba(data.X_prediction)
308+
proba_df.columns = [f"proba_{col}" for col in proba_df.columns]
309+
310+
# Concatenate the probabilities DataFrame with the existing DataFrame
311+
data.flow = pd.concat([data.flow, proba_df], axis=1)
312+
data.predict_proba = proba_df
313+
303314
return data

ml_garden/implementation/tabular/autogluon/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,10 @@ def fit(self, X: pd.DataFrame, y: pd.Series, eval_set=None, verbose=True) -> Non
3333

3434
def predict(self, X: pd.DataFrame) -> pd.Series:
3535
predictions = self.model.predict(X)
36-
return predictions
36+
return pd.Series(predictions, index=X.index)
37+
38+
def predict_proba(self, X: pd.DataFrame) -> pd.DataFrame:
39+
if self.model.problem_type == "regression":
40+
raise ValueError("predict_proba is not available for regression tasks.")
41+
probabilities = self.model.predict_proba(X)
42+
return pd.DataFrame(probabilities, index=X.index)

ml_garden/implementation/tabular/xgboost/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import List
23

34
import pandas as pd
45
import xgboost as xgb
@@ -75,4 +76,5 @@ def predict_proba(self, X: pd.DataFrame) -> pd.DataFrame:
7576
pd.DataFrame
7677
Predicted class probabilities for the input features.
7778
"""
78-
return self.model.predict_proba(X)
79+
proba = self.model.predict_proba(X)
80+
return pd.DataFrame(proba, columns=self.model.classes_)

0 commit comments

Comments
 (0)