-
Notifications
You must be signed in to change notification settings - Fork 203
Open
Description
Hello, I found a parameter set where the converted XGBoost model is inconsistent with the original. Here's the minimal reproduction:
from onnxmltools.convert import convert_xgboost
import onnxruntime as rt
import numpy as np
from xgboost import XGBRegressor
from skl2onnx.common.data_types import (
FloatTensorType,
)
df = pd.DataFrame(
{
"f1": [1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 1.0, 2.0],
"label": [1, 0, 1, 0, 1, 1, 0, 1],
}
)
params = {
"max_depth": 1,
"n_estimators": 3,
"subsample": 0.95,
"objective": "binary:logistic",
}
model = XGBRegressor(**params)
initial_types = [
("f1", FloatTensorType([None, 1])),
]
model.fit(df.drop(columns=["label"]), df["label"])
onnx_model = convert_xgboost(
model,
"XGBoostXGBRegressor",
initial_types,
target_opset=13,
)
assert onnx_model is not None and hasattr(onnx_model, "SerializeToString")
sess = rt.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
onnx_output = sess.run(
None,
{
"f1": df["f1"].values.reshape(-1, 1).astype(np.float32),
},
)[0]
expected_output = model.predict(df.drop(columns=["label"])).reshape(-1, 1).astype(np.float32)
assert np.allclose(
onnx_output, expected_output, rtol=1e-5, atol=1e-8, equal_nan=True
), f"ONNX output does not match expected values for params: {params}"
Which outputs;
AssertionError: ONNX output does not match expected values for params: {'max_depth': 1, 'n_estimators': 3, 'subsample': 0.95, 'objective': 'binary:logistic'}
E assert False
E + where False = <function allclose at 0x106fe80f0>(array([[0.69600594],\n [0.69600594],\n [0.69600594],\n [0.69600594],\n [0.69600594],\n [0.69600594],\n [0.69600594],\n [0.69600594]], dtype=float32), array([[0.64148873],\n [0.64148873],\n [0.64148873],\n [0.64148873],\n [0.64148873],\n [0.64148873],\n [0.64148873],\n [0.64148873]], dtype=float32), rtol=1e-05, atol=1e-08, equal_nan=True)
I noted that removing "subsample" parameter also fixes the issue, not sure what this means.
Versions:
xgboost==3.0.2 ; python_version >= "3.11" and python_version < "3.14"
onnx==1.18.0 ; python_version >= "3.11" and python_version < "3.14"
onnxmltools==1.14.0 ; python_version >= "3.11" and python_version < "3.14"
onnxruntime==1.22.0 ; python_version >= "3.11" and python_version < "3.14"
Metadata
Metadata
Assignees
Labels
No labels