From cd9e498a7d5a619b4c0e1b87c68ec68d409a9117 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Wed, 29 Oct 2025 17:05:55 +0100 Subject: [PATCH 01/11] ENH: BCC multi-risk example --- .../plot_multi_risk_control_binary_classification.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py diff --git a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py new file mode 100644 index 000000000..e69de29bb From 5317d5fa160ec52603b72dbab1bf03f6e4b00860 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Thu, 30 Oct 2025 10:22:51 +0100 Subject: [PATCH 02/11] init --- ...ulti_risk_control_binary_classification.py | 187 ++++++++++++++++++ 1 file changed, 187 insertions(+) diff --git a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py index e69de29bb..67a25f427 100644 --- a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py +++ b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py @@ -0,0 +1,187 @@ +""" +======================================================================== +Use MAPIE to control multiple performance metrics of a binary classifier +======================================================================== + +In this example, we explain how to do multi-risk control for binary classification with MAPIE. + +""" + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.datasets import make_circles +from sklearn.metrics import precision_score, recall_score +from sklearn.neural_network import MLPClassifier + +from mapie.risk_control import BinaryClassificationController +from mapie.utils import train_conformalize_test_split + +RANDOM_STATE = 1 + +############################################################################## +# Fist, load the dataset and then split it into training, calibration +# (for conformalization), and test sets. + +X, y = make_circles(n_samples=5000, noise=0.3, factor=0.3, random_state=RANDOM_STATE) +(X_train, X_calib, X_test, y_train, y_calib, y_test) = train_conformalize_test_split( + X, + y, + train_size=0.8, + conformalize_size=0.1, + test_size=0.1, + random_state=RANDOM_STATE, +) + +# Plot the three datasets to visualize the distribution of the two classes. +fig, axes = plt.subplots(1, 3, figsize=(18, 6)) +titles = ["Training Data", "Calibration Data", "Test Data"] +datasets = [(X_train, y_train), (X_calib, y_calib), (X_test, y_test)] + +for i, (ax, (X_data, y_data), title) in enumerate(zip(axes, datasets, titles)): + ax.scatter( + X_data[y_data == 0, 0], + X_data[y_data == 0, 1], + edgecolors="k", + c="tab:blue", + label='"negative" class', + alpha=0.5, + ) + ax.scatter( + X_data[y_data == 1, 0], + X_data[y_data == 1, 1], + edgecolors="k", + c="tab:red", + label='"positive" class', + alpha=0.5, + ) + ax.set_title(title, fontsize=18) + ax.set_xlabel("Feature 1", fontsize=16) + ax.tick_params(labelsize=14) + + if i == 0: + ax.set_ylabel("Feature 2", fontsize=16) + else: + ax.set_ylabel("") + ax.set_yticks([]) + +handles, labels = axes[0].get_legend_handles_labels() +fig.legend( + handles, + labels, + loc="lower center", + bbox_to_anchor=(0.5, -0.01), + ncol=2, + fontsize=16, +) + +plt.suptitle("Visualization of Train, Calibration, and Test Sets", fontsize=22) +plt.tight_layout(rect=[0, 0.05, 1, 0.95]) +plt.show() + +############################################################################## +# Second, fit a Multi-layer Perceptron classifier on the training data. + +clf = MLPClassifier(max_iter=150, random_state=RANDOM_STATE) +clf.fit(X_train, y_train) + +############################################################################## +# Next, we initialize a :class:`~mapie.risk_control.BinaryClassificationController` +# using the probability estimation function from the fitted estimator: +# ``clf.predict_proba``, a list risk or performance metric (here, ["precision", "recall"]), +# a list target risk level, and a single confidence level. Then we use the calibration data +# to compute statistically guaranteed thresholds using a risk control method. +# +# Different risks or performance metrics have been implemented, such as precision, +# recall and accuracy, but you can also implement your own custom function using +# :class:`~mapie.risk_control.BinaryClassificationRisk` and choose your own +# secondary objective. + +target_precision = 0.7 +target_recall = 0.75 +confidence_level = 0.9 + +bcc = BinaryClassificationController( + predict_function=clf.predict_proba, + risk=["precision", "recall"], + target_level=[target_precision, target_recall], + confidence_level=confidence_level, + best_predict_param_choice="precision", +) +bcc.calibrate(X_calib, y_calib) + +print( + f"{len(bcc.valid_predict_params)} thresholds found that guarantee a precision of " + f"at least {target_precision} with a confidence of {confidence_level}.\n" + "Among those, the one that maximizes the precision (passed in `best_predict_param_choice`) is: " + f"{bcc.best_predict_param:.3f}." +) + + +############################################################################## +# In the plot below, we visualize how the threshold values impact precision and recall, +# and what thresholds have been computed as statistically guaranteed. + +proba_positive_class = clf.predict_proba(X_calib)[:, 1] + +tested_thresholds = bcc._predict_params +precisions = np.full(len(tested_thresholds), np.inf) +recalls = np.full(len(tested_thresholds), np.inf) +for i, threshold in enumerate(tested_thresholds): + y_pred = (proba_positive_class >= threshold).astype(int) + precisions[i] = precision_score(y_calib, y_pred) + recalls[i] = recall_score(y_calib, y_pred) + +valid_thresholds_indices = np.array( + [t in bcc.valid_predict_params for t in tested_thresholds] +) +best_threshold_index = np.where(tested_thresholds == bcc.best_predict_param)[0][0] + +plt.figure(figsize=(8, 6)) +plt.scatter( + tested_thresholds[valid_thresholds_indices], + precisions[valid_thresholds_indices], + c="tab:green", marker="o", label="Precision at Valid Thresholds" +) +plt.scatter( + tested_thresholds[valid_thresholds_indices], + recalls[valid_thresholds_indices], + marker="p", facecolors="none", edgecolors="tab:green", + label="Recall at Valid Thresholds" +) + +plt.scatter( + tested_thresholds[~valid_thresholds_indices], + precisions[~valid_thresholds_indices], + c="tab:red", marker="o", label="Precision at Invalid Thresholds" +) +plt.scatter( + tested_thresholds[~valid_thresholds_indices], + recalls[~valid_thresholds_indices], + marker="p", + facecolors="none", + edgecolors="tab:blue", + label="Recall at Invalid Thresholds", +) +plt.scatter( + tested_thresholds[best_threshold_index], + precisions[best_threshold_index], + c="tab:green", marker="*", edgecolors="k", s=300, label="Best threshold" +) +plt.axhline(target_precision, color="tab:gray", linestyle="--") +plt.text( + 0.8, + target_precision + 0.02, + "Target precision", + color="tab:gray", + fontstyle="italic", +) +plt.axhline(target_recall, color="magenta", linestyle=":") +plt.text( + 0.0, target_recall + 0.02, "Target recall", color="magenta", fontstyle="italic" +) +plt.xlabel("Threshold") +plt.ylabel("Metric value") +plt.title("Precision and Recall by Threshold") +plt.legend() +plt.tight_layout() +plt.show() From fe26fa0e62c7f8367840eacf6016428d4dd605b6 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Thu, 30 Oct 2025 19:26:31 +0100 Subject: [PATCH 03/11] bcc multi risk tuto --- ...ulti_risk_control_binary_classification.py | 362 ++++++++++++++---- 1 file changed, 291 insertions(+), 71 deletions(-) diff --git a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py index 67a25f427..31c55d836 100644 --- a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py +++ b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py @@ -1,16 +1,20 @@ """ -======================================================================== -Use MAPIE to control multiple performance metrics of a binary classifier -======================================================================== +========================================================== +Use MAPIE to control multiple risks of a binary classifier +========================================================== In this example, we explain how to do multi-risk control for binary classification with MAPIE. """ +import warnings + import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import make_circles +from sklearn.inspection import DecisionBoundaryDisplay from sklearn.metrics import precision_score, recall_score +from sklearn.model_selection import FixedThresholdClassifier from sklearn.neural_network import MLPClassifier from mapie.risk_control import BinaryClassificationController @@ -89,99 +93,315 @@ # using the probability estimation function from the fitted estimator: # ``clf.predict_proba``, a list risk or performance metric (here, ["precision", "recall"]), # a list target risk level, and a single confidence level. Then we use the calibration data -# to compute statistically guaranteed thresholds using a risk control method. +# to compute statistically guaranteed thresholds using a multi-risk control method. # # Different risks or performance metrics have been implemented, such as precision, -# recall and accuracy, but you can also implement your own custom function using +# recall and accuracy, but you can also implement your own custom functions using # :class:`~mapie.risk_control.BinaryClassificationRisk` and choose your own -# secondary objective. +# secondary objective (passed in ``best_predict_param_choice``) +# +# Note that if the secondary objective is not specified, the first risk in the list is used +# as the secondary objective by default. Here, we choose "recall" as the secondary objective. +# +# Here we consider the list of risks ["precision", "recall"] and choose "recall" as the secondary +# objective. Furthermore, we consider two scenarios according to different target levels +# for precision and recall. + +############################################################################## +# The following table summarizes the configuration of both scenarios: +# +# +-------------------------------+------------------------+------------------------+ +# | **Parameter** | **Scenario 1** | **Scenario 2** | +# +-------------------------------+------------------------+------------------------+ +# | **List of lisks** | ["precision", "recall"]| ["precision", "recall"]| +# +-------------------------------+------------------------+------------------------+ +# | **List of target levels** | [0.75, 0.70] | [0.85, 0.80] | +# +-------------------------------+------------------------+------------------------+ +# | **Confidence level** | 0.9 | 0.9 | +# +-------------------------------+------------------------+------------------------+ +# | **Best predict param choice** | "recall" | "recall" | +# +-------------------------------+------------------------+------------------------+ +# +# Both scenarios use the same list of risks and best parameter choice, +# but with different target levels for precision and recall. +# +# For each scenario, we first fit two single-risk controllers, followed by a multi-risk controller. +# The objective is to illustrate that, even when single-risk controllers find valid thresholds for both risks, +# the multi-risk controller may not find any threshold that satisfies both simultaneously +# with statistical guarantees. -target_precision = 0.7 -target_recall = 0.75 -confidence_level = 0.9 -bcc = BinaryClassificationController( +############################################################################## + +# Scenario 1: +target_levels_1 = [0.75, 0.70] +confidence_level_1 = 0.9 + +# Cas mono risk +bcc_precision_1 = BinaryClassificationController( + predict_function=clf.predict_proba, + risk="precision", + target_level=target_levels_1[0], + confidence_level=confidence_level_1, + best_predict_param_choice="recall", +) +bcc_precision_1.calibrate(X_calib, y_calib) + +bcc_recall_1 = BinaryClassificationController( + predict_function=clf.predict_proba, + risk="recall", + target_level=target_levels_1[1], + confidence_level=confidence_level_1, + best_predict_param_choice="recall", +) +bcc_recall_1.calibrate(X_calib, y_calib) + +# Cas multi risk +bcc_1 = BinaryClassificationController( predict_function=clf.predict_proba, risk=["precision", "recall"], - target_level=[target_precision, target_recall], - confidence_level=confidence_level, - best_predict_param_choice="precision", + target_level=target_levels_1, + confidence_level=confidence_level_1, + best_predict_param_choice="recall", ) -bcc.calibrate(X_calib, y_calib) +bcc_1.calibrate(X_calib, y_calib) print( - f"{len(bcc.valid_predict_params)} thresholds found that guarantee a precision of " - f"at least {target_precision} with a confidence of {confidence_level}.\n" - "Among those, the one that maximizes the precision (passed in `best_predict_param_choice`) is: " - f"{bcc.best_predict_param:.3f}." + f"Scenario 1 - Multiple risks : {len(bcc_1.valid_predict_params)} " + "thresholds found that guarantee a precision of " + f"at least {target_levels_1[0]} and a recall of at least {target_levels_1[1]} " + f"with a confidence of {confidence_level_1}.\n" + "Among those, the one that maximizes the secondary objective " + "(here, recall, passed in `best_predict_param_choice`) is: " + f"{bcc_1.best_predict_param:.3f}.\n" ) - ############################################################################## -# In the plot below, we visualize how the threshold values impact precision and recall, -# and what thresholds have been computed as statistically guaranteed. - -proba_positive_class = clf.predict_proba(X_calib)[:, 1] -tested_thresholds = bcc._predict_params -precisions = np.full(len(tested_thresholds), np.inf) -recalls = np.full(len(tested_thresholds), np.inf) -for i, threshold in enumerate(tested_thresholds): - y_pred = (proba_positive_class >= threshold).astype(int) - precisions[i] = precision_score(y_calib, y_pred) - recalls[i] = recall_score(y_calib, y_pred) +# Scenario 2: +target_levels_2 = [0.85, 0.8] +confidence_level_2 = 0.9 -valid_thresholds_indices = np.array( - [t in bcc.valid_predict_params for t in tested_thresholds] +# Cas mono risk +bcc_precision_2 = BinaryClassificationController( + predict_function=clf.predict_proba, + risk="precision", + target_level=target_levels_2[0], + confidence_level=confidence_level_2, + best_predict_param_choice="recall", ) -best_threshold_index = np.where(tested_thresholds == bcc.best_predict_param)[0][0] +bcc_precision_2.calibrate(X_calib, y_calib) -plt.figure(figsize=(8, 6)) -plt.scatter( - tested_thresholds[valid_thresholds_indices], - precisions[valid_thresholds_indices], - c="tab:green", marker="o", label="Precision at Valid Thresholds" +bcc_recall_2 = BinaryClassificationController( + predict_function=clf.predict_proba, + risk="recall", + target_level=target_levels_2[1], + confidence_level=confidence_level_2, + best_predict_param_choice="recall", ) -plt.scatter( - tested_thresholds[valid_thresholds_indices], - recalls[valid_thresholds_indices], - marker="p", facecolors="none", edgecolors="tab:green", - label="Recall at Valid Thresholds" +bcc_recall_2.calibrate(X_calib, y_calib) + +# Cas multi risk +bcc_2 = BinaryClassificationController( + predict_function=clf.predict_proba, + risk=["precision", "recall"], + target_level=target_levels_2, + confidence_level=confidence_level_2, + best_predict_param_choice="recall", ) +bcc_2.calibrate(X_calib, y_calib) +with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Capture all warnings + bcc_2.calibrate(X_calib, y_calib) -plt.scatter( - tested_thresholds[~valid_thresholds_indices], - precisions[~valid_thresholds_indices], - c="tab:red", marker="o", label="Precision at Invalid Thresholds" + if w: # If any warnings were raised + print(f"Scenario 2 - Multiple risks : {w[0].message}") + +############################################################################## +# In the plot below, we visualize how the threshold values impact precision +# and recall, and what thresholds have been computed as statistically guaranteed. + +proba_positive_class = clf.predict_proba(X_calib)[:, 1] +scenarios = [ + { + "name": "Scenario 1 - Mono Risk", + "bcc": bcc_precision_1, + "target_levels": [target_levels_1[0], target_levels_1[1]], + }, + {"name": "Scenario 1 - Multi Risk", "bcc": bcc_1, "target_levels": target_levels_1}, + { + "name": "Scenario 2 - Mono Risk", + "bcc": bcc_precision_2, + "target_levels": [target_levels_2[0], target_levels_2[1]], + }, + {"name": "Scenario 2 - Multi Risk", "bcc": bcc_2, "target_levels": target_levels_2}, +] + +fig, axes = plt.subplots(2, 2, figsize=(16, 12), sharey=True) +axes = axes.flatten() + +for ax, scenario in zip(axes, scenarios): + bcc = scenario["bcc"] + target_precision, target_recall = scenario["target_levels"] + + tested_thresholds = bcc._predict_params + precisions = np.array( + [ + precision_score(y_calib, (proba_positive_class >= t).astype(int)) + for t in tested_thresholds + ] + ) + recalls = np.array( + [ + recall_score(y_calib, (proba_positive_class >= t).astype(int)) + for t in tested_thresholds + ] + ) + + if bcc.valid_predict_params is not None and len(bcc.valid_predict_params) > 0: + valid_indices = np.array( + [t in bcc.valid_predict_params for t in tested_thresholds] + ) + ax.scatter( + tested_thresholds[valid_indices], + precisions[valid_indices], + color="tab:green", + marker="o", + label="Precision at valid thresholds", + ) + ax.scatter( + tested_thresholds[valid_indices], + recalls[valid_indices], + marker="p", + facecolors="none", + edgecolors="tab:green", + label="Recall at valid thresholds", + ) + else: + valid_indices = np.array([False] * len(tested_thresholds)) + + invalid_indices = ~valid_indices + ax.scatter( + tested_thresholds[invalid_indices], + precisions[invalid_indices], + color="tab:red", + marker="o", + label="Precision at invalid thresholds", + ) + ax.scatter( + tested_thresholds[invalid_indices], + recalls[invalid_indices], + marker="p", + facecolors="none", + edgecolors="tab:orange", + label="Recall at invalid thresholds", + ) + + if bcc.best_predict_param in tested_thresholds: + best_index = np.where(tested_thresholds == bcc.best_predict_param)[0][0] + ax.scatter( + tested_thresholds[best_index], + precisions[best_index], + color="tab:green", + marker="*", + edgecolors="k", + s=200, + label="Best threshold", + ) + + ax.axhline(target_precision, color="tab:gray", linestyle="--") + ax.text( + 0.8, + target_precision + 0.02, + "Target precision", + color="tab:gray", + fontstyle="italic", + fontsize=14, + ) + ax.axhline(target_recall, color="tab:blue", linestyle=":") + ax.text( + 0.0, + target_recall + 0.02, + "Target recall", + color="tab:blue", + fontstyle="italic", + fontsize=14, + ) + + ax.set_xlabel("Threshold", fontsize=14) + ax.set_ylabel("Performance metric Value", fontsize=14) + ax.set_title(scenario["name"], fontsize=16) + ax.legend(fontsize=16) + +plt.xticks(fontsize=14) +plt.yticks(fontsize=14) +plt.suptitle("Precision and recall by threshold for all scenarios", fontsize=18) +plt.tight_layout(rect=[0, 0, 1, 0.95]) +plt.show() + + +############################################################################## +# Contrary to the naive way of computing a threshold to satisfy a precision and +# a recall targets on calibration data, risk control provides statistical guarantees +# on unseen data. In the plot above, we can see that not all thresholds corresponding +# to a precision (resp. recall) higher (resp. lower) than the target are valid. +# This is due to the uncertainty inherent to the finite size of the calibration set, +# which risk control takes into account. +# +# In particular, for instance, for precision, the highest threshold values are considered +# invalid due to the small number of observations used to compute the precision, +# following the Learn Then Test procedure. In the most extreme case, no observation +# is available, which causes the precision value to be ill-defined and set to 0. +# +# In scenario 1, both the mono-risk controllers and the multi-risk controller found +# valid thresholds that satisfy the precision and recall targets individually and jointly. +# The jointly valid thresholds found by the multi-risk controller are shown as green markers in the plot. +# In scenario 2, although valid thresholds are found individually for precision and recall +# by the mono-risk controllers, the multi-risk controller cannot find any threshold +# that satisfies both targets simultaneously. + +# For Scenario 1 - Multi-risk only: +# Besides computing a set of valid thresholds, +# :class:`~mapie.risk_control.BinaryClassificationController` also outputs the "best" +# one, which is the valid threshold that maximizes a secondary objective +# (recall here). +# +# After obtaining the best threshold, we can use the ``predict`` function of +# :class:`~mapie.risk_control.BinaryClassificationController` for future predictions, +# or use scikit-learn's ``FixedThresholdClassifier`` as a wrapper to benefit +# from functionalities like easily plotting the decision boundary as seen below. + +y_pred = bcc_1.predict(X_test) + +clf_threshold = FixedThresholdClassifier(clf, threshold=bcc_1.best_predict_param) +clf_threshold.fit(X_train, y_train) +# .fit necessary for plotting, alternatively you can use sklearn.frozen.FrozenEstimator + +disp = DecisionBoundaryDisplay.from_estimator( + clf_threshold, X_test, response_method="predict", cmap=plt.cm.coolwarm ) + plt.scatter( - tested_thresholds[~valid_thresholds_indices], - recalls[~valid_thresholds_indices], - marker="p", - facecolors="none", - edgecolors="tab:blue", - label="Recall at Invalid Thresholds", + X_test[y_test == 0, 0], + X_test[y_test == 0, 1], + edgecolors="k", + c="tab:blue", + alpha=0.5, + label='"negative" class', ) plt.scatter( - tested_thresholds[best_threshold_index], - precisions[best_threshold_index], - c="tab:green", marker="*", edgecolors="k", s=300, label="Best threshold" -) -plt.axhline(target_precision, color="tab:gray", linestyle="--") -plt.text( - 0.8, - target_precision + 0.02, - "Target precision", - color="tab:gray", - fontstyle="italic", + X_test[y_test == 1, 0], + X_test[y_test == 1, 1], + edgecolors="k", + c="tab:red", + alpha=0.5, + label='"positive" class', ) -plt.axhline(target_recall, color="magenta", linestyle=":") -plt.text( - 0.0, target_recall + 0.02, "Target recall", color="magenta", fontstyle="italic" +plt.title( + "Decision Boundary of FixedThresholdClassifier for the Scenario 1 - Multi Risk", + fontsize=10, ) -plt.xlabel("Threshold") -plt.ylabel("Metric value") -plt.title("Precision and Recall by Threshold") +plt.xlabel("Feature 1") +plt.ylabel("Feature 2") plt.legend() -plt.tight_layout() plt.show() From 3a89c3bfac3195f57a77a302be77f699e4ef9f41 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Thu, 30 Oct 2025 20:51:33 +0100 Subject: [PATCH 04/11] correct figure --- ...ulti_risk_control_binary_classification.py | 137 +++++++++++------- 1 file changed, 82 insertions(+), 55 deletions(-) diff --git a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py index 31c55d836..9dd11987f 100644 --- a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py +++ b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py @@ -125,10 +125,13 @@ # Both scenarios use the same list of risks and best parameter choice, # but with different target levels for precision and recall. # -# For each scenario, we first fit two single-risk controllers, followed by a multi-risk controller. -# The objective is to illustrate that, even when single-risk controllers find valid thresholds for both risks, +# For each scenario, we first fit two mono-risk controllers, followed by a multi-risk controller. +# The objective is to illustrate that, even when mono-risk controllers find valid thresholds for both risks, # the multi-risk controller may not find any threshold that satisfies both simultaneously # with statistical guarantees. +# +# Note that in the mono-risk case, the best predict parameter is left as "auto". +# See :class:`~mapie.risk_control.BinaryClassificationController` documentation for more details. ############################################################################## @@ -143,7 +146,7 @@ risk="precision", target_level=target_levels_1[0], confidence_level=confidence_level_1, - best_predict_param_choice="recall", + best_predict_param_choice="auto", ) bcc_precision_1.calibrate(X_calib, y_calib) @@ -152,7 +155,7 @@ risk="recall", target_level=target_levels_1[1], confidence_level=confidence_level_1, - best_predict_param_choice="recall", + best_predict_param_choice="auto", ) bcc_recall_1.calibrate(X_calib, y_calib) @@ -188,7 +191,7 @@ risk="precision", target_level=target_levels_2[0], confidence_level=confidence_level_2, - best_predict_param_choice="recall", + best_predict_param_choice="auto", ) bcc_precision_2.calibrate(X_calib, y_calib) @@ -197,7 +200,7 @@ risk="recall", target_level=target_levels_2[1], confidence_level=confidence_level_2, - best_predict_param_choice="recall", + best_predict_param_choice="auto", ) bcc_recall_2.calibrate(X_calib, y_calib) @@ -225,13 +228,13 @@ scenarios = [ { "name": "Scenario 1 - Mono Risk", - "bcc": bcc_precision_1, + "bcc": [bcc_precision_1, bcc_recall_1], "target_levels": [target_levels_1[0], target_levels_1[1]], }, {"name": "Scenario 1 - Multi Risk", "bcc": bcc_1, "target_levels": target_levels_1}, { "name": "Scenario 2 - Mono Risk", - "bcc": bcc_precision_2, + "bcc": [bcc_precision_2, bcc_recall_2], "target_levels": [target_levels_2[0], target_levels_2[1]], }, {"name": "Scenario 2 - Multi Risk", "bcc": bcc_2, "target_levels": target_levels_2}, @@ -241,72 +244,96 @@ axes = axes.flatten() for ax, scenario in zip(axes, scenarios): - bcc = scenario["bcc"] - target_precision, target_recall = scenario["target_levels"] - - tested_thresholds = bcc._predict_params - precisions = np.array( - [ - precision_score(y_calib, (proba_positive_class >= t).astype(int)) - for t in tested_thresholds - ] - ) - recalls = np.array( - [ - recall_score(y_calib, (proba_positive_class >= t).astype(int)) - for t in tested_thresholds - ] - ) - - if bcc.valid_predict_params is not None and len(bcc.valid_predict_params) > 0: - valid_indices = np.array( - [t in bcc.valid_predict_params for t in tested_thresholds] - ) - ax.scatter( - tested_thresholds[valid_indices], - precisions[valid_indices], - color="tab:green", - marker="o", - label="Precision at valid thresholds", - ) - ax.scatter( - tested_thresholds[valid_indices], - recalls[valid_indices], - marker="p", - facecolors="none", - edgecolors="tab:green", - label="Recall at valid thresholds", - ) + if isinstance(scenario["bcc"], list): + bcc_precision, bcc_recall = scenario["bcc"] + target_precision, target_recall = scenario["target_levels"] + tested_thresholds = bcc_precision._predict_params + bccs = {"precision": bcc_precision, "recall": bcc_recall} else: - valid_indices = np.array([False] * len(tested_thresholds)) + bcc = scenario["bcc"] + target_precision, target_recall = scenario["target_levels"] + tested_thresholds = bcc._predict_params + bccs = {"precision": bcc, "recall": bcc} + + metrics = { + "precision": np.array( + [ + precision_score(y_calib, (proba_positive_class >= t).astype(int)) + for t in tested_thresholds + ] + ), + "recall": np.array( + [ + recall_score(y_calib, (proba_positive_class >= t).astype(int)) + for t in tested_thresholds + ] + ), + } + + valid_indices = {} + best_indices = {} + for key, controller in bccs.items(): + valid = controller.valid_predict_params + if valid is None: + valid = [] + valid = np.array(valid).tolist() + valid_indices[key] = np.array([t in valid for t in tested_thresholds]) + best_indices[key] = ( + np.where(tested_thresholds == controller.best_predict_param)[0][0] + if controller.best_predict_param in tested_thresholds + else None + ) - invalid_indices = ~valid_indices ax.scatter( - tested_thresholds[invalid_indices], - precisions[invalid_indices], + tested_thresholds[valid_indices["precision"]], + metrics["precision"][valid_indices["precision"]], + color="tab:green", + marker="o", + label="Precision at valid thresholds", + ) + ax.scatter( + tested_thresholds[valid_indices["recall"]], + metrics["recall"][valid_indices["recall"]], + marker="p", + facecolors="none", + edgecolors="tab:green", + label="Recall at valid thresholds", + ) + ax.scatter( + tested_thresholds[~valid_indices["precision"]], + metrics["precision"][~valid_indices["precision"]], color="tab:red", marker="o", label="Precision at invalid thresholds", ) ax.scatter( - tested_thresholds[invalid_indices], - recalls[invalid_indices], + tested_thresholds[~valid_indices["recall"]], + metrics["recall"][~valid_indices["recall"]], marker="p", facecolors="none", edgecolors="tab:orange", label="Recall at invalid thresholds", ) - if bcc.best_predict_param in tested_thresholds: - best_index = np.where(tested_thresholds == bcc.best_predict_param)[0][0] + if best_indices["precision"] is not None: ax.scatter( - tested_thresholds[best_index], - precisions[best_index], + tested_thresholds[best_indices["precision"]], + metrics["precision"][best_indices["precision"]], color="tab:green", marker="*", edgecolors="k", s=200, - label="Best threshold", + label="Precision best threshold", + ) + if best_indices["recall"] is not None: + ax.scatter( + tested_thresholds[best_indices["recall"]], + metrics["recall"][best_indices["recall"]], + color="tab:blue", + marker="*", + edgecolors="k", + s=200, + label="Recall best threshold", ) ax.axhline(target_precision, color="tab:gray", linestyle="--") From 6332b38dc3b9ba16ed072a5361b158fe757f49a0 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Fri, 31 Oct 2025 09:59:10 +0100 Subject: [PATCH 05/11] correct typo --- .../plot_multi_risk_control_binary_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py index 9dd11987f..a43a86ba3 100644 --- a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py +++ b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py @@ -113,7 +113,7 @@ # +-------------------------------+------------------------+------------------------+ # | **Parameter** | **Scenario 1** | **Scenario 2** | # +-------------------------------+------------------------+------------------------+ -# | **List of lisks** | ["precision", "recall"]| ["precision", "recall"]| +# | **List of risks** | ["precision", "recall"]| ["precision", "recall"]| # +-------------------------------+------------------------+------------------------+ # | **List of target levels** | [0.75, 0.70] | [0.85, 0.80] | # +-------------------------------+------------------------+------------------------+ From a7b4e4e3b633d2e93140916f6c8b97bd02d56ade Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Mon, 3 Nov 2025 11:50:34 +0100 Subject: [PATCH 06/11] correct quickstart typo --- .../plot_risk_control_binary_classification.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py b/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py index 7f733c2dc..fe75f20bd 100644 --- a/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py +++ b/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py @@ -7,13 +7,13 @@ """ -import numpy as np import matplotlib.pyplot as plt +import numpy as np from sklearn.datasets import make_circles -from sklearn.neural_network import MLPClassifier -from sklearn.model_selection import FixedThresholdClassifier -from sklearn.metrics import precision_score from sklearn.inspection import DecisionBoundaryDisplay +from sklearn.metrics import precision_score +from sklearn.model_selection import FixedThresholdClassifier +from sklearn.neural_network import MLPClassifier from mapie.risk_control import BinaryClassificationController from mapie.utils import train_conformalize_test_split @@ -21,7 +21,7 @@ RANDOM_STATE = 1 ############################################################################## -# Fist, load the dataset and then split it into training, calibration +# First, load the dataset and then split it into training, calibration # (for conformalization), and test sets. X, y = make_circles(n_samples=5000, noise=0.3, factor=0.3, random_state=RANDOM_STATE) @@ -172,7 +172,7 @@ # Contrary to the naive way of computing a threshold to satisfy a precision target on # calibration data, risk control provides statistical guarantees on unseen data. # In the plot above, we can see that not all thresholds corresponding to a precision -# higher that the target are valid. This is due to the uncertainty inherent to the +# higher then the target are valid. This is due to the uncertainty inherent to the # finite size of the calibration set, which risk control takes into account. # # In particular, the highest threshold values are considered invalid due to the From 9eab5d8a11c2d8bd4644298344c3fe42e1b5a092 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Mon, 3 Nov 2025 12:11:31 +0100 Subject: [PATCH 07/11] correct typo --- .../1-quickstart/plot_risk_control_binary_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py b/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py index fe75f20bd..b9f31baad 100644 --- a/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py +++ b/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py @@ -172,7 +172,7 @@ # Contrary to the naive way of computing a threshold to satisfy a precision target on # calibration data, risk control provides statistical guarantees on unseen data. # In the plot above, we can see that not all thresholds corresponding to a precision -# higher then the target are valid. This is due to the uncertainty inherent to the +# higher than the target are valid. This is due to the uncertainty inherent to the # finite size of the calibration set, which risk control takes into account. # # In particular, the highest threshold values are considered invalid due to the From c7bd3b020365e52eca4de2115e6edb372c94e687 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Mon, 3 Nov 2025 13:54:33 +0100 Subject: [PATCH 08/11] correct typos and improve graphs --- ...ulti_risk_control_binary_classification.py | 103 +++++++++++------- 1 file changed, 61 insertions(+), 42 deletions(-) diff --git a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py index a43a86ba3..389a07736 100644 --- a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py +++ b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py @@ -6,7 +6,7 @@ In this example, we explain how to do multi-risk control for binary classification with MAPIE. """ - +# %% import warnings import matplotlib.pyplot as plt @@ -23,7 +23,7 @@ RANDOM_STATE = 1 ############################################################################## -# Fist, load the dataset and then split it into training, calibration +# First, load the dataset and then split it into training, calibration # (for conformalization), and test sets. X, y = make_circles(n_samples=5000, noise=0.3, factor=0.3, random_state=RANDOM_STATE) @@ -91,7 +91,7 @@ ############################################################################## # Next, we initialize a :class:`~mapie.risk_control.BinaryClassificationController` # using the probability estimation function from the fitted estimator: -# ``clf.predict_proba``, a list risk or performance metric (here, ["precision", "recall"]), +# ``clf.predict_proba``, a list of risk or performance metric (here, ["precision", "recall"]), # a list target risk level, and a single confidence level. Then we use the calibration data # to compute statistically guaranteed thresholds using a multi-risk control method. # @@ -140,7 +140,7 @@ target_levels_1 = [0.75, 0.70] confidence_level_1 = 0.9 -# Cas mono risk +# Mono risk case bcc_precision_1 = BinaryClassificationController( predict_function=clf.predict_proba, risk="precision", @@ -159,7 +159,7 @@ ) bcc_recall_1.calibrate(X_calib, y_calib) -# Cas multi risk +# Multi risk case bcc_1 = BinaryClassificationController( predict_function=clf.predict_proba, risk=["precision", "recall"], @@ -171,11 +171,10 @@ print( f"Scenario 1 - Multiple risks : {len(bcc_1.valid_predict_params)} " - "thresholds found that guarantee a precision of " - f"at least {target_levels_1[0]} and a recall of at least {target_levels_1[1]} " - f"with a confidence of {confidence_level_1}.\n" - "Among those, the one that maximizes the secondary objective " - "(here, recall, passed in `best_predict_param_choice`) is: " + f"thresholds found that guarantee a precision of at least {target_levels_1[0]}\n" + f"and a recall of at least {target_levels_1[1]} with a confidence of {confidence_level_1}." + "Among those, the one that maximizes\n" + "the secondary objective (here, recall, passed in `best_predict_param_choice`) is: " f"{bcc_1.best_predict_param:.3f}.\n" ) @@ -212,7 +211,6 @@ confidence_level=confidence_level_2, best_predict_param_choice="recall", ) -bcc_2.calibrate(X_calib, y_calib) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Capture all warnings bcc_2.calibrate(X_calib, y_calib) @@ -227,17 +225,17 @@ proba_positive_class = clf.predict_proba(X_calib)[:, 1] scenarios = [ { - "name": "Scenario 1 - Mono Risk", + "name": "Scenario 1 - Mono risk", "bcc": [bcc_precision_1, bcc_recall_1], "target_levels": [target_levels_1[0], target_levels_1[1]], }, - {"name": "Scenario 1 - Multi Risk", "bcc": bcc_1, "target_levels": target_levels_1}, + {"name": "Scenario 1 - Multi risk", "bcc": bcc_1, "target_levels": target_levels_1}, { - "name": "Scenario 2 - Mono Risk", + "name": "Scenario 2 - Mono risk", "bcc": [bcc_precision_2, bcc_recall_2], "target_levels": [target_levels_2[0], target_levels_2[1]], }, - {"name": "Scenario 2 - Multi Risk", "bcc": bcc_2, "target_levels": target_levels_2}, + {"name": "Scenario 2 - Multi risk", "bcc": bcc_2, "target_levels": target_levels_2}, ] fig, axes = plt.subplots(2, 2, figsize=(16, 12), sharey=True) @@ -311,30 +309,51 @@ metrics["recall"][~valid_indices["recall"]], marker="p", facecolors="none", - edgecolors="tab:orange", + edgecolors="tab:red", label="Recall at invalid thresholds", ) if best_indices["precision"] is not None: - ax.scatter( - tested_thresholds[best_indices["precision"]], - metrics["precision"][best_indices["precision"]], - color="tab:green", - marker="*", - edgecolors="k", - s=200, - label="Precision best threshold", - ) + if scenario["name"] == "Scenario 1 - Multi risk": + ax.scatter( + tested_thresholds[best_indices["precision"]], + metrics["precision"][best_indices["precision"]], + color="tab:green", + marker="*", + edgecolors="k", + s=200, + label="Multi risk best threshold", + ) + else: + ax.scatter( + tested_thresholds[best_indices["precision"]], + metrics["precision"][best_indices["precision"]], + color="tab:green", + marker="*", + edgecolors="k", + s=200, + label="Precision best threshold", + ) if best_indices["recall"] is not None: - ax.scatter( - tested_thresholds[best_indices["recall"]], - metrics["recall"][best_indices["recall"]], - color="tab:blue", - marker="*", - edgecolors="k", - s=200, - label="Recall best threshold", - ) + if scenario["name"] == "Scenario 1 - Multi risk": + ax.scatter( + tested_thresholds[best_indices["recall"]], + metrics["recall"][best_indices["recall"]], + color="tab:green", + marker="*", + edgecolors="k", + s=200, + ) + else: + ax.scatter( + tested_thresholds[best_indices["recall"]], + metrics["recall"][best_indices["recall"]], + color="tab:blue", + marker="*", + edgecolors="k", + s=200, + label="Recall best threshold", + ) ax.axhline(target_precision, color="tab:gray", linestyle="--") ax.text( @@ -347,21 +366,19 @@ ) ax.axhline(target_recall, color="tab:blue", linestyle=":") ax.text( - 0.0, - target_recall + 0.02, + 0.4, + target_recall - 0.045, "Target recall", color="tab:blue", fontstyle="italic", fontsize=14, ) - - ax.set_xlabel("Threshold", fontsize=14) - ax.set_ylabel("Performance metric Value", fontsize=14) + ax.tick_params(axis="x", labelsize=16) + ax.tick_params(axis="y", labelsize=16) + ax.set_xlabel("Threshold", fontsize=16) + ax.set_ylabel("Performance metric value", fontsize=16) ax.set_title(scenario["name"], fontsize=16) ax.legend(fontsize=16) - -plt.xticks(fontsize=14) -plt.yticks(fontsize=14) plt.suptitle("Precision and recall by threshold for all scenarios", fontsize=18) plt.tight_layout(rect=[0, 0, 1, 0.95]) plt.show() @@ -432,3 +449,5 @@ plt.ylabel("Feature 2") plt.legend() plt.show() + +# %% From 0b13be019d73ace980f2e6d70b437de151479c26 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Mon, 3 Nov 2025 14:00:31 +0100 Subject: [PATCH 09/11] update --- ...lot_multi_risk_control_binary_classification.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py index 389a07736..83d570d34 100644 --- a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py +++ b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py @@ -379,24 +379,12 @@ ax.set_ylabel("Performance metric value", fontsize=16) ax.set_title(scenario["name"], fontsize=16) ax.legend(fontsize=16) -plt.suptitle("Precision and recall by threshold for all scenarios", fontsize=18) +plt.suptitle("Precision and recall by threshold for all scenarios", fontsize=22) plt.tight_layout(rect=[0, 0, 1, 0.95]) plt.show() ############################################################################## -# Contrary to the naive way of computing a threshold to satisfy a precision and -# a recall targets on calibration data, risk control provides statistical guarantees -# on unseen data. In the plot above, we can see that not all thresholds corresponding -# to a precision (resp. recall) higher (resp. lower) than the target are valid. -# This is due to the uncertainty inherent to the finite size of the calibration set, -# which risk control takes into account. -# -# In particular, for instance, for precision, the highest threshold values are considered -# invalid due to the small number of observations used to compute the precision, -# following the Learn Then Test procedure. In the most extreme case, no observation -# is available, which causes the precision value to be ill-defined and set to 0. -# # In scenario 1, both the mono-risk controllers and the multi-risk controller found # valid thresholds that satisfy the precision and recall targets individually and jointly. # The jointly valid thresholds found by the multi-risk controller are shown as green markers in the plot. From 38ec278ef64b6c311df637888bfec937b1074fc2 Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Mon, 3 Nov 2025 14:12:23 +0100 Subject: [PATCH 10/11] update bcc docstring --- mapie/risk_control/binary_classification.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mapie/risk_control/binary_classification.py b/mapie/risk_control/binary_classification.py index fe80e72a0..fc73defd7 100644 --- a/mapie/risk_control/binary_classification.py +++ b/mapie/risk_control/binary_classification.py @@ -78,7 +78,8 @@ class BinaryClassificationController: (or maximize) a secondary objective. Valid options: - - "auto" (default) + - "auto" (default). For mono risk defined in mapie.risk_control, an automatic choice is made. + For multi risk, we use the first risk in the list. - An existing risk defined in `mapie.risk_control` accessible through its string equivalent: "precision", "recall", "accuracy", or "fpr" for false positive rate. From 38213f46d869135d4e452d2ecc24e659b848dcac Mon Sep 17 00:00:00 2001 From: Hassan Maissoro Date: Mon, 3 Nov 2025 14:17:05 +0100 Subject: [PATCH 11/11] remove cell run --- .../plot_multi_risk_control_binary_classification.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py index 83d570d34..c5e73ca79 100644 --- a/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py +++ b/examples/risk_control/2-advanced-analysis/plot_multi_risk_control_binary_classification.py @@ -6,7 +6,7 @@ In this example, we explain how to do multi-risk control for binary classification with MAPIE. """ -# %% + import warnings import matplotlib.pyplot as plt @@ -437,5 +437,3 @@ plt.ylabel("Feature 2") plt.legend() plt.show() - -# %%