Skip to content

Commit 3c92ed2

Browse files
committed
improve example
1 parent d3c6e28 commit 3c92ed2

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

examples/risk_control/2-advanced-analysis/plot_risk_control_multi_parameter_binary_classification.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
Use MAPIE to control risk of a binary classifier with multiple prediction parameters
44
====================================================================================
55
6-
AI is a powerful tool for mail sorting (for example between spam and urgent mails).
7-
However, because algorithms are not perfects,it sometimes requires manual verification.
8-
Thus one would like to be able to controlthe amount of mail sent to human validation.
9-
One way to do so is to define a multi-parameters prediction function based on a
10-
classifier predicted scores. This would allow to define a rule for mail checking,
6+
AI is a powerful tool for email sorting (for example between spam and urgent emails).
7+
However, because algorithms are not perfect, manual verification is sometimes required.
8+
Thus one would like to be able to control the amount of emails sent to human validation.
9+
One way to do so is to define a multi-parameter prediction function based on a
10+
classifier's predicted scores. This would allow defining a rule for email checking,
1111
which could be adapted by varying the prediction parameters.
1212
1313
In this example, we explain how to do risk control for binary classification relying
@@ -26,8 +26,8 @@
2626
RANDOM_STATE = 1
2727

2828
##############################################################################
29-
# First, load the dataset and then split it into training, calibration
30-
# (for conformalization), and test sets.
29+
# First, load the dataset and then split it into training, calibration,
30+
# and test sets.
3131

3232
X, y = make_circles(n_samples=5000, noise=0.3, factor=0.3, random_state=RANDOM_STATE)
3333
(X_train, X_calib, X_test, y_train, y_calib, y_test) = train_conformalize_test_split(
@@ -40,7 +40,7 @@
4040
)
4141

4242
# Plot the three datasets to visualize the distribution of the two classes. We can
43-
# assume that the feature space represents some embedding of e-mails.
43+
# assume that the feature space represents some embedding of emails.
4444
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
4545
titles = ["Training Data", "Calibration Data", "Test Data"]
4646
datasets = [(X_train, y_train), (X_calib, y_calib), (X_test, y_test)]
@@ -94,15 +94,20 @@
9494

9595

9696
#############################################################################
97-
# Third define a multiparameter prediciton function
97+
# Third define a multi-parameter prediction function. For an email to be sent
98+
# to human verification, we want the predicted score of the positive class to be
99+
# between two thresholds `lambda_1` and `lambda_2`. High (respectively low) values of
100+
# the score correspond to high confidence that the email is a spam (respectively not a spam).
101+
# Therefore, emails with intermediate scores are the ones for which the classifier
102+
# is the least certain, and we want these emails to be verified by a human.
98103
def send_to_human(X, lambda_1, lambda_2):
99104
y_score = clf.predict_proba(X)[:, 1]
100105
return (lambda_1 <= y_score) & (y_score < lambda_2)
101106

102107

103108
#############################################################################
104109
# From the previous function, we know we have a constraint
105-
# lambda_1 <= lambda_2. We can generate a set of values to explore respecting
110+
# `lambda_1` <= `lambda_2`. We can generate a set of values to explore respecting
106111
# this constraint.
107112

108113
to_explore = []
@@ -116,8 +121,9 @@ def send_to_human(X, lambda_1, lambda_2):
116121
to_explore = np.array(to_explore)
117122

118123
#############################################################################
119-
# As we want to control the proportion of mail to be verified by a human.
120-
# We need to define a specific :class:`BinaryClassificationRisk`
124+
# Because we want to control the proportion of emails to be verified by a human,
125+
# we need to define a specific :class:`BinaryClassificationRisk` which represents
126+
# the fraction of samples predicted as positive (i.e., sent to human verification).
121127

122128
prop_positive = BinaryClassificationRisk(
123129
risk_occurrence=lambda y_true, y_pred: y_pred,
@@ -129,7 +135,7 @@ def send_to_human(X, lambda_1, lambda_2):
129135
# Finally, we initialize a :class:`~mapie.risk_control.BinaryClassificationController`
130136
# using our custom function ``send_to_human``, our custom risk ``prop_positive``,
131137
# a target risk level (0.2), and a confidence level (0.9). Then we use the calibration
132-
# data to compute statistically guaranteed thresholds using a multiparameter control
138+
# data to compute statistically guaranteed thresholds using a multi-parameter control
133139
# method.
134140

135141
target_level = 0.2
@@ -146,9 +152,9 @@ def send_to_human(X, lambda_1, lambda_2):
146152
bcc.calibrate(X_calib, y_calib)
147153

148154
print(
149-
f"Multiple parameters : {len(bcc.valid_predict_params)} "
150-
f"thresholds found that guarantee a precision of at least {target_level}\n"
151-
f"and a recall of at least {target_level} with a confidence of {confidence_level}."
155+
f"{len(bcc.valid_predict_params)} multi-dimensional parameters "
156+
f"found that guarantee a proportion of emails sent to verification\n"
157+
f"of at most {target_level} with a confidence of {confidence_level}."
152158
)
153159

154160
#######################################################################
@@ -158,10 +164,12 @@ def send_to_human(X, lambda_1, lambda_2):
158164
col = valid_params[1] * 10
159165
matrix[int(row), int(col)] = 1
160166

161-
fig, ax = plt.subplots(figsize=(16, 12))
167+
fig, ax = plt.subplots(figsize=(6, 6))
162168
im = ax.imshow(matrix, cmap="inferno")
163169
ax.set_xticks(range(10), labels=(np.array(range(10)) / 10))
164170
ax.set_yticks(range(10), labels=(np.array(range(10)) / 10))
165-
ax.set_title("Validated parameters")
171+
ax.set_xlabel(r"lambda_2")
172+
ax.set_ylabel(r"lambda_1")
173+
ax.set_title("Valid parameters")
166174
fig.tight_layout()
167175
plt.show()

0 commit comments

Comments
 (0)