Skip to content

Commit 96ab58c

Browse files
TESTS: fix and improve v1 integration tests
1 parent fffeb87 commit 96ab58c

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

mapie_v1/integration_tests/tests/test_regression.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def test_exact_interval_equality_split(
138138
"ensemble": True,
139139
"method": "base",
140140
"sample_weight": sample_weight,
141+
"random_state": RANDOM_STATE,
141142
},
142143
"v1": {
143144
"confidence_level": 0.8,
@@ -146,24 +147,27 @@ def test_exact_interval_equality_split(
146147
"aggregation_method": "median",
147148
"method": "base",
148149
"fit_params": {"sample_weight": sample_weight},
150+
"random_state": RANDOM_STATE,
149151
}
150152
},
151153
{
152154
"v0": {
153155
"estimator": positive_predictor,
154-
"alpha": 0.5,
156+
"alpha": [0.5, 0.5],
155157
"conformity_score": GammaConformityScore(),
156158
"cv": LeaveOneOut(),
157159
"method": "plus",
158160
"optimize_beta": True,
161+
"random_state": RANDOM_STATE,
159162
},
160163
"v1": {
161164
"estimator": positive_predictor,
162-
"confidence_level": 0.5,
165+
"confidence_level": [0.5, 0.5],
163166
"conformity_score": "gamma",
164167
"cv": LeaveOneOut(),
165168
"method": "plus",
166169
"minimize_interval_width": True,
170+
"random_state": RANDOM_STATE,
167171
}
168172
},
169173
{
@@ -173,12 +177,14 @@ def test_exact_interval_equality_split(
173177
"groups": groups,
174178
"method": "minmax",
175179
"allow_infinite_bounds": True,
180+
"random_state": RANDOM_STATE,
176181
},
177182
"v1": {
178183
"cv": GroupKFold(),
179184
"groups": groups,
180185
"method": "minmax",
181186
"allow_infinite_bounds": True,
187+
"random_state": RANDOM_STATE,
182188
}
183189
},
184190
]
@@ -209,12 +215,14 @@ def test_intervals_and_predictions_exact_equality_cross(params_cross):
209215
v1_predict_set_params = filter_params(v1.predict_set, v1_params)
210216

211217
v0_preds, v0_pred_intervals = v0.predict(X_cross, **v0_predict_params)
212-
v0_pred_intervals = v0_pred_intervals[:, :, 0]
218+
213219
v1_pred_intervals = v1.predict_set(X_cross, **v1_predict_set_params)
220+
if v1_pred_intervals.ndim == 2:
221+
v1_pred_intervals = np.expand_dims(v1_pred_intervals, axis=2)
214222
v1_preds = v1.predict(X_cross, **v1_predict_params)
215223

216-
assert np.equal(v0_preds, v1_preds)
217-
assert np.equal(v0_pred_intervals, v1_pred_intervals)
224+
np.testing.assert_array_equal(v0_preds, v1_preds)
225+
np.testing.assert_array_equal(v0_pred_intervals, v1_pred_intervals)
218226

219227

220228
def initialize_models(

0 commit comments

Comments
 (0)