diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 11f26b814..dfcd5c35b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8776,6 +8776,12 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" + if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): + # We can create a definitive split op if the input shape is static + # Only torch>=2.7 supports correctly generating the correct number of outputs for Split + outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) + return [op.Squeeze(out, [dim]) for out in outputs] + return op.SplitToSequence(self, axis=dim, keepdims=False) diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 7ba6f9d37..45875043e 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -39,6 +39,7 @@ from torch.utils import _pytree as pytree import onnxscript +from onnxscript._internal import version_utils from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -200,7 +201,7 @@ def run_test_output_match( reference_torch_outputs, _ = pytree.tree_flatten(torch_output) if ( op.name.startswith("split") - or op.name.startswith("unbind") + or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7")) or op.name in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"} ): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index ff4a68d2f..1b998b1d2 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1478,6 +1478,7 @@ def _where_input_wrangler( reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", ) .xfail( + enabled_if=version_utils.torch_older_than("2.7"), dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ),