Skip to content

Commit c1bcc53

Browse files
committed
add quick start risk control
1 parent d8057c6 commit c1bcc53

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
:caption: Control prediction errors
2525

2626
theoretical_description_risk_control
27+
examples_risk_control/1-quickstart/plot_risk_control_binary_classification
2728
examples_risk_control/index
2829
external_risk_control_package
2930

doc/quick_start.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,10 @@ Here, we generate one-dimensional noisy data that we fit with a MLPRegressor: `U
4040
3. Classification
4141
=======================
4242

43-
Similarly, it's possible to do the same for a basic classification problem: `Use MAPIE to plot prediction sets <https://mapie.readthedocs.io/en/stable/examples_classification/1-quickstart/plot_quickstart_classification.html>`_
43+
Similarly, it's possible to do the same for a basic classification problem: `Use MAPIE to plot prediction sets <https://mapie.readthedocs.io/en/stable/examples_classification/1-quickstart/plot_quickstart_classification.html>`_
44+
45+
46+
4. Risk Control
47+
=======================
48+
49+
MAPIE implements risk control methods for multilabel classification (in particular, image segmentation) and binary classification: `Use MAPIE to control risk for a binary classifier <https://mapie.readthedocs.io/en/stable/examples_risk_control/1-quickstart/plot_risk_control_binary_classification.html>`_
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
=================================================
3+
Use MAPIE to control risk for a binary classifier
4+
=================================================
5+
6+
In this example, we explain how to do risk control for binary classification with MAPIE.
7+
8+
"""
9+
10+
import numpy as np
11+
import matplotlib.pyplot as plt
12+
from sklearn.datasets import make_circles
13+
from sklearn.svm import SVC
14+
from sklearn.model_selection import FixedThresholdClassifier
15+
from sklearn.metrics import precision_score
16+
from sklearn.inspection import DecisionBoundaryDisplay
17+
18+
from mapie.risk_control import BinaryClassificationController, precision
19+
from mapie.utils import train_conformalize_test_split
20+
21+
RANDOM_STATE = 1
22+
23+
##############################################################################
24+
# Let us first load the dataset and fit an SVC on the training data.
25+
26+
X, y = make_circles(n_samples=3000, noise=0.3, factor=0.3, random_state=RANDOM_STATE)
27+
(X_train, X_calib, X_test,
28+
y_train, y_calib, y_test) = train_conformalize_test_split(
29+
X, y, train_size=0.8, conformalize_size=0.1, test_size=0.1,
30+
random_state=RANDOM_STATE)
31+
32+
clf = SVC(probability=True, random_state=RANDOM_STATE)
33+
clf.fit(X_train, y_train)
34+
35+
##############################################################################
36+
# Next, we initialize a :class:`~mapie.risk_control.BinaryClassificationController` using the probability estimation function from the fitted estimator: ``clf.predict_proba``, a risk function (here the precision), a target risk level, and a confidence level. Then we use the calibration data to compute statistically guaranteed thresholds using a risk control method.
37+
38+
target_precision = 0.8
39+
bcc = BinaryClassificationController(
40+
clf.predict_proba, precision, target_level=target_precision, confidence_level=0.9)
41+
bcc.calibrate(X_calib, y_calib)
42+
43+
print(f'{len(bcc.valid_predict_params)} valid thresholds found. The best one is {bcc.best_predict_param:.3f}.')
44+
45+
46+
##############################################################################
47+
# In the plot below, we visualize how the threshold values impact precision, and what thresholds have been computed as statistically guaranteed.
48+
49+
proba_positive_class = clf.predict_proba(X_calib)[:, 1]
50+
51+
tested_thresholds = bcc._predict_params
52+
precisions = np.full(len(tested_thresholds), np.inf)
53+
for i, threshold in enumerate(tested_thresholds):
54+
y_pred = (proba_positive_class >= threshold).astype(int)
55+
precisions[i] = precision_score(y_calib, y_pred)
56+
57+
valid_thresholds_indices = np.array([t in bcc.valid_predict_params for t in tested_thresholds])
58+
best_threshold_index = np.where(tested_thresholds == bcc.best_predict_param)[0][0]
59+
60+
plt.figure()
61+
plt.scatter(tested_thresholds[valid_thresholds_indices], precisions[valid_thresholds_indices], c='tab:green', label='Valid thresholds')
62+
plt.scatter(tested_thresholds[~valid_thresholds_indices], precisions[~valid_thresholds_indices], c='tab:red', label='Invalid thresholds')
63+
plt.scatter(tested_thresholds[best_threshold_index], precisions[best_threshold_index], c='tab:green', label='Best threshold', marker='*', edgecolors='k', s=300)
64+
plt.axhline(target_precision, color='tab:gray', linestyle='--')
65+
plt.text(0, target_precision+0.02, 'Target precision', color='tab:gray', fontstyle='italic')
66+
plt.xlabel('Threshold', labelpad=15)
67+
plt.ylabel('Precision')
68+
plt.legend()
69+
plt.show()
70+
71+
##############################################################################
72+
# Contrary to the naive way of computing a threshold to satisfy a precision target on calibration data, risk control provides statistical guarantees on unseen data. Besides computing a set of valid thresholds, :class:`~mapie.risk_control.BinaryClassificationController` also outputs the best one, which in the case of precision is the threshold that, among all valid ones, maximizes recall.
73+
#
74+
# In the figure above, the highest threshold values are considered invalid due to the small number of observations used to compute the precision, following the Learn then Test procedure. In the most extreme case, no observation is available, which causes the precision value to be ill-defined and set to 0.
75+
#
76+
# After obtaining the best threshold, we can use the ``predict`` function of :class:`~mapie.risk_control.BinaryClassificationController` for future predictions, or use scikit-learn's ``FixedThresholdClassifier`` as a wrapper to benefit from functionalities like easily plotting the decision boundary as seen below.
77+
78+
y_pred = bcc.predict(X_test)
79+
80+
clf_threshold = FixedThresholdClassifier(clf, threshold=bcc.best_predict_param)
81+
clf_threshold.fit(X_train, y_train) # necessary for plotting, alternatively you can use sklearn.frozen.FrozenEstimator
82+
83+
disp = DecisionBoundaryDisplay.from_estimator(clf_threshold, X_test, response_method="predict", cmap=plt.cm.coolwarm)
84+
85+
plt.scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], edgecolors='k', c='tab:blue', alpha=0.5, label='"negative" class')
86+
plt.scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], edgecolors='k', c='tab:red', alpha=0.5, label='"positive" class')
87+
plt.title("Decision Boundary of FixedThresholdClassifier")
88+
plt.xlabel("Feature 1")
89+
plt.ylabel("Feature 2")
90+
plt.legend()
91+
plt.show()
92+
93+
##############################################################################
94+
# Different risk functions have been implemented, such as precision and recall, but you can also implement your own custom function using :class:`~mapie.risk_control.BinaryClassificationRisk`.

0 commit comments

Comments
 (0)