-
Notifications
You must be signed in to change notification settings - Fork 83
Milestone
Description
For the following line in model definition:
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
onnxscript translates the aten.unbind op to Sequence ops:

In the case of static shapes, I think it should be possible to translate to a regular Split op instead? Would it be possible to add a rewrite pattern for this case?
The motivating use case comes from the Wan2.2 model definition: https://github.com/huggingface/diffusers/blob/9ae5b6299d3b1a7b0378dc77c3c69baf521587d2/src/diffusers/models/transformers/transformer_wan.py#L109
justinchuby