Skip to content

Commit dfc95ae

Browse files
authored
Binary risk control - multi risk (#762)
* test_check_risks_targets_same_len working * __init__ and _set_best_predict_param_choice handle multi risk * update and rename _get_risks_and_effective_sample_sizes_per_param for multi risk * update _set_best_predict_param for multi risk * fix alpha * self._risk is now always a list * lists of len 1 = mono risk * test updated to new behavior for binary: r_hat and n_obs now always have 2 dimensions (even in mono risk) * simplify ltt for binary * keep returning lists of lists for ltt * update ltt_procedure and its calls and tests for multi risk * ltt now fails when bad shape of inputs for multi risk * add unit tests ltt multi risk * ensure compatibility with python<3.10
1 parent a152f01 commit dfc95ae

File tree

4 files changed

+238
-61
lines changed

4 files changed

+238
-61
lines changed

mapie/control_risk/ltt.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, List, Tuple, Union
2+
from typing import Any, List, Tuple
33

44
import numpy as np
55

@@ -12,7 +12,7 @@ def ltt_procedure(
1212
r_hat: NDArray,
1313
alpha_np: NDArray,
1414
delta: float,
15-
n_obs: Union[int, NDArray],
15+
n_obs: NDArray,
1616
binary: bool = False,
1717
) -> List[List[Any]]:
1818
"""
@@ -24,28 +24,36 @@ def ltt_procedure(
2424
- Apply a family wise error rate algorithm, here Bonferonni correction
2525
- Return the index lambdas that give you the control at alpha level
2626
27+
Note that in the case of multi-risk, the arrays r_hat, alpha_np, and n_obs
28+
should have the same length for the first dimension which corresponds
29+
to the number of risks. In the case of a single risk, the length should be 1.
30+
2731
Parameters
2832
----------
29-
r_hat: NDArray of shape (n_lambdas, ).
33+
r_hat: NDArray of shape (n_risks, n_lambdas).
3034
Empirical risk with respect to the lambdas.
3135
Here lambdas are thresholds that impact decision-making,
3236
therefore empirical risk.
3337
34-
alpha_np: NDArray of shape (n_alpha, ).
38+
alpha_np: NDArray of shape (n_risks, n_alpha).
3539
Contains the different alphas control level.
3640
The empirical risk should be less than alpha with
3741
probability 1-delta.
42+
Note: MAPIE 1.2 does not support multiple risks and multiple alphas
43+
simultaneously.
44+
For PrecisionRecallController, the shape should be (1, n_alpha).
45+
For BinaryClassificationController, the shape should be (n_risks, 1).
3846
3947
delta: float.
4048
Probability of not controlling empirical risk.
4149
Correspond to proportion of failure we don't
4250
want to exceed.
4351
44-
n_obs: Union[int, NDArray]
52+
n_obs: NDArray of shape (n_risks, n_lambdas).
4553
Correspond to the number of observations used to compute the risk.
4654
In the case of a conditional loss, n_obs must be the
4755
number of effective observations used to compute the empirical risk
48-
for each lambda, hence of shape (n_lambdas, ).
56+
for each lambda.
4957
5058
binary: bool, default=False
5159
Must be True if the loss associated to the risk is binary.
@@ -62,11 +70,19 @@ def ltt_procedure(
6270
M. I., & Lei, L. (2021). Learn then test:
6371
"Calibrating predictive algorithms to achieve risk control".
6472
"""
65-
p_values = compute_hoeffding_bentkus_p_value(r_hat, n_obs, alpha_np, binary)
73+
if not (r_hat.shape[0] == n_obs.shape[0] == alpha_np.shape[0]):
74+
raise ValueError(
75+
"r_hat, n_obs, and alpha_np must have the same length."
76+
)
77+
p_values = np.array([
78+
compute_hoeffding_bentkus_p_value(r_hat_i, n_obs_i, alpha_np_i, binary)
79+
for r_hat_i, n_obs_i, alpha_np_i in zip(r_hat, n_obs, alpha_np)
80+
])
81+
p_values = p_values.max(axis=0) # take max over risks (no effect if mono risk)
6682
N = len(p_values)
6783
valid_index = []
68-
for i in range(len(alpha_np)):
69-
l_index = np.where(p_values[:, i] <= delta/N)[0].tolist()
84+
for i in range(alpha_np.shape[1]):
85+
l_index = np.nonzero(p_values[:, i] <= delta/N)[0].tolist()
7086
valid_index.append(l_index)
7187
return valid_index
7288

mapie/risk_control.py

Lines changed: 97 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def _transform_pred_proba(
471471
y_pred_proba_array = y_pred_proba
472472
else:
473473
y_pred_proba_stacked = np.stack(
474-
y_pred_proba, # type: ignore
474+
y_pred_proba,
475475
axis=0
476476
)[:, :, 1]
477477
y_pred_proba_array = np.moveaxis(y_pred_proba_stacked, 0, -1)
@@ -669,7 +669,10 @@ def predict(
669669
self.n_obs = len(self.risks)
670670
self.r_hat = self.risks.mean(axis=0)
671671
self.valid_index = ltt_procedure(
672-
self.r_hat, alpha_np, cast(float, delta), self.n_obs
672+
np.expand_dims(self.r_hat, axis=0),
673+
np.expand_dims(alpha_np, axis=0),
674+
cast(float, delta),
675+
np.expand_dims(np.array([self.n_obs]), axis=0)
673676
)
674677
self._check_valid_index(alpha_np)
675678
self.lambdas_star, self.r_star = find_lambda_control_star(
@@ -865,16 +868,20 @@ class BinaryClassificationController:
865868
predict_proba method of a fitted binary classifier.
866869
Its output signature must be of shape (len(X), 2)
867870
868-
risk : BinaryClassificationRisk
871+
risk : Union[BinaryClassificationRisk, List[BinaryClassificationRisk]]
869872
The risk or performance metric to control.
870873
Valid options:
871874
872875
- An existing risk defined in `mapie.risk_control` (e.g. precision, recall,
873876
accuracy, false_positive_rate)
874877
- A custom instance of BinaryClassificationRisk object
875878
876-
target_level : float
879+
Can be a list of risks in the case of multi risk control.
880+
881+
target_level : Union[float, List[float]]
877882
The maximum risk level (or minimum performance level). Must be between 0 and 1.
883+
Can be a list of target levels in the case of multi risk control (length should
884+
match the length of the risks list).
878885
879886
confidence_level : float, default=0.9
880887
The confidence level with which the risk (or performance) is controlled.
@@ -950,18 +957,19 @@ class BinaryClassificationController:
950957
def __init__(
951958
self,
952959
predict_function: Callable[[ArrayLike], NDArray],
953-
risk: BinaryClassificationRisk,
954-
target_level: float,
960+
risk: Union[BinaryClassificationRisk, List[BinaryClassificationRisk]],
961+
target_level: Union[float, List[float]],
955962
confidence_level: float = 0.9,
956963
best_predict_param_choice: Union[
957964
Literal["auto"], BinaryClassificationRisk] = "auto",
958965
):
966+
self.is_multi_risk = self._check_if_multi_risk_control(risk, target_level)
959967
self._predict_function = predict_function
960-
self._risk = risk
961-
if self._risk.higher_is_better:
962-
self._alpha = 1 - target_level
963-
else:
964-
self._alpha = target_level
968+
self._risk = risk if isinstance(risk, list) else [risk]
969+
target_level_list = (
970+
target_level if isinstance(target_level, list) else [target_level]
971+
)
972+
self._alpha = self._convert_target_level_to_alpha(target_level_list)
965973
self._delta = 1 - confidence_level
966974

967975
self._best_predict_param_choice = self._set_best_predict_param_choice(
@@ -1006,20 +1014,16 @@ def calibrate( # pragma: no cover
10061014
self._predict_params
10071015
)
10081016

1009-
risks_and_eff_sizes = self._get_risks_and_effective_sample_sizes_per_param(
1017+
risk_values, eff_sample_sizes = self._get_risk_values_and_eff_sample_sizes(
10101018
y_calibrate_,
10111019
predictions_per_param,
10121020
self._risk
10131021
)
1014-
1015-
risks_per_param = risks_and_eff_sizes[:, 0]
1016-
eff_sample_sizes_per_param = risks_and_eff_sizes[:, 1]
1017-
10181022
valid_params_index = ltt_procedure(
1019-
risks_per_param,
1020-
np.array([self._alpha]),
1023+
risk_values,
1024+
np.expand_dims(self._alpha, axis=1),
10211025
self._delta,
1022-
eff_sample_sizes_per_param,
1026+
eff_sample_sizes,
10231027
True,
10241028
)[0]
10251029

@@ -1072,16 +1076,20 @@ def _set_best_predict_param_choice(
10721076
Literal["auto"], BinaryClassificationRisk] = "auto",
10731077
) -> BinaryClassificationRisk:
10741078
if best_predict_param_choice == "auto":
1075-
try:
1076-
return self._best_predict_param_choice_map[
1077-
self._risk
1078-
]
1079-
except KeyError:
1080-
raise ValueError(
1081-
"When best_predict_param_choice is 'auto', "
1082-
"risk must be one of the risks defined in mapie.risk_control"
1083-
"(e.g. precision, accuracy, false_positive_rate)."
1084-
)
1079+
if self.is_multi_risk:
1080+
# when multi risk, we minimize the first risk in the list
1081+
return self._risk[0]
1082+
else:
1083+
try:
1084+
return self._best_predict_param_choice_map[
1085+
self._risk[0]
1086+
]
1087+
except KeyError:
1088+
raise ValueError(
1089+
"When best_predict_param_choice is 'auto', "
1090+
"risk must be one of the risks defined in mapie.risk_control"
1091+
"(e.g. precision, accuracy, false_positive_rate)."
1092+
)
10851093
else:
10861094
return best_predict_param_choice
10871095

@@ -1099,29 +1107,37 @@ def _set_best_predict_param(
10991107
predictions_per_param: NDArray,
11001108
valid_params_index: List[Any],
11011109
):
1102-
secondary_risks_per_param = \
1103-
self._get_risks_and_effective_sample_sizes_per_param(
1110+
secondary_risks_per_param, _ = self._get_risk_values_and_eff_sample_sizes(
11041111
y_calibrate_,
11051112
predictions_per_param[valid_params_index],
1106-
self._best_predict_param_choice
1107-
)[:, 0]
1113+
[self._best_predict_param_choice]
1114+
)
11081115

11091116
self.best_predict_param = self.valid_predict_params[
11101117
np.argmin(secondary_risks_per_param)
11111118
]
11121119

11131120
@staticmethod
1114-
def _get_risks_and_effective_sample_sizes_per_param(
1121+
def _get_risk_values_and_eff_sample_sizes(
11151122
y_true: NDArray,
11161123
predictions_per_param: NDArray,
1117-
risk: BinaryClassificationRisk,
1118-
) -> NDArray:
1119-
return np.array(
1120-
[risk.get_value_and_effective_sample_size(
1121-
y_true,
1122-
predictions
1123-
) for predictions in predictions_per_param]
1124-
)
1124+
risks: List[BinaryClassificationRisk],
1125+
) -> Tuple[NDArray, NDArray]:
1126+
"""
1127+
Compute the values of risks and effective sample sizes for multiple risks
1128+
and for multiple parameter values.
1129+
Returns arrays with shape (n_risks, n_params).
1130+
"""
1131+
risks_values_and_eff_sizes = np.array([
1132+
[risk.get_value_and_effective_sample_size(y_true, predictions)
1133+
for predictions in predictions_per_param]
1134+
for risk in risks
1135+
])
1136+
1137+
risk_values = risks_values_and_eff_sizes[:, :, 0]
1138+
effective_sample_sizes = risks_values_and_eff_sizes[:, :, 1]
1139+
1140+
return risk_values, effective_sample_sizes
11251141

11261142
def _get_predictions_per_param(self, X: ArrayLike, params: NDArray) -> NDArray:
11271143
try:
@@ -1148,3 +1164,42 @@ def _get_predictions_per_param(self, X: ArrayLike, params: NDArray) -> NDArray:
11481164
else:
11491165
raise
11501166
return (predictions_proba[:, np.newaxis] >= params).T.astype(int)
1167+
1168+
def _convert_target_level_to_alpha(self, target_level: List[float]) -> NDArray:
1169+
alpha = []
1170+
for risk, target in zip(self._risk, target_level):
1171+
if risk.higher_is_better:
1172+
alpha.append(1 - target)
1173+
else:
1174+
alpha.append(target)
1175+
return np.array(alpha)
1176+
1177+
@staticmethod
1178+
def _check_if_multi_risk_control(
1179+
risk: Union[BinaryClassificationRisk, List[BinaryClassificationRisk]],
1180+
target_level: Union[float, List[float]],
1181+
) -> bool:
1182+
"""
1183+
Check if we are in a multi risk setting and if inputs types are correct.
1184+
"""
1185+
if (
1186+
isinstance(risk, list) and isinstance(target_level, list)
1187+
and len(risk) == len(target_level)
1188+
and len(risk) > 0
1189+
):
1190+
if len(risk) == 1:
1191+
return False
1192+
else:
1193+
return True
1194+
elif (
1195+
isinstance(risk, BinaryClassificationRisk)
1196+
and isinstance(target_level, float)
1197+
):
1198+
return False
1199+
else:
1200+
raise ValueError(
1201+
"If you provide a list of risks, you must provide "
1202+
"a list of target levels of the same length and vice versa. "
1203+
"If you provide a single BinaryClassificationRisk risk, "
1204+
"you must provide a single float target level."
1205+
)

0 commit comments

Comments
 (0)