1010from numpy .random import RandomState
1111from numpy .typing import ArrayLike , NDArray
1212from sklearn .datasets import make_regression
13- from sklearn .linear_model import LinearRegression
13+ from sklearn .linear_model import LinearRegression , LogisticRegression
1414from sklearn .model_selection import BaseCrossValidator , KFold , LeaveOneOut , ShuffleSplit
1515from sklearn .utils .validation import check_is_fitted as sk_check_is_fitted
1616
4949 _transform_confidence_level_to_alpha ,
5050 _transform_confidence_level_to_alpha_list ,
5151 check_is_fitted ,
52+ check_user_model_is_fitted ,
5253 train_conformalize_test_split ,
5354)
5455
@@ -897,3 +898,59 @@ def test_check_is_fitted_passes_after_fit():
897898 model = DummyModel ()
898899 model .is_fitted = True
899900 check_is_fitted (model )
901+
902+
903+ def test_check_user_model_is_fitted_unfitted ():
904+ model = DummyModel ()
905+ with pytest .raises (NotFittedError ):
906+ check_user_model_is_fitted (model )
907+
908+
909+ def test_check_user_model_is_fitted_raises_for_unfitted_model ():
910+ model = LinearRegression ()
911+ with pytest .raises (NotFittedError ):
912+ check_user_model_is_fitted (model )
913+
914+
915+ @pytest .mark .parametrize ("Model" , [LinearRegression , LogisticRegression ])
916+ def test_check_user_model_is_fitted_sklearn_models (Model ):
917+ """Check that sklearn classifiers and regressors pass."""
918+ X = np .random .randn (20 , 4 )
919+ y = (
920+ (np .random .randn (20 ) > 0 ).astype (int )
921+ if Model is LogisticRegression
922+ else np .random .randn (20 )
923+ )
924+ model = Model ().fit (X , y )
925+ assert check_user_model_is_fitted (model ) is True
926+
927+
928+ class DummyFittedNoFeatures :
929+ """A fake estimator that mimics a fitted model but without n_features_in_."""
930+
931+ def __init__ (self ):
932+ self .coef_ = np .array ([1.0 ])
933+
934+ def predict (self , X ):
935+ return np .array ([0.0 ])
936+
937+
938+ def test_check_user_model_fitted_no_n_features_in ():
939+ model = DummyFittedNoFeatures ()
940+ with pytest .warns (UserWarning ):
941+ assert check_user_model_is_fitted (model ) is True
942+
943+
944+ class PartiallyFitted :
945+ def __init__ (self ):
946+ self .coef_ = np .array ([1 , 2 , 3 ])
947+
948+
949+ def test_check_user_model_is_fitted_partial_fit_warning ():
950+ """
951+ Test that a partially fitted user model triggers a UserWarning
952+ but still returns True from check_user_model_is_fitted.
953+ """
954+ model = PartiallyFitted ()
955+ with pytest .warns (UserWarning ):
956+ assert check_user_model_is_fitted (model ) is True
0 commit comments