|
7 | 7 | from abc import ABCMeta, abstractmethod
|
8 | 8 |
|
9 | 9 | import numpy as np
|
| 10 | +import sklearn |
10 | 11 | from sklearn.base import BaseEstimator, OneToOneFeatureMixin
|
11 | 12 | from sklearn.preprocessing import label_binarize
|
| 13 | +from sklearn.utils.metaestimators import available_if |
12 | 14 | from sklearn.utils.multiclass import check_classification_targets
|
| 15 | +from sklearn.utils.fixes import parse_version |
13 | 16 |
|
14 | 17 | from .utils import check_sampling_strategy, check_target_type
|
| 18 | +from .utils.fixes import validate_data |
15 | 19 | from .utils._param_validation import validate_parameter_constraints
|
| 20 | +from .utils._tags import InputTags |
16 | 21 | from .utils._validation import ArraysTransformer
|
17 | 22 |
|
18 | 23 |
|
| 24 | +def check_version(): |
| 25 | + return parse_version( |
| 26 | + parse_version(sklearn.__version__).base_version |
| 27 | + ) >= parse_version("1.6") |
| 28 | + |
| 29 | + |
19 | 30 | class _ParamsValidationMixin:
|
20 | 31 | """Mixin class to validate parameters."""
|
21 | 32 |
|
@@ -147,7 +158,7 @@ def _check_X_y(self, X, y, accept_sparse=None):
|
147 | 158 | if accept_sparse is None:
|
148 | 159 | accept_sparse = ["csr", "csc"]
|
149 | 160 | y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
|
150 |
| - X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse) |
| 161 | + X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse) |
151 | 162 | return X, y, binarize_y
|
152 | 163 |
|
153 | 164 | def fit(self, X, y):
|
@@ -196,9 +207,18 @@ def fit_resample(self, X, y):
|
196 | 207 | self._validate_params()
|
197 | 208 | return super().fit_resample(X, y)
|
198 | 209 |
|
| 210 | + @available_if(check_version) |
199 | 211 | def _more_tags(self):
|
200 | 212 | return {"X_types": ["2darray", "sparse", "dataframe"]}
|
201 | 213 |
|
| 214 | + def __sklearn_tags__(self): |
| 215 | + tags = super().__sklearn_tags__() |
| 216 | + tags.input_tags = InputTags() |
| 217 | + tags.input_tags.two_d_array = True |
| 218 | + tags.input_tags.sparse = True |
| 219 | + tags.input_tags.dataframe = True |
| 220 | + return tags |
| 221 | + |
202 | 222 |
|
203 | 223 | def _identity(X, y):
|
204 | 224 | return X, y
|
|
0 commit comments