-
Notifications
You must be signed in to change notification settings - Fork 1.3k
MAINT compatibility with sklearn 1.6 #1104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ed7e6fc
3d25e47
eaa6873
92924eb
176b614
fa206e4
c1514dc
acb8234
2453ca1
ef735f4
1629b06
7c91d5d
5d12d07
e74293a
a33b9f8
7878bb2
8903b00
468f925
c457b4a
762fa48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,10 +26,10 @@ | |
from ..utils import Substitution, check_sampling_strategy, check_target_type | ||
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring | ||
from ..utils._param_validation import HasMethods, Interval, StrOptions | ||
from ..utils.fixes import _fit_context | ||
from ..utils.fixes import _fit_context, validate_data | ||
from ._common import _bagging_parameter_constraints, _estimator_has | ||
|
||
sklearn_version = parse_version(sklearn.__version__) | ||
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version) | ||
|
||
|
||
@Substitution( | ||
|
@@ -382,12 +382,17 @@ def decision_function(self, X): | |
check_is_fitted(self) | ||
|
||
# Check data | ||
X = self._validate_data( | ||
X, | ||
if sklearn_version < parse_version("1.6"): | ||
kwargs = {"force_all_finite": False} | ||
else: | ||
kwargs = {"ensure_all_finite": False} | ||
|
||
X = validate_data( | ||
self, | ||
X=X, | ||
accept_sparse=["csr", "csc"], | ||
dtype=None, | ||
force_all_finite=False, | ||
reset=False, | ||
**kwargs | ||
) | ||
|
||
# Parallel loop | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""Common tests""" | ||
|
||
# Authors: Guillaume Lemaitre <[email protected]> | ||
# Christos Aridas | ||
# License: MIT | ||
|
@@ -10,8 +11,7 @@ | |
import pytest | ||
from sklearn.base import clone | ||
from sklearn.exceptions import ConvergenceWarning | ||
from sklearn.utils._testing import SkipTest, ignore_warnings, set_random_state | ||
from sklearn.utils.estimator_checks import _construct_instance, _get_check_estimator_ids | ||
from sklearn.utils._testing import ignore_warnings | ||
from sklearn.utils.estimator_checks import ( | ||
parametrize_with_checks as parametrize_with_checks_sklearn, | ||
) | ||
|
@@ -25,6 +25,10 @@ | |
parametrize_with_checks, | ||
) | ||
from imblearn.utils.testing import all_estimators | ||
from imblearn.utils._test_common.instance_generator import ( | ||
_get_check_estimator_ids, | ||
_tested_estimators, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("name, Estimator", all_estimators()) | ||
|
@@ -34,22 +38,6 @@ def test_all_estimator_no_base_class(name, Estimator): | |
assert not name.lower().startswith("base"), msg | ||
|
||
|
||
def _tested_estimators(): | ||
for name, Estimator in all_estimators(): | ||
try: | ||
estimator = _construct_instance(Estimator) | ||
set_random_state(estimator) | ||
except SkipTest: | ||
continue | ||
|
||
if isinstance(estimator, NearMiss): | ||
# For NearMiss, let's check the three algorithms | ||
for version in (1, 2, 3): | ||
yield clone(estimator).set_params(version=version) | ||
else: | ||
yield estimator | ||
|
||
|
||
@parametrize_with_checks_sklearn(list(_tested_estimators())) | ||
def test_estimators_compatibility_sklearn(estimator, check, request): | ||
_set_checking_parameters(estimator) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from dataclasses import dataclass | ||
|
||
import sklearn | ||
from sklearn.utils.fixes import parse_version | ||
|
||
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version) | ||
|
||
if sklearn_version >= parse_version("1.6"): | ||
from sklearn.utils._tags import InputTags | ||
|
||
|
||
@dataclass | ||
class InputTags(InputTags): | ||
dataframe: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably do a
try / except
withimport Tags
kinda thing though. With a note that "remove this after 1.6 becomes minimum"There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the pattern that we document. But actually, I don't know why we are telling to use
available_if
. Right now, we could have both_more_tags
and__sklearn_tags__
define.The
available_if
would make sense if you want to avoid a deprecation warning of an error that we could raise in the future if you define both_more_tags
and__sklearn_tags__
. This is something that I mentioned here: scikit-learn/scikit-learn#30257