diff --git a/cleantext/sklearn.py b/cleantext/sklearn.py index e8b8438..df74974 100644 --- a/cleantext/sklearn.py +++ b/cleantext/sklearn.py @@ -68,7 +68,13 @@ def __init__( self.replace_with_punct = replace_with_punct self.lang = lang - def fit(self, X: Any): + def fit(self, X: Any, y=None): + """ + This method is defined for compatibility. It does nothing. + """ + return self + + def partial_fit(self, X: Any, y=None): """ This method is defined for compatibility. It does nothing. """ @@ -89,3 +95,14 @@ def transform(self, X: Union[List[str], pd.Series]) -> Union[List[str], pd.Serie return X.apply(lambda text: clean(text, **self.get_params())) else: return list(map(lambda text: clean(text, **self.get_params()), X)) + + def get_feature_names_out(self, feature_names_out=None): + """ + For compatibility with scikit-learn Pipeline objects. + This transformer will only return one column, which is ``'Clean Text``. + Args: + feature_names_out: Defined for compatibility with scikit-learn. + Returns: + list[str]: List with one element (i.e. The cleaned text column). + """ + return ["Clean Text"] diff --git a/tests/test_clean_transformer.py b/tests/test_clean_transformer.py index 3437f97..4beaa09 100644 --- a/tests/test_clean_transformer.py +++ b/tests/test_clean_transformer.py @@ -1,7 +1,6 @@ try: from cleantext.sklearn import CleanTransformer import pandas as pd - import pytest transformer = CleanTransformer() @@ -25,5 +24,12 @@ def test_set_params(): assert transformer.get_params()["no_line_breaks"] assert transformer.get_params()["no_digits"] + def test_get_feature_names_out(): + feature_names = transformer.get_feature_names_out() + assert feature_names == ["Clean Text"] + + def test_fit(): + transformer.fit(["sample1", "sample2"], [0, 1]) + transformer.partial_fit(["sample1", "sample2"]) except ImportError: pass