Skip to content

Commit 5857eb1

Browse files
committed
user_model_check : add tests
1 parent 05b23be commit 5857eb1

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

mapie/tests/test_utils.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from numpy.random import RandomState
1111
from numpy.typing import ArrayLike, NDArray
1212
from sklearn.datasets import make_regression
13-
from sklearn.linear_model import LinearRegression
13+
from sklearn.linear_model import LinearRegression, LogisticRegression
1414
from sklearn.model_selection import BaseCrossValidator, KFold, LeaveOneOut, ShuffleSplit
1515
from sklearn.utils.validation import check_is_fitted as sk_check_is_fitted
1616

@@ -49,6 +49,7 @@
4949
_transform_confidence_level_to_alpha,
5050
_transform_confidence_level_to_alpha_list,
5151
check_is_fitted,
52+
check_user_model_is_fitted,
5253
train_conformalize_test_split,
5354
)
5455

@@ -897,3 +898,59 @@ def test_check_is_fitted_passes_after_fit():
897898
model = DummyModel()
898899
model.is_fitted = True
899900
check_is_fitted(model)
901+
902+
903+
def test_check_user_model_is_fitted_unfitted():
904+
model = DummyModel()
905+
with pytest.raises(NotFittedError):
906+
check_user_model_is_fitted(model)
907+
908+
909+
def test_check_user_model_is_fitted_raises_for_unfitted_model():
910+
model = LinearRegression()
911+
with pytest.raises(NotFittedError):
912+
check_user_model_is_fitted(model)
913+
914+
915+
@pytest.mark.parametrize("Model", [LinearRegression, LogisticRegression])
916+
def test_check_user_model_is_fitted_sklearn_models(Model):
917+
"""Check that sklearn classifiers and regressors pass."""
918+
X = np.random.randn(20, 4)
919+
y = (
920+
(np.random.randn(20) > 0).astype(int)
921+
if Model is LogisticRegression
922+
else np.random.randn(20)
923+
)
924+
model = Model().fit(X, y)
925+
assert check_user_model_is_fitted(model) is True
926+
927+
928+
class DummyFittedNoFeatures:
929+
"""A fake estimator that mimics a fitted model but without n_features_in_."""
930+
931+
def __init__(self):
932+
self.coef_ = np.array([1.0])
933+
934+
def predict(self, X):
935+
return np.array([0.0])
936+
937+
938+
def test_check_user_model_fitted_no_n_features_in():
939+
model = DummyFittedNoFeatures()
940+
with pytest.warns(UserWarning):
941+
assert check_user_model_is_fitted(model) is True
942+
943+
944+
class PartiallyFitted:
945+
def __init__(self):
946+
self.coef_ = np.array([1, 2, 3])
947+
948+
949+
def test_check_user_model_is_fitted_partial_fit_warning():
950+
"""
951+
Test that a partially fitted user model triggers a UserWarning
952+
but still returns True from check_user_model_is_fitted.
953+
"""
954+
model = PartiallyFitted()
955+
with pytest.warns(UserWarning):
956+
assert check_user_model_is_fitted(model) is True

0 commit comments

Comments
 (0)