Skip to content

Commit fffeb87

Browse files
ENH: implement checks for CrossConformalRegression, make utils.py non-public
1 parent 55d2794 commit fffeb87

File tree

5 files changed

+78
-25
lines changed

5 files changed

+78
-25
lines changed

mapie_v1/_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import warnings
2+
from typing import Union, List
3+
4+
from numpy import array
5+
from numpy.typing import ArrayLike
6+
from sklearn.model_selection import BaseCrossValidator
7+
8+
9+
def transform_confidence_level_to_alpha_list(
10+
confidence_level: Union[float, List[float]]
11+
) -> List[float]:
12+
if isinstance(confidence_level, float):
13+
confidence_levels = [confidence_level]
14+
else:
15+
confidence_levels = confidence_level
16+
return [1 - level for level in confidence_levels]
17+
18+
19+
def check_method_not_naive(method: str) -> None:
20+
if method == "naive":
21+
raise ValueError(
22+
'"naive" method not available in MAPIE >= v1'
23+
)
24+
25+
26+
def check_cv_not_string(cv: Union[int, str, BaseCrossValidator]):
27+
if isinstance(cv, str):
28+
raise ValueError(
29+
"'cv' string options not available in MAPIE >= v1"
30+
)
31+
32+
33+
def hash_X_y(X: ArrayLike, y: ArrayLike) -> int:
34+
# Known issues:
35+
# - the hash calculated with `hash` changes between Python processes
36+
# - two arrays with the same content but different shapes will all have
37+
# the same hash because .tobytes() ignores shape
38+
return hash(array(X).tobytes() + array(y).tobytes())
39+
40+
41+
def check_if_X_y_different_from_fit(
42+
X: ArrayLike,
43+
y: ArrayLike,
44+
previous_X_y_hash: int
45+
) -> None:
46+
if hash_X_y(X, y) != previous_X_y_hash:
47+
warnings.warn(
48+
"You have to use the same X and y in .fit and .conformalize"
49+
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from . import REGRESSION_CONFORMITY_SCORES_STRING_MAP
44

55

6-
def check_and_select_split_conformity_score(
6+
def check_and_select_regression_conformity_score(
77
conformity_score: Union[str, BaseRegressionScore]
88
) -> BaseRegressionScore:
99
if isinstance(conformity_score, BaseRegressionScore):

mapie_v1/integration_tests/tests/test_regression.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
from mapiev0.regression import MapieRegressor as MapieRegressorV0 # noqa
1717

18-
from mapie_v1.conformity_scores.utils import \
19-
check_and_select_split_conformity_score
18+
from mapie_v1.conformity_scores._utils import \
19+
check_and_select_regression_conformity_score
2020
from mapie_v1.integration_tests.utils import (filter_params,
2121
train_test_split_shuffle)
2222
from sklearn.model_selection import LeaveOneOut, GroupKFold
@@ -66,7 +66,7 @@ def test_exact_interval_equality_split(
6666
v0_params = {
6767
"estimator": estimator,
6868
"method": method,
69-
"conformity_score": check_and_select_split_conformity_score(
69+
"conformity_score": check_and_select_regression_conformity_score(
7070
conformity_score
7171
),
7272
"alpha": 1 - confidence_level,

mapie_v1/regression.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from mapie.conformity_scores import BaseRegressionScore
1111
from mapie.regression import MapieRegressor
1212
from mapie.utils import check_estimator_fit_predict
13-
from mapie_v1.conformity_scores.utils import (
14-
check_and_select_split_conformity_score,
13+
from mapie_v1.conformity_scores._utils import (
14+
check_and_select_regression_conformity_score,
1515
)
16-
from mapie_v1.utils import transform_confidence_level_to_alpha_list
16+
from mapie_v1._utils import transform_confidence_level_to_alpha_list, \
17+
check_method_not_naive, check_cv_not_string, hash_X_y, \
18+
check_if_X_y_different_from_fit
1719

1820

1921
class SplitConformalRegressor:
@@ -87,7 +89,7 @@ def __init__(
8789
check_estimator_fit_predict(estimator)
8890
self._estimator = estimator
8991
self._prefit = prefit
90-
self._conformity_score = check_and_select_split_conformity_score(
92+
self._conformity_score = check_and_select_regression_conformity_score(
9193
conformity_score)
9294

9395
# Note to developers: to implement this v1 class without touching the
@@ -323,20 +325,27 @@ def __init__(
323325
verbose: int = 0,
324326
random_state: Optional[Union[int, np.random.RandomState]] = None
325327
) -> None:
326-
self.mapie_regressor = MapieRegressor(
327-
estimator=self.estimator,
328+
check_method_not_naive(method)
329+
check_cv_not_string(cv)
330+
331+
self._mapie_regressor = MapieRegressor(
332+
estimator=estimator,
328333
method=method,
329334
cv=cv,
330335
n_jobs=n_jobs,
331336
verbose=verbose,
332-
conformity_score=self.conformity_score,
337+
conformity_score=check_and_select_regression_conformity_score(
338+
conformity_score
339+
),
333340
random_state=random_state,
334341
)
335342

336343
self._alphas = transform_confidence_level_to_alpha_list(
337344
confidence_level
338345
)
339346

347+
self.hashed_X_y: int = 0
348+
340349
def fit(
341350
self,
342351
X: ArrayLike,
@@ -363,10 +372,14 @@ def fit(
363372
Self
364373
The fitted CrossConformalRegressor instance.
365374
"""
366-
X, y, sample_weight, groups = self.init_fit(
375+
self.hashed_X_y = hash_X_y(X, y)
376+
377+
X, y, sample_weight, groups = self._mapie_regressor.init_fit(
367378
X, y, fit_params=fit_params
368379
)
369-
self.mapie_regressor.fit_estimator(X, y, sample_weight, groups)
380+
self._mapie_regressor.fit_estimator(X, y, sample_weight, groups)
381+
382+
return self
370383

371384
def conformalize(
372385
self,
@@ -403,7 +416,9 @@ def conformalize(
403416
Self
404417
The conformalized SplitConformalRegressor instance.
405418
"""
406-
self.mapie_regressor.conformalize(
419+
check_if_X_y_different_from_fit(X, y, self.hashed_X_y)
420+
421+
self._mapie_regressor.conformalize(
407422
X,
408423
y,
409424
groups=groups,

mapie_v1/utils.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)