11from typing import Optional , Tuple , Union , cast
22
33import numpy as np
4- from sklearn .utils import check_random_state
5- from sklearn .preprocessing import label_binarize
4+ from numpy .typing import ArrayLike , NDArray
65from sklearn .model_selection import BaseCrossValidator
6+ from sklearn .preprocessing import label_binarize
7+ from sklearn .utils import check_random_state
78
9+ from mapie ._machine_precision import EPSILON
810from mapie .conformity_scores .sets .naive import NaiveConformityScore
911from mapie .conformity_scores .sets .utils import check_include_last_label
10-
11- from mapie ._machine_precision import EPSILON
12- from numpy .typing import ArrayLike , NDArray
1312from mapie .utils import _compute_quantiles
1413
1514
@@ -30,7 +29,7 @@ class APSConformityScore(NaiveConformityScore):
3029 classes: Optional[ArrayLike]
3130 Names of the classes.
3231
33- random_state: Optional[Union[int, RandomState]]
32+ random_state: Optional[Union[int, np.random. RandomState]]
3433 Pseudo random number generator state.
3534
3635 quantiles_: ArrayLike of shape (n_alpha)
@@ -65,7 +64,7 @@ def get_predictions(
6564 Predicted probabilities from the estimator.
6665
6766 cv: Optional[Union[int, str, BaseCrossValidator]]
68- Cross-validation strategy used by the estimator.
67+ Cross-validation strategy used by the estimator (not used here) .
6968
7069 agg_scores: Optional[str]
7170 Method to aggregate the scores from the base estimators.
@@ -94,20 +93,20 @@ def get_true_label_cumsum_proba(
9493
9594 Parameters
9695 ----------
97- y: NDArray of shape (n_samples, )
96+ y: ArrayLike of shape (n_samples, )
9897 Array with the labels.
9998
10099 y_pred_proba: NDArray of shape (n_samples, n_classes)
101100 Predictions of the model.
102101
103- classes: NDArray of shape (n_classes, )
102+ classes: ArrayLike of shape (n_classes, )
104103 Array with the classes.
105104
106105 Returns
107106 -------
108107 Tuple[NDArray, NDArray] of shapes (n_samples, 1) and (n_samples, ).
109108 The first element is the cumsum probability of the true label.
110- The second is the sorted position of the true label.
109+ The second is the 1-based rank of the true label in the sorted probabilities .
111110 """
112111 y_true = label_binarize (y = y , classes = classes )
113112 index_sorted = np .fliplr (np .argsort (y_pred_proba , axis = 1 ))
@@ -136,7 +135,7 @@ def get_conformity_scores(
136135 y_pred: NDArray of shape (n_samples,)
137136 Predicted target values.
138137
139- y_enc: NDArray of shape (n_samples,)
138+ y_enc: Optional[ NDArray] of shape (n_samples,)
140139 Target values as normalized encodings.
141140
142141 Returns
@@ -225,8 +224,8 @@ def _compute_v_parameter(
225224 y_pred_proba_last: NDArray of shape (n_samples, 1, n_alpha)
226225 Last included probability.
227226
228- predicition_sets : NDArray of shape (n_samples, n_alpha)
229- Prediction sets.
227+ prediction_sets : NDArray of shape (n_samples, n_alpha)
228+ Prediction sets (not used here) .
230229
231230 Returns
232231 --------
@@ -328,7 +327,7 @@ def get_prediction_sets(
328327
329328 alpha_np: NDArray of shape (n_alpha,)
330329 NDArray of floats between 0 and 1, representing the uncertainty
331- of the confidence interval.
330+ of the confidence interval (not used here) .
332331
333332 cv: Optional[Union[int, str, BaseCrossValidator]]
334333 Cross-validation strategy used by the estimator.
0 commit comments