-
Notifications
You must be signed in to change notification settings - Fork 128
Refine doc risk control #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Valentin-Laurent
merged 15 commits into
binary-risk-control
from
refine-doc-risk-control
Sep 22, 2025
+223
−41
Merged
Changes from 5 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
9ec2d65
better doc organisation: rename multilabel_classification as risk_con…
allglc d8057c6
make clarifications, improve the overview table, and fix typos
allglc e088e85
add quick start risk control
allglc d1636b8
Revert incorrect renaming of calibration to conformalization in risk_…
allglc fcbff66
add link to notebook theoretical validity risk control
allglc 0610bb1
Update examples/risk_control/1-quickstart/plot_risk_control_binary_cl…
Valentin-Laurent 8ae2c47
DOC - BinaryClassificationController docstrings
Valentin-Laurent 709d4e2
DOC - BinaryClassificationRisk docstring, + make some attributes private
Valentin-Laurent ca8f178
DOC - Fix docstrings formatting, add classes to the API page in ReadT…
Valentin-Laurent 3026f13
clarifications of explanations and formatting
allglc 28dbc3c
fix trailing whitespace
allglc 373f25f
change position of notebook link
allglc ea6d0e3
DOC & MTN - Fix docstrings, add an exception handling if users passes…
Valentin-Laurent 08c8bf4
FIX - Fix wrong risk value with higher_is_better risks when undefined
Valentin-Laurent 00cad03
Merge branch 'binary-risk-control' into refine-doc-risk-control
allglc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
4 changes: 2 additions & 2 deletions
4
...el_classification/1-quickstart/README.rst → ...ples/risk_control/1-quickstart/README.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
.. _multilabel_classification_examples_1: | ||
.. _risk_control_examples_1: | ||
|
||
1. Quickstart examples | ||
---------------------- | ||
|
||
The following examples present the main functionalities of MAPIE through basic quickstart regression problems. | ||
The following examples present the main functionalities of MAPIE through basic quickstart risk control problems. |
126 changes: 126 additions & 0 deletions
126
examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
""" | ||
================================================= | ||
Use MAPIE to control risk for a binary classifier | ||
================================================= | ||
|
||
In this example, we explain how to do risk control for binary classification with MAPIE. | ||
|
||
""" | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from sklearn.datasets import make_circles | ||
from sklearn.svm import SVC | ||
from sklearn.model_selection import FixedThresholdClassifier | ||
from sklearn.metrics import precision_score | ||
from sklearn.inspection import DecisionBoundaryDisplay | ||
|
||
from mapie.risk_control import BinaryClassificationController, precision | ||
from mapie.utils import train_conformalize_test_split | ||
|
||
RANDOM_STATE = 1 | ||
|
||
############################################################################## | ||
# Let us first load the dataset and fit an SVC on the training data. | ||
|
||
X, y = make_circles(n_samples=3000, noise=0.3, | ||
factor=0.3, random_state=RANDOM_STATE) | ||
(X_train, X_calib, X_test, | ||
y_train, y_calib, y_test) = train_conformalize_test_split( | ||
X, y, train_size=0.8, conformalize_size=0.1, test_size=0.1, | ||
random_state=RANDOM_STATE) | ||
|
||
clf = SVC(probability=True, random_state=RANDOM_STATE) | ||
clf.fit(X_train, y_train) | ||
|
||
############################################################################## | ||
# Next, we initialize a :class:`~mapie.risk_control.BinaryClassificationController` | ||
# using the probability estimation function from the fitted estimator: | ||
# ``clf.predict_proba``, a risk function (here the precision), a target risk level, and | ||
# a confidence level. Then we use the calibration data to compute statistically | ||
# guaranteed thresholds using a risk control method. | ||
|
||
target_precision = 0.8 | ||
bcc = BinaryClassificationController( | ||
clf.predict_proba, precision, target_level=target_precision, confidence_level=0.9) | ||
bcc.calibrate(X_calib, y_calib) | ||
|
||
print(f'{len(bcc.valid_predict_params)} valid thresholds found. ' | ||
f'The best one is {bcc.best_predict_param:.3f}.') | ||
|
||
|
||
############################################################################## | ||
# In the plot below, we visualize how the threshold values impact precision, and what | ||
# thresholds have been computed as statistically guaranteed. | ||
|
||
proba_positive_class = clf.predict_proba(X_calib)[:, 1] | ||
|
||
tested_thresholds = bcc._predict_params | ||
precisions = np.full(len(tested_thresholds), np.inf) | ||
for i, threshold in enumerate(tested_thresholds): | ||
y_pred = (proba_positive_class >= threshold).astype(int) | ||
precisions[i] = precision_score(y_calib, y_pred) | ||
|
||
valid_thresholds_indices = np.array( | ||
[t in bcc.valid_predict_params for t in tested_thresholds]) | ||
best_threshold_index = np.where( | ||
tested_thresholds == bcc.best_predict_param)[0][0] | ||
|
||
plt.figure() | ||
plt.scatter(tested_thresholds[valid_thresholds_indices], | ||
precisions[valid_thresholds_indices], c='tab:green', | ||
label='Valid thresholds') | ||
plt.scatter(tested_thresholds[~valid_thresholds_indices], | ||
precisions[~valid_thresholds_indices], c='tab:red', | ||
label='Invalid thresholds') | ||
plt.scatter(tested_thresholds[best_threshold_index], precisions[best_threshold_index], | ||
c='tab:green', label='Best threshold', marker='*', edgecolors='k', s=300) | ||
plt.axhline(target_precision, color='tab:gray', linestyle='--') | ||
plt.text(0, target_precision+0.02, 'Target precision', | ||
color='tab:gray', fontstyle='italic') | ||
plt.xlabel('Threshold', labelpad=15) | ||
plt.ylabel('Precision') | ||
plt.legend() | ||
plt.show() | ||
|
||
############################################################################## | ||
# Contrary to the naive way of computing a threshold to satisfy a precision target on | ||
# calibration data, risk control provides statistical guarantees on unseen data. | ||
# Besides computing a set of valid thresholds, | ||
# :class:`~mapie.risk_control.BinaryClassificationController` also outputs the best | ||
# one, which in the case of precision is the threshold that, among all valid ones, | ||
# maximizes recall. | ||
# | ||
# In the figure above, the highest threshold values are considered invalid due to the | ||
# small number of observations used to compute the precision, following the Learn then | ||
# Test procedure. In the most extreme case, no observation is available, which causes | ||
# the precision value to be ill-defined and set to 0. | ||
# | ||
# After obtaining the best threshold, we can use the ``predict`` function of | ||
# :class:`~mapie.risk_control.BinaryClassificationController` for future predictions, | ||
# or use scikit-learn's ``FixedThresholdClassifier`` as a wrapper to benefit | ||
# from functionalities like easily plotting the decision boundary as seen below. | ||
|
||
y_pred = bcc.predict(X_test) | ||
|
||
clf_threshold = FixedThresholdClassifier(clf, threshold=bcc.best_predict_param) | ||
# necessary for plotting, alternatively you can use sklearn.frozen.FrozenEstimator | ||
clf_threshold.fit(X_train, y_train) | ||
|
||
disp = DecisionBoundaryDisplay.from_estimator( | ||
clf_threshold, X_test, response_method="predict", cmap=plt.cm.coolwarm) | ||
|
||
plt.scatter(X_test[y_test == 0, 0], X_test[y_test == 0, 1], | ||
edgecolors='k', c='tab:blue', alpha=0.5, label='"negative" class') | ||
plt.scatter(X_test[y_test == 1, 0], X_test[y_test == 1, 1], | ||
edgecolors='k', c='tab:red', alpha=0.5, label='"positive" class') | ||
plt.title("Decision Boundary of FixedThresholdClassifier") | ||
plt.xlabel("Feature 1") | ||
plt.ylabel("Feature 2") | ||
plt.legend() | ||
plt.show() | ||
|
||
############################################################################## | ||
# Different risk functions have been implemented, such as precision and recall, but you | ||
# can also implement your own custom function using | ||
# :class:`~mapie.risk_control.BinaryClassificationRisk`. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _risk_control_examples_2: | ||
|
||
2. Advanced analysis | ||
-------------------- | ||
|
||
The following examples use MAPIE for discussing more complex risk control problems. |
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _risk_control_examples: | ||
|
||
All risk control examples | ||
========================= | ||
|
||
Following is a collection of notebooks demonstrating how to use MAPIE for risk control. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.