diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 952381166..42bfd5512 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -73,5 +73,7 @@ jobs: run: make type-check - name: Test and coverage with pytest run: make coverage + - name: Test notebooks + run: make notebook-tests - name: Code coverage run: codecov diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 24cf43ce5..320d4e39d 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -81,6 +81,13 @@ The coverage should absolutely be 100%. $ make coverage +If your modification affects notebook-based experiments in ``mapie/tests/notebooks/``, +you should also ensure that all notebooks run successfully. + +.. code-block:: sh + + $ make notebook-tests + Documenting your change ----------------------- diff --git a/Makefile b/Makefile index 00b6121dd..9fc6999bd 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,6 @@ coverage: --no-cov-on-fail \ --doctest-modules - ### Checks that are run in ReadTheDocs CI ### doc: $(MAKE) html -C doc @@ -39,7 +38,7 @@ all-checks: $(MAKE) coverage tests: - pytest -vs --doctest-modules mapie + pytest -vs --doctest-modules mapie --ignore=mapie/tests/notebooks clean-doc: $(MAKE) clean -C doc @@ -55,3 +54,8 @@ clean: rm -rf **__pycache__ $(MAKE) clean-build $(MAKE) clean-doc + +# Run all notebooks located in mapie/tests/notebooks/ +notebook-tests: + @echo "Executing all notebooks in mapie/tests/notebooks/..." + python mapie/tests/notebooks/_run_notebooks.py diff --git a/environment.ci.yml b/environment.ci.yml index 15c13cec1..606590fe0 100644 --- a/environment.ci.yml +++ b/environment.ci.yml @@ -8,3 +8,5 @@ dependencies: - mypy<1.15 - pandas - pytest-cov + - jupyter + - nbclient diff --git a/mapie/tests/notebooks/__init__.py b/mapie/tests/notebooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mapie/tests/notebooks/_run_notebooks.py b/mapie/tests/notebooks/_run_notebooks.py new file mode 100644 index 000000000..19396c132 --- /dev/null +++ b/mapie/tests/notebooks/_run_notebooks.py @@ -0,0 +1,24 @@ +# pragma: no cover +import sys +from pathlib import Path + +from nbclient import NotebookClient +from nbformat import read + +NOTEBOOKS_DIR = Path(__file__).parent +notebooks = NOTEBOOKS_DIR.rglob("*.ipynb") + +if __name__ == "__main__": # pragma: no cover + for nb_path in notebooks: + print(f"Running {nb_path} ...") + try: + with nb_path.open() as f: + nb = read(f, as_version=4) + client = NotebookClient(nb, timeout=3600) + client.execute() + except Exception as e: + print(f"Notebook {nb_path} failed:\n{e}") + sys.exit(1) + + print("\nAll notebooks executed successfully. 100% OK.\n") + sys.exit(0) diff --git a/mapie/tests/notebooks/risk_control/theoretical_validity_tests.ipynb b/mapie/tests/notebooks/risk_control/theoretical_validity_tests.ipynb new file mode 100644 index 000000000..76aec016d --- /dev/null +++ b/mapie/tests/notebooks/risk_control/theoretical_validity_tests.ipynb @@ -0,0 +1,845 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ed592eb3f8989aa8", + "metadata": {}, + "source": [ + "# Binary classification risk control - Theoretical tests to validate implementation" + ] + }, + { + "cell_type": "markdown", + "id": "aa3340c3", + "metadata": {}, + "source": [ + "# 1. The case of a random classifier" + ] + }, + { + "cell_type": "markdown", + "id": "8c1746b673c148dd", + "metadata": {}, + "source": [ + "## 1.1. Protocol description\n", + "We test the theoretical guarantees of risk control in binary classification by using a random classifier and synthetic data. The aim is to evaluate the effectiveness of the BinaryClassificationController in maintaining a predefined risk level under different conditions.\n", + "\n", + "Each test case looks at a combination of parameters, for which we repeat the experiment `n_repeat` times. The model is the same for all experiments (basically a random classifier), but the data is different each time.\n", + "\n", + "Each experiment consists of the following:\n", + " - We calibrate a BinaryClassificationController. It gives us the list of lambda values that control the risk according to LTT.\n", + " - Because we know that the model is random, we know the theoretical risk associated with each lambda value. So we are able to check if the lambda values given by LTT actually control the risk. If not, we count 1 \"error\". Note that *each* lambda value should control the risk, not just one of them.\n", + "\n", + "After n_repeat experiments, we compute the proportion of errors, that should be less than delta (1 - confidence_level).\n", + "\n", + "## 1.2. Results\n", + "The risk is controlled in all the test cases. Overall, LTT seems very conservative (to achieve a high percentage of errors, we need to lower the confidence level significantly (0.01) and use only one threshold to avoid the Bonferroni effect). But this is likely due to the model being random, and thus having a lot of variance. It would be interesting to see how this evolves with a better model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9b1422ae620955fd", + "metadata": { + "ExecuteTime": { + "end_time": "2025-09-15T16:21:19.107147Z", + "start_time": "2025-09-15T16:21:19.071278Z" + } + }, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "faeb2f47a92dbf35", + "metadata": { + "ExecuteTime": { + "end_time": "2025-09-15T16:21:20.596927Z", + "start_time": "2025-09-15T16:21:19.127705Z" + } + }, + "outputs": [], + "source": [ + "from sklearn.datasets import make_classification\n", + "from sklearn.dummy import check_random_state\n", + "from sklearn.metrics import precision_score, recall_score, accuracy_score\n", + "import numpy as np\n", + "from mapie.risk_control import BinaryClassificationController, precision, accuracy, recall\n", + "from itertools import product\n", + "from decimal import Decimal" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "eefafd6d1697fb9c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-09-15T16:21:20.802766Z", + "start_time": "2025-09-15T16:21:20.652168Z" + } + }, + "outputs": [], + "source": [ + "# Using sklearn.dummy.DummyClassifier would be cleaner\n", + "class RandomClassifier:\n", + " def __init__(self, seed=42, threshold=0.5):\n", + " self.seed = seed\n", + " self.threshold = threshold\n", + "\n", + " def _get_prob(self, x):\n", + " local_seed = hash((x, self.seed)) % (2**32)\n", + " rng = np.random.RandomState(local_seed)\n", + " return np.round(rng.rand(), 2)\n", + "\n", + " def predict_proba(self, X):\n", + " probs = np.array([self._get_prob(x) for x in X])\n", + " return np.vstack([1 - probs, probs]).T\n", + "\n", + " def predict(self, X):\n", + " probs = self.predict_proba(X)[:, 1]\n", + " return (probs >= self.threshold).astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1fdffae392bb7a65", + "metadata": { + "ExecuteTime": { + "end_time": "2025-09-15T16:21:54.095452Z", + "start_time": "2025-09-15T16:21:20.810388Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "N=2000, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "N=2000, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 88\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 89\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 8\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 9\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "\n", + "\n", + "\n", + "All good!\n" + ] + } + ], + "source": [ + "N = 2000 # size of the calibration set\n", + "risk = [\n", + " {\"name\": \"precision\", \"risk\": precision},\n", + " {\"name\": \"recall\", \"risk\": recall},\n", + " {\"name\": \"accuracy\", \"risk\": accuracy},\n", + "]\n", + "predict_params = [np.linspace(0, 0.99, 100), np.empty(1)]\n", + "target_level = [0.1, 0.9]\n", + "confidence_level = [0.8, 0.2]\n", + "\n", + "n_repeats = 100\n", + "invalid_experiment = False\n", + "\n", + "for combination in product(risk, predict_params, target_level, confidence_level):\n", + " risk, predict_params, target_level, confidence_level = combination\n", + " if len(predict_params) == 1:\n", + " predict_params = np.array([np.random.choice(np.linspace(0, 0.9, 10))]) # random threshold\n", + " alpha = float(Decimal(\"1\") - Decimal(str(target_level))) # to avoid floating point issues\n", + " delta = float(Decimal(\"1\") - Decimal(str(confidence_level))) # to avoid floating point issues\n", + "\n", + " clf = RandomClassifier()\n", + " nb_errors = 0 # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)\n", + " total_nb_valid_params = 0\n", + "\n", + " for _ in range(n_repeats):\n", + "\n", + " X_calibrate, y_calibrate = make_classification(\n", + " n_samples=N,\n", + " n_features=1,\n", + " n_informative=1,\n", + " n_redundant=0,\n", + " n_repeated=0,\n", + " n_classes=2,\n", + " n_clusters_per_class=1,\n", + " weights=[0.5, 0.5],\n", + " flip_y=0,\n", + " random_state=None\n", + " )\n", + " X_calibrate = X_calibrate.squeeze()\n", + "\n", + " controller = BinaryClassificationController(\n", + " predict_function=clf.predict_proba,\n", + " risk=risk[\"risk\"],\n", + " target_level=target_level,\n", + " confidence_level=confidence_level,\n", + " )\n", + " controller._predict_params = predict_params\n", + " controller.calibrate(X_calibrate, y_calibrate)\n", + " valid_parameters = controller.valid_predict_params\n", + " total_nb_valid_params += len(valid_parameters)\n", + "\n", + " # In the following, we check that all the valid thresholds found by LTT actually control the risk.\n", + " # Instead of sampling a large test set, we use the fact that we know the theoretical risk of a random classifier.\n", + " # The calculations here are valid only for a balanced data generator.\n", + " if risk[\"risk\"] == precision or risk[\"risk\"] == accuracy:\n", + " if target_level > 0.5 and len(valid_parameters) >= 1:\n", + " nb_errors += 1\n", + " elif risk[\"risk\"] == recall:\n", + " if any(x > alpha for x in valid_parameters) and len(valid_parameters) >= 1:\n", + " nb_errors += 1\n", + "\n", + " print(f\"\\n{N=}, {risk['name']=}, {len(predict_params)=}, {target_level=}, {confidence_level=}\")\n", + "\n", + " print(f\"Proportion of times the risk is not controlled: {nb_errors/n_repeats}\")\n", + " print(f\"Delta: {delta}\")\n", + " print(f\"Mean number of valid thresholds found per iteration: {int(np.round(total_nb_valid_params/n_repeats))}\")\n", + "\n", + " if nb_errors/n_repeats <= delta:\n", + " print(\"Valid experiment\")\n", + " else:\n", + " print(\"Invalid experiment\")\n", + " invalid_experiment = True\n", + "\n", + "print(\"\\n\\n\\n\")\n", + "if invalid_experiment:\n", + " print(\"Some experiments failed.\")\n", + "else:\n", + " print(\"All good!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4c3f437f0b2897a1", + "metadata": { + "ExecuteTime": { + "end_time": "2025-09-15T16:22:49.189292Z", + "start_time": "2025-09-15T16:22:48.871875Z" + } + }, + "outputs": [], + "source": [ + "assert not invalid_experiment" + ] + }, + { + "cell_type": "markdown", + "id": "6745fa4e", + "metadata": {}, + "source": [ + "# 2. The case of a logistic classifier" + ] + }, + { + "cell_type": "markdown", + "id": "f099c00a", + "metadata": {}, + "source": [ + "## 2.1. Protocol description\n", + "We use the same protocol as described above, with the difference that we employ a logistic classifier and synthetic data generated from a logistic data generator.\n", + "\n", + "Another difference is that, with this model, we do not know the theoretical risk associated with each value of lambda.\n", + "Therefore, we generate a sample of test data to estimate the risk for each lambda value.\n", + "\n", + "## 2.2. Results\n", + "The risk is controlled in all test cases. Overall, LTT appears to be very conservative. It would be interesting to observe how this behavior evolves with a more accurate or complex model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b07c3a91", + "metadata": {}, + "outputs": [], + "source": [ + "# Define a simple logistic classifier\n", + "class LogisticClassifier:\n", + " \"\"\"Deterministic sigmoid-based binary classifier.\"\"\"\n", + "\n", + " def __init__(self, scale=2.0, threshold=0.5):\n", + " self.scale = scale\n", + " self.threshold = threshold\n", + "\n", + " def _get_prob(self, x):\n", + " \"\"\"Probability of class 1 for input x.\"\"\"\n", + " return 1 / (1 + np.exp(-self.scale * x))\n", + "\n", + " def predict_proba(self, X):\n", + " \"\"\"Return probabilities [p(y=0), p(y=1)] for each sample in X.\"\"\"\n", + " probs = np.array([self._get_prob(x) for x in X])\n", + " return np.vstack([1 - probs, probs]).T\n", + "\n", + " def predict(self, X):\n", + " \"\"\"Return predicted class labels based on threshold.\"\"\"\n", + " probs = self.predict_proba(X)[:, 1]\n", + " return (probs >= self.threshold).astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2591d8b4", + "metadata": {}, + "outputs": [], + "source": [ + "# Function to generate logistic data\n", + "def make_logistic_data(n_samples=200, scale=2.0, random_state=None):\n", + " rng = check_random_state(random_state)\n", + " X = rng.uniform(-3, 3, size=n_samples)\n", + " probs = 1 / (1 + np.exp(- scale * X))\n", + " y = rng.binomial(1, probs)\n", + " return X, y" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5f71a565", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "N=2000, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.01\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 31\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.02\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 32\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n", + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "N=2000, risk['name']='precision', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.09\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='precision', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.02\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 36\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.02\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 38\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='recall', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "N=2000, risk['name']='recall', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 100\n", + "Valid experiment\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 1\n", + "Valid experiment\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hmaissoro/Projects/MAPIE/mapie/risk_control/binary_classification.py:303: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.2\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "N=2000, risk['name']='accuracy', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", + "Proportion of times the risk is not controlled: 0.0\n", + "Delta: 0.8\n", + "Mean number of valid thresholds found per iteration: 0\n", + "Valid experiment\n", + "\n", + "\n", + "\n", + "\n", + "All good!\n" + ] + } + ], + "source": [ + "N = 2000\n", + "risk = [\n", + " {\"name\": \"precision\", \"risk\": precision},\n", + " {\"name\": \"recall\", \"risk\": recall},\n", + " {\"name\": \"accuracy\", \"risk\": accuracy},\n", + "]\n", + "predict_params = [np.linspace(0, 0.99, 100), np.empty(1)]\n", + "target_level = [0.1, 0.9]\n", + "confidence_level = [0.8, 0.2]\n", + "\n", + "n_repeats = 100\n", + "invalid_experiment = False\n", + "\n", + "for combination in product(risk, predict_params, target_level, confidence_level):\n", + " risk, predict_params, target_level, confidence_level = combination\n", + " if len(predict_params) == 1:\n", + " predict_params = np.array([np.random.choice(np.linspace(0, 0.9, 10))]) # random threshold\n", + " alpha = float(Decimal(\"1\") - Decimal(str(target_level))) # to avoid floating point issues\n", + " delta = float(Decimal(\"1\") - Decimal(str(confidence_level))) # to avoid floating point issues\n", + "\n", + " clf = LogisticClassifier(scale=2.0, threshold=0.5)\n", + " nb_errors = 0 # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)\n", + " total_nb_valid_params = 0\n", + "\n", + " for _ in range(n_repeats):\n", + "\n", + " X_calibrate, y_calibrate = make_logistic_data(n_samples=N, scale=2.0, random_state=None)\n", + "\n", + " controller = BinaryClassificationController(\n", + " predict_function=clf.predict_proba,\n", + " risk=risk[\"risk\"],\n", + " target_level=target_level,\n", + " confidence_level=confidence_level,\n", + " )\n", + " controller._predict_params = predict_params\n", + " controller = controller.calibrate(X_calibrate, y_calibrate)\n", + " valid_parameters = controller.valid_predict_params\n", + " total_nb_valid_params += len(valid_parameters)\n", + "\n", + " # In the following, we check that all the valid thresholds found by LTT actually control the risk.\n", + " # We sample a large test set and estimate the risk for each valid_parameters using the logistic classifier.\n", + " X_test, y_test = make_logistic_data(n_samples=N, scale=2.0, random_state=None)\n", + " probs = clf.predict_proba(X_test)[:, 1]\n", + " \n", + " # If no valid parameters found, risk is not controlled\n", + " if len(valid_parameters) >= 1:\n", + " for lambda_ in valid_parameters:\n", + " y_pred = (probs >= lambda_).astype(int)\n", + "\n", + " if risk[\"risk\"] == precision:\n", + " empirical_metric = precision_score(y_test, y_pred, zero_division=0)\n", + " elif risk[\"risk\"] == recall:\n", + " empirical_metric = recall_score(y_test, y_pred, zero_division=0)\n", + " elif risk[\"risk\"] == accuracy:\n", + " empirical_metric = accuracy_score(y_test, y_pred)\n", + "\n", + " # Check if the risk control fails\n", + " if risk[\"risk\"].higher_is_better:\n", + " if empirical_metric <= target_level:\n", + " nb_errors += 1\n", + " break \n", + " else:\n", + " if empirical_metric > target_level:\n", + " nb_errors += 1\n", + " break\n", + "\n", + " print(f\"\\n{N=}, {risk['name']=}, {len(predict_params)=}, {target_level=}, {confidence_level=}\")\n", + "\n", + " print(f\"Proportion of times the risk is not controlled: {nb_errors/n_repeats}\")\n", + " print(f\"Delta: {delta}\")\n", + " print(f\"Mean number of valid thresholds found per iteration: {int(np.round(total_nb_valid_params/n_repeats))}\")\n", + "\n", + " if nb_errors/n_repeats <= delta:\n", + " print(\"Valid experiment\")\n", + " else:\n", + " print(\"Invalid experiment\")\n", + " invalid_experiment = True\n", + "\n", + "print(\"\\n\\n\\n\")\n", + "if invalid_experiment:\n", + " print(\"Some experiments failed.\")\n", + "else:\n", + " print(\"All good!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2c0c81fe", + "metadata": {}, + "outputs": [], + "source": [ + "assert not invalid_experiment" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/risk_control/theoretical_validity_tests.ipynb b/notebooks/risk_control/theoretical_validity_tests.ipynb deleted file mode 100644 index a6d705414..000000000 --- a/notebooks/risk_control/theoretical_validity_tests.ipynb +++ /dev/null @@ -1,542 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "ed592eb3f8989aa8", - "metadata": {}, - "source": [ - "# Binary classification risk control - Theoretical tests to validate implementation" - ] - }, - { - "cell_type": "markdown", - "id": "8c1746b673c148dd", - "metadata": {}, - "source": [ - "# Protocol description\n", - "Testing theoretical guarantees of risk control in binary classification using a random classifier and synthetic data.\n", - "\n", - "Each test case looks at a combination of parameters, for which we repeat the experiment `n_repeat` times. The model is the same for all experiments (basically a random classifier), but the data is different each time.\n", - "\n", - "Each experiment consists of the following:\n", - " - We calibrate a BinaryClassificationController. It gives us the list of lambda values that control the risk according to LTT.\n", - " - Because we know that the model is random, we know the theoretical risk associated with each lambda value. So we are able to check if the lambda values given by LTT actually control the risk. If not, we count 1 \"error\". Note that *each* lambda value should control the risk, not just one of them.\n", - "\n", - "After n_repeat experiments, we compute the proportion of errors, that should be less than delta (1 - confidence_level).\n", - "\n", - "# Results\n", - "The risk is controlled in all the test cases. Overall, LTT seems very conservative (to achieve a high percentage of errors, we need to lower the confidence level significantly (0.01) and use only one threshold to avoid the Bonferroni effect). But this is likely due to the model being random, and thus having a lot of variance. It would be interesting to see how this evolves with a better model." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "9b1422ae620955fd", - "metadata": { - "ExecuteTime": { - "end_time": "2025-09-15T16:21:19.107147Z", - "start_time": "2025-09-15T16:21:19.071278Z" - } - }, - "outputs": [], - "source": [ - "%reload_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "faeb2f47a92dbf35", - "metadata": { - "ExecuteTime": { - "end_time": "2025-09-15T16:21:20.596927Z", - "start_time": "2025-09-15T16:21:19.127705Z" - } - }, - "outputs": [], - "source": [ - "from sklearn.datasets import make_classification\n", - "import numpy as np\n", - "from mapie.risk_control import BinaryClassificationController, precision, accuracy, recall\n", - "from itertools import product\n", - "from decimal import Decimal" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "eefafd6d1697fb9c", - "metadata": { - "ExecuteTime": { - "end_time": "2025-09-15T16:21:20.802766Z", - "start_time": "2025-09-15T16:21:20.652168Z" - } - }, - "outputs": [], - "source": [ - "# Using sklearn.dummy.DummyClassifier would be cleaner\n", - "class RandomClassifier:\n", - " def __init__(self, seed=42, threshold=0.5):\n", - " self.seed = seed\n", - " self.threshold = threshold\n", - "\n", - " def _get_prob(self, x):\n", - " local_seed = hash((x, self.seed)) % (2**32)\n", - " rng = np.random.RandomState(local_seed)\n", - " return np.round(rng.rand(), 2)\n", - "\n", - " def predict_proba(self, X):\n", - " probs = np.array([self._get_prob(x) for x in X])\n", - " return np.vstack([1 - probs, probs]).T\n", - "\n", - " def predict(self, X):\n", - " probs = self.predict_proba(X)[:, 1]\n", - " return (probs >= self.threshold).astype(int)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1fdffae392bb7a65", - "metadata": { - "ExecuteTime": { - "end_time": "2025-09-15T16:21:54.095452Z", - "start_time": "2025-09-15T16:21:20.810388Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "N=100, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 90\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 92\n", - "Valid experiment\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/vlaurent/code/pro/MAPIE/MAPIE/mapie/risk_control.py:891: UserWarning: No predict parameters were found to control the risk at the given target and confidence levels. Try using a larger calibration set or a better model.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "N=100, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='precision', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='precision', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='precision', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='precision', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='recall', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 74\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='recall', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.01\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 78\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='recall', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='recall', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 3\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='recall', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='recall', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='recall', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='recall', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='accuracy', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 100\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='accuracy', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 100\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='accuracy', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='accuracy', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='accuracy', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='accuracy', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='accuracy', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=100, risk['name']='accuracy', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='precision', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='precision', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='precision', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='precision', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='recall', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='recall', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='recall', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='recall', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='recall', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='recall', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='recall', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='recall', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='accuracy', len(predict_params)=100, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 16\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='accuracy', len(predict_params)=100, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 15\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='accuracy', len(predict_params)=100, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='accuracy', len(predict_params)=100, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='accuracy', len(predict_params)=1, target_level=0.1, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='accuracy', len(predict_params)=1, target_level=0.1, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 1\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='accuracy', len(predict_params)=1, target_level=0.9, confidence_level=0.8\n", - "Proportion of times the risk is not controlled: 0.0\n", - "Delta: 0.2\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "N=5, risk['name']='accuracy', len(predict_params)=1, target_level=0.9, confidence_level=0.2\n", - "Proportion of times the risk is not controlled: 0.02\n", - "Delta: 0.8\n", - "Mean number of valid thresholds found per iteration: 0\n", - "Valid experiment\n", - "\n", - "\n", - "\n", - "\n", - "All good!\n" - ] - } - ], - "source": [ - "N = [100, 5] # size of the calibration set\n", - "risk = [\n", - " {\"name\": \"precision\", \"risk\": precision},\n", - " {\"name\": \"recall\", \"risk\": recall},\n", - " {\"name\": \"accuracy\", \"risk\": accuracy},\n", - "]\n", - "predict_params = [np.linspace(0, 0.99, 100), np.array([0.5])]\n", - "target_level = [0.1, 0.9]\n", - "confidence_level = [0.8, 0.2]\n", - "\n", - "n_repeats = 100\n", - "invalid_experiment = False\n", - "\n", - "for combination in product(N, risk, predict_params, target_level, confidence_level):\n", - " N, risk, predict_params, target_level, confidence_level = combination\n", - " alpha = float(Decimal(\"1\") - Decimal(str(target_level))) # to avoid floating point issues\n", - " delta = float(Decimal(\"1\") - Decimal(str(confidence_level))) # to avoid floating point issues\n", - "\n", - " clf = RandomClassifier()\n", - " nb_errors = 0 # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)\n", - " total_nb_valid_params = 0\n", - "\n", - " for _ in range(n_repeats):\n", - "\n", - " X_calibrate, y_calibrate = make_classification(\n", - " n_samples=N,\n", - " n_features=1,\n", - " n_informative=1,\n", - " n_redundant=0,\n", - " n_repeated=0,\n", - " n_classes=2,\n", - " n_clusters_per_class=1,\n", - " weights=[0.5, 0.5],\n", - " flip_y=0,\n", - " random_state=None\n", - " )\n", - " X_calibrate = X_calibrate.squeeze()\n", - "\n", - " controller = BinaryClassificationController(\n", - " predict_function=clf.predict_proba,\n", - " risk=risk[\"risk\"],\n", - " target_level=target_level,\n", - " confidence_level=confidence_level,\n", - " )\n", - " controller._predict_params = predict_params\n", - " controller.calibrate(X_calibrate, y_calibrate)\n", - " valid_parameters = controller.valid_predict_params\n", - " total_nb_valid_params += len(valid_parameters)\n", - "\n", - " # In the following, we check that all the valid thresholds found by LTT actually control the risk.\n", - " # Instead of sampling a large test set, we use the fact that we know the theoretical risk of a random classifier.\n", - " # The calculations here are valid only for a balanced data generator.\n", - " if risk[\"risk\"] == precision or risk[\"risk\"] == accuracy:\n", - " if target_level > 0.5 and len(valid_parameters) >= 1:\n", - " nb_errors += 1\n", - " elif risk[\"risk\"] == recall:\n", - " if any(x > alpha for x in valid_parameters) and len(valid_parameters) >= 1:\n", - " nb_errors += 1\n", - "\n", - " print(f\"\\n{N=}, {risk['name']=}, {len(predict_params)=}, {target_level=}, {confidence_level=}\")\n", - "\n", - " print(f\"Proportion of times the risk is not controlled: {nb_errors/n_repeats}\")\n", - " print(f\"Delta: {delta}\")\n", - " print(f\"Mean number of valid thresholds found per iteration: {int(np.round(total_nb_valid_params/n_repeats))}\")\n", - "\n", - " if nb_errors/n_repeats <= delta:\n", - " print(\"Valid experiment\")\n", - " else:\n", - " print(\"Invalid experiment\")\n", - " invalid_experiment = True\n", - "\n", - "print(\"\\n\\n\\n\")\n", - "if invalid_experiment:\n", - " print(\"Some experiments failed.\")\n", - "else:\n", - " print(\"All good!\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "4c3f437f0b2897a1", - "metadata": { - "ExecuteTime": { - "end_time": "2025-09-15T16:22:49.189292Z", - "start_time": "2025-09-15T16:22:48.871875Z" - } - }, - "outputs": [], - "source": [ - "assert not invalid_experiment" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": ".venv-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.18" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/requirements.dev.txt b/requirements.dev.txt index adfeeebca..b053640e7 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -9,4 +9,5 @@ pytest pytest-cov scikit-learn twine -wheel \ No newline at end of file +wheel +nbclient \ No newline at end of file