Skip to content

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Oct 2, 2025

Add static shape handling to aten_unbind function.

Fix #2596

Add static shape handling to aten_unbind function
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Oct 2, 2025
@justinchuby justinchuby requested a review from Copilot October 2, 2025 15:56
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR simplifies the aten_unbind function by adding static shape handling to improve efficiency when the tensor dimension is known at compile time.

  • Adds a static shape optimization path that uses Split + Squeeze operations instead of SplitToSequence
  • Maintains backward compatibility by falling back to the original dynamic implementation

Copy link

codecov bot commented Oct 2, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 69.95%. Comparing base (897345d) to head (99d0441).
⚠️ Report is 10 commits behind head on main.
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2597   +/-   ##
=======================================
  Coverage   69.95%   69.95%           
=======================================
  Files         222      222           
  Lines       26311    26314    +3     
  Branches     2604     2605    +1     
=======================================
+ Hits        18406    18409    +3     
  Misses       6993     6993           
  Partials      912      912           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

if isinstance(self.shape[dim], int):
# We can create a definitive split op if the input shape is static
outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
return [op.Squeeze(out, [dim]) for out in outputs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it makes more sense to rewrite it? Although this PR probably works, it's adding another dimension on torchlib (covering both static and dynamic cases). Maybe let torchlib be as dynamic as possible, and we can optimize it after.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How easy is it to create the optimization rules? I am fine either way

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like SplitToSequence(self, axis=dim, keepdims=False) should be generically rewritable to the subgraph if the split axis is known. This can potentially cover more cases with other ops when encountered in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gramalingam for suggestions on the rewrite rule. Is this related to #2581 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a computation point of view, it is always better to generate the correct graph rather than producing a graph which needs to be rewritten. Matching a pattern takes time.

Copy link
Contributor

@titaiwangms titaiwangms Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But my understanding is that the original implementation is not incorrect. It's only the op is not preferred because of the backend implementation, which fits the category of rewritten rules. We surely can say it's more convenient to address this way (this PR), but I prefer an established/explicit rule to say when/what we should add support in torchlib, and under what condition we add rewrite rules/constat folding. Otherwise, it's just scattered around. And if we want it to be done in this way, do we consider upstream some other optimizations downstream that are optimized away because of static shapes as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda think we should do both. (1) we can measure the complexity of the torchlib implementation. I feel that the complexity of the current implementation is not high for the immediate benefits it brings. If the graph can be significantly simplified because we know some shapes are static, I think we should pursue that in torchlib. When looking at micro-optimizations like this I agree that we should be more careful and decide on a case by case basis. (2) we should still have a rule that will simplify this so that our tooling can handle SplitToSequence generally.

@justinchuby justinchuby added the merge at lgtm Reviewers can merge when they approve label Oct 6, 2025
@justinchuby justinchuby added this to the 0.5.4 milestone Oct 6, 2025
@justinchuby
Copy link
Collaborator Author

I updated tests and improved compatibility with torch<2.7

@titaiwangms titaiwangms removed their request for review October 6, 2025 20:08
Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.

@justinchuby
Copy link
Collaborator Author

Merging for now. Please let me know if further changes are needed.

@justinchuby justinchuby merged commit 075fc4d into main Oct 7, 2025
32 checks passed
@justinchuby justinchuby deleted the justinchu/static-unbind branch October 7, 2025 21:39
@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Oct 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge at lgtm Reviewers can merge when they approve module: torchlib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

aten.unbind exports to Sequence ops
5 participants