@@ -984,14 +984,10 @@ def __init__(
984984 best_predict_param_choice : Union [
985985 Literal ["auto" ], BinaryClassificationRisk ] = "auto" ,
986986 ):
987- self ._check_risks_targets_same_len (risk , target_level )
988-
987+ self .is_multi_risk = self ._check_if_multi_risk_control (risk , target_level )
989988 self ._predict_function = predict_function
990989 self ._risk = risk
991- if self ._risk .higher_is_better :
992- self ._alpha = 1 - target_level
993- else :
994- self ._alpha = target_level
990+ self ._alpha = self ._convert_target_level_to_alpha (target_level )
995991 self ._delta = 1 - confidence_level
996992
997993 self ._best_predict_param_choice = self ._set_best_predict_param_choice (
@@ -1003,20 +999,6 @@ def __init__(
1003999 self .valid_predict_params : NDArray = np .array ([])
10041000 self .best_predict_param : Optional [float ] = None
10051001
1006- def convert_target_level_to_alpha (self , target_level ):
1007- if isinstance (target_level , float ):
1008- if self ._risk .higher_is_better :
1009- self ._alpha = 1 - target_level
1010- else :
1011- self ._alpha = target_level
1012- else :
1013- self ._alpha = []
1014- for risk , target in zip (self ._risk , target_level ):
1015- if risk .higher_is_better :
1016- self ._alpha .append (1 - target )
1017- else :
1018- self ._alpha .append (target )
1019-
10201002 # All subfunctions are unit-tested. To avoid having to write
10211003 # tests just to make sure those subfunctions are called,
10221004 # we don't include .calibrate in the coverage report
@@ -1114,16 +1096,20 @@ def _set_best_predict_param_choice(
11141096 Literal ["auto" ], BinaryClassificationRisk ] = "auto" ,
11151097 ) -> BinaryClassificationRisk :
11161098 if best_predict_param_choice == "auto" :
1117- try :
1118- return self ._best_predict_param_choice_map [
1119- self ._risk
1120- ]
1121- except KeyError :
1122- raise ValueError (
1123- "When best_predict_param_choice is 'auto', "
1124- "risk must be one of the risks defined in mapie.risk_control"
1125- "(e.g. precision, accuracy, false_positive_rate)."
1126- )
1099+ if self .is_multi_risk :
1100+ # when multi risk, we minimize the first risk in the list
1101+ return self ._risk [0 ] # type: ignore
1102+ else :
1103+ try :
1104+ return self ._best_predict_param_choice_map [
1105+ self ._risk # type: ignore
1106+ ]
1107+ except KeyError :
1108+ raise ValueError (
1109+ "When best_predict_param_choice is 'auto', "
1110+ "risk must be one of the risks defined in mapie.risk_control"
1111+ "(e.g. precision, accuracy, false_positive_rate)."
1112+ )
11271113 else :
11281114 return best_predict_param_choice
11291115
@@ -1191,25 +1177,43 @@ def _get_predictions_per_param(self, X: ArrayLike, params: NDArray) -> NDArray:
11911177 raise
11921178 return (predictions_proba [:, np .newaxis ] >= params ).T .astype (int )
11931179
1180+ def _convert_target_level_to_alpha (self , target_level ):
1181+ if self .is_multi_risk :
1182+ alpha = []
1183+ for risk , target in zip (self ._risk , target_level ):
1184+ if risk .higher_is_better :
1185+ alpha .append (1 - target )
1186+ else :
1187+ alpha .append (target )
1188+ else :
1189+ if self ._risk .higher_is_better :
1190+ alpha = 1 - target_level
1191+ else :
1192+ alpha = target_level
1193+ return alpha
1194+
11941195 @staticmethod
1195- def _check_risks_targets_same_len ( # TODO what about lists of len 1
1196+ def _check_if_multi_risk_control ( # TODO what about lists of len 1
11961197 risk : Union [BinaryClassificationRisk , List [BinaryClassificationRisk ]],
11971198 target_level : Union [float , List [float ]],
1198- ) -> None :
1199+ ) -> bool :
1200+ """
1201+ Check if we are in a multi risk setting and if inputs types are correct.
1202+ """
11991203 if (
1200- isinstance (risk , list ) and isinstance (target_level , float )
1201- or (
1202- isinstance (risk , BinaryClassificationRisk )
1203- and isinstance (target_level , list )
1204- )
1205- or (
1206- isinstance (risk , list )
1207- and isinstance (target_level , list )
1208- and len (risk ) != len (target_level )
1209- )
1204+ isinstance (risk , list ) and isinstance (target_level , list )
1205+ and len (risk ) == len (target_level )
1206+ ):
1207+ return True
1208+ elif (
1209+ isinstance (risk , BinaryClassificationRisk )
1210+ and isinstance (target_level , float )
12101211 ):
1212+ return False
1213+ else :
12111214 raise ValueError (
1212- "If you provide a list of risks, "
1213- "you must provide a list of target levels of the same length "
1214- "and vice versa."
1215+ "If you provide a list of risks, you must provide "
1216+ "a list of target levels of the same length and vice versa. "
1217+ "If you provide a single BinaryClassificationRisk risk, "
1218+ "you must provide a single float target level."
12151219 )
0 commit comments