Skip to content

Commit d0b64b6

Browse files
committed
__init__ and _set_best_predict_param_choice handle multi risk
1 parent 83893ff commit d0b64b6

File tree

1 file changed

+49
-45
lines changed

1 file changed

+49
-45
lines changed

mapie/risk_control.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -984,14 +984,10 @@ def __init__(
984984
best_predict_param_choice: Union[
985985
Literal["auto"], BinaryClassificationRisk] = "auto",
986986
):
987-
self._check_risks_targets_same_len(risk, target_level)
988-
987+
self.is_multi_risk = self._check_if_multi_risk_control(risk, target_level)
989988
self._predict_function = predict_function
990989
self._risk = risk
991-
if self._risk.higher_is_better:
992-
self._alpha = 1 - target_level
993-
else:
994-
self._alpha = target_level
990+
self._alpha = self._convert_target_level_to_alpha(target_level)
995991
self._delta = 1 - confidence_level
996992

997993
self._best_predict_param_choice = self._set_best_predict_param_choice(
@@ -1003,20 +999,6 @@ def __init__(
1003999
self.valid_predict_params: NDArray = np.array([])
10041000
self.best_predict_param: Optional[float] = None
10051001

1006-
def convert_target_level_to_alpha(self, target_level):
1007-
if isinstance(target_level, float):
1008-
if self._risk.higher_is_better:
1009-
self._alpha = 1 - target_level
1010-
else:
1011-
self._alpha = target_level
1012-
else:
1013-
self._alpha = []
1014-
for risk, target in zip(self._risk, target_level):
1015-
if risk.higher_is_better:
1016-
self._alpha.append(1 - target)
1017-
else:
1018-
self._alpha.append(target)
1019-
10201002
# All subfunctions are unit-tested. To avoid having to write
10211003
# tests just to make sure those subfunctions are called,
10221004
# we don't include .calibrate in the coverage report
@@ -1114,16 +1096,20 @@ def _set_best_predict_param_choice(
11141096
Literal["auto"], BinaryClassificationRisk] = "auto",
11151097
) -> BinaryClassificationRisk:
11161098
if best_predict_param_choice == "auto":
1117-
try:
1118-
return self._best_predict_param_choice_map[
1119-
self._risk
1120-
]
1121-
except KeyError:
1122-
raise ValueError(
1123-
"When best_predict_param_choice is 'auto', "
1124-
"risk must be one of the risks defined in mapie.risk_control"
1125-
"(e.g. precision, accuracy, false_positive_rate)."
1126-
)
1099+
if self.is_multi_risk:
1100+
# when multi risk, we minimize the first risk in the list
1101+
return self._risk[0] # type: ignore
1102+
else:
1103+
try:
1104+
return self._best_predict_param_choice_map[
1105+
self._risk # type: ignore
1106+
]
1107+
except KeyError:
1108+
raise ValueError(
1109+
"When best_predict_param_choice is 'auto', "
1110+
"risk must be one of the risks defined in mapie.risk_control"
1111+
"(e.g. precision, accuracy, false_positive_rate)."
1112+
)
11271113
else:
11281114
return best_predict_param_choice
11291115

@@ -1191,25 +1177,43 @@ def _get_predictions_per_param(self, X: ArrayLike, params: NDArray) -> NDArray:
11911177
raise
11921178
return (predictions_proba[:, np.newaxis] >= params).T.astype(int)
11931179

1180+
def _convert_target_level_to_alpha(self, target_level):
1181+
if self.is_multi_risk:
1182+
alpha = []
1183+
for risk, target in zip(self._risk, target_level):
1184+
if risk.higher_is_better:
1185+
alpha.append(1 - target)
1186+
else:
1187+
alpha.append(target)
1188+
else:
1189+
if self._risk.higher_is_better:
1190+
alpha = 1 - target_level
1191+
else:
1192+
alpha = target_level
1193+
return alpha
1194+
11941195
@staticmethod
1195-
def _check_risks_targets_same_len( # TODO what about lists of len 1
1196+
def _check_if_multi_risk_control( # TODO what about lists of len 1
11961197
risk: Union[BinaryClassificationRisk, List[BinaryClassificationRisk]],
11971198
target_level: Union[float, List[float]],
1198-
) -> None:
1199+
) -> bool:
1200+
"""
1201+
Check if we are in a multi risk setting and if inputs types are correct.
1202+
"""
11991203
if (
1200-
isinstance(risk, list) and isinstance(target_level, float)
1201-
or (
1202-
isinstance(risk, BinaryClassificationRisk)
1203-
and isinstance(target_level, list)
1204-
)
1205-
or (
1206-
isinstance(risk, list)
1207-
and isinstance(target_level, list)
1208-
and len(risk) != len(target_level)
1209-
)
1204+
isinstance(risk, list) and isinstance(target_level, list)
1205+
and len(risk) == len(target_level)
1206+
):
1207+
return True
1208+
elif (
1209+
isinstance(risk, BinaryClassificationRisk)
1210+
and isinstance(target_level, float)
12101211
):
1212+
return False
1213+
else:
12111214
raise ValueError(
1212-
"If you provide a list of risks, "
1213-
"you must provide a list of target levels of the same length "
1214-
"and vice versa."
1215+
"If you provide a list of risks, you must provide "
1216+
"a list of target levels of the same length and vice versa. "
1217+
"If you provide a single BinaryClassificationRisk risk, "
1218+
"you must provide a single float target level."
12151219
)

0 commit comments

Comments
 (0)