Skip to content
2 changes: 1 addition & 1 deletion examples/ensemble/plot_comparison_ensemble_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@

from imblearn.ensemble import EasyEnsembleClassifier, RUSBoostClassifier

estimator = AdaBoostClassifier(n_estimators=10, algorithm="SAMME")
estimator = AdaBoostClassifier(n_estimators=10)
eec = EasyEnsembleClassifier(n_estimators=10, estimator=estimator)
eec.fit(X_train, y_train)
y_pred_eec = eec.predict(X_test)
Expand Down
28 changes: 24 additions & 4 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
from abc import ABCMeta, abstractmethod

import numpy as np
import sklearn
from sklearn.base import BaseEstimator, OneToOneFeatureMixin
from sklearn.preprocessing import label_binarize
from sklearn.utils.metaestimators import available_if
from sklearn.utils.multiclass import check_classification_targets

from .utils import check_sampling_strategy, check_target_type
from .utils.fixes import check_version_package, validate_data
from .utils._param_validation import validate_parameter_constraints
from .utils._validation import ArraysTransformer


class _ParamsValidationMixin:
"""Mixin class to validate parameters."""

Expand All @@ -35,7 +37,7 @@ class attribute, which is a dictionary `param_name: list of constraints`. See
)


class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
class SamplerMixin(_ParamsValidationMixin, metaclass=ABCMeta):
"""Mixin class for samplers with abstract method.

Warning: This class should not be used directly. Use the derive classes
Expand Down Expand Up @@ -133,7 +135,7 @@ def _fit_resample(self, X, y):
pass


class BaseSampler(SamplerMixin, OneToOneFeatureMixin):
class BaseSampler(SamplerMixin, OneToOneFeatureMixin, BaseEstimator):
"""Base class for sampling algorithms.

Warning: This class should not be used directly. Use the derive classes
Expand All @@ -147,7 +149,7 @@ def _check_X_y(self, X, y, accept_sparse=None):
if accept_sparse is None:
accept_sparse = ["csr", "csc"]
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse)
return X, y, binarize_y

def fit(self, X, y):
Expand Down Expand Up @@ -196,9 +198,27 @@ def fit_resample(self, X, y):
self._validate_params()
return super().fit_resample(X, y)

@available_if(check_version_package("sklearn", "<", "1.6"))
def _more_tags(self):
return {"X_types": ["2darray", "sparse", "dataframe"]}

@available_if(check_version_package("sklearn", ">=", "1.6"))
def __sklearn_tags__(self):
from .utils._tags import Tags, SamplerTags, TargetTags, InputTags
tags = Tags(
estimator_type="sampler",
target_tags=TargetTags(required=True),
transformer_tags=None,
regressor_tags=None,
classifier_tags=None,
sampler_tags=SamplerTags(),
)
tags.input_tags = InputTags()
tags.input_tags.two_d_array = True
tags.input_tags.sparse = True
tags.input_tags.dataframe = True
return tags


def _identity(X, y):
return X, y
Expand Down
12 changes: 7 additions & 5 deletions imblearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from ..utils import Substitution, check_sampling_strategy, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import HasMethods, Interval, StrOptions
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, check_version_package, validate_data
from ._common import _bagging_parameter_constraints, _estimator_has

sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


@Substitution(
Expand Down Expand Up @@ -382,12 +382,13 @@ def decision_function(self, X):
check_is_fitted(self)

# Check data
X = self._validate_data(
X,
X = validate_data(
self,
X=X,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
reset=False,
ensure_all_finite=False,
)

# Parallel loop
Expand Down Expand Up @@ -415,6 +416,7 @@ def base_estimator_(self):
)
raise error

@available_if(check_version_package("sklearn", "<", "1.6"))
def _more_tags(self):
tags = super()._more_tags()
tags_key = "_xfail_checks"
Expand Down
26 changes: 18 additions & 8 deletions imblearn/ensemble/_easy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
from sklearn.ensemble._bagging import _parallel_decision_function
from sklearn.ensemble._base import _partition_estimators
from sklearn.utils._tags import _safe_tags
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.parallel import Parallel, delayed
Expand All @@ -27,11 +26,11 @@
from ..utils import Substitution, check_sampling_strategy, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import Interval, StrOptions
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, check_version_package, get_tags, validate_data
from ._common import _bagging_parameter_constraints, _estimator_has

MAX_INT = np.iinfo(np.int32).max
sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


@Substitution(
Expand Down Expand Up @@ -311,12 +310,13 @@ def decision_function(self, X):
check_is_fitted(self)

# Check data
X = self._validate_data(
X,
X = validate_data(
self,
X=X,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
reset=False,
ensure_all_finite=False,
)

# Parallel loop
Expand Down Expand Up @@ -346,9 +346,19 @@ def base_estimator_(self):

def _get_estimator(self):
if self.estimator is None:
return AdaBoostClassifier(algorithm="SAMME")
if parse_version("1.4") <= sklearn_version < parse_version("1.6"):
return AdaBoostClassifier(algorithm="SAMME")
else:
return AdaBoostClassifier()
return self.estimator

# TODO: remove when minimum supported version of scikit-learn is 1.5
@available_if(check_version_package("sklearn", "<", "1.6"))
def _more_tags(self):
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
return {"allow_nan": get_tags(self._get_estimator())["allow_nan"]}

@available_if(check_version_package("sklearn", ">=", "1.6"))
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = get_tags(self._get_estimator()).input_tags.allow_nan
return tags
43 changes: 28 additions & 15 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numbers
from copy import deepcopy
from dataclasses import is_dataclass
from warnings import warn

import numpy as np
Expand All @@ -24,6 +25,7 @@
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import _safe_indexing, check_random_state
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.parallel import Parallel, delayed
from sklearn.utils.validation import _check_sample_weight
Expand All @@ -35,11 +37,11 @@
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._validation import check_sampling_strategy
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, check_version_package, get_tags, validate_data
from ._common import _random_forest_classifier_parameter_constraints

MAX_INT = np.iinfo(np.int32).max
sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


def _local_parallel_build_trees(
Expand Down Expand Up @@ -77,7 +79,7 @@ def _local_parallel_build_trees(
"bootstrap": bootstrap,
}

if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
if sklearn_version >= parse_version("1.4"):
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
# support for missing values
params_parallel_build_trees["missing_values_in_feature_mask"] = (
Expand Down Expand Up @@ -474,7 +476,7 @@ def __init__(
"max_samples": max_samples,
}
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
if sklearn_version >= parse_version("1.4"):
# use scikit-learn support for monotonic constraints
params_random_forest["monotonic_cst"] = monotonic_cst
else:
Expand Down Expand Up @@ -594,24 +596,25 @@ def fit(self, X, y, sample_weight=None):
if issparse(y):
raise ValueError("sparse multilabel-indicator for y is not supported.")

# TODO: remove when the minimum supported version of scipy will be 1.4
# Support for missing values
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
force_all_finite = False
# TODO (1.6): simplify because we will only have dataclass tags
tags = get_tags(self)
if is_dataclass(tags):
ensure_all_finite = not tags.input_tags.allow_nan
else:
force_all_finite = True
ensure_all_finite = not tags.get("allow_nan", False)

X, y = self._validate_data(
X,
y,
X, y = validate_data(
self,
X=X,
y=y,
multi_output=True,
accept_sparse="csc",
dtype=DTYPE,
force_all_finite=force_all_finite,
ensure_all_finite=ensure_all_finite,
)

# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
if sklearn_version >= parse_version("1.4"):
# _compute_missing_values_in_feature_mask checks if X has missing values and
# will raise an error if the underlying tree base estimator can't handle
# missing values. Only the criterion is required to determine if the tree
Expand Down Expand Up @@ -880,5 +883,15 @@ def _compute_oob_predictions(self, X, y):

return oob_pred

@available_if(check_version_package("sklearn", "<", "1.6"))
def _more_tags(self):
return {"multioutput": False, "multilabel": False}
allow_nan = sklearn_version >= parse_version("1.4")
return {"multioutput": False, "multilabel": False, "allow_nan": allow_nan}

@available_if(check_version_package("sklearn", ">=", "1.6"))
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.target_tags.multi_output = False
tags.classifier_tags.multi_label = False
tags.input_tags.allow_nan = sklearn_version >= parse_version("1.4")
return tags
34 changes: 26 additions & 8 deletions imblearn/ensemble/_weight_boosting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import numbers
import warnings
from copy import deepcopy

import numpy as np
Expand All @@ -10,6 +11,7 @@
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import _safe_indexing
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import has_fit_parameter

from ..base import _ParamsValidationMixin
Expand All @@ -18,8 +20,8 @@
from ..under_sampling.base import BaseUnderSampler
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring
from ..utils._param_validation import Interval, StrOptions
from ..utils.fixes import _fit_context
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils.fixes import _fit_context, check_version_package
from ._common import _adaboost_classifier_parameter_constraints

sklearn_version = parse_version(sklearn.__version__)
Expand Down Expand Up @@ -58,16 +60,15 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
``learning_rate``. There is a trade-off between ``learning_rate`` and
``n_estimators``.

algorithm : {{'SAMME', 'SAMME.R'}}, default='SAMME.R'
algorithm : {{'SAMME', 'SAMME.R'}}, default='deprecated'
If 'SAMME.R' then use the SAMME.R real boosting algorithm.
``base_estimator`` must support calculation of class probabilities.
If 'SAMME' then use the SAMME discrete boosting algorithm.
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.

.. deprecated:: 0.12
`"SAMME.R"` is deprecated and will be removed in version 0.14.
'"SAMME"' will become the default.
`algorithm` is deprecated in 0.12 and will be removed 0.14.

{sampling_strategy}

Expand Down Expand Up @@ -109,7 +110,7 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
ensemble.

feature_importances_ : ndarray of shape (n_features,)
The feature importances if supported by the ``base_estimator``.
The feature importances if supported by the ``estimator``.

n_features_in_ : int
Number of features in the input dataset.
Expand Down Expand Up @@ -167,6 +168,10 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):

_parameter_constraints.update(
{
"algorithm": [
StrOptions({"SAMME", "SAMME.R"}),
Hidden(StrOptions({"deprecated"})),
],
"sampling_strategy": [
Interval(numbers.Real, 0, 1, closed="right"),
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
Expand All @@ -186,17 +191,17 @@ def __init__(
*,
n_estimators=50,
learning_rate=1.0,
algorithm="SAMME.R",
algorithm="deprecated",
sampling_strategy="auto",
replacement=False,
random_state=None,
):
super().__init__(
n_estimators=n_estimators,
learning_rate=learning_rate,
algorithm=algorithm,
random_state=random_state,
)
self.algorithm = algorithm
self.estimator = estimator
self.sampling_strategy = sampling_strategy
self.replacement = replacement
Expand Down Expand Up @@ -394,3 +399,16 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
sample_weight *= np.exp(estimator_weight * incorrect * (sample_weight > 0))

return sample_weight, estimator_weight, estimator_error

def _boost(self, iboost, X, y, sample_weight, random_state):
if self.algorithm != "deprecated":
warnings.warn(
"`algorithm` parameter is deprecated in 0.12 and will be removed in "
"0.14. In the future, the SAMME algorithm will always be used.",
FutureWarning,
)
if self.algorithm == "SAMME.R":
return self._boost_real(iboost, X, y, sample_weight, random_state)

else: # elif self.algorithm == "SAMME":
return self._boost_discrete(iboost, X, y, sample_weight, random_state)
Loading
Loading