@@ -138,6 +138,7 @@ def test_exact_interval_equality_split(
138138 "ensemble" : True ,
139139 "method" : "base" ,
140140 "sample_weight" : sample_weight ,
141+ "random_state" : RANDOM_STATE ,
141142 },
142143 "v1" : {
143144 "confidence_level" : 0.8 ,
@@ -146,24 +147,27 @@ def test_exact_interval_equality_split(
146147 "aggregation_method" : "median" ,
147148 "method" : "base" ,
148149 "fit_params" : {"sample_weight" : sample_weight },
150+ "random_state" : RANDOM_STATE ,
149151 }
150152 },
151153 {
152154 "v0" : {
153155 "estimator" : positive_predictor ,
154- "alpha" : 0.5 ,
156+ "alpha" : [ 0.5 , 0.5 ] ,
155157 "conformity_score" : GammaConformityScore (),
156158 "cv" : LeaveOneOut (),
157159 "method" : "plus" ,
158160 "optimize_beta" : True ,
161+ "random_state" : RANDOM_STATE ,
159162 },
160163 "v1" : {
161164 "estimator" : positive_predictor ,
162- "confidence_level" : 0.5 ,
165+ "confidence_level" : [ 0.5 , 0.5 ] ,
163166 "conformity_score" : "gamma" ,
164167 "cv" : LeaveOneOut (),
165168 "method" : "plus" ,
166169 "minimize_interval_width" : True ,
170+ "random_state" : RANDOM_STATE ,
167171 }
168172 },
169173 {
@@ -173,12 +177,14 @@ def test_exact_interval_equality_split(
173177 "groups" : groups ,
174178 "method" : "minmax" ,
175179 "allow_infinite_bounds" : True ,
180+ "random_state" : RANDOM_STATE ,
176181 },
177182 "v1" : {
178183 "cv" : GroupKFold (),
179184 "groups" : groups ,
180185 "method" : "minmax" ,
181186 "allow_infinite_bounds" : True ,
187+ "random_state" : RANDOM_STATE ,
182188 }
183189 },
184190]
@@ -209,12 +215,14 @@ def test_intervals_and_predictions_exact_equality_cross(params_cross):
209215 v1_predict_set_params = filter_params (v1 .predict_set , v1_params )
210216
211217 v0_preds , v0_pred_intervals = v0 .predict (X_cross , ** v0_predict_params )
212- v0_pred_intervals = v0_pred_intervals [:, :, 0 ]
218+
213219 v1_pred_intervals = v1 .predict_set (X_cross , ** v1_predict_set_params )
220+ if v1_pred_intervals .ndim == 2 :
221+ v1_pred_intervals = np .expand_dims (v1_pred_intervals , axis = 2 )
214222 v1_preds = v1 .predict (X_cross , ** v1_predict_params )
215223
216- assert np .equal (v0_preds , v1_preds )
217- assert np .equal (v0_pred_intervals , v1_pred_intervals )
224+ np .testing . assert_array_equal (v0_preds , v1_preds )
225+ np .testing . assert_array_equal (v0_pred_intervals , v1_pred_intervals )
218226
219227
220228def initialize_models (
0 commit comments