Skip to content

Commit e4406da

Browse files
authored
[enc] Support training continuation with sklearn. (#11605)
1 parent 9261f05 commit e4406da

38 files changed

+465
-199
lines changed

python-package/xgboost/_data_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def array_interface_dict(data: np.ndarray) -> ArrayInf:
397397
return cast(ArrayInf, ainf)
398398

399399

400-
def pd_cats_inf( # pylint: disable=too-many-locals
400+
def pd_cat_inf( # pylint: disable=too-many-locals
401401
cats: DfCatAccessor, codes: "pd.Series"
402402
) -> Tuple[Union[StringArray, ArrayInf], ArrayInf, Tuple]:
403403
"""Get the array interface representation of pandas category accessor."""
@@ -665,12 +665,18 @@ def to_arrow(self) -> ArrowCatList:
665665
)
666666
return self._arrow_arrays
667667

668+
def empty(self) -> bool:
669+
"""Returns True if there's no category."""
670+
return self._handle.value is None
671+
668672
def get_handle(self) -> int:
669673
"""Internal method for retrieving the handle."""
670674
assert self._handle.value
671675
return self._handle.value
672676

673677
def __del__(self) -> None:
678+
if self._handle.value is None:
679+
return
674680
self._free()
675681

676682

@@ -718,7 +724,7 @@ class TransformedDf(ABC):
718724

719725
def __init__(self, ref_categories: Optional[Categories], aitfs: AifType) -> None:
720726
self.ref_categories = ref_categories
721-
if ref_categories is not None:
727+
if ref_categories is not None and ref_categories.get_handle() is not None:
722728
aif = ref_categories.get_handle()
723729
self.ref_aif: Optional[int] = aif
724730
else:

python-package/xgboost/compat.py

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
import logging
66
import sys
77
import types
8-
from typing import Any, Sequence, cast
8+
from typing import TYPE_CHECKING, Any, Sequence, TypeGuard, cast
99

1010
import numpy as np
1111

12-
from ._typing import _T
12+
from ._typing import _T, DataType
13+
14+
if TYPE_CHECKING:
15+
import pandas as pd
16+
import pyarrow as pa
1317

1418
assert sys.version_info[0] == 3, "Python 2 is no longer supported."
1519

@@ -31,17 +35,6 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
3135
return is_same_module and has_same_name
3236

3337

34-
# pandas
35-
try:
36-
from pandas import DataFrame, Series
37-
38-
PANDAS_INSTALLED = True
39-
except ImportError:
40-
DataFrame = object
41-
Series = object
42-
PANDAS_INSTALLED = False
43-
44-
4538
# sklearn
4639
try:
4740
from sklearn import __version__ as _sklearn_version
@@ -139,6 +132,14 @@ def import_pyarrow() -> types.ModuleType:
139132
return pa
140133

141134

135+
@functools.cache
136+
def import_pandas() -> types.ModuleType:
137+
"""Import pandas with memory cache."""
138+
import pandas as pd
139+
140+
return pd
141+
142+
142143
@functools.cache
143144
def import_polars() -> types.ModuleType:
144145
"""Import polars with memory cache."""
@@ -147,6 +148,14 @@ def import_polars() -> types.ModuleType:
147148
return pl
148149

149150

151+
@functools.cache
152+
def is_pandas_available() -> bool:
153+
"""Check the pandas package is available or not."""
154+
if importlib.util.find_spec("pandas") is None:
155+
return False
156+
return True
157+
158+
150159
try:
151160
import scipy.sparse as scipy_sparse
152161
from scipy.sparse import csr_matrix as scipy_csr
@@ -155,6 +164,84 @@ def import_polars() -> types.ModuleType:
155164
scipy_csr = object
156165

157166

167+
def _is_polars_lazyframe(data: DataType) -> bool:
168+
return lazy_isinstance(data, "polars.lazyframe.frame", "LazyFrame")
169+
170+
171+
def _is_polars_series(data: DataType) -> bool:
172+
return lazy_isinstance(data, "polars.series.series", "Series")
173+
174+
175+
def _is_polars(data: DataType) -> bool:
176+
lf = _is_polars_lazyframe(data)
177+
df = lazy_isinstance(data, "polars.dataframe.frame", "DataFrame")
178+
return lf or df
179+
180+
181+
def _is_arrow(data: DataType) -> TypeGuard["pa.Table"]:
182+
return lazy_isinstance(data, "pyarrow.lib", "Table")
183+
184+
185+
def _is_cudf_df(data: DataType) -> bool:
186+
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")
187+
188+
189+
def _is_cudf_ser(data: DataType) -> bool:
190+
return lazy_isinstance(data, "cudf.core.series", "Series")
191+
192+
193+
def _is_cudf_pandas(data: DataType) -> bool:
194+
"""Must go before both pandas and cudf checks."""
195+
return (_is_pandas_df(data) or _is_pandas_series(data)) and lazy_isinstance(
196+
type(data), "cudf.pandas.fast_slow_proxy", "_FastSlowProxyMeta"
197+
)
198+
199+
200+
def _is_pandas_df(data: DataType) -> TypeGuard["pd.DataFrame"]:
201+
return lazy_isinstance(data, "pandas.core.frame", "DataFrame")
202+
203+
204+
def _is_pandas_series(data: DataType) -> TypeGuard["pd.Series"]:
205+
return lazy_isinstance(data, "pandas.core.series", "Series")
206+
207+
208+
def _is_modin_df(data: DataType) -> bool:
209+
return lazy_isinstance(data, "modin.pandas.dataframe", "DataFrame")
210+
211+
212+
def _is_modin_series(data: DataType) -> bool:
213+
return lazy_isinstance(data, "modin.pandas.series", "Series")
214+
215+
216+
def is_dataframe(data: DataType) -> bool:
217+
"""Whether the input is a dataframe. Currently supported dataframes:
218+
219+
- pandas
220+
- cudf
221+
- cudf.pandas
222+
- polars
223+
- pyarrow
224+
- modin
225+
226+
227+
"""
228+
return any(
229+
p(data)
230+
for p in (
231+
_is_polars,
232+
_is_polars_series,
233+
_is_arrow,
234+
_is_cudf_df,
235+
_is_cudf_ser,
236+
_is_cudf_pandas,
237+
_is_pandas_df,
238+
_is_pandas_series,
239+
_is_modin_df,
240+
_is_modin_series,
241+
)
242+
)
243+
244+
158245
def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
159246
"""Concatenate row-wise."""
160247
if isinstance(value[0], np.ndarray):
@@ -167,7 +254,7 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
167254
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
168255
# other sparse format will be converted to CSR.
169256
return scipy_sparse.vstack(value, format="csr")
170-
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
257+
if _is_pandas_df(value[0]) or _is_pandas_series(value[0]):
171258
from pandas import concat as pd_concat
172259

173260
return pd_concat(value, axis=0)

python-package/xgboost/core.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
22
# pylint: disable=too-many-lines, too-many-locals
33
"""Core XGBoost Library."""
4+
45
import copy
56
import ctypes
67
import json
@@ -69,15 +70,17 @@
6970
c_bst_ulong,
7071
)
7172
from .compat import (
72-
PANDAS_INSTALLED,
73-
DataFrame,
7473
import_polars,
7574
import_pyarrow,
75+
is_pandas_available,
7676
is_pyarrow_available,
7777
py_str,
7878
)
7979
from .libpath import find_lib_path, is_sphinx_build
8080

81+
if TYPE_CHECKING:
82+
from pandas import DataFrame as PdDataFrame
83+
8184

8285
class XGBoostError(ValueError):
8386
"""Error thrown by xgboost trainer."""
@@ -782,7 +785,7 @@ def _get_categories(
782785
cfn: Callable[[ctypes.c_char_p], int],
783786
feature_names: FeatureNames,
784787
n_features: int,
785-
) -> Optional[ArrowCatList]:
788+
) -> ArrowCatList:
786789
if not is_pyarrow_available():
787790
raise ImportError(
788791
"`pyarrow` is required for exporting categories to arrow arrays."
@@ -797,7 +800,9 @@ def _get_categories(
797800

798801
ret = ctypes.c_char_p()
799802
_check_call(cfn(ret))
800-
assert ret.value is not None
803+
if ret.value is None:
804+
results = [(feature_names[i], None) for i in range(n_features)]
805+
return results
801806

802807
retstr = ret.value.decode() # pylint: disable=no-member
803808
jcats = json.loads(retstr)
@@ -3201,7 +3206,8 @@ def get_score(
32013206
"""Get feature importance of each feature.
32023207
For tree model Importance type can be defined as:
32033208
3204-
* 'weight': the number of times a feature is used to split the data across all trees.
3209+
* 'weight': the number of times a feature is used to split the data across all
3210+
trees.
32053211
* 'gain': the average gain across all splits the feature is used in.
32063212
* 'cover': the average coverage across all splits the feature is used in.
32073213
* 'total_gain': the total gain across all splits the feature is used in.
@@ -3261,7 +3267,7 @@ def get_score(
32613267
return results
32623268

32633269
# pylint: disable=too-many-statements
3264-
def trees_to_dataframe(self, fmap: PathLike = "") -> DataFrame:
3270+
def trees_to_dataframe(self, fmap: PathLike = "") -> "PdDataFrame":
32653271
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
32663272
32673273
This feature is only defined when the decision tree model is chosen as base
@@ -3274,8 +3280,10 @@ def trees_to_dataframe(self, fmap: PathLike = "") -> DataFrame:
32743280
The name of feature map file.
32753281
"""
32763282
# pylint: disable=too-many-locals
3283+
from pandas import DataFrame
3284+
32773285
fmap = os.fspath(os.path.expanduser(fmap))
3278-
if not PANDAS_INSTALLED:
3286+
if not is_pandas_available():
32793287
raise ImportError(
32803288
(
32813289
"pandas must be available to use this method."
@@ -3426,7 +3434,7 @@ def get_split_value_histogram(
34263434
fmap: PathLike = "",
34273435
bins: Optional[int] = None,
34283436
as_pandas: bool = True,
3429-
) -> Union[np.ndarray, DataFrame]:
3437+
) -> Union[np.ndarray, "PdDataFrame"]:
34303438
"""Get split value histogram of a feature
34313439
34323440
Parameters
@@ -3482,9 +3490,11 @@ def get_split_value_histogram(
34823490
"Split value historgam doesn't support categorical split."
34833491
)
34843492

3485-
if as_pandas and PANDAS_INSTALLED:
3493+
if as_pandas and is_pandas_available():
3494+
from pandas import DataFrame
3495+
34863496
return DataFrame(nph_stacked, columns=["SplitValue", "Count"])
3487-
if as_pandas and not PANDAS_INSTALLED:
3497+
if as_pandas and not is_pandas_available():
34883498
warnings.warn(
34893499
"Returning histogram as ndarray"
34903500
" (as_pandas == True, but pandas is not installed).",

python-package/xgboost/dask/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
from ..collective import Config as CollConfig
9595
from ..collective import _Args as CollArgs
9696
from ..collective import _ArgVals as CollArgsVals
97-
from ..compat import DataFrame, lazy_isinstance
97+
from ..compat import _is_cudf_df
9898
from ..core import (
9999
Booster,
100100
DMatrix,
@@ -942,7 +942,7 @@ def _maybe_dataframe(
942942
# In older versions of dask, the partition is actually a numpy array when input
943943
# is dataframe.
944944
index = getattr(data, "index", None)
945-
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
945+
if _is_cudf_df(data):
946946
import cudf
947947

948948
if prediction.size == 0:
@@ -952,10 +952,14 @@ def _maybe_dataframe(
952952
prediction, columns=columns, dtype=numpy.float32, index=index
953953
)
954954
else:
955+
import pandas as pd
956+
955957
if prediction.size == 0:
956-
return DataFrame({}, columns=columns, dtype=numpy.float32, index=index)
958+
return pd.DataFrame(
959+
{}, columns=columns, dtype=numpy.float32, index=index
960+
)
957961

958-
prediction = DataFrame(
962+
prediction = pd.DataFrame(
959963
prediction, columns=columns, dtype=numpy.float32, index=index
960964
)
961965
return prediction

0 commit comments

Comments
 (0)