Skip to content

Commit 7407431

Browse files
Add reproduction test case for incorrect slice rewrite and add potential fix (#2478)
This adds a reproduction of the rule introduced in f42c2bb leading to an incorrect rewrite of the graph. The original rule does not consider the step parameter, which can influence the result of a `Slice` to be the identity even when input and output shape are equivalent. The potential fix seems to be to not apply the rule on `step != 1`, therefore the second commit adds this to the original rule implementation. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent 41c5dd5 commit 7407431

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

onnxscript/rewriter/collapse_slices.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66

77
from onnxscript import ir
8+
from onnxscript.rewriter._ir_utils import is_singleton_value
89
from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet
910

1011
logger = logging.getLogger(__name__)
@@ -76,10 +77,14 @@ def _potential_redundant_slice(op, data, starts, ends, axes, steps):
7677
return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"])
7778

7879

79-
def _same_shape(op, data: ir.Value, slice_output: ir.Value, **_):
80+
def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_):
8081
"""Check if the shape of the slice output is the same as the data."""
8182
if data.shape is None or slice_output.shape is None:
8283
return False
84+
85+
if not is_singleton_value(steps, 1):
86+
return False
87+
8388
return data.shape == slice_output.shape
8489

8590

onnxscript/rewriter/collapse_slices_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,22 @@ def test_slice_equal_dynamic_shape(self):
100100
model = ir.serde.deserialize_model(model_proto)
101101
count = collapse_slices.rules.apply_to_model(model)
102102
self.assertEqual(count, 1)
103+
104+
def test_slice_equal_dynamic_shape_but_step_reverse(self):
105+
model_proto = onnx.parser.parse_model(
106+
f"""
107+
<ir_version: 7, opset_import: [ "" : 17]>
108+
agraph (float[L, M, N] data) => (float[L, M, N] output)
109+
{{
110+
starts = Constant<value: tensor = int64[1] {{0}}>()
111+
ends = Constant<value: tensor = int64[1] {{{9}}}>()
112+
axes = Constant<value: tensor = int64[1] {{0}}>()
113+
steps = Constant<value: tensor = int64[1] {{-1}}>()
114+
output = Slice (data, starts, ends, axes, steps)
115+
}}
116+
"""
117+
)
118+
model = ir.serde.deserialize_model(model_proto)
119+
count = collapse_slices.rules.apply_to_model(model)
120+
# Should not change the output shape if we did not use the default step of 1
121+
self.assertEqual(count, 0)

0 commit comments

Comments
 (0)