|
1 | | -from typing import Optional, Union, List |
| 1 | +import copy |
| 2 | +from typing import Optional, Union, List, cast |
2 | 3 | from typing_extensions import Self |
3 | 4 |
|
4 | 5 | import numpy as np |
|
15 | 16 | ) |
16 | 17 | from mapie_v1._utils import transform_confidence_level_to_alpha_list, \ |
17 | 18 | check_method_not_naive, check_cv_not_string, hash_X_y, \ |
18 | | - check_if_X_y_different_from_fit |
| 19 | + check_if_X_y_different_from_fit, make_intervals_single_if_single_alpha |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class SplitConformalRegressor: |
@@ -213,10 +214,10 @@ def predict_set( |
213 | 214 | allow_infinite_bounds=allow_infinite_bounds |
214 | 215 | ) |
215 | 216 |
|
216 | | - if len(self._alphas) == 1: |
217 | | - intervals = intervals[:, :, 0] |
218 | | - |
219 | | - return intervals |
| 217 | + return make_intervals_single_if_single_alpha( |
| 218 | + intervals, |
| 219 | + self._alphas |
| 220 | + ) |
220 | 221 |
|
221 | 222 | def predict( |
222 | 223 | self, |
@@ -344,7 +345,8 @@ def __init__( |
344 | 345 | confidence_level |
345 | 346 | ) |
346 | 347 |
|
347 | | - self.hashed_X_y: int = 0 |
| 348 | + self._hashed_X_y: int = 0 |
| 349 | + self._sample_weight: Optional[NDArray] = None |
348 | 350 |
|
349 | 351 | def fit( |
350 | 352 | self, |
@@ -372,13 +374,21 @@ def fit( |
372 | 374 | Self |
373 | 375 | The fitted CrossConformalRegressor instance. |
374 | 376 | """ |
375 | | - self.hashed_X_y = hash_X_y(X, y) |
| 377 | + self._hashed_X_y = hash_X_y(X, y) |
376 | 378 |
|
377 | | - X, y, sample_weight, groups = self._mapie_regressor.init_fit( |
378 | | - X, y, fit_params=fit_params |
| 379 | + if fit_params: |
| 380 | + fit_params_ = copy.deepcopy(fit_params) |
| 381 | + self._sample_weight = fit_params_.pop("sample_weight", None) |
| 382 | + else: |
| 383 | + fit_params_ = {} |
| 384 | + |
| 385 | + X, y, self._sample_weight, groups = self._mapie_regressor.init_fit( |
| 386 | + X, y, self._sample_weight, fit_params=fit_params_ |
379 | 387 | ) |
380 | | - self._mapie_regressor.fit_estimator(X, y, sample_weight, groups) |
381 | 388 |
|
| 389 | + self._mapie_regressor.fit_estimator( |
| 390 | + X, y, self._sample_weight |
| 391 | + ) |
382 | 392 | return self |
383 | 393 |
|
384 | 394 | def conformalize( |
@@ -416,11 +426,15 @@ def conformalize( |
416 | 426 | Self |
417 | 427 | The conformalized SplitConformalRegressor instance. |
418 | 428 | """ |
419 | | - check_if_X_y_different_from_fit(X, y, self.hashed_X_y) |
| 429 | + check_if_X_y_different_from_fit(X, y, self._hashed_X_y) |
| 430 | + groups = cast(Optional[NDArray], groups) |
| 431 | + if not predict_params: |
| 432 | + predict_params = {} |
420 | 433 |
|
421 | 434 | self._mapie_regressor.conformalize( |
422 | 435 | X, |
423 | 436 | y, |
| 437 | + sample_weight=self._sample_weight, |
424 | 438 | groups=groups, |
425 | 439 | predict_params=predict_params |
426 | 440 | ) |
@@ -455,7 +469,19 @@ def predict_set( |
455 | 469 | `(n_samples, 2, n_confidence_levels)` if `confidence_level` is a |
456 | 470 | list of floats. |
457 | 471 | """ |
458 | | - pass |
| 472 | + # TODO: factorize this function once the v0 backend is updated with |
| 473 | + # correct param names |
| 474 | + _, intervals = self._mapie_regressor.predict( |
| 475 | + X, |
| 476 | + alpha=self._alphas, |
| 477 | + optimize_beta=minimize_interval_width, |
| 478 | + allow_infinite_bounds=allow_infinite_bounds |
| 479 | + ) |
| 480 | + |
| 481 | + return make_intervals_single_if_single_alpha( |
| 482 | + intervals, |
| 483 | + self._alphas |
| 484 | + ) |
459 | 485 |
|
460 | 486 | def predict( |
461 | 487 | self, |
@@ -485,7 +511,14 @@ def predict( |
485 | 511 | NDArray |
486 | 512 | Array of point predictions, with shape `(n_samples,)`. |
487 | 513 | """ |
488 | | - pass |
| 514 | + if not aggregation_method: |
| 515 | + ensemble = False |
| 516 | + else: |
| 517 | + ensemble = True |
| 518 | + self._mapie_regressor._check_agg_function(aggregation_method) |
| 519 | + self._mapie_regressor.agg_function = aggregation_method |
| 520 | + |
| 521 | + return self._mapie_regressor.predict(X, alpha=None, ensemble=ensemble) |
489 | 522 |
|
490 | 523 |
|
491 | 524 | class JackknifeAfterBootstrapRegressor: |
|
0 commit comments