Skip to content
23 changes: 22 additions & 1 deletion imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,25 @@
from abc import ABCMeta, abstractmethod

import numpy as np
import sklearn
from sklearn.base import BaseEstimator, OneToOneFeatureMixin
from sklearn.preprocessing import label_binarize
from sklearn.utils.metaestimators import available_if
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.fixes import parse_version

from .utils import check_sampling_strategy, check_target_type
from .utils.fixes import validate_data
from .utils._param_validation import validate_parameter_constraints
from .utils._validation import ArraysTransformer


def check_version():
return parse_version(
parse_version(sklearn.__version__).base_version
) >= parse_version("1.6")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably do a try / except with import Tags kinda thing though. With a note that "remove this after 1.6 becomes minimum"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the pattern that we document. But actually, I don't know why we are telling to use available_if. Right now, we could have both _more_tags and __sklearn_tags__ define.

The available_if would make sense if you want to avoid a deprecation warning of an error that we could raise in the future if you define both _more_tags and __sklearn_tags__. This is something that I mentioned here: scikit-learn/scikit-learn#30257



class _ParamsValidationMixin:
"""Mixin class to validate parameters."""

Expand Down Expand Up @@ -147,7 +157,7 @@ def _check_X_y(self, X, y, accept_sparse=None):
if accept_sparse is None:
accept_sparse = ["csr", "csc"]
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse)
return X, y, binarize_y

def fit(self, X, y):
Expand Down Expand Up @@ -196,9 +206,20 @@ def fit_resample(self, X, y):
self._validate_params()
return super().fit_resample(X, y)

@available_if(check_version)
def _more_tags(self):
return {"X_types": ["2darray", "sparse", "dataframe"]}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()

from .utils._tags import InputTags
tags.input_tags = InputTags()
tags.input_tags.two_d_array = True
tags.input_tags.sparse = True
tags.input_tags.dataframe = True
return tags


def _identity(X, y):
return X, y
Expand Down
15 changes: 10 additions & 5 deletions imblearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from ..utils import Substitution, check_sampling_strategy, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import HasMethods, Interval, StrOptions
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, validate_data
from ._common import _bagging_parameter_constraints, _estimator_has

sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


@Substitution(
Expand Down Expand Up @@ -382,12 +382,17 @@ def decision_function(self, X):
check_is_fitted(self)

# Check data
X = self._validate_data(
X,
if sklearn_version < parse_version("1.6"):
kwargs = {"force_all_finite": False}
else:
kwargs = {"ensure_all_finite": False}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic can go in your fixes.validate_data

X = validate_data(
self,
X=X,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
reset=False,
**kwargs
)

# Parallel loop
Expand Down
20 changes: 13 additions & 7 deletions imblearn/ensemble/_easy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
from sklearn.ensemble._bagging import _parallel_decision_function
from sklearn.ensemble._base import _partition_estimators
from sklearn.utils._tags import _safe_tags
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.parallel import Parallel, delayed
Expand All @@ -27,11 +26,11 @@
from ..utils import Substitution, check_sampling_strategy, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import Interval, StrOptions
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, get_tags, validate_data
from ._common import _bagging_parameter_constraints, _estimator_has

MAX_INT = np.iinfo(np.int32).max
sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


@Substitution(
Expand Down Expand Up @@ -311,12 +310,17 @@ def decision_function(self, X):
check_is_fitted(self)

# Check data
X = self._validate_data(
X,
if sklearn_version < parse_version("1.6"):
kwargs = {"force_all_finite": False}
else:
kwargs = {"ensure_all_finite": False}
X = validate_data(
self,
X=X,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
reset=False,
**kwargs,
)

# Parallel loop
Expand Down Expand Up @@ -351,4 +355,6 @@ def _get_estimator(self):

# TODO: remove when minimum supported version of scikit-learn is 1.5
def _more_tags(self):
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
# This code should not be called for scikit-learn >= 1.6
# Therefore, get_tags corresponds to _safe_tags that returns a dict
return {"allow_nan": get_tags(self._get_estimator(), "allow_nan")}
22 changes: 13 additions & 9 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._validation import check_sampling_strategy
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, validate_data
from ._common import _random_forest_classifier_parameter_constraints

MAX_INT = np.iinfo(np.int32).max
sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


def _local_parallel_build_trees(
Expand Down Expand Up @@ -597,21 +597,25 @@ def fit(self, X, y, sample_weight=None):
# TODO: remove when the minimum supported version of scipy will be 1.4
# Support for missing values
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
force_all_finite = False
if sklearn_version >= parse_version("1.6"):
kwargs = {"ensure_all_finite": False}
else:
kwargs = {"force_all_finite": False}
else:
force_all_finite = True
kwargs = {"force_all_finite": False}

X, y = self._validate_data(
X,
y,
X, y = validate_data(
self,
X=X,
y=y,
multi_output=True,
accept_sparse="csc",
dtype=DTYPE,
force_all_finite=force_all_finite,
**kwargs,
)

# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
if sklearn_version >= parse_version("1.4"):
# _compute_missing_values_in_feature_mask checks if X has missing values and
# will raise an error if the underlying tree base estimator can't handle
# missing values. Only the criterion is required to determine if the tree
Expand Down
7 changes: 4 additions & 3 deletions imblearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..base import _ParamsValidationMixin
from ..utils._param_validation import StrOptions
from ..utils.fixes import validate_data


class ValueDifferenceMetric(_ParamsValidationMixin, BaseEstimator):
Expand Down Expand Up @@ -148,7 +149,7 @@ def fit(self, X, y):
"""
self._validate_params()
check_consistent_length(X, y)
X, y = self._validate_data(X, y, reset=True, dtype=np.int32)
X, y = validate_data(self, X=X, y=y, reset=True, dtype=np.int32)

if isinstance(self.n_categories, str) and self.n_categories == "auto":
# categories are expected to be encoded from 0 to n_categories - 1
Expand Down Expand Up @@ -207,11 +208,11 @@ def pairwise(self, X, Y=None):
The VDM pairwise distance.
"""
check_is_fitted(self)
X = self._validate_data(X, reset=False, dtype=np.int32)
X = validate_data(self, X=X, reset=False, dtype=np.int32)
n_samples_X = X.shape[0]

if Y is not None:
Y = self._validate_data(Y, reset=False, dtype=np.int32)
Y = validate_data(self, Y=Y, reset=False, dtype=np.int32)
n_samples_Y = Y.shape[0]
else:
n_samples_Y = n_samples_X
Expand Down
5 changes: 3 additions & 2 deletions imblearn/over_sampling/_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring
from ..utils._param_validation import Interval
from ..utils.fixes import _check_n_features, _check_feature_names
from ..utils._validation import _check_X
from .base import BaseOverSampler

Expand Down Expand Up @@ -156,8 +157,8 @@ def __init__(
def _check_X_y(self, X, y):
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
_check_n_features(self, X, reset=True)
_check_feature_names(self, X, reset=True)
return X, y, binarize_y

def _fit_resample(self, X, y):
Expand Down
13 changes: 7 additions & 6 deletions imblearn/over_sampling/_smote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
from ...utils._param_validation import HasMethods, Interval, StrOptions
from ...utils._validation import _check_X
from ...utils.fixes import _is_pandas_df, _mode
from ...utils.fixes import _check_n_features, _check_feature_names, _is_pandas_df, _mode, validate_data
from ..base import BaseOverSampler

sklearn_version = parse_version(sklearn.__version__).base_version
Expand Down Expand Up @@ -601,8 +601,8 @@ def _check_X_y(self, X, y):
"""
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
_check_n_features(self, X, reset=True)
_check_feature_names(self, X, reset=True)
return X, y, binarize_y

def _validate_column_types(self, X):
Expand Down Expand Up @@ -963,9 +963,10 @@ def __init__(
def _check_X_y(self, X, y):
"""Check should accept strings and not sparse matrices."""
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(
X,
y,
X, y = validate_data(
self,
X=X,
y=y,
reset=True,
dtype=None,
accept_sparse=["csr", "csc"],
Expand Down
24 changes: 6 additions & 18 deletions imblearn/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common tests"""

# Authors: Guillaume Lemaitre <[email protected]>
# Christos Aridas
# License: MIT
Expand All @@ -10,8 +11,7 @@
import pytest
from sklearn.base import clone
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils._testing import SkipTest, ignore_warnings, set_random_state
from sklearn.utils.estimator_checks import _construct_instance, _get_check_estimator_ids
from sklearn.utils._testing import ignore_warnings
from sklearn.utils.estimator_checks import (
parametrize_with_checks as parametrize_with_checks_sklearn,
)
Expand All @@ -25,6 +25,10 @@
parametrize_with_checks,
)
from imblearn.utils.testing import all_estimators
from imblearn.utils._test_common.instance_generator import (
_get_check_estimator_ids,
_tested_estimators,
)


@pytest.mark.parametrize("name, Estimator", all_estimators())
Expand All @@ -34,22 +38,6 @@ def test_all_estimator_no_base_class(name, Estimator):
assert not name.lower().startswith("base"), msg


def _tested_estimators():
for name, Estimator in all_estimators():
try:
estimator = _construct_instance(Estimator)
set_random_state(estimator)
except SkipTest:
continue

if isinstance(estimator, NearMiss):
# For NearMiss, let's check the three algorithms
for version in (1, 2, 3):
yield clone(estimator).set_params(version=version)
else:
yield estimator


@parametrize_with_checks_sklearn(list(_tested_estimators()))
def test_estimators_compatibility_sklearn(estimator, check, request):
_set_checking_parameters(estimator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.utils import _safe_indexing, check_random_state

from ...utils import Substitution, check_target_type
from ...utils.fixes import _check_n_features, _check_feature_names
from ...utils._docstring import _random_state_docstring
from ...utils._validation import _check_X
from ..base import BaseUnderSampler
Expand Down Expand Up @@ -99,8 +100,8 @@ def __init__(
def _check_X_y(self, X, y):
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
_check_n_features(self, X, reset=True)
_check_feature_names(self, X, reset=True)
return X, y, binarize_y

def _fit_resample(self, X, y):
Expand Down
13 changes: 13 additions & 0 deletions imblearn/utils/_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass

import sklearn
from sklearn.utils.fixes import parse_version

sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)

if sklearn_version >= parse_version("1.6"):
from sklearn.utils._tags import InputTags
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A try / except around the import (from the public path) makes more sense I think, which shows the exact thing you care about, instead of using the version as a proxy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Tags are not available in a public path if I'm not wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I check the wrong __init__.py :)


@dataclass
class InputTags(InputTags):
dataframe: bool = True
Empty file.
Loading
Loading