@@ -471,7 +471,7 @@ def _transform_pred_proba(
471471 y_pred_proba_array = y_pred_proba
472472 else :
473473 y_pred_proba_stacked = np .stack (
474- y_pred_proba , # type: ignore
474+ y_pred_proba ,
475475 axis = 0
476476 )[:, :, 1 ]
477477 y_pred_proba_array = np .moveaxis (y_pred_proba_stacked , 0 , - 1 )
@@ -669,7 +669,10 @@ def predict(
669669 self .n_obs = len (self .risks )
670670 self .r_hat = self .risks .mean (axis = 0 )
671671 self .valid_index = ltt_procedure (
672- self .r_hat , alpha_np , cast (float , delta ), self .n_obs
672+ np .expand_dims (self .r_hat , axis = 0 ),
673+ np .expand_dims (alpha_np , axis = 0 ),
674+ cast (float , delta ),
675+ np .expand_dims (np .array ([self .n_obs ]), axis = 0 )
673676 )
674677 self ._check_valid_index (alpha_np )
675678 self .lambdas_star , self .r_star = find_lambda_control_star (
@@ -865,16 +868,20 @@ class BinaryClassificationController:
865868 predict_proba method of a fitted binary classifier.
866869 Its output signature must be of shape (len(X), 2)
867870
868- risk : BinaryClassificationRisk
871+ risk : Union[ BinaryClassificationRisk, List[BinaryClassificationRisk]]
869872 The risk or performance metric to control.
870873 Valid options:
871874
872875 - An existing risk defined in `mapie.risk_control` (e.g. precision, recall,
873876 accuracy, false_positive_rate)
874877 - A custom instance of BinaryClassificationRisk object
875878
876- target_level : float
879+ Can be a list of risks in the case of multi risk control.
880+
881+ target_level : Union[float, List[float]]
877882 The maximum risk level (or minimum performance level). Must be between 0 and 1.
883+ Can be a list of target levels in the case of multi risk control (length should
884+ match the length of the risks list).
878885
879886 confidence_level : float, default=0.9
880887 The confidence level with which the risk (or performance) is controlled.
@@ -950,18 +957,19 @@ class BinaryClassificationController:
950957 def __init__ (
951958 self ,
952959 predict_function : Callable [[ArrayLike ], NDArray ],
953- risk : BinaryClassificationRisk ,
954- target_level : float ,
960+ risk : Union [ BinaryClassificationRisk , List [ BinaryClassificationRisk ]] ,
961+ target_level : Union [ float , List [ float ]] ,
955962 confidence_level : float = 0.9 ,
956963 best_predict_param_choice : Union [
957964 Literal ["auto" ], BinaryClassificationRisk ] = "auto" ,
958965 ):
966+ self .is_multi_risk = self ._check_if_multi_risk_control (risk , target_level )
959967 self ._predict_function = predict_function
960- self ._risk = risk
961- if self . _risk . higher_is_better :
962- self . _alpha = 1 - target_level
963- else :
964- self ._alpha = target_level
968+ self ._risk = risk if isinstance ( risk , list ) else [ risk ]
969+ target_level_list = (
970+ target_level if isinstance ( target_level , list ) else [ target_level ]
971+ )
972+ self ._alpha = self . _convert_target_level_to_alpha ( target_level_list )
965973 self ._delta = 1 - confidence_level
966974
967975 self ._best_predict_param_choice = self ._set_best_predict_param_choice (
@@ -1006,20 +1014,16 @@ def calibrate( # pragma: no cover
10061014 self ._predict_params
10071015 )
10081016
1009- risks_and_eff_sizes = self ._get_risks_and_effective_sample_sizes_per_param (
1017+ risk_values , eff_sample_sizes = self ._get_risk_values_and_eff_sample_sizes (
10101018 y_calibrate_ ,
10111019 predictions_per_param ,
10121020 self ._risk
10131021 )
1014-
1015- risks_per_param = risks_and_eff_sizes [:, 0 ]
1016- eff_sample_sizes_per_param = risks_and_eff_sizes [:, 1 ]
1017-
10181022 valid_params_index = ltt_procedure (
1019- risks_per_param ,
1020- np .array ([ self ._alpha ] ),
1023+ risk_values ,
1024+ np .expand_dims ( self ._alpha , axis = 1 ),
10211025 self ._delta ,
1022- eff_sample_sizes_per_param ,
1026+ eff_sample_sizes ,
10231027 True ,
10241028 )[0 ]
10251029
@@ -1072,16 +1076,20 @@ def _set_best_predict_param_choice(
10721076 Literal ["auto" ], BinaryClassificationRisk ] = "auto" ,
10731077 ) -> BinaryClassificationRisk :
10741078 if best_predict_param_choice == "auto" :
1075- try :
1076- return self ._best_predict_param_choice_map [
1077- self ._risk
1078- ]
1079- except KeyError :
1080- raise ValueError (
1081- "When best_predict_param_choice is 'auto', "
1082- "risk must be one of the risks defined in mapie.risk_control"
1083- "(e.g. precision, accuracy, false_positive_rate)."
1084- )
1079+ if self .is_multi_risk :
1080+ # when multi risk, we minimize the first risk in the list
1081+ return self ._risk [0 ]
1082+ else :
1083+ try :
1084+ return self ._best_predict_param_choice_map [
1085+ self ._risk [0 ]
1086+ ]
1087+ except KeyError :
1088+ raise ValueError (
1089+ "When best_predict_param_choice is 'auto', "
1090+ "risk must be one of the risks defined in mapie.risk_control"
1091+ "(e.g. precision, accuracy, false_positive_rate)."
1092+ )
10851093 else :
10861094 return best_predict_param_choice
10871095
@@ -1099,29 +1107,37 @@ def _set_best_predict_param(
10991107 predictions_per_param : NDArray ,
11001108 valid_params_index : List [Any ],
11011109 ):
1102- secondary_risks_per_param = \
1103- self ._get_risks_and_effective_sample_sizes_per_param (
1110+ secondary_risks_per_param , _ = self ._get_risk_values_and_eff_sample_sizes (
11041111 y_calibrate_ ,
11051112 predictions_per_param [valid_params_index ],
1106- self ._best_predict_param_choice
1107- )[:, 0 ]
1113+ [ self ._best_predict_param_choice ]
1114+ )
11081115
11091116 self .best_predict_param = self .valid_predict_params [
11101117 np .argmin (secondary_risks_per_param )
11111118 ]
11121119
11131120 @staticmethod
1114- def _get_risks_and_effective_sample_sizes_per_param (
1121+ def _get_risk_values_and_eff_sample_sizes (
11151122 y_true : NDArray ,
11161123 predictions_per_param : NDArray ,
1117- risk : BinaryClassificationRisk ,
1118- ) -> NDArray :
1119- return np .array (
1120- [risk .get_value_and_effective_sample_size (
1121- y_true ,
1122- predictions
1123- ) for predictions in predictions_per_param ]
1124- )
1124+ risks : List [BinaryClassificationRisk ],
1125+ ) -> Tuple [NDArray , NDArray ]:
1126+ """
1127+ Compute the values of risks and effective sample sizes for multiple risks
1128+ and for multiple parameter values.
1129+ Returns arrays with shape (n_risks, n_params).
1130+ """
1131+ risks_values_and_eff_sizes = np .array ([
1132+ [risk .get_value_and_effective_sample_size (y_true , predictions )
1133+ for predictions in predictions_per_param ]
1134+ for risk in risks
1135+ ])
1136+
1137+ risk_values = risks_values_and_eff_sizes [:, :, 0 ]
1138+ effective_sample_sizes = risks_values_and_eff_sizes [:, :, 1 ]
1139+
1140+ return risk_values , effective_sample_sizes
11251141
11261142 def _get_predictions_per_param (self , X : ArrayLike , params : NDArray ) -> NDArray :
11271143 try :
@@ -1148,3 +1164,42 @@ def _get_predictions_per_param(self, X: ArrayLike, params: NDArray) -> NDArray:
11481164 else :
11491165 raise
11501166 return (predictions_proba [:, np .newaxis ] >= params ).T .astype (int )
1167+
1168+ def _convert_target_level_to_alpha (self , target_level : List [float ]) -> NDArray :
1169+ alpha = []
1170+ for risk , target in zip (self ._risk , target_level ):
1171+ if risk .higher_is_better :
1172+ alpha .append (1 - target )
1173+ else :
1174+ alpha .append (target )
1175+ return np .array (alpha )
1176+
1177+ @staticmethod
1178+ def _check_if_multi_risk_control (
1179+ risk : Union [BinaryClassificationRisk , List [BinaryClassificationRisk ]],
1180+ target_level : Union [float , List [float ]],
1181+ ) -> bool :
1182+ """
1183+ Check if we are in a multi risk setting and if inputs types are correct.
1184+ """
1185+ if (
1186+ isinstance (risk , list ) and isinstance (target_level , list )
1187+ and len (risk ) == len (target_level )
1188+ and len (risk ) > 0
1189+ ):
1190+ if len (risk ) == 1 :
1191+ return False
1192+ else :
1193+ return True
1194+ elif (
1195+ isinstance (risk , BinaryClassificationRisk )
1196+ and isinstance (target_level , float )
1197+ ):
1198+ return False
1199+ else :
1200+ raise ValueError (
1201+ "If you provide a list of risks, you must provide "
1202+ "a list of target levels of the same length and vice versa. "
1203+ "If you provide a single BinaryClassificationRisk risk, "
1204+ "you must provide a single float target level."
1205+ )
0 commit comments