diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 3db9b1409351..800d52283df8 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1239,6 +1239,61 @@ def get_booster(self) -> Booster: assert self._xgb_sklearn_model is not None return self._xgb_sklearn_model.get_booster() + @classmethod + def _load_model_as_sklearn_model(cls, model_path: str) -> XGBModel: + """ + Subclasses should override this method and + returns a _SparkXGBModel subclass + """ + raise NotImplementedError() + + @classmethod + def convert_sklearn_model_to_spark_xgb_model( + cls, + xgb_sklearn_model: XGBModel, + training_summary: Optional[XGBoostTrainingSummary] = None, + ) -> "_SparkXGBModel": + """ + Convert a sklearn model to pyspark xgboost model. + """ + spark_xgb_model = cls( + xgb_sklearn_model=xgb_sklearn_model, training_summary=training_summary + ) + spark_xgb_model._setDefault( + device="cpu", + use_gpu=False, + tree_method="hist", + ) + return spark_xgb_model + + @classmethod + def load_model( + cls, model_path: str, training_summary: Optional[XGBoostTrainingSummary] = None + ) -> "_SparkXGBModel": + """Load a model from the specified path and convert it into a Spark XGBoost model. + + The model is loaded from the given file path, and then it's converted into a + Spark XGBoost model. The optional training summary can provide additional + details related to the training process. + + Parameters + ---------- + model_path: str + The file path to the saved model that needs to be loaded. + training_summary: Optional[XGBoostTrainingSummary], default None + An optional summary of the training process, which can be used for further + analysis or reference when converting the model. + + Returns + ------- + _SparkXGBModel + The converted Spark XGBoost model. + """ + xgb_sklearn_model = cls._load_model_as_sklearn_model(model_path) + return cls.convert_sklearn_model_to_spark_xgb_model( + xgb_sklearn_model=xgb_sklearn_model, training_summary=training_summary + ) + def get_feature_importances( self, importance_type: str = "weight" ) -> Dict[str, Union[float, List[float]]]: diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 011f7ea0b715..e678965c08d9 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -12,7 +12,7 @@ from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from ..collective import Config -from ..sklearn import XGBClassifier, XGBRanker, XGBRegressor +from ..sklearn import XGBClassifier, XGBModel, XGBRanker, XGBRegressor from .core import ( # type: ignore _ClassificationModel, _SparkXGBEstimator, @@ -256,6 +256,12 @@ class SparkXGBRegressorModel(_SparkXGBModel): def _xgb_cls(cls) -> Type[XGBRegressor]: return XGBRegressor + @classmethod + def _load_model_as_sklearn_model(cls, model_path: str) -> XGBModel: + sklearn_model = XGBRegressor() + sklearn_model.load_model(model_path) + return sklearn_model + _set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel) @@ -452,6 +458,12 @@ class SparkXGBClassifierModel(_ClassificationModel): def _xgb_cls(cls) -> Type[XGBClassifier]: return XGBClassifier + @classmethod + def _load_model_as_sklearn_model(cls, model_path: str) -> XGBModel: + sklearn_model = XGBClassifier() + sklearn_model.load_model(model_path) + return sklearn_model + _set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel) @@ -639,5 +651,11 @@ class SparkXGBRankerModel(_SparkXGBModel): def _xgb_cls(cls) -> Type[XGBRanker]: return XGBRanker + @classmethod + def _load_model_as_sklearn_model(cls, model_path: str) -> XGBModel: + sklearn_model = XGBRanker() + sklearn_model.load_model(model_path) + return sklearn_model + _set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel) diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 2cdafffaae6e..003672e97a56 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -1,5 +1,6 @@ import glob import logging +import os import random import tempfile import uuid @@ -1307,6 +1308,60 @@ def test_classifier_xgb_summary_with_validation( atol=1e-3, ) + def test_convert_sklearn_model_to_spark_xgb_model_classifier( + self, clf_data: ClfData + ) -> None: + X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + y = np.array([0, 1]) + cl1 = xgb.XGBClassifier() + cl1.fit(X, y) + spark_xgb_model = ( + SparkXGBClassifierModel.convert_sklearn_model_to_spark_xgb_model( + xgb_sklearn_model=cl1 + ) + ) + pred_result = spark_xgb_model.transform(clf_data.cls_df_test).collect() + for row in pred_result: + assert np.isclose(row.prediction, row.expected_prediction, atol=1e-3) + np.testing.assert_allclose( + row.probability, row.expected_probability, atol=1e-3 + ) + + def test_convert_sklearn_model_to_spark_xgb_model_regressor( + self, reg_data: RegData + ): + X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + y = np.array([0, 1]) + reg1 = xgb.XGBRegressor() + reg1.fit(X, y) + spark_xgb_model = ( + SparkXGBRegressorModel.convert_sklearn_model_to_spark_xgb_model( + xgb_sklearn_model=reg1 + ) + ) + pred_result = spark_xgb_model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose(row.prediction, row.expected_prediction, atol=1e-3) + + def test_load_model_as_spark_model(self): + X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + y = np.array([0, 1]) + cl1 = xgb.XGBClassifier() + cl1.fit(X, y) + + X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + y = np.array([0, 1]) + reg1 = xgb.XGBRegressor() + reg1.fit(X, y) + with tempfile.TemporaryDirectory() as tmpdir: + clf_model_path = os.path.join(tmpdir, "clf_model.json") + cl1.get_booster().save_model(clf_model_path) + SparkXGBClassifierModel.load_model(model_path=clf_model_path) + + reg_model_path = os.path.join(tmpdir, "reg_model.json") + reg1.get_booster().save_model(reg_model_path) + SparkXGBRegressorModel.load_model(model_path=reg_model_path) + class XgboostLocalTest(SparkTestCase): def setUp(self):