-
Notifications
You must be signed in to change notification settings - Fork 203
Open
Description
Predictions change after saving to onnx and then loading the model again. See below, where I compare to saving an XGBoost model, which results in the same predictions after loading from ONNX format
Source:
import os
import shutil
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
import plotly
import lightgbm as lgb
import xgboost as xgb
import onnxmltools
import onnxruntime
from skl2onnx.common.data_types import FloatTensorType
pd.options.plotting.backend = "plotly"
end_date = datetime.now()
start_date = end_date - timedelta(days=30)
date_range = pd.date_range(start=start_date, end=end_date, freq='5min')
df_timestamps = pd.DataFrame(index=date_range)
N = len(df_timestamps)
used = pd.Series([0] * N, index=date_range)
used[(used.index.dayofweek <= 4) & (used.index.hour == 8)] = 1
used[(used.index.dayofweek <= 4) & (used.index.hour == 12)] = 2
used[(used.index.dayofweek <= 4) & (used.index.hour == 14)] = 3
y = pd.DataFrame(
{
'y': used,
},
index=date_range
)
X = pd.DataFrame(
{
'sin_day_of_week': np.sin(2 * np.pi * date_range.dayofweek / 7),
'cos_day_of_week': np.cos(2 * np.pi * date_range.dayofweek / 7),
'sin_hour_of_day': np.sin(2 * np.pi * date_range.hour / 24),
'cos_hour_of_day': np.cos(2 * np.pi * date_range.hour / 24),
},
index=date_range
)
X.columns = [f'f{i}' for i in range(X.shape[1])]
fig = y.plot()
fig.show()
# Get predictions directly from trained model
lgb_model = lgb.LGBMRegressor(
objective='quantile', # Use quantile loss
alpha=.95, # Quantile for the loss (default is median: 0.5)
n_estimators=500, # Number of boosting iterations
max_depth=10, # Maximum tree depth
)
xgb_model = xgb.XGBRegressor(
objective='reg:quantileerror', # Use quantile loss
quantile_alpha=.95, # Quantile for the loss (default is median: 0.5)
n_estimators=500, # Number of boosting iterations
max_depth=10, # Maximum tree depth
)
lgb_model.fit(X, y)
xgb_model.fit(X, y)
initial_type = [
('float_input', FloatTensorType([None, X.shape[1]]))
]
onnx_model_lgmb = onnxmltools.convert_lightgbm(lgb_model, initial_types=initial_type)
onnx_model_xgboost = onnxmltools.convert_xgboost(xgb_model, initial_types=initial_type)
lgmb_path = "tmp/lgbm/"
xgboost_path = "tmp/xgboost/"
if os.path.exists(lgmb_path):
shutil.rmtree(lgmb_path)
if os.path.exists(xgboost_path):
shutil.rmtree(xgboost_path)
os.makedirs(lgmb_path, exist_ok=True)
os.makedirs(xgboost_path, exist_ok=True)
onnxmltools.utils.save_model(onnx_model_lgmb, lgmb_path + "model.onnx")
onnxmltools.utils.save_model(onnx_model_xgboost, xgboost_path + "model.onnx")
# Predictions before saving
lgb_predictions = lgb_model.predict(X)
xgb_predictions = xgb_model.predict(X)
df = pd.DataFrame(
{
'actual': y['y'],
'lbgm predictions': lgb_predictions,
'xgb predictions': xgb_predictions,
},
index=X.index
)
fig = df.plot(title="Before Saving")
fig.show()
# Get predictions from saved model
lgbm_sess = onnxruntime.InferenceSession(lgmb_path + "model.onnx")
xgb_sess = onnxruntime.InferenceSession(xgboost_path + "model.onnx")
loaded_lgb_predictions = lgbm_sess.run(output_names=['variable'], input_feed={'float_input': X.to_numpy().astype(np.float32)})[0]
loaded_lgb_predictions = pd.Series(loaded_lgb_predictions.ravel(), index=X.index)
loaded_xgb_predictions = xgb_sess.run(output_names=['variable'], input_feed={'float_input': X.to_numpy().astype(np.float32)})[0]
loaded_xgb_predictions = pd.Series(loaded_xgb_predictions.ravel(), index=X.index)
df = pd.DataFrame(
{
'actual': y['y'],
'lbgm predictions': loaded_lgb_predictions,
'xgb predictions': loaded_xgb_predictions,
},
index=X.index
)
fig = df.plot(title="After Saving")
fig.show()
Metadata
Metadata
Assignees
Labels
No labels