Skip to content

Commit 0fc8977

Browse files
jawadhussein462Valentin-Laurent
authored andcommitted
ENH: improve jackknife docstring (#561)
ENH: improve jackknife docstring
1 parent e931b06 commit 0fc8977

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

mapie_v1/integration_tests/tests/test_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def test_intervals_and_predictions_exact_equality_cross(params_cross):
209209
"aggregation_method": "median",
210210
"method": "plus",
211211
"fit_params": {"sample_weight": sample_weight},
212+
"ensemble": True,
212213
"random_state": RANDOM_STATE,
213214
},
214215
},
@@ -218,7 +219,6 @@ def test_intervals_and_predictions_exact_equality_cross(params_cross):
218219
"alpha": [0.5, 0.5],
219220
"conformity_score": GammaConformityScore(),
220221
"agg_function": "mean",
221-
"ensemble": True,
222222
"cv": Subsample(n_resamplings=20,
223223
replace=True,
224224
random_state=RANDOM_STATE),
@@ -256,6 +256,7 @@ def test_intervals_and_predictions_exact_equality_cross(params_cross):
256256
),
257257
"method": "minmax",
258258
"aggregation_method": "mean",
259+
"ensemble": True,
259260
"allow_infinite_bounds": True,
260261
"random_state": RANDOM_STATE,
261262
}

mapie_v1/regression.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class SplitConformalRegressor:
4545
The conformity score method used to calculate the conformity error.
4646
Valid options: see keys and values of the dictionnary
4747
:py:const:`mapie_v1.conformity_scores.REGRESSION_CONFORMITY_SCORES_STRING_MAP`.
48-
See: TODO : reference conformity score classes or documentation
48+
See :doc:`theoretical_description_conformity_scores`
4949
5050
A custom score function inheriting from BaseRegressionScore may also
5151
be provided.
@@ -267,7 +267,7 @@ class CrossConformalRegressor:
267267
The conformity score method used to calculate the conformity error.
268268
Valid options: TODO : reference here the valid options, once the list
269269
has been be created during the implementation.
270-
See: TODO : reference conformity score classes or documentation
270+
See :doc:`theoretical_description_conformity_scores`
271271
272272
A custom score function inheriting from BaseRegressionScore may also be
273273
provided.
@@ -561,22 +561,29 @@ class JackknifeAfterBootstrapRegressor:
561561
The conformity score method used to calculate the conformity error.
562562
Valid options: TODO : reference here the valid options, once the list
563563
has been be created during the implementation.
564-
See: TODO : reference conformity score classes or documentation
564+
See :doc:`theoretical_description_conformity_scores`
565565
566566
A custom score function inheriting from BaseRegressionScore may also
567567
be provided.
568568
569569
method : str, default="plus"
570570
The method used for jackknife-after-bootstrap prediction. Options are:
571-
- "base": Based on the conformity scores from each bootstrap sample.
572571
- "plus": Based on the conformity scores from each bootstrap sample and
573572
the testing prediction.
574573
- "minmax": Based on the minimum and maximum conformity scores from
575574
each bootstrap sample.
576575
577-
n_bootstraps : int, default=100
578-
The number of bootstrap resamples to generate for the
579-
jackknife-after-bootstrap procedure.
576+
Note: The "base" method is not mentioned in the conformal inference
577+
literature for Jackknife after bootstrap strategies, hence not provided
578+
here.
579+
580+
resampling : Union[int, Subsample], default=30
581+
Number of bootstrap resamples or an instance of `Subsample` for
582+
custom resampling strategy.
583+
584+
aggregation_method : str, default="mean"
585+
Aggregation method for predictions across bootstrap samples.
586+
Options: ["mean", "median"].
580587
581588
n_jobs : Optional[int], default=None
582589
The number of jobs to run in parallel when applicable.
@@ -599,7 +606,10 @@ class JackknifeAfterBootstrapRegressor:
599606
Examples
600607
--------
601608
>>> regressor = JackknifeAfterBootstrapRegressor(
602-
... estimator=LinearRegression(), confidence_level=0.9, n_bootstraps=8)
609+
... estimator=LinearRegression(),
610+
... confidence_level=0.9,
611+
... resampling=8,
612+
... aggregation_method="mean")
603613
>>> regressor.fit(X_train, y_train)
604614
>>> regressor.conformalize(X_conf, y_conf)
605615
>>> intervals = regressor.predict_set(X_test)
@@ -763,14 +773,18 @@ def predict_set(
763773
X : ArrayLike
764774
Test data for prediction intervals.
765775
776+
minimize_interval_width : bool, default=False
777+
If True, minimizes the width of prediction intervals while
778+
maintaining coverage.
779+
766780
allow_infinite_bounds : bool, default=False
767781
If True, allows intervals to include infinite bounds
768782
if required for coverage.
769783
770784
Returns
771785
-------
772786
NDArray
773-
Prediction intervals of shape `(n_samples, 2)`,
787+
Prediction intervals of shape (n_samples, 2),
774788
with lower and upper bounds for each sample.
775789
"""
776790
_, intervals = self._mapie_regressor.predict(
@@ -788,6 +802,7 @@ def predict_set(
788802
def predict(
789803
self,
790804
X: ArrayLike,
805+
ensemble: bool = False,
791806
) -> NDArray:
792807
"""
793808
Generates point predictions for the input data using the fitted model,
@@ -798,13 +813,20 @@ def predict(
798813
X : ArrayLike
799814
Data features for generating point predictions.
800815
816+
ensemble : bool, default=False
817+
If True, aggregates predictions across models fitted on each
818+
bootstrap samples, this is using the aggregation method defined
819+
during the initialization of the model.
820+
If False, returns predictions from the estimator trained on the
821+
entire dataset.
822+
801823
Returns
802824
-------
803825
NDArray
804826
Array of point predictions, with shape `(n_samples,)`.
805827
"""
806828
predictions = self._mapie_regressor.predict(
807-
X, alpha=None, ensemble=True
829+
X, alpha=None, ensemble=ensemble
808830
)
809831
return cast_point_predictions_to_ndarray(predictions)
810832

0 commit comments

Comments
 (0)