Skip to content

Commit 211ce77

Browse files
ENH: fix and finish CrossConformalRegressor implementation
1 parent 540bb92 commit 211ce77

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

mapie_v1/_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Union, List
33

44
from numpy import array
5-
from numpy.typing import ArrayLike
5+
from mapie._typing import ArrayLike, NDArray
66
from sklearn.model_selection import BaseCrossValidator
77

88

@@ -47,3 +47,12 @@ def check_if_X_y_different_from_fit(
4747
warnings.warn(
4848
"You have to use the same X and y in .fit and .conformalize"
4949
)
50+
51+
52+
def make_intervals_single_if_single_alpha(
53+
intervals: NDArray,
54+
alphas: List[float]
55+
) -> NDArray:
56+
if len(alphas) == 1:
57+
return intervals[:, :, 0]
58+
return intervals

mapie_v1/regression.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional, Union, List
1+
import copy
2+
from typing import Optional, Union, List, cast
23
from typing_extensions import Self
34

45
import numpy as np
@@ -15,7 +16,7 @@
1516
)
1617
from mapie_v1._utils import transform_confidence_level_to_alpha_list, \
1718
check_method_not_naive, check_cv_not_string, hash_X_y, \
18-
check_if_X_y_different_from_fit
19+
check_if_X_y_different_from_fit, make_intervals_single_if_single_alpha
1920

2021

2122
class SplitConformalRegressor:
@@ -213,10 +214,10 @@ def predict_set(
213214
allow_infinite_bounds=allow_infinite_bounds
214215
)
215216

216-
if len(self._alphas) == 1:
217-
intervals = intervals[:, :, 0]
218-
219-
return intervals
217+
return make_intervals_single_if_single_alpha(
218+
intervals,
219+
self._alphas
220+
)
220221

221222
def predict(
222223
self,
@@ -344,7 +345,8 @@ def __init__(
344345
confidence_level
345346
)
346347

347-
self.hashed_X_y: int = 0
348+
self._hashed_X_y: int = 0
349+
self._sample_weight: Optional[NDArray] = None
348350

349351
def fit(
350352
self,
@@ -372,13 +374,21 @@ def fit(
372374
Self
373375
The fitted CrossConformalRegressor instance.
374376
"""
375-
self.hashed_X_y = hash_X_y(X, y)
377+
self._hashed_X_y = hash_X_y(X, y)
376378

377-
X, y, sample_weight, groups = self._mapie_regressor.init_fit(
378-
X, y, fit_params=fit_params
379+
if fit_params:
380+
fit_params_ = copy.deepcopy(fit_params)
381+
self._sample_weight = fit_params_.pop("sample_weight", None)
382+
else:
383+
fit_params_ = {}
384+
385+
X, y, self._sample_weight, groups = self._mapie_regressor.init_fit(
386+
X, y, self._sample_weight, fit_params=fit_params_
379387
)
380-
self._mapie_regressor.fit_estimator(X, y, sample_weight, groups)
381388

389+
self._mapie_regressor.fit_estimator(
390+
X, y, self._sample_weight
391+
)
382392
return self
383393

384394
def conformalize(
@@ -416,11 +426,15 @@ def conformalize(
416426
Self
417427
The conformalized SplitConformalRegressor instance.
418428
"""
419-
check_if_X_y_different_from_fit(X, y, self.hashed_X_y)
429+
check_if_X_y_different_from_fit(X, y, self._hashed_X_y)
430+
groups = cast(Optional[NDArray], groups)
431+
if not predict_params:
432+
predict_params = {}
420433

421434
self._mapie_regressor.conformalize(
422435
X,
423436
y,
437+
sample_weight=self._sample_weight,
424438
groups=groups,
425439
predict_params=predict_params
426440
)
@@ -455,7 +469,19 @@ def predict_set(
455469
`(n_samples, 2, n_confidence_levels)` if `confidence_level` is a
456470
list of floats.
457471
"""
458-
pass
472+
# TODO: factorize this function once the v0 backend is updated with
473+
# correct param names
474+
_, intervals = self._mapie_regressor.predict(
475+
X,
476+
alpha=self._alphas,
477+
optimize_beta=minimize_interval_width,
478+
allow_infinite_bounds=allow_infinite_bounds
479+
)
480+
481+
return make_intervals_single_if_single_alpha(
482+
intervals,
483+
self._alphas
484+
)
459485

460486
def predict(
461487
self,
@@ -485,7 +511,14 @@ def predict(
485511
NDArray
486512
Array of point predictions, with shape `(n_samples,)`.
487513
"""
488-
pass
514+
if not aggregation_method:
515+
ensemble = False
516+
else:
517+
ensemble = True
518+
self._mapie_regressor._check_agg_function(aggregation_method)
519+
self._mapie_regressor.agg_function = aggregation_method
520+
521+
return self._mapie_regressor.predict(X, alpha=None, ensemble=ensemble)
489522

490523

491524
class JackknifeAfterBootstrapRegressor:

0 commit comments

Comments
 (0)