Skip to content

Commit 86ca3e1

Browse files
DOC: add groups in .conformalize for CrossConformalRegression, improve docstrings (#543)
1 parent eba5939 commit 86ca3e1

File tree

1 file changed

+52
-88
lines changed

1 file changed

+52
-88
lines changed

mapie_v1/regression.py

Lines changed: 52 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,7 @@ def conformalize(
165165
Returns
166166
-------
167167
Self
168-
The SplitConformalRegressor instance with updated prediction
169-
intervals.
168+
The conformalized SplitConformalRegressor instance.
170169
"""
171170
predict_params = {} if predict_params is None else predict_params
172171
self.mapie_regressor.fit(X_conf,
@@ -241,8 +240,9 @@ class CrossConformalRegressor:
241240
"""
242241
A conformal regression model using cross-conformal prediction to generate
243242
prediction intervals with statistical guarantees. This method involves
244-
cross-validation with conformity scoring across multiple folds to determine
245-
prediction intervals around point predictions from a base regressor.
243+
computing conformity scoring across multiple folds in a cross-validation
244+
fashion to determine prediction intervals around point predictions from a
245+
base regressor.
246246
247247
Parameters
248248
----------
@@ -265,17 +265,19 @@ class CrossConformalRegressor:
265265
provided.
266266
267267
method : str, default="plus"
268-
The method used for cross-conformal prediction. Options are:
268+
The method used to compute prediction intervals. Options are:
269269
- "base": Based on the conformity scores from each fold.
270-
- "plus": Based on the conformity scores from each fold plus
271-
the testing prediction.
272-
- "minmax": Based on the minimum and maximum conformity scores
273-
from each fold.
270+
- "plus": Based on the conformity scores from each fold and
271+
the test set predictions.
272+
- "minmax": Based on the conformity scores from each fold and
273+
the test set predictions, using the minimum and maximum among
274+
each fold models.
274275
275276
cv : Union[int, BaseCrossValidator], default=5
276-
The cross-validation splitting strategy. If an integer is passed, it is
277-
the number of folds for `KFold` cross-validation. Alternatively, a
278-
specific cross-validation splitter from scikit-learn can be provided.
277+
The cross-validation strategy used to compute confomity scores. If an
278+
integer is passed, it is the number of folds for `KFold`
279+
cross-validation. Alternatively, a BaseCrossValidator from scikit-learn
280+
can be provided. Valid options:
279281
TODO : reference here the valid options,
280282
once the list has been be created during the implementation
281283
@@ -290,25 +292,6 @@ class CrossConformalRegressor:
290292
A seed or random state instance to ensure reproducibility in any random
291293
operations within the regressor.
292294
293-
Methods
294-
-------
295-
fit(X_train, y_train, fit_params=None) -> Self
296-
Fits the base estimator to the training data using cross-validation.
297-
298-
conformalize(X_conf, y_conf, predict_params=None) -> Self
299-
Calibrates the model using cross-validation and updates the prediction
300-
intervals based on conformity errors observed across folds.
301-
302-
predict(X, aggregation_method=None) -> NDArray
303-
Generates point predictions for the input data `X` using the specified
304-
aggregation method across the cross-validation folds.
305-
306-
predict_set(X, minimize_interval_width=False, allow_infinite_bounds=False)
307-
-> NDArray
308-
Generates prediction intervals for the input data `X` based on the
309-
conformity score and confidence level, adjusted to achieve the desired
310-
coverage probability.
311-
312295
Returns
313296
-------
314297
NDArray
@@ -317,18 +300,12 @@ class CrossConformalRegressor:
317300
- `(n_samples, 2, n_confidence_levels)` if `confidence_level`
318301
is a list of floats.
319302
320-
Notes
321-
-----
322-
Cross-conformal prediction provides enhanced robustness through the
323-
aggregation of multiple conformal scores across cross-validation folds,
324-
potentially yielding tighter intervals with reliable coverage guarantees.
325-
326303
Examples
327304
--------
328305
>>> regressor = CrossConformalRegressor(
329306
... estimator=LinearRegression(), confidence_level=0.95, cv=10)
330-
>>> regressor.fit(X_train, y_train)
331-
>>> regressor.conformalize(X_conf, y_conf)
307+
>>> regressor.fit(X, y)
308+
>>> regressor.conformalize(X, y)
332309
>>> intervals = regressor.predict_set(X_test)
333310
"""
334311

@@ -347,20 +324,20 @@ def __init__(
347324

348325
def fit(
349326
self,
350-
X_train: ArrayLike,
351-
y_train: ArrayLike,
327+
X: ArrayLike,
328+
y: ArrayLike,
352329
fit_params: Optional[dict] = None,
353330
) -> Self:
354331
"""
355-
Fits the base estimator to the training data using cross-validation.
332+
Fits the base estimator using the entire dataset provided.
356333
357334
Parameters
358335
----------
359-
X_train : ArrayLike
360-
Training data features.
336+
X : ArrayLike
337+
Features
361338
362-
y_train : ArrayLike
363-
Training data targets.
339+
y : ArrayLike
340+
Targets
364341
365342
fit_params : Optional[dict], default=None
366343
Additional parameters to pass to the `fit` method
@@ -375,23 +352,29 @@ def fit(
375352

376353
def conformalize(
377354
self,
378-
X_conf: ArrayLike,
379-
y_conf: ArrayLike,
355+
X: ArrayLike,
356+
y: ArrayLike,
357+
groups: Optional[ArrayLike] = None,
380358
predict_params: Optional[dict] = None,
381359
) -> Self:
382360
"""
383-
Calibrates the fitted model using cross-validation conformal folds.
384-
This step analyzes conformity scores across multiple cross-validation
385-
folds and adjusts the prediction intervals based on conformity errors
386-
and specified confidence levels.
361+
Computes conformity scores in a cross conformal fashion, allowing to
362+
predict intervals later on.
387363
388364
Parameters
389365
----------
390-
X_conf : ArrayLike
366+
X : ArrayLike
391367
Features for generating conformity scores across folds.
368+
Must be the same X used in .fit
392369
393-
y_conf : ArrayLike
370+
y : ArrayLike
394371
Target values for generating conformity scores across folds.
372+
Must be the same y used in .fit
373+
374+
groups: Optional[ArrayLike] of shape (n_samples,)
375+
Group labels for the samples used while splitting the dataset into
376+
train/conformity set.
377+
By default ``None``.
395378
396379
predict_params : Optional[dict], default=None
397380
Additional parameters for generating predictions
@@ -400,8 +383,7 @@ def conformalize(
400383
Returns
401384
-------
402385
Self
403-
The CrossConformalRegressor instance with calibrated prediction
404-
intervals based on cross-validated conformity scores.
386+
The conformalized SplitConformalRegressor instance.
405387
"""
406388
pass
407389

@@ -412,8 +394,8 @@ def predict_set(
412394
allow_infinite_bounds: bool = False,
413395
) -> NDArray:
414396
"""
415-
Generates prediction intervals for the input data `X` based on the
416-
calibrated model and cross-conformal prediction framework.
397+
Generates prediction intervals for the input data `X` based on
398+
conformity scores and confidence level(s).
417399
418400
Parameters
419401
----------
@@ -430,10 +412,10 @@ def predict_set(
430412
Returns
431413
-------
432414
NDArray
433-
Prediction intervals with shape
434-
- `(n_samples, 2)` if `confidence_level`is a single float,
435-
- `(n_samples, 2, n_confidence_levels)` if multiple confidence
436-
levels are specified.
415+
An array containing the prediction intervals with shape
416+
`(n_samples, 2)` if `confidence_level` is a single float, or
417+
`(n_samples, 2, n_confidence_levels)` if `confidence_level` is a
418+
list of floats.
437419
"""
438420
pass
439421

@@ -443,9 +425,10 @@ def predict(
443425
aggregation_method: Optional[str] = None,
444426
) -> NDArray:
445427
"""
446-
Generates point predictions for the input data `X` using the
447-
fitted model. Optionally aggregates predictions across cross-
448-
validation folds models.
428+
Generates point predictions for the input data `X`:
429+
- using the model fitted on the entire dataset
430+
- or if aggregation_method is provided, aggregating predictions from
431+
the models fitted on each fold
449432
450433
Parameters
451434
----------
@@ -518,25 +501,6 @@ class JackknifeAfterBootstrapRegressor:
518501
A seed or random state instance to ensure reproducibility in any random
519502
operations within the regressor.
520503
521-
Methods
522-
-------
523-
fit(X_train, y_train, fit_params=None) -> Self
524-
Fits the base estimator to the training data and initializes internal
525-
parameters required for the jackknife-after-bootstrap process.
526-
527-
conformalize(X_conf, y_conf, predict_params=None) -> Self
528-
Calibrates the model on provided data using the
529-
jackknife-after-bootstrap approach, adjusting the prediction intervals
530-
based on the observed conformity scores.
531-
532-
predict(X, aggregation_method="mean") -> NDArray
533-
Generates point predictions for the input data `X` using the specified
534-
aggregation method over bootstrap samples.
535-
536-
predict_set(X, allow_infinite_bounds=False) -> NDArray
537-
Generates prediction intervals for the input data `X` based on the
538-
calibrated jackknife-after-bootstrap predictions.
539-
540504
Returns
541505
-------
542506
NDArray
@@ -568,19 +532,19 @@ def __init__(
568532

569533
def fit(
570534
self,
571-
X_train: ArrayLike,
572-
y_train: ArrayLike,
535+
X: ArrayLike,
536+
y: ArrayLike,
573537
fit_params: Optional[dict] = None,
574538
) -> Self:
575539
"""
576540
Fits the base estimator to the training data.
577541
578542
Parameters
579543
----------
580-
X_train : ArrayLike
544+
X : ArrayLike
581545
Training data features.
582546
583-
y_train : ArrayLike
547+
y : ArrayLike
584548
Training data targets.
585549
586550
fit_params : Optional[dict], default=None

0 commit comments

Comments
 (0)