Skip to content

Commit 985a5d7

Browse files
committed
TEST: Add integrations test quantile
1 parent 55d9196 commit 985a5d7

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

mapie_v1/integration_tests/tests/test_regression.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.linear_model import LinearRegression
1010
from sklearn.ensemble import RandomForestRegressor
1111
from sklearn.linear_model import QuantileRegressor
12+
from lightgbm import LGBMRegressor
1213

1314
from mapie.subsample import Subsample
1415
from mapie._typing import ArrayLike
@@ -33,14 +34,14 @@
3334

3435

3536
X, y_signed = make_regression(
36-
n_samples=50,
37+
n_samples=100,
3738
n_features=10,
3839
noise=1.0,
3940
random_state=RANDOM_STATE
4041
)
4142
y = np.abs(y_signed)
4243
sample_weight = RandomState(RANDOM_STATE).random(len(X))
43-
groups = [0] * 10 + [1] * 10 + [2] * 10 + [3] * 10 + [4] * 10
44+
groups = [0] * 20 + [1] * 20 + [2] * 20 + [3] * 20 + [4] * 20
4445
positive_predictor = TransformedTargetRegressor(
4546
regressor=LinearRegression(),
4647
func=lambda y_: np.log(y_ + 1),
@@ -271,7 +272,15 @@ def test_intervals_and_predictions_exact_equality_cross(params_cross):
271272
alpha=0.0,
272273
)
273274

274-
prefit_models = []
275+
lgbm_models = []
276+
lgbm_alpha = 0.1
277+
for alpha_ in [lgbm_alpha / 2, (1 - (lgbm_alpha / 2)), 0.5]:
278+
estimator_ = LGBMRegressor(
279+
objective='quantile',
280+
alpha=alpha_,
281+
)
282+
lgbm_models.append(estimator_)
283+
275284

276285
@pytest.mark.parametrize("params_jackknife", params_test_cases_jackknife)
277286
def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
@@ -302,8 +311,8 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
302311
},
303312
{
304313
"v0": {
305-
"estimator": prefit_models,
306-
"alpha": [0.5, 0.5],
314+
"estimator": lgbm_models,
315+
"alpha": lgbm_alpha,
307316
"cv": "prefit",
308317
"method": "quantile",
309318
"calib_size": 0.3,
@@ -312,8 +321,8 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
312321
"random_state": RANDOM_STATE,
313322
},
314323
"v1": {
315-
"estimator": prefit_models,
316-
"confidence_level": [0.5, 0.5],
324+
"estimator": lgbm_models,
325+
"confidence_level": 1-lgbm_alpha,
317326
"prefit": True,
318327
"test_size": 0.3,
319328
"fit_params": {"sample_weight": sample_weight},
@@ -324,17 +333,19 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
324333
{
325334
"v0": {
326335
"estimator": split_model,
327-
"alpha": 0.1,
336+
"alpha": 0.5,
328337
"cv": "split",
329338
"method": "quantile",
330339
"calib_size": 0.3,
340+
"allow_infinite_bounds": True,
331341
"random_state": RANDOM_STATE,
332342
},
333343
"v1": {
334344
"estimator": split_model,
335-
"confidence_level": 0.9,
345+
"confidence_level": 0.5,
336346
"prefit": False,
337347
"test_size": 0.3,
348+
"allow_infinite_bounds": True,
338349
"random_state": RANDOM_STATE,
339350
},
340351
},
@@ -358,19 +369,23 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
358369
]
359370

360371

361-
@pytest.mark.parametrize("params_quantile", params_test_cases_jackknife)
372+
@pytest.mark.parametrize("params_quantile", params_test_cases_quantile)
362373
def test_intervals_and_predictions_exact_equality_quantile(params_quantile):
363374
v0_params = params_quantile["v0"]
364375
v1_params = params_quantile["v1"]
365376

377+
test_size = v1_params["test_size"] if "test_size" in v1_params else None
378+
prefit = ("prefit" in v1_params) and v1_params["prefit"]
379+
366380
v0, v1 = select_models_by_strategy("quantile")
367381
compare_model_predictions_and_intervals(model_v0=v0,
368382
model_v1=v1,
369-
X=X_split,
370-
y=y_split,
383+
X=X,
384+
y=y,
371385
v0_params=v0_params,
372386
v1_params=v1_params,
373-
test_size=v1_params["test_size"],
387+
test_size=test_size,
388+
prefit=prefit,
374389
random_state=RANDOM_STATE)
375390

376391

0 commit comments

Comments
 (0)