diff --git a/.gitignore b/.gitignore index f787972b1..4d3384e64 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,7 @@ doc/_build/ doc/examples_classification/ doc/examples_regression/ doc/examples_calibration/ -doc/examples_multilabel_classification/ +doc/examples_risk_control/ doc/examples_mondrian/ doc/auto_examples/ doc/modules/generated/ diff --git a/HISTORY.rst b/HISTORY.rst index 29408cc31..eb3409a93 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -16,6 +16,7 @@ History * MAPIE now supports Python versions up to the latest release (currently 3.13) * Change `prefit` default value to `True` in split methods' docstrings to remain consistent with the implementation * Fix issue 699 to replace `TimeSeriesRegressor.partial_fit` with `TimeSeriesRegressor.update` +* Revert incorrect renaming of calibration to conformalization in risk_control.py 1.0.1 (2025-05-22) ------------------ diff --git a/doc/Makefile b/doc/Makefile index 841011bd2..ba1723db0 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -50,7 +50,7 @@ clean: -rm -rf $(BUILDDIR)/* -rm -rf examples_regression/ -rm -rf examples_classification/ - -rm -rf examples_multilabel_classification/ + -rm -rf examples_risk_control/ -rm -rf examples_calibration/ -rm -rf examples_mondrian/ -rm -rf generated/* diff --git a/doc/conf.py b/doc/conf.py index 78cee8a31..eacd46e6e 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -321,14 +321,14 @@ "examples_dirs": [ "../examples/regression", "../examples/classification", - "../examples/multilabel_classification", + "../examples/risk_control", "../examples/calibration", "../examples/mondrian", ], "gallery_dirs": [ "examples_regression", "examples_classification", - "examples_multilabel_classification", + "examples_risk_control", "examples_calibration", "examples_mondrian", ], diff --git a/doc/index.rst b/doc/index.rst index 2807c04bd..1d2881cb0 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -24,7 +24,8 @@ :caption: Control prediction errors theoretical_description_risk_control - examples_multilabel_classification/1-quickstart/plot_tutorial_risk_control + examples_risk_control/1-quickstart/plot_risk_control_binary_classification + examples_risk_control/index external_risk_control_package .. toctree:: diff --git a/doc/quick_start.rst b/doc/quick_start.rst index 9794a4000..995d68157 100644 --- a/doc/quick_start.rst +++ b/doc/quick_start.rst @@ -40,4 +40,10 @@ Here, we generate one-dimensional noisy data that we fit with a MLPRegressor: `U 3. Classification ======================= -Similarly, it's possible to do the same for a basic classification problem: `Use MAPIE to plot prediction sets `_ \ No newline at end of file +Similarly, it's possible to do the same for a basic classification problem: `Use MAPIE to plot prediction sets `_ + + +4. Risk Control +======================= + +MAPIE implements risk control methods for multilabel classification (in particular, image segmentation) and binary classification: `Use MAPIE to control risk for a binary classifier `_ \ No newline at end of file diff --git a/doc/theoretical_description_risk_control.rst b/doc/theoretical_description_risk_control.rst index 76629d311..6c8cbe4c4 100644 --- a/doc/theoretical_description_risk_control.rst +++ b/doc/theoretical_description_risk_control.rst @@ -13,26 +13,43 @@ Getting started with risk control in MAPIE Overview ======== +This section provides an overview of risk control in MAPIE. For those unfamiliar with the concept of risk control, the next section provides an introduction to the topic. + Three methods of risk control have been implemented in MAPIE so far : **Risk-Controlling Prediction Sets** (RCPS) [1], **Conformal Risk Control** (CRC) [2] and **Learn Then Test** (LTT) [3]. -The difference between these methods is the way the conformity scores are computed. -As of now, MAPIE supports risk control for two machine learning tasks: **binary classification**, as well as **multi-label classification** (including applications like image segmentation). +As of now, MAPIE supports risk control for two machine learning tasks: **binary classification**, as well as **multi-label classification** (in particular applications like image segmentation). The table below details the available methods for each task: +.. |br| raw:: html + +
+ .. list-table:: Available risk control methods in MAPIE for each ML task :header-rows: 1 - * - Risk control method - - Binary classification - - Multi-label classification (image segmentation) + * - Risk control |br| method + - Type of |br| control + - Assumption |br| on the data + - Non-monotonic |br| risks + - Binary |br| classification + - Multi-label |br| classification * - RCPS + - Probability + - i.i.d. + - ❌ - ❌ - ✅ * - CRC + - Expectation + - Exchangeable + - ❌ - ❌ - ✅ * - LTT + - Probability + - i.i.d + - ✅ - ✅ - ✅ @@ -41,7 +58,7 @@ In MAPIE for multi-label classification, CRC and RCPS are used for recall contro 1. What is risk control? ======================== -Before diving into risk control, let's take the simple example of a binary classification model, which separates the incoming data into the two classes thanks to its threshold: predictions above it are classified as 1, and those below as 0. Suppose we want to find a threshold that guarantees that our model achieves a certain level of precision. A naive, yet straightforward approach to do this is to evaluate how precision varies with different threshold values on a validation dataset. By plotting this relationship (see plot below), we can identify the range of thresholds that meet our desired precision requirement (green zone on the graph). +Before diving into risk control, let's take the simple example of a binary classification model, which separates the incoming data into two classes. Predicted probabilities above a given threshold (e.g., 0.5) correspond to predicting the "positive" class and probabilities below correspond to the "negative" class. Suppose we want to find a threshold that guarantees that our model achieves a certain level of precision. A naive, yet straightforward approach to do this is to evaluate how precision varies with different threshold values on a validation dataset. By plotting this relationship (see plot below), we can identify the range of thresholds that meet our desired precision requirement (green zone on the graph). .. image:: images/example_without_risk_control.png :width: 600 @@ -54,7 +71,7 @@ So far, so good. But here is the catch: while the chosen threshold effectively k Risk control is the science of adjusting a model's parameter, typically denoted :math:`\lambda`, so that a given risk stays below a desired level with high probability on unseen data. Note that here, the term *risk* is used to describe an undesirable outcome of the model (e.g., type I error): therefore, it is a value we want to minimize, and in our case, keep under a certain level. Also note that risk control can easily be applied to metrics we want to maximize (e.g., precision), simply by controlling the complement (e.g., 1-precision). -The strength of risk control lies in the statistical guarantees it provides on unseen data. Unlike the naive method presented earlier, it determines a value of :math:`\lambda` that ensures the risk is controlled *beyond* the training data. +The strength of risk control lies in the statistical guarantees it provides on unseen data. Unlike the naive method presented earlier, it determines a value of :math:`\lambda` that ensures the risk is controlled *beyond* the validation data. Applying risk control to the previous example would allow us to get a new — albeit narrower — range of thresholds (blue zone on the graph) that are **statistically guaranteed**. @@ -66,7 +83,7 @@ This guarantee is critical in a wide range of use cases (especially in high-stak — -To express risk control in mathematical terms, we denote by R the risk we want to control, and introduce the following two parameters: +To express risk control in mathematical terms, we denote by :math:`R` the risk we want to control, and introduce the following two parameters: - :math:`\alpha`: the target level below which we want the risk to remain, as shown in the figure below; @@ -76,13 +93,13 @@ To express risk control in mathematical terms, we denote by R the risk we want t - :math:`\delta`: the confidence level associated with the risk control. -In other words, the risk is said to be controlled if :math:`R \leq \alpha` with probability at least :math:`1 - \delta`. +In other words, the risk is said to be controlled if :math:`R \leq \alpha` with probability at least :math:`1 - \delta`, where the probability is over the randomness in the sampling of the dataset. The three risk control methods implemented in MAPIE — RCPS, CRC and LTT — rely on different assumptions, and offer slightly different guarantees: - **CRC** requires the data to be **exchangeable**, and gives a guarantee on the **expectation of the risk**: :math:`\mathbb{E}(R) \leq \alpha`; -- **RCPS** and **LTT** both impose stricter assumptions, requiring the data to be **independent and identically distributed** (i.i.d.), which implies exchangeability. The guarantee they provide is on the **probability that the risk does not exceed :math:`\alpha`**: :math:`\mathbb{P}(R \leq \alpha) \geq 1 - \delta`. +- **RCPS** and **LTT** both impose stricter assumptions, requiring the data to be **independent and identically distributed** (i.i.d.), which implies exchangeability. The guarantee they provide is on the **probability that the risk does not exceed** :math:`\boldsymbol{\alpha}`: :math:`\mathbb{P}(R \leq \alpha) \geq 1 - \delta`. .. image:: images/risk_distribution.png :width: 600 @@ -94,12 +111,13 @@ The plot above gives a visual representation of the difference between the two t - The risk is controlled in probability (RCPS/LTT) if at least :math:`1 - \delta` percent of its distribution over unseen data is below :math:`\alpha`. -Note that at the opposite of the other two methods, LTT allows to control any non-monotonic risk. +Note that contrary to the other two methods, LTT allows to control any non-monotonic risk. The following section provides a detailed overview of each method. 2. Theoretical description ========================== + 2.1 Risk-Controlling Prediction Sets ------------------------------------ 2.1.1 General settings @@ -234,7 +252,7 @@ We are going to present the Learn Then Test framework that allows the user to co This method has been introduced in article [3]. The settings here are the same as RCPS and CRC, we just need to introduce some new parameters: -- Let :math:`\Lambda` be a discretized for our :math:`\lambda`, meaning that :math:`\Lambda = \{\lambda_1, ..., \lambda_n\}`. +- Let :math:`\Lambda` be a discretized set for our :math:`\lambda`, meaning that :math:`\Lambda = \{\lambda_1, ..., \lambda_n\}`. - Let :math:`p_\lambda` be a valid p-value for the null hypothesis :math:`\mathbb{H}_j: R(\lambda_j)>\alpha`. @@ -250,7 +268,7 @@ In order to find all the parameters :math:`\lambda` that satisfy the above condi :math:`\{(x_1, y_1), \dots, (x_n, y_n)\}`. - For each :math:`\lambda_j` in a discrete set :math:`\Lambda = \{\lambda_1, \lambda_2,\dots, \lambda_n\}`, we associate the null hypothesis - :math:`\mathcal{H}_j: R(\lambda_j) > \alpha`, as rejecting the hypothesis corresponds to selecting :math:`\lambda_j` as a point where risk the risk + :math:`\mathcal{H}_j: R(\lambda_j) > \alpha`, as rejecting the hypothesis corresponds to selecting :math:`\lambda_j` as a point where the risk is controlled. - For each null hypothesis, we compute a valid p-value using a concentration inequality :math:`p_{\lambda_j}`. Here we choose to compute the Hoeffding-Bentkus p-value @@ -259,6 +277,7 @@ In order to find all the parameters :math:`\lambda` that satisfy the above condi - Return :math:`\hat{\Lambda} = \mathcal{A}(\{p_j\}_{j\in\{1,\dots,\lvert \Lambda \rvert})`, where :math:`\mathcal{A}`, is an algorithm that controls the family-wise error rate (FWER), for example, Bonferonni correction. +Note that a notebook testing theoretical guarantees of risk control in binary classification using a random classifier and synthetic data is available here: `theoretical_validity_tests.ipynb `__. References ========== diff --git a/doc/v1_release_notes.rst b/doc/v1_release_notes.rst index 41ae6aa08..0a946a1e3 100644 --- a/doc/v1_release_notes.rst +++ b/doc/v1_release_notes.rst @@ -263,8 +263,6 @@ Risk control The ``MapieMultiLabelClassifier`` class has been renamed ``PrecisionRecallController``. -The parameter ``calib_size`` from the ``fit`` method has been renamed ``conformalize_size``. - Calibration ^^^^^^^^^^^^^ diff --git a/examples/multilabel_classification/README.rst b/examples/multilabel_classification/README.rst deleted file mode 100644 index 1f8a7b3fc..000000000 --- a/examples/multilabel_classification/README.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. _general_examples: - -General examples -================ \ No newline at end of file diff --git a/examples/multilabel_classification/1-quickstart/README.rst b/examples/risk_control/1-quickstart/README.rst similarity index 56% rename from examples/multilabel_classification/1-quickstart/README.rst rename to examples/risk_control/1-quickstart/README.rst index 65aaf6366..2970a4ef1 100644 --- a/examples/multilabel_classification/1-quickstart/README.rst +++ b/examples/risk_control/1-quickstart/README.rst @@ -1,6 +1,6 @@ -.. _multilabel_classification_examples_1: +.. _risk_control_examples_1: 1. Quickstart examples ---------------------- -The following examples present the main functionalities of MAPIE through basic quickstart regression problems. \ No newline at end of file +The following examples present the main functionalities of MAPIE through basic quickstart risk control problems. \ No newline at end of file diff --git a/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py b/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py new file mode 100644 index 000000000..216594119 --- /dev/null +++ b/examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py @@ -0,0 +1,149 @@ +""" +========================================================= +Use MAPIE to control the precision of a binary classifier +========================================================= + +In this example, we explain how to do risk control for binary classification with MAPIE. + +""" + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.datasets import make_circles +from sklearn.svm import SVC +from sklearn.model_selection import FixedThresholdClassifier +from sklearn.metrics import precision_score +from sklearn.inspection import DecisionBoundaryDisplay + +from mapie.risk_control import BinaryClassificationController, precision +from mapie.utils import train_conformalize_test_split + +RANDOM_STATE = 1 + +############################################################################## +# Let us first load the dataset and fit an SVC on the training data. + +X, y = make_circles(n_samples=3000, noise=0.3, factor=0.3, random_state=RANDOM_STATE) +(X_train, X_calib, X_test, + y_train, y_calib, y_test) = train_conformalize_test_split( + X, y, + train_size=0.8, conformalize_size=0.1, test_size=0.1, + random_state=RANDOM_STATE + ) + +clf = SVC(probability=True, random_state=RANDOM_STATE) +clf.fit(X_train, y_train) + +############################################################################## +# Next, we initialize a :class:`~mapie.risk_control.BinaryClassificationController` +# using the probability estimation function from the fitted estimator: +# ``clf.predict_proba``, a risk function (here the precision), a target risk level, and +# a confidence level. Then we use the calibration data to compute statistically +# guaranteed thresholds using a risk control method. + +target_precision = 0.8 +confidence_level = 0.9 +bcc = BinaryClassificationController( + clf.predict_proba, + precision, target_level=target_precision, + confidence_level=confidence_level + ) +bcc.calibrate(X_calib, y_calib) + +print(f'{len(bcc.valid_predict_params)} thresholds found that guarantee a precision of ' + f'at least {target_precision} with a confidence of {confidence_level}.\n' + 'Among those, the one that maximizes the secondary objective (recall here) is: ' + f'{bcc.best_predict_param:.3f}.') + + +############################################################################## +# In the plot below, we visualize how the threshold values impact precision, and what +# thresholds have been computed as statistically guaranteed. + +proba_positive_class = clf.predict_proba(X_calib)[:, 1] + +tested_thresholds = bcc._predict_params +precisions = np.full(len(tested_thresholds), np.inf) +for i, threshold in enumerate(tested_thresholds): + y_pred = (proba_positive_class >= threshold).astype(int) + precisions[i] = precision_score(y_calib, y_pred) + +valid_thresholds_indices = np.array( + [t in bcc.valid_predict_params for t in tested_thresholds]) +best_threshold_index = np.where( + tested_thresholds == bcc.best_predict_param)[0][0] + +plt.figure() +plt.scatter( + tested_thresholds[valid_thresholds_indices], precisions[valid_thresholds_indices], + c='tab:green', label='Valid thresholds' + ) +plt.scatter( + tested_thresholds[~valid_thresholds_indices], precisions[~valid_thresholds_indices], + c='tab:red', label='Invalid thresholds' + ) +plt.scatter( + tested_thresholds[best_threshold_index], precisions[best_threshold_index], + c='tab:green', label='Best threshold', marker='*', edgecolors='k', s=300 + ) +plt.axhline(target_precision, color='tab:gray', linestyle='--') +plt.text( + 0.7, target_precision+0.02, 'Target precision', color='tab:gray', fontstyle='italic' +) +plt.xlabel('Threshold') +plt.ylabel('Precision') +plt.legend() +plt.show() + +############################################################################## +# Contrary to the naive way of computing a threshold to satisfy a precision target on +# calibration data, risk control provides statistical guarantees on unseen data. +# In the plot above, we can see that not all thresholds corresponding to a precision +# higher that the target are valid. This is due to the uncertainty inherent to the +# finite size of the calibration set, which risk control takes into account. +# +# In particular, the highest threshold values are considered invalid due to the +# small number of observations used to compute the precision, following the Learn then +# Test procedure. In the most extreme case, no observation is available, which causes +# the precision value to be ill-defined and set to 0. + +# Besides computing a set of valid thresholds, +# :class:`~mapie.risk_control.BinaryClassificationController` also outputs the "best" +# one, which is the valid threshold that maximizes a secondary objective +# (recall here). +# +# After obtaining the best threshold, we can use the ``predict`` function of +# :class:`~mapie.risk_control.BinaryClassificationController` for future predictions, +# or use scikit-learn's ``FixedThresholdClassifier`` as a wrapper to benefit +# from functionalities like easily plotting the decision boundary as seen below. + +y_pred = bcc.predict(X_test) + +clf_threshold = FixedThresholdClassifier(clf, threshold=bcc.best_predict_param) +clf_threshold.fit(X_train, y_train) +# .fit necessary for plotting, alternatively you can use sklearn.frozen.FrozenEstimator + + +disp = DecisionBoundaryDisplay.from_estimator( + clf_threshold, X_test, response_method="predict", cmap=plt.cm.coolwarm + ) + +plt.scatter( + X_test[y_test == 0, 0], X_test[y_test == 0, 1], + edgecolors='k', c='tab:blue', alpha=0.5, label='"negative" class' + ) +plt.scatter( + X_test[y_test == 1, 0], X_test[y_test == 1, 1], + edgecolors='k', c='tab:red', alpha=0.5, label='"positive" class' + ) +plt.title("Decision Boundary of FixedThresholdClassifier") +plt.xlabel("Feature 1") +plt.ylabel("Feature 2") +plt.legend() +plt.show() + +############################################################################## +# Different risk functions have been implemented, such as precision and recall, but you +# can also implement your own custom function using +# :class:`~mapie.risk_control.BinaryClassificationRisk` and choose your own +# secondary objective. diff --git a/examples/risk_control/2-advanced-analysis/README.rst b/examples/risk_control/2-advanced-analysis/README.rst new file mode 100644 index 000000000..2179cbdbd --- /dev/null +++ b/examples/risk_control/2-advanced-analysis/README.rst @@ -0,0 +1,6 @@ +.. _risk_control_examples_2: + +2. Advanced analysis +-------------------- + +The following examples use MAPIE for discussing more complex risk control problems. \ No newline at end of file diff --git a/examples/multilabel_classification/1-quickstart/plot_tutorial_risk_control.py b/examples/risk_control/2-advanced-analysis/plot_tutorial_risk_control.py similarity index 100% rename from examples/multilabel_classification/1-quickstart/plot_tutorial_risk_control.py rename to examples/risk_control/2-advanced-analysis/plot_tutorial_risk_control.py diff --git a/examples/risk_control/README.rst b/examples/risk_control/README.rst new file mode 100644 index 000000000..f5f00e9f5 --- /dev/null +++ b/examples/risk_control/README.rst @@ -0,0 +1,6 @@ +.. _risk_control_examples: + +All risk control examples +========================= + +Following is a collection of notebooks demonstrating how to use MAPIE for risk control. \ No newline at end of file diff --git a/mapie/risk_control.py b/mapie/risk_control.py index a8fb9f190..a46e80730 100644 --- a/mapie/risk_control.py +++ b/mapie/risk_control.py @@ -362,7 +362,7 @@ def _check_estimator( Warning If estimator is then to warn about the split of the - data between train and conformalization + data between train and calibration """ if (estimator is None) and (not _refit): raise ValueError( @@ -374,19 +374,19 @@ def _check_estimator( estimator = MultiOutputClassifier( LogisticRegression() ) - X_train, X_conf, y_train, y_conf = train_test_split( - X, - y, - test_size=self.conformalize_size, - random_state=self.random_state, + X_train, X_calib, y_train, y_calib = train_test_split( + X, + y, + test_size=self.calib_size, + random_state=self.random_state, ) estimator.fit(X_train, y_train) warnings.warn( "WARNING: To avoid overfitting, X has been split" - + "into X_train and X_conf. The conformalization will only" - + "be done on X_conf" + + "into X_train and X_calib. The calibration will only" + + "be done on X_calib" ) - return estimator, X_conf, y_conf + return estimator, X_calib, y_calib if isinstance(estimator, Pipeline): est = estimator[-1] @@ -589,7 +589,7 @@ def fit( self, X: ArrayLike, y: ArrayLike, - conformalize_size: Optional[float] = .3 + calib_size: Optional[float] = .3 ) -> PrecisionRecallController: """ Fit the base estimator or use the fitted base estimator. @@ -602,8 +602,8 @@ def fit( y: NDArray of shape (n_samples, n_classes) Training labels. - conformalize_size: Optional[float] - Size of the conformalization dataset with respect to X if the + calib_size: Optional[float] + Size of the calibration dataset with respect to X if the given model is ``None`` need to fit a LogisticRegression. By default .3 @@ -613,7 +613,7 @@ def fit( PrecisionRecallController The model itself. """ - self.conformalize_size = conformalize_size + self.calib_size = calib_size return self.partial_fit(X, y, _refit=True) def predict( @@ -696,7 +696,7 @@ def predict( ) self._check_valid_index(alpha_np) self.lambdas_star, self.r_star = find_lambda_control_star( - self.r_hat, self.valid_index, self.lambdas + self.r_hat, self.valid_index, self.lambdas ) y_pred_proba_array = ( y_pred_proba_array >