Skip to content

Commit 9ac784e

Browse files
authored
Merge input and output shape when removing identity
1 parent 3a26097 commit 9ac784e

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
608608
input = node.inputs[0]
609609
output = node.outputs[0]
610610
if input is not None and output is not None:
611+
input.shape = _merge_shapes(input.shape, output.shape)
611612
state.set_sym_value(output, input)
612613
return None
613614

0 commit comments

Comments
 (0)