11from __future__ import annotations
2- from typing import Optional , Union , Dict , Tuple , Type
2+ from typing import Optional , Union , Dict , Type
33
44import numpy as np
55import pytest
99from sklearn .linear_model import LinearRegression
1010from sklearn .ensemble import RandomForestRegressor
1111from sklearn .linear_model import QuantileRegressor
12- from lightgbm import LGBMRegressor
12+ from sklearn . ensemble import GradientBoostingRegressor
1313
1414from mapie .subsample import Subsample
1515from mapie ._typing import ArrayLike
@@ -109,16 +109,17 @@ def test_intervals_and_predictions_exact_equality_split(
109109 "random_state" : RANDOM_STATE ,
110110 }
111111
112- v0 , v1 = select_models_by_strategy (cv )
113- compare_model_predictions_and_intervals (model_v0 = v0 ,
114- model_v1 = v1 ,
115- X = X_split ,
116- y = y_split ,
117- v0_params = v0_params ,
118- v1_params = v1_params ,
119- test_size = test_size ,
120- random_state = RANDOM_STATE ,
121- prefit = prefit )
112+ compare_model_predictions_and_intervals (
113+ model_v0 = MapieRegressorV0 ,
114+ model_v1 = SplitConformalRegressor ,
115+ X = X_split ,
116+ y = y_split ,
117+ v0_params = v0_params ,
118+ v1_params = v1_params ,
119+ test_size = test_size ,
120+ prefit = prefit ,
121+ random_state = RANDOM_STATE ,
122+ )
122123
123124
124125params_test_cases_cross = [
@@ -185,11 +186,16 @@ def test_intervals_and_predictions_exact_equality_split(
185186
186187@pytest .mark .parametrize ("params_cross" , params_test_cases_cross )
187188def test_intervals_and_predictions_exact_equality_cross (params_cross ):
188- v0_params = params_cross ["v0" ]
189- v1_params = params_cross ["v1" ]
190189
191- v0 , v1 = select_models_by_strategy ("cross" )
192- compare_model_predictions_and_intervals (v0 , v1 , X , y , v0_params , v1_params )
190+ compare_model_predictions_and_intervals (
191+ model_v0 = MapieRegressorV0 ,
192+ model_v1 = CrossConformalRegressor ,
193+ X = X ,
194+ y = y ,
195+ v0_params = params_cross ["v0" ],
196+ v1_params = params_cross ["v1" ],
197+ random_state = RANDOM_STATE ,
198+ )
193199
194200
195201params_test_cases_jackknife = [
@@ -268,28 +274,37 @@ def test_intervals_and_predictions_exact_equality_cross(params_cross):
268274]
269275
270276
277+ @pytest .mark .parametrize ("params_jackknife" , params_test_cases_jackknife )
278+ def test_intervals_and_predictions_exact_equality_jackknife (params_jackknife ):
279+
280+ compare_model_predictions_and_intervals (
281+ model_v0 = MapieRegressorV0 ,
282+ model_v1 = JackknifeAfterBootstrapRegressor ,
283+ X = X ,
284+ y = y ,
285+ v0_params = params_jackknife ["v0" ],
286+ v1_params = params_jackknife ["v1" ],
287+ random_state = RANDOM_STATE ,
288+ )
289+
290+
271291split_model = QuantileRegressor (
272292 solver = "highs-ds" ,
273293 alpha = 0.0 ,
274294 )
275295
276- lgbm_models = []
277- lgbm_alpha = 0.1
278- for alpha_ in [lgbm_alpha / 2 , (1 - (lgbm_alpha / 2 )), 0.5 ]:
279- estimator_ = LGBMRegressor (
280- objective = 'quantile' ,
296+ gbr_models = []
297+ gbr_alpha = 0.1
298+
299+ for alpha_ in [gbr_alpha / 2 , (1 - (gbr_alpha / 2 )), 0.5 ]:
300+ estimator_ = GradientBoostingRegressor (
301+ loss = 'quantile' ,
281302 alpha = alpha_ ,
303+ n_estimators = 100 ,
304+ learning_rate = 0.1 ,
305+ max_depth = 3
282306 )
283- lgbm_models .append (estimator_ )
284-
285-
286- @pytest .mark .parametrize ("params_jackknife" , params_test_cases_jackknife )
287- def test_intervals_and_predictions_exact_equality_jackknife (params_jackknife ):
288- v0_params = params_jackknife ["v0" ]
289- v1_params = params_jackknife ["v1" ]
290-
291- v0 , v1 = select_models_by_strategy ("jackknife" )
292- compare_model_predictions_and_intervals (v0 , v1 , X , y , v0_params , v1_params )
307+ gbr_models .append (estimator_ )
293308
294309
295310params_test_cases_quantile = [
@@ -312,8 +327,7 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
312327 },
313328 {
314329 "v0" : {
315- "estimator" : lgbm_models ,
316- "alpha" : lgbm_alpha ,
330+ "estimator" : gbr_models ,
317331 "cv" : "prefit" ,
318332 "method" : "quantile" ,
319333 "calib_size" : 0.3 ,
@@ -322,8 +336,7 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
322336 "random_state" : RANDOM_STATE ,
323337 },
324338 "v1" : {
325- "estimator" : lgbm_models ,
326- "confidence_level" : 1 - lgbm_alpha ,
339+ "estimator" : gbr_models ,
327340 "prefit" : True ,
328341 "test_size" : 0.3 ,
329342 "fit_params" : {"sample_weight" : sample_weight },
@@ -378,58 +391,17 @@ def test_intervals_and_predictions_exact_equality_quantile(params_quantile):
378391 test_size = v1_params ["test_size" ] if "test_size" in v1_params else None
379392 prefit = ("prefit" in v1_params ) and v1_params ["prefit" ]
380393
381- v0 , v1 = select_models_by_strategy ("quantile" )
382- compare_model_predictions_and_intervals (model_v0 = v0 ,
383- model_v1 = v1 ,
384- X = X ,
385- y = y ,
386- v0_params = v0_params ,
387- v1_params = v1_params ,
388- test_size = test_size ,
389- prefit = prefit ,
390- random_state = RANDOM_STATE )
391-
392-
393- def select_models_by_strategy (
394- strategy_key : str
395- ) -> Tuple [
396- Type [Union [MapieRegressorV0 , MapieQuantileRegressorV0 ]],
397- Type [Union [
398- SplitConformalRegressor ,
399- CrossConformalRegressor ,
400- JackknifeAfterBootstrapRegressor ,
401- ConformalizedQuantileRegressor
402- ]]
403- ]:
404-
405- model_v0 : Type [Union [MapieRegressorV0 , MapieQuantileRegressorV0 ]]
406- model_v1 : Type [Union [
407- SplitConformalRegressor ,
408- CrossConformalRegressor ,
409- JackknifeAfterBootstrapRegressor ,
410- ConformalizedQuantileRegressor
411- ]]
412-
413- if strategy_key in ["split" , "prefit" ]:
414- model_v1 = SplitConformalRegressor
415- model_v0 = MapieRegressorV0
416-
417- elif strategy_key == "cross" :
418- model_v1 = CrossConformalRegressor
419- model_v0 = MapieRegressorV0
420-
421- elif strategy_key == "jackknife" :
422- model_v1 = JackknifeAfterBootstrapRegressor
423- model_v0 = MapieRegressorV0
424-
425- elif strategy_key == "quantile" :
426- model_v1 = ConformalizedQuantileRegressor
427- model_v0 = MapieQuantileRegressorV0
428-
429- else :
430- raise ValueError (f"Unknown strategy key: { strategy_key } " )
431-
432- return model_v0 , model_v1
394+ compare_model_predictions_and_intervals (
395+ model_v0 = MapieQuantileRegressorV0 ,
396+ model_v1 = ConformalizedQuantileRegressor ,
397+ X = X ,
398+ y = y ,
399+ v0_params = v0_params ,
400+ v1_params = v1_params ,
401+ test_size = test_size ,
402+ prefit = prefit ,
403+ random_state = RANDOM_STATE ,
404+ )
433405
434406
435407def compare_model_predictions_and_intervals (
@@ -486,6 +458,9 @@ def compare_model_predictions_and_intervals(
486458 v1 .conformalize (X_conf , y_conf , ** v1_conformalize_params )
487459
488460 v0_predict_params = filter_params (v0 .predict , v0_params )
461+ if 'alpha' in v0_init_params :
462+ v0_predict_params .pop ('alpha' )
463+
489464 v1_predict_params = filter_params (v1 .predict , v1_params )
490465 v1_predict_set_params = filter_params (v1 .predict_set , v1_params )
491466
0 commit comments