Skip to content

Commit ed7e6fc

Browse files
committed
iter
1 parent b56b346 commit ed7e6fc

File tree

16 files changed

+543
-107
lines changed

16 files changed

+543
-107
lines changed

imblearn/base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,26 @@
77
from abc import ABCMeta, abstractmethod
88

99
import numpy as np
10+
import sklearn
1011
from sklearn.base import BaseEstimator, OneToOneFeatureMixin
1112
from sklearn.preprocessing import label_binarize
13+
from sklearn.utils.metaestimators import available_if
1214
from sklearn.utils.multiclass import check_classification_targets
15+
from sklearn.utils.fixes import parse_version
1316

1417
from .utils import check_sampling_strategy, check_target_type
18+
from .utils.fixes import validate_data
1519
from .utils._param_validation import validate_parameter_constraints
20+
from .utils._tags import InputTags
1621
from .utils._validation import ArraysTransformer
1722

1823

24+
def check_version():
25+
return parse_version(
26+
parse_version(sklearn.__version__).base_version
27+
) >= parse_version("1.6")
28+
29+
1930
class _ParamsValidationMixin:
2031
"""Mixin class to validate parameters."""
2132

@@ -147,7 +158,7 @@ def _check_X_y(self, X, y, accept_sparse=None):
147158
if accept_sparse is None:
148159
accept_sparse = ["csr", "csc"]
149160
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
150-
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
161+
X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse)
151162
return X, y, binarize_y
152163

153164
def fit(self, X, y):
@@ -196,9 +207,18 @@ def fit_resample(self, X, y):
196207
self._validate_params()
197208
return super().fit_resample(X, y)
198209

210+
@available_if(check_version)
199211
def _more_tags(self):
200212
return {"X_types": ["2darray", "sparse", "dataframe"]}
201213

214+
def __sklearn_tags__(self):
215+
tags = super().__sklearn_tags__()
216+
tags.input_tags = InputTags()
217+
tags.input_tags.two_d_array = True
218+
tags.input_tags.sparse = True
219+
tags.input_tags.dataframe = True
220+
return tags
221+
202222

203223
def _identity(X, y):
204224
return X, y

imblearn/ensemble/_bagging.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from ..utils import Substitution, check_sampling_strategy, check_target_type
2727
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
2828
from ..utils._param_validation import HasMethods, Interval, StrOptions
29-
from ..utils.fixes import _fit_context
29+
from ..utils.fixes import _fit_context, validate_data
3030
from ._common import _bagging_parameter_constraints, _estimator_has
3131

32-
sklearn_version = parse_version(sklearn.__version__)
32+
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
3333

3434

3535
@Substitution(
@@ -382,12 +382,17 @@ def decision_function(self, X):
382382
check_is_fitted(self)
383383

384384
# Check data
385-
X = self._validate_data(
386-
X,
385+
if sklearn_version < parse_version("1.6"):
386+
kwargs = {"force_all_finite": False}
387+
else:
388+
kwargs = {"ensure_all_finite": False}
389+
X = validate_data(
390+
self,
391+
X=X,
387392
accept_sparse=["csr", "csc"],
388393
dtype=None,
389-
force_all_finite=False,
390394
reset=False,
395+
**kwargs
391396
)
392397

393398
# Parallel loop

imblearn/ensemble/_easy_ensemble.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
1515
from sklearn.ensemble._bagging import _parallel_decision_function
1616
from sklearn.ensemble._base import _partition_estimators
17-
from sklearn.utils._tags import _safe_tags
1817
from sklearn.utils.fixes import parse_version
1918
from sklearn.utils.metaestimators import available_if
2019
from sklearn.utils.parallel import Parallel, delayed
@@ -27,11 +26,11 @@
2726
from ..utils import Substitution, check_sampling_strategy, check_target_type
2827
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
2928
from ..utils._param_validation import Interval, StrOptions
30-
from ..utils.fixes import _fit_context
29+
from ..utils.fixes import _fit_context, get_tags, validate_data
3130
from ._common import _bagging_parameter_constraints, _estimator_has
3231

3332
MAX_INT = np.iinfo(np.int32).max
34-
sklearn_version = parse_version(sklearn.__version__)
33+
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
3534

3635

3736
@Substitution(
@@ -311,12 +310,17 @@ def decision_function(self, X):
311310
check_is_fitted(self)
312311

313312
# Check data
314-
X = self._validate_data(
315-
X,
313+
if sklearn_version < parse_version("1.6"):
314+
kwargs = {"force_all_finite": False}
315+
else:
316+
kwargs = {"ensure_all_finite": False}
317+
X = validate_data(
318+
self,
319+
X=X,
316320
accept_sparse=["csr", "csc"],
317321
dtype=None,
318-
force_all_finite=False,
319322
reset=False,
323+
**kwargs,
320324
)
321325

322326
# Parallel loop
@@ -351,4 +355,6 @@ def _get_estimator(self):
351355

352356
# TODO: remove when minimum supported version of scikit-learn is 1.5
353357
def _more_tags(self):
354-
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
358+
# This code should not be called for scikit-learn >= 1.6
359+
# Therefore, get_tags corresponds to _safe_tags that returns a dict
360+
return {"allow_nan": get_tags(self._get_estimator(), "allow_nan")}

imblearn/ensemble/_forest.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@
3535
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
3636
from ..utils._param_validation import Hidden, Interval, StrOptions
3737
from ..utils._validation import check_sampling_strategy
38-
from ..utils.fixes import _fit_context
38+
from ..utils.fixes import _fit_context, validate_data
3939
from ._common import _random_forest_classifier_parameter_constraints
4040

4141
MAX_INT = np.iinfo(np.int32).max
42-
sklearn_version = parse_version(sklearn.__version__)
42+
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
4343

4444

4545
def _local_parallel_build_trees(
@@ -597,21 +597,25 @@ def fit(self, X, y, sample_weight=None):
597597
# TODO: remove when the minimum supported version of scipy will be 1.4
598598
# Support for missing values
599599
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
600-
force_all_finite = False
600+
if sklearn_version >= parse_version("1.6"):
601+
kwargs = {"ensure_all_finite": False}
602+
else:
603+
kwargs = {"force_all_finite": False}
601604
else:
602-
force_all_finite = True
605+
kwargs = {"force_all_finite": False}
603606

604-
X, y = self._validate_data(
605-
X,
606-
y,
607+
X, y = validate_data(
608+
self,
609+
X=X,
610+
y=y,
607611
multi_output=True,
608612
accept_sparse="csc",
609613
dtype=DTYPE,
610-
force_all_finite=force_all_finite,
614+
**kwargs,
611615
)
612616

613617
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
614-
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
618+
if sklearn_version >= parse_version("1.4"):
615619
# _compute_missing_values_in_feature_mask checks if X has missing values and
616620
# will raise an error if the underlying tree base estimator can't handle
617621
# missing values. Only the criterion is required to determine if the tree

imblearn/metrics/pairwise.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ..base import _ParamsValidationMixin
1616
from ..utils._param_validation import StrOptions
17+
from ..utils.fixes import validate_data
1718

1819

1920
class ValueDifferenceMetric(_ParamsValidationMixin, BaseEstimator):
@@ -148,7 +149,7 @@ def fit(self, X, y):
148149
"""
149150
self._validate_params()
150151
check_consistent_length(X, y)
151-
X, y = self._validate_data(X, y, reset=True, dtype=np.int32)
152+
X, y = validate_data(self, X=X, y=y, reset=True, dtype=np.int32)
152153

153154
if isinstance(self.n_categories, str) and self.n_categories == "auto":
154155
# categories are expected to be encoded from 0 to n_categories - 1
@@ -207,11 +208,11 @@ def pairwise(self, X, Y=None):
207208
The VDM pairwise distance.
208209
"""
209210
check_is_fitted(self)
210-
X = self._validate_data(X, reset=False, dtype=np.int32)
211+
X = validate_data(self, X=X, reset=False, dtype=np.int32)
211212
n_samples_X = X.shape[0]
212213

213214
if Y is not None:
214-
Y = self._validate_data(Y, reset=False, dtype=np.int32)
215+
Y = validate_data(self, Y=Y, reset=False, dtype=np.int32)
215216
n_samples_Y = Y.shape[0]
216217
else:
217218
n_samples_Y = n_samples_X

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..utils import Substitution, check_target_type
1616
from ..utils._docstring import _random_state_docstring
1717
from ..utils._param_validation import Interval
18+
from ..utils.fixes import _check_n_features, _check_feature_names
1819
from ..utils._validation import _check_X
1920
from .base import BaseOverSampler
2021

@@ -156,8 +157,8 @@ def __init__(
156157
def _check_X_y(self, X, y):
157158
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
158159
X = _check_X(X)
159-
self._check_n_features(X, reset=True)
160-
self._check_feature_names(X, reset=True)
160+
_check_n_features(self, X, reset=True)
161+
_check_feature_names(self, X, reset=True)
161162
return X, y, binarize_y
162163

163164
def _fit_resample(self, X, y):

imblearn/over_sampling/_smote/base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
3333
from ...utils._param_validation import HasMethods, Interval, StrOptions
3434
from ...utils._validation import _check_X
35-
from ...utils.fixes import _is_pandas_df, _mode
35+
from ...utils.fixes import _check_n_features, _check_feature_names, _is_pandas_df, _mode, validate_data
3636
from ..base import BaseOverSampler
3737

3838
sklearn_version = parse_version(sklearn.__version__).base_version
@@ -601,8 +601,8 @@ def _check_X_y(self, X, y):
601601
"""
602602
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
603603
X = _check_X(X)
604-
self._check_n_features(X, reset=True)
605-
self._check_feature_names(X, reset=True)
604+
_check_n_features(self, X, reset=True)
605+
_check_feature_names(self, X, reset=True)
606606
return X, y, binarize_y
607607

608608
def _validate_column_types(self, X):
@@ -963,9 +963,10 @@ def __init__(
963963
def _check_X_y(self, X, y):
964964
"""Check should accept strings and not sparse matrices."""
965965
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
966-
X, y = self._validate_data(
967-
X,
968-
y,
966+
X, y = validate_data(
967+
self,
968+
X=X,
969+
y=y,
969970
reset=True,
970971
dtype=None,
971972
accept_sparse=["csr", "csc"],

imblearn/tests/test_common.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Common tests"""
2+
23
# Authors: Guillaume Lemaitre <[email protected]>
34
# Christos Aridas
45
# License: MIT
@@ -10,8 +11,7 @@
1011
import pytest
1112
from sklearn.base import clone
1213
from sklearn.exceptions import ConvergenceWarning
13-
from sklearn.utils._testing import SkipTest, ignore_warnings, set_random_state
14-
from sklearn.utils.estimator_checks import _construct_instance, _get_check_estimator_ids
14+
from sklearn.utils._testing import ignore_warnings
1515
from sklearn.utils.estimator_checks import (
1616
parametrize_with_checks as parametrize_with_checks_sklearn,
1717
)
@@ -25,6 +25,10 @@
2525
parametrize_with_checks,
2626
)
2727
from imblearn.utils.testing import all_estimators
28+
from imblearn.utils._test_common.instance_generator import (
29+
_get_check_estimator_ids,
30+
_tested_estimators,
31+
)
2832

2933

3034
@pytest.mark.parametrize("name, Estimator", all_estimators())
@@ -34,22 +38,6 @@ def test_all_estimator_no_base_class(name, Estimator):
3438
assert not name.lower().startswith("base"), msg
3539

3640

37-
def _tested_estimators():
38-
for name, Estimator in all_estimators():
39-
try:
40-
estimator = _construct_instance(Estimator)
41-
set_random_state(estimator)
42-
except SkipTest:
43-
continue
44-
45-
if isinstance(estimator, NearMiss):
46-
# For NearMiss, let's check the three algorithms
47-
for version in (1, 2, 3):
48-
yield clone(estimator).set_params(version=version)
49-
else:
50-
yield estimator
51-
52-
5341
@parametrize_with_checks_sklearn(list(_tested_estimators()))
5442
def test_estimators_compatibility_sklearn(estimator, check, request):
5543
_set_checking_parameters(estimator)

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sklearn.utils import _safe_indexing, check_random_state
99

1010
from ...utils import Substitution, check_target_type
11+
from ...utils.fixes import _check_n_features, _check_feature_names
1112
from ...utils._docstring import _random_state_docstring
1213
from ...utils._validation import _check_X
1314
from ..base import BaseUnderSampler
@@ -99,8 +100,8 @@ def __init__(
99100
def _check_X_y(self, X, y):
100101
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
101102
X = _check_X(X)
102-
self._check_n_features(X, reset=True)
103-
self._check_feature_names(X, reset=True)
103+
_check_n_features(self, X, reset=True)
104+
_check_feature_names(self, X, reset=True)
104105
return X, y, binarize_y
105106

106107
def _fit_resample(self, X, y):

imblearn/utils/_tags.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from dataclasses import dataclass
2+
3+
import sklearn
4+
from sklearn.utils.fixes import parse_version
5+
6+
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
7+
8+
if sklearn_version >= parse_version("1.6"):
9+
from sklearn.utils._tags import InputTags
10+
11+
@dataclass
12+
class InputTags(InputTags):
13+
dataframe: bool = True

0 commit comments

Comments
 (0)