Skip to content

Commit f643450

Browse files
authored
docs: update the docstring for the conformity score methods to align with the typing. (#808)
* update the docstring of sets and BaseClassificationScore : get_predictions, get_conformity_score_quantiles, get_prediction_sets * remove not used args
1 parent 1afa613 commit f643450

File tree

6 files changed

+67
-93
lines changed

6 files changed

+67
-93
lines changed

mapie/conformity_scores/classification.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from typing import Optional, Union
33

44
import numpy as np
5-
6-
from mapie.conformity_scores.interface import BaseConformityScore
5+
from numpy.typing import ArrayLike, NDArray
76
from sklearn.model_selection import BaseCrossValidator
87

9-
from numpy.typing import ArrayLike, NDArray
8+
from mapie.conformity_scores.interface import BaseConformityScore
109

1110

1211
class BaseClassificationScore(BaseConformityScore, metaclass=ABCMeta):
@@ -17,6 +16,12 @@ class BaseClassificationScore(BaseConformityScore, metaclass=ABCMeta):
1716
1817
Attributes
1918
----------
19+
classes: Optional[ArrayLike]
20+
Names of the classes.
21+
22+
random_state: Optional[Union[int, np.random.RandomState]]
23+
Pseudo random number generator state.
24+
2025
quantiles_: ArrayLike of shape (n_alpha)
2126
The quantiles estimated from ``get_sets`` method.
2227
"""
@@ -41,7 +46,7 @@ def set_external_attributes(
4146
4247
By default ``None``.
4348
44-
random_state: Optional[Union[int, RandomState]]
49+
random_state: Optional[Union[int, np.random.RandomState]]
4550
Pseudo random number generator state.
4651
"""
4752
super().set_external_attributes(**kwargs)
@@ -71,8 +76,11 @@ def get_predictions(
7176
NDArray of floats between ``0`` and ``1``, represents the
7277
uncertainty of the confidence set.
7378
74-
estimator: EnsembleClassifier
75-
Estimator that is fitted to predict y from X.
79+
y_pred_proba: NDArray
80+
Predicted probabilities from the estimator.
81+
82+
cv: Optional[Union[int, str, BaseCrossValidator]]
83+
Cross-validation strategy used by the estimator.
7684
7785
Returns
7886
--------
@@ -102,8 +110,8 @@ def get_conformity_score_quantiles(
102110
NDArray of floats between 0 and 1, representing the uncertainty
103111
of the confidence set.
104112
105-
estimator: EnsembleClassifier
106-
Estimator that is fitted to predict y from X.
113+
cv: Optional[Union[int, str, BaseCrossValidator]]
114+
Cross-validation strategy used by the estimator.
107115
108116
Returns
109117
--------
@@ -138,8 +146,8 @@ def get_prediction_sets(
138146
NDArray of floats between 0 and 1, representing the uncertainty
139147
of the confidence set.
140148
141-
estimator: EnsembleClassifier
142-
Estimator that is fitted to predict y from X.
149+
cv: Optional[Union[int, str, BaseCrossValidator]]
150+
Cross-validation strategy used by the estimator.
143151
144152
Returns
145153
--------
@@ -211,12 +219,6 @@ def predict_set(self, X: NDArray, alpha_np: NDArray, **kwargs):
211219
alpha_np: NDArray of shape (n_alpha, )
212220
Represents the uncertainty of the confidence set to produce.
213221
214-
y_pred_proba: NDArray
215-
Predicted probabilities from the estimator.
216-
217-
cv: Optional[Union[int, str, BaseCrossValidator]]
218-
Cross-validation strategy used by the estimator.
219-
220222
**kwargs: dict
221223
Additional keyword arguments.
222224

mapie/conformity_scores/sets/aps.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from typing import Optional, Tuple, Union, cast
22

33
import numpy as np
4-
from sklearn.utils import check_random_state
5-
from sklearn.preprocessing import label_binarize
4+
from numpy.typing import ArrayLike, NDArray
65
from sklearn.model_selection import BaseCrossValidator
6+
from sklearn.preprocessing import label_binarize
7+
from sklearn.utils import check_random_state
78

9+
from mapie._machine_precision import EPSILON
810
from mapie.conformity_scores.sets.naive import NaiveConformityScore
911
from mapie.conformity_scores.sets.utils import check_include_last_label
10-
11-
from mapie._machine_precision import EPSILON
12-
from numpy.typing import ArrayLike, NDArray
1312
from mapie.utils import _compute_quantiles
1413

1514

@@ -30,7 +29,7 @@ class APSConformityScore(NaiveConformityScore):
3029
classes: Optional[ArrayLike]
3130
Names of the classes.
3231
33-
random_state: Optional[Union[int, RandomState]]
32+
random_state: Optional[Union[int, np.random.RandomState]]
3433
Pseudo random number generator state.
3534
3635
quantiles_: ArrayLike of shape (n_alpha)
@@ -65,7 +64,7 @@ def get_predictions(
6564
Predicted probabilities from the estimator.
6665
6766
cv: Optional[Union[int, str, BaseCrossValidator]]
68-
Cross-validation strategy used by the estimator.
67+
Cross-validation strategy used by the estimator (not used here).
6968
7069
agg_scores: Optional[str]
7170
Method to aggregate the scores from the base estimators.
@@ -94,20 +93,20 @@ def get_true_label_cumsum_proba(
9493
9594
Parameters
9695
----------
97-
y: NDArray of shape (n_samples, )
96+
y: ArrayLike of shape (n_samples, )
9897
Array with the labels.
9998
10099
y_pred_proba: NDArray of shape (n_samples, n_classes)
101100
Predictions of the model.
102101
103-
classes: NDArray of shape (n_classes, )
102+
classes: ArrayLike of shape (n_classes, )
104103
Array with the classes.
105104
106105
Returns
107106
-------
108107
Tuple[NDArray, NDArray] of shapes (n_samples, 1) and (n_samples, ).
109108
The first element is the cumsum probability of the true label.
110-
The second is the sorted position of the true label.
109+
The second is the 1-based rank of the true label in the sorted probabilities.
111110
"""
112111
y_true = label_binarize(y=y, classes=classes)
113112
index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1))
@@ -136,7 +135,7 @@ def get_conformity_scores(
136135
y_pred: NDArray of shape (n_samples,)
137136
Predicted target values.
138137
139-
y_enc: NDArray of shape (n_samples,)
138+
y_enc: Optional[NDArray] of shape (n_samples,)
140139
Target values as normalized encodings.
141140
142141
Returns
@@ -225,8 +224,8 @@ def _compute_v_parameter(
225224
y_pred_proba_last: NDArray of shape (n_samples, 1, n_alpha)
226225
Last included probability.
227226
228-
predicition_sets: NDArray of shape (n_samples, n_alpha)
229-
Prediction sets.
227+
prediction_sets: NDArray of shape (n_samples, n_alpha)
228+
Prediction sets (not used here).
230229
231230
Returns
232231
--------
@@ -328,7 +327,7 @@ def get_prediction_sets(
328327
329328
alpha_np: NDArray of shape (n_alpha,)
330329
NDArray of floats between 0 and 1, representing the uncertainty
331-
of the confidence interval.
330+
of the confidence interval (not used here).
332331
333332
cv: Optional[Union[int, str, BaseCrossValidator]]
334333
Cross-validation strategy used by the estimator.

mapie/conformity_scores/sets/lac.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class LACConformityScore(BaseClassificationScore):
2929
classes: Optional[ArrayLike]
3030
Names of the classes.
3131
32-
random_state: Optional[Union[int, RandomState]]
32+
random_state: Optional[Union[int, np.random.RandomState]]
3333
Pseudo random number generator state.
3434
3535
quantiles_: ArrayLike of shape (n_alpha)
@@ -48,7 +48,7 @@ def get_conformity_scores(
4848
Parameters
4949
----------
5050
y: NDArray of shape (n_samples,)
51-
Observed target values.
51+
Observed target values (not used here).
5252
5353
y_pred: NDArray of shape (n_samples,)
5454
Predicted target values.
@@ -94,7 +94,7 @@ def get_predictions(
9494
Predicted probabilities from the estimator.
9595
9696
cv: Optional[Union[int, str, BaseCrossValidator]]
97-
Cross-validation strategy used by the estimator.
97+
Cross-validation strategy used by the estimator (not used here).
9898
9999
agg_scores: Optional[str]
100100
Method to aggregate the scores from the base estimators.

mapie/conformity_scores/sets/naive.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from typing import Tuple, Union, Optional
1+
from typing import Optional, Tuple, Union
22

33
import numpy as np
4-
5-
from mapie.conformity_scores.classification import BaseClassificationScore
6-
from mapie.conformity_scores.sets.utils import get_last_index_included
4+
from numpy.typing import NDArray
75
from sklearn.model_selection import BaseCrossValidator
86

97
from mapie._machine_precision import EPSILON
10-
from numpy.typing import NDArray
8+
from mapie.conformity_scores.classification import BaseClassificationScore
9+
from mapie.conformity_scores.sets.utils import get_last_index_included
1110

1211

1312
class NaiveConformityScore(BaseClassificationScore):
@@ -20,10 +19,10 @@ class NaiveConformityScore(BaseClassificationScore):
2019
classes: Optional[ArrayLike]
2120
Names of the classes.
2221
23-
random_state: Optional[Union[int, RandomState]]
22+
random_state: Optional[Union[int, np.random.RandomState]]
2423
Pseudo random number generator state.
2524
26-
quantiles_: ArrayLike of shape (n_alpha)
25+
quantiles_: ArrayLike of shape (n_alpha,)
2726
The quantiles estimated from ``get_sets`` method.
2827
"""
2928

@@ -37,7 +36,7 @@ def get_conformity_scores(self, y: NDArray, y_pred: NDArray, **kwargs) -> NDArra
3736
Parameters
3837
----------
3938
y: NDArray of shape (n_samples,)
40-
Observed target values.
39+
Observed target values (not used here).
4140
4241
y_pred: NDArray of shape (n_samples,)
4342
Predicted target values.
@@ -97,11 +96,11 @@ def get_conformity_score_quantiles(
9796
Parameters
9897
-----------
9998
conformity_scores: NDArray of shape (n_samples,)
100-
Conformity scores for each sample.
99+
Conformity scores for each sample (not used here).
101100
102101
alpha_np: NDArray of shape (n_alpha,)
103102
NDArray of floats between 0 and 1, representing the uncertainty
104-
of the confidence interval.
103+
of the confidence interval (not used here).
105104
106105
cv: Optional[Union[int, str, BaseCrossValidator]]
107106
Cross-validation strategy used by the estimator (not used here).
@@ -222,11 +221,11 @@ def get_prediction_sets(
222221
Target prediction.
223222
224223
conformity_scores: NDArray of shape (n_samples,)
225-
Conformity scores for each sample.
224+
Conformity scores for each sample (not used here).
226225
227226
alpha_np: NDArray of shape (n_alpha,)
228227
NDArray of floats between 0 and 1, representing the uncertainty
229-
of the confidence interval.
228+
of the confidence interval (not used here).
230229
231230
cv: Optional[Union[int, str, BaseCrossValidator]]
232231
Cross-validation strategy used by the estimator (not used here).

0 commit comments

Comments
 (0)