Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8776,6 +8776,12 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""

if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
# We can create a definitive split op if the input shape is static
# Only torch>=2.7 supports correctly generating the correct number of outputs for Split
outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
return [op.Squeeze(out, [dim]) for out in outputs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it makes more sense to rewrite it? Although this PR probably works, it's adding another dimension on torchlib (covering both static and dynamic cases). Maybe let torchlib be as dynamic as possible, and we can optimize it after.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How easy is it to create the optimization rules? I am fine either way

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like SplitToSequence(self, axis=dim, keepdims=False) should be generically rewritable to the subgraph if the split axis is known. This can potentially cover more cases with other ops when encountered in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gramalingam for suggestions on the rewrite rule. Is this related to #2581 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a computation point of view, it is always better to generate the correct graph rather than producing a graph which needs to be rewritten. Matching a pattern takes time.

Copy link
Contributor

@titaiwangms titaiwangms Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But my understanding is that the original implementation is not incorrect. It's only the op is not preferred because of the backend implementation, which fits the category of rewritten rules. We surely can say it's more convenient to address this way (this PR), but I prefer an established/explicit rule to say when/what we should add support in torchlib, and under what condition we add rewrite rules/constat folding. Otherwise, it's just scattered around. And if we want it to be done in this way, do we consider upstream some other optimizations downstream that are optimized away because of static shapes as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda think we should do both. (1) we can measure the complexity of the torchlib implementation. I feel that the complexity of the current implementation is not high for the immediate benefits it brings. If the graph can be significantly simplified because we know some shapes are static, I think we should pursue that in torchlib. When looking at micro-optimizations like this I agree that we should be more careful and decide on a case by case basis. (2) we should still have a rule that will simplify this so that our tooling can handle SplitToSequence generally.


return op.SplitToSequence(self, axis=dim, keepdims=False)


Expand Down
3 changes: 2 additions & 1 deletion tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torch.utils import _pytree as pytree

import onnxscript
from onnxscript._internal import version_utils
from tests.function_libs.torch_lib import (
error_reproduction,
ops_test_common,
Expand Down Expand Up @@ -200,7 +201,7 @@ def run_test_output_match(
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
if (
op.name.startswith("split")
or op.name.startswith("unbind")
or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7"))
or op.name
in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}
):
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,7 @@ def _where_input_wrangler(
reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006",
)
.xfail(
enabled_if=version_utils.torch_older_than("2.7"),
dtypes=(torch.bool,),
reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905",
),
Expand Down
Loading