Skip to content

Commit 075fc4d

Browse files
justinchubyCopilot
andauthored
Simplify aten_unbind when shape is static (#2597)
Add static shape handling to aten_unbind function. Fix #2596 --------- Signed-off-by: Justin Chu <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 4eaf36d commit 075fc4d

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8799,6 +8799,12 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
87998799
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
88008800
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
88018801

8802+
if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
8803+
# We can create a definitive split op if the input shape is static
8804+
# Only torch>=2.7 supports correctly generating the correct number of outputs for Split
8805+
outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
8806+
return [op.Squeeze(out, [dim]) for out in outputs]
8807+
88028808
return op.SplitToSequence(self, axis=dim, keepdims=False)
88038809

88048810

tests/function_libs/torch_lib/ops_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from torch.utils import _pytree as pytree
4040

4141
import onnxscript
42+
from onnxscript._internal import version_utils
4243
from tests.function_libs.torch_lib import (
4344
error_reproduction,
4445
ops_test_common,
@@ -200,7 +201,7 @@ def run_test_output_match(
200201
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
201202
if (
202203
op.name.startswith("split")
203-
or op.name.startswith("unbind")
204+
or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7"))
204205
or op.name
205206
in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}
206207
):

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,6 +1481,7 @@ def _where_input_wrangler(
14811481
reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006",
14821482
)
14831483
.xfail(
1484+
enabled_if=version_utils.torch_older_than("2.7"),
14841485
dtypes=(torch.bool,),
14851486
reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905",
14861487
),

0 commit comments

Comments
 (0)