Skip to content

Commit a3acff2

Browse files
author
a.cherkaoui
committed
load xgboost model as a pyspark model
1 parent 4e24639 commit a3acff2

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

python-package/xgboost/spark/core.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,42 @@ def get_booster(self) -> Booster:
12391239
assert self._xgb_sklearn_model is not None
12401240
return self._xgb_sklearn_model.get_booster()
12411241

1242+
@classmethod
1243+
def _load_model_as_sklearn_model(cls, model_path: str) -> XGBModel:
1244+
"""
1245+
Subclasses should override this method and
1246+
returns a _SparkXGBModel subclass
1247+
"""
1248+
raise NotImplementedError()
1249+
1250+
@classmethod
1251+
def convert_sklearn_model_to_spark_xgb_model(
1252+
cls,
1253+
xgb_sklearn_model: XGBModel,
1254+
training_summary: Optional[XGBoostTrainingSummary] = None,
1255+
) -> "_SparkXGBModel":
1256+
"""
1257+
Convert a sklearn model to pyspark xgboost model.
1258+
"""
1259+
spark_xgb_model = cls(
1260+
xgb_sklearn_model=xgb_sklearn_model, training_summary=training_summary
1261+
)
1262+
spark_xgb_model._setDefault(
1263+
device="cpu",
1264+
use_gpu=False,
1265+
tree_method="hist",
1266+
)
1267+
return spark_xgb_model
1268+
1269+
@classmethod
1270+
def load_model(
1271+
cls, model_path: str, training_summary: Optional[XGBoostTrainingSummary] = None
1272+
) -> "_SparkXGBModel":
1273+
xgb_sklearn_model = cls._load_model_as_sklearn_model(model_path)
1274+
return cls.convert_sklearn_model_to_spark_xgb_model(
1275+
xgb_sklearn_model=xgb_sklearn_model, training_summary=training_summary
1276+
)
1277+
12421278
def get_feature_importances(
12431279
self, importance_type: str = "weight"
12441280
) -> Dict[str, Union[float, List[float]]]:

python-package/xgboost/spark/estimator.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
1313

1414
from ..collective import Config
15-
from ..sklearn import XGBClassifier, XGBRanker, XGBRegressor
15+
from ..sklearn import XGBClassifier, XGBModel, XGBRanker, XGBRegressor
1616
from .core import ( # type: ignore
1717
_ClassificationModel,
1818
_SparkXGBEstimator,
@@ -256,6 +256,12 @@ class SparkXGBRegressorModel(_SparkXGBModel):
256256
def _xgb_cls(cls) -> Type[XGBRegressor]:
257257
return XGBRegressor
258258

259+
@classmethod
260+
def _load_model_as_sklearn_model(cls, model_path: str) -> XGBModel:
261+
sklearn_model = XGBRegressor()
262+
sklearn_model.load_model(model_path)
263+
return sklearn_model
264+
259265

260266
_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)
261267

@@ -452,6 +458,12 @@ class SparkXGBClassifierModel(_ClassificationModel):
452458
def _xgb_cls(cls) -> Type[XGBClassifier]:
453459
return XGBClassifier
454460

461+
@classmethod
462+
def _load_model_as_sklearn_model(cls, model_path: str) -> XGBModel:
463+
sklearn_model = XGBClassifier()
464+
sklearn_model.load_model(model_path)
465+
return sklearn_model
466+
455467

456468
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
457469

@@ -639,5 +651,11 @@ class SparkXGBRankerModel(_SparkXGBModel):
639651
def _xgb_cls(cls) -> Type[XGBRanker]:
640652
return XGBRanker
641653

654+
@classmethod
655+
def _load_model_as_sklearn_model(cls, model_path: str) -> XGBModel:
656+
sklearn_model = XGBRanker()
657+
sklearn_model.load_model(model_path)
658+
return sklearn_model
659+
642660

643661
_set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel)

tests/test_distributed/test_with_spark/test_spark_local.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import glob
22
import logging
3+
import os
34
import random
45
import tempfile
56
import uuid
@@ -1307,6 +1308,60 @@ def test_classifier_xgb_summary_with_validation(
13071308
atol=1e-3,
13081309
)
13091310

1311+
def test_convert_sklearn_model_to_spark_xgb_model_classifier(
1312+
self, clf_data: ClfData
1313+
) -> None:
1314+
X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
1315+
y = np.array([0, 1])
1316+
cl1 = xgb.XGBClassifier()
1317+
cl1.fit(X, y)
1318+
spark_xgb_model = (
1319+
SparkXGBClassifierModel.convert_sklearn_model_to_spark_xgb_model(
1320+
xgb_sklearn_model=cl1
1321+
)
1322+
)
1323+
pred_result = spark_xgb_model.transform(clf_data.cls_df_test).collect()
1324+
for row in pred_result:
1325+
assert np.isclose(row.prediction, row.expected_prediction, atol=1e-3)
1326+
np.testing.assert_allclose(
1327+
row.probability, row.expected_probability, atol=1e-3
1328+
)
1329+
1330+
def test_convert_sklearn_model_to_spark_xgb_model_regressor(
1331+
self, reg_data: RegData
1332+
):
1333+
X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
1334+
y = np.array([0, 1])
1335+
reg1 = xgb.XGBRegressor()
1336+
reg1.fit(X, y)
1337+
spark_xgb_model = (
1338+
SparkXGBRegressorModel.convert_sklearn_model_to_spark_xgb_model(
1339+
xgb_sklearn_model=reg1
1340+
)
1341+
)
1342+
pred_result = spark_xgb_model.transform(reg_data.reg_df_test).collect()
1343+
for row in pred_result:
1344+
assert np.isclose(row.prediction, row.expected_prediction, atol=1e-3)
1345+
1346+
def test_load_model_as_spark_model(self):
1347+
X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
1348+
y = np.array([0, 1])
1349+
cl1 = xgb.XGBClassifier()
1350+
cl1.fit(X, y)
1351+
1352+
X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
1353+
y = np.array([0, 1])
1354+
reg1 = xgb.XGBRegressor()
1355+
reg1.fit(X, y)
1356+
with tempfile.TemporaryDirectory() as tmpdir:
1357+
clf_model_path = os.path.join(tmpdir, "clf_model.json")
1358+
cl1.get_booster().save_model(clf_model_path)
1359+
SparkXGBClassifierModel.load_model(model_path=clf_model_path)
1360+
1361+
reg_model_path = os.path.join(tmpdir, "reg_model.json")
1362+
reg1.get_booster().save_model(reg_model_path)
1363+
SparkXGBRegressorModel.load_model(model_path=reg_model_path)
1364+
13101365

13111366
class XgboostLocalTest(SparkTestCase):
13121367
def setUp(self):

0 commit comments

Comments
 (0)