Skip to content

Commit 1ca3207

Browse files
rolshovenLuca Rolshoven
authored andcommitted
Fixed order of input parameters for onnx export
1 parent 4ebee43 commit 1ca3207

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/setfit/exporters/onnx.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,12 @@ def export_onnx_setfit_model(setfit_model: OnnxSetFitModel, inputs, output_path,
8787
for output_name in output_names:
8888
dynamic_axes_output[output_name] = {0: "batch_size"}
8989

90-
# Move inputs to the right device
90+
# Move inputs to the right device and put them in the right order
91+
forward_params = tuple(signature(setfit_model.model_body.forward).parameters.keys()) # keys of ordered dict are ordered
92+
ordered_kwargs = sorted(inputs.items(), key=lambda param: forward_params.index(param[0]))
93+
odered_params = [param_value for (_, param_value) in ordered_kwargs]
9194
target = setfit_model.model_body.device
92-
args = tuple(value.to(target) for value in inputs.values())
95+
args = tuple(value.to(target) for value in odered_params)
9396

9497
setfit_model.eval()
9598
with torch.no_grad():

0 commit comments

Comments
 (0)