Skip to content

Commit 05b23be

Browse files
committed
user_model_check : add check function
1 parent aa4f423 commit 05b23be

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

mapie/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,3 +1642,49 @@ def check_is_fitted(obj):
16421642
"""Check that _is_fitted attribute is True"""
16431643
if not getattr(obj, "is_fitted", False):
16441644
raise NotFittedError(f"{obj.__class__.__name__} is not fitted yet. ")
1645+
1646+
1647+
FIT_INDICATORS = [
1648+
"n_features_in_",
1649+
"classes_",
1650+
"coef_",
1651+
"feature_names_in_",
1652+
"tree_",
1653+
"estimators_",
1654+
]
1655+
1656+
1657+
def check_user_model_is_fitted(estimator):
1658+
"""
1659+
Check whether a user-provided estimator is fitted.
1660+
1661+
Logic:
1662+
1. Raise error if no typical fit-related attributes are present.
1663+
2. If `n_features_in_` exists try a minimal predict-probe. Else, assume fitted but warn.
1664+
"""
1665+
present_attrs = [attr for attr in FIT_INDICATORS if hasattr(estimator, attr)]
1666+
1667+
if not present_attrs:
1668+
raise NotFittedError(
1669+
"Estimator does not appear fitted. "
1670+
f"Missing expected attributes: {FIT_INDICATORS}."
1671+
)
1672+
1673+
if hasattr(estimator, "n_features_in_"):
1674+
try:
1675+
estimator.predict(np.zeros((1, estimator.n_features_in_)))
1676+
return True
1677+
except Exception as err:
1678+
warnings.warn(
1679+
f"Estimator has `n_features_in_` but failed a minimal prediction test "
1680+
f"(shape={(1, estimator.n_features_in_)}). Error: {err}",
1681+
UserWarning,
1682+
)
1683+
return True
1684+
1685+
warnings.warn(
1686+
f"Estimator exposes fitted-like attributes {present_attrs} but lacks "
1687+
"`n_features_in_`.",
1688+
UserWarning,
1689+
)
1690+
return True

0 commit comments

Comments
 (0)