@@ -45,7 +45,7 @@ class SplitConformalRegressor:
4545 The conformity score method used to calculate the conformity error.
4646 Valid options: see keys and values of the dictionnary
4747:py:const:`mapie_v1.conformity_scores.REGRESSION_CONFORMITY_SCORES_STRING_MAP`.
48- See: TODO : reference conformity score classes or documentation
48+ See :doc:`theoretical_description_conformity_scores`
4949
5050 A custom score function inheriting from BaseRegressionScore may also
5151 be provided.
@@ -267,7 +267,7 @@ class CrossConformalRegressor:
267267 The conformity score method used to calculate the conformity error.
268268 Valid options: TODO : reference here the valid options, once the list
269269 has been be created during the implementation.
270- See: TODO : reference conformity score classes or documentation
270+ See :doc:`theoretical_description_conformity_scores`
271271
272272 A custom score function inheriting from BaseRegressionScore may also be
273273 provided.
@@ -561,22 +561,29 @@ class JackknifeAfterBootstrapRegressor:
561561 The conformity score method used to calculate the conformity error.
562562 Valid options: TODO : reference here the valid options, once the list
563563 has been be created during the implementation.
564- See: TODO : reference conformity score classes or documentation
564+ See :doc:`theoretical_description_conformity_scores`
565565
566566 A custom score function inheriting from BaseRegressionScore may also
567567 be provided.
568568
569569 method : str, default="plus"
570570 The method used for jackknife-after-bootstrap prediction. Options are:
571- - "base": Based on the conformity scores from each bootstrap sample.
572571 - "plus": Based on the conformity scores from each bootstrap sample and
573572 the testing prediction.
574573 - "minmax": Based on the minimum and maximum conformity scores from
575574 each bootstrap sample.
576575
577- n_bootstraps : int, default=100
578- The number of bootstrap resamples to generate for the
579- jackknife-after-bootstrap procedure.
576+ Note: The "base" method is not mentioned in the conformal inference
577+ literature for Jackknife after bootstrap strategies, hence not provided
578+ here.
579+
580+ resampling : Union[int, Subsample], default=30
581+ Number of bootstrap resamples or an instance of `Subsample` for
582+ custom resampling strategy.
583+
584+ aggregation_method : str, default="mean"
585+ Aggregation method for predictions across bootstrap samples.
586+ Options: ["mean", "median"].
580587
581588 n_jobs : Optional[int], default=None
582589 The number of jobs to run in parallel when applicable.
@@ -599,7 +606,10 @@ class JackknifeAfterBootstrapRegressor:
599606 Examples
600607 --------
601608 >>> regressor = JackknifeAfterBootstrapRegressor(
602- ... estimator=LinearRegression(), confidence_level=0.9, n_bootstraps=8)
609+ ... estimator=LinearRegression(),
610+ ... confidence_level=0.9,
611+ ... resampling=8,
612+ ... aggregation_method="mean")
603613 >>> regressor.fit(X_train, y_train)
604614 >>> regressor.conformalize(X_conf, y_conf)
605615 >>> intervals = regressor.predict_set(X_test)
@@ -763,14 +773,18 @@ def predict_set(
763773 X : ArrayLike
764774 Test data for prediction intervals.
765775
776+ minimize_interval_width : bool, default=False
777+ If True, minimizes the width of prediction intervals while
778+ maintaining coverage.
779+
766780 allow_infinite_bounds : bool, default=False
767781 If True, allows intervals to include infinite bounds
768782 if required for coverage.
769783
770784 Returns
771785 -------
772786 NDArray
773- Prediction intervals of shape ` (n_samples, 2)` ,
787+ Prediction intervals of shape (n_samples, 2),
774788 with lower and upper bounds for each sample.
775789 """
776790 _ , intervals = self ._mapie_regressor .predict (
@@ -788,6 +802,7 @@ def predict_set(
788802 def predict (
789803 self ,
790804 X : ArrayLike ,
805+ ensemble : bool = False ,
791806 ) -> NDArray :
792807 """
793808 Generates point predictions for the input data using the fitted model,
@@ -798,13 +813,20 @@ def predict(
798813 X : ArrayLike
799814 Data features for generating point predictions.
800815
816+ ensemble : bool, default=False
817+ If True, aggregates predictions across models fitted on each
818+ bootstrap samples, this is using the aggregation method defined
819+ during the initialization of the model.
820+ If False, returns predictions from the estimator trained on the
821+ entire dataset.
822+
801823 Returns
802824 -------
803825 NDArray
804826 Array of point predictions, with shape `(n_samples,)`.
805827 """
806828 predictions = self ._mapie_regressor .predict (
807- X , alpha = None , ensemble = True
829+ X , alpha = None , ensemble = ensemble
808830 )
809831 return cast_point_predictions_to_ndarray (predictions )
810832
0 commit comments