Skip to content

Commit 4547bd2

Browse files
committed
[Rewriter]: introduce remove_optional_bias
Removes optional bias when it is all zero from Conv, ConvTranspose, Gemm and QLinearConv operations.
1 parent 75b3d42 commit 4547bd2

File tree

4 files changed

+372
-0
lines changed

4 files changed

+372
-0
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_min_max_to_clip,
4242
_no_op,
4343
_redundant_scatter_nd,
44+
_remove_optional_bias,
4445
)
4546

4647
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
@@ -55,6 +56,7 @@
5556
*_redundant_scatter_nd.rules,
5657
*_fuse_pad_into_conv.rules,
5758
*_fuse_batchnorm.rules,
59+
*_remove_optional_bias.rules,
5860
)
5961

6062

onnxscript/rewriter/rules/common/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
"normalize_pad_format_conv_integer_rule",
3535
"normalize_pad_format_conv_rule",
3636
"one_reshape_matmul_reshape_rule",
37+
"remove_optional_bias_from_conv_rule",
38+
"remove_optional_bias_from_conv_transpose_rule",
39+
"remove_optional_bias_from_gemm_rule",
40+
"remove_optional_bias_from_qlinear_conv_rule",
3741
"reshape_reshape_rule",
3842
"slice_split_rule",
3943
"squeeze_reshape_1d_rule",
@@ -121,3 +125,9 @@
121125
no_op_dynamic_scatter_nd_rule,
122126
no_op_static_scatter_nd_rule,
123127
)
128+
from onnxscript.rewriter.rules.common._remove_optional_bias import (
129+
remove_optional_bias_from_conv_rule,
130+
remove_optional_bias_from_conv_transpose_rule,
131+
remove_optional_bias_from_gemm_rule,
132+
remove_optional_bias_from_qlinear_conv_rule,
133+
)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Remove optional bias when it is all zero from Conv, ConvTranspose, Gemm and QLinearConv operations."""
4+
5+
from __future__ import annotations
6+
7+
from typing import ClassVar
8+
9+
import numpy as np
10+
11+
from onnxscript import ir
12+
from onnxscript.rewriter._basics import MatchResult
13+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
14+
15+
16+
class _RemoveOptionalBias(RewriteRuleClassBase):
17+
def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value:
18+
node = out.producer()
19+
20+
return op.op(
21+
self.op_type,
22+
inputs=node.inputs[:-1],
23+
attributes=node.attributes,
24+
)
25+
26+
def check(self, context, b: ir.Value, **_) -> MatchResult:
27+
"""Condition to check if we need to replace the pattern.
28+
29+
The pattern is applied only when the bias is all zeros. The bias should be
30+
a constant value (i.e., provided by Constant nodes or initializers).
31+
32+
Returns:
33+
MatchResult:
34+
Success if we need to replace the pattern, Failure otherwise.
35+
"""
36+
del context # Unused
37+
check_result = MatchResult()
38+
39+
# Check if bias is a constant/initializer
40+
bias_tensor = ir.convenience.get_const_tensor(b)
41+
if bias_tensor is None:
42+
return check_result.fail("Bias is not a constant/initializer.")
43+
44+
# Check if bias is all zeros
45+
bias_array = bias_tensor.numpy()
46+
if not np.equal(bias_array, 0.0).all():
47+
return check_result.fail("Bias is not all zeros.")
48+
49+
return check_result
50+
51+
52+
class RemoveOptionalBiasFromConv(_RemoveOptionalBias):
53+
"""Remove zero bias from Conv operation."""
54+
55+
op_type: ClassVar = "Conv"
56+
57+
def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value:
58+
return op.Conv(x, w, b, _outputs=["out"])
59+
60+
61+
class RemoveOptionalBiasFromConvTranspose(_RemoveOptionalBias):
62+
"""Remove zero bias from ConvTranspose operation."""
63+
64+
op_type: ClassVar = "ConvTranspose"
65+
66+
def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value:
67+
return op.ConvTranspose(x, w, b, _outputs=["out"])
68+
69+
70+
class RemoveOptionalBiasFromQLinearConv(_RemoveOptionalBias):
71+
"""Remove zero bias from QLinearConv operation."""
72+
73+
op_type: ClassVar = "QLinearConv"
74+
75+
def pattern(
76+
self,
77+
op: ir.tape.Tape,
78+
x,
79+
x_scale,
80+
x_zero_point,
81+
w,
82+
w_scale,
83+
w_zero_point,
84+
y_scale,
85+
y_zero_point,
86+
b: ir.Value,
87+
) -> ir.Value:
88+
return op.QLinearConv(
89+
x,
90+
x_scale,
91+
x_zero_point,
92+
w,
93+
w_scale,
94+
w_zero_point,
95+
y_scale,
96+
y_zero_point,
97+
b,
98+
_outputs=["out"],
99+
)
100+
101+
102+
class RemoveOptionalBiasFromGemm(_RemoveOptionalBias):
103+
"""Remove zero bias from Gemm operation."""
104+
105+
op_type: ClassVar = "Gemm"
106+
107+
def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value:
108+
return op.Gemm(x, w, b, _outputs=["out"])
109+
110+
111+
remove_optional_bias_from_conv_rule = RemoveOptionalBiasFromConv().rule()
112+
remove_optional_bias_from_conv_transpose_rule = RemoveOptionalBiasFromConvTranspose().rule()
113+
remove_optional_bias_from_qlinear_conv_rule = RemoveOptionalBiasFromQLinearConv().rule()
114+
remove_optional_bias_from_gemm_rule = RemoveOptionalBiasFromGemm().rule()
115+
116+
rules = RewriteRuleSet(
117+
[
118+
remove_optional_bias_from_conv_rule,
119+
remove_optional_bias_from_conv_transpose_rule,
120+
remove_optional_bias_from_qlinear_conv_rule,
121+
remove_optional_bias_from_gemm_rule,
122+
]
123+
)
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import unittest
5+
6+
import numpy as np
7+
import onnx
8+
import onnx_ir as ir
9+
from onnx_ir.passes.common import onnx_checker
10+
11+
from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing
12+
from onnxscript.rewriter.rules.common import _remove_optional_bias
13+
from onnxscript.rewriter.rules.common._remove_optional_bias import (
14+
remove_optional_bias_from_conv_rule,
15+
remove_optional_bias_from_conv_transpose_rule,
16+
remove_optional_bias_from_gemm_rule,
17+
remove_optional_bias_from_qlinear_conv_rule,
18+
)
19+
20+
21+
class _RemoveOptionalBiasTestBase(unittest.TestCase):
22+
@property
23+
def rng(self):
24+
return np.random.default_rng(20251016)
25+
26+
def clone_model(self, model: ir.Model) -> ir.Model:
27+
return ir.from_proto(ir.to_proto(model))
28+
29+
def _get_test_model(
30+
self,
31+
op_type: str,
32+
input_shape: ir.Shape,
33+
weight_shape: ir.Shape,
34+
zero_bias: bool,
35+
attributes=None,
36+
):
37+
tape = ir.tape.Tape()
38+
bias_shape = weight_shape[1] if op_type == "ConvTranspose" else weight_shape[0]
39+
output_shape = ir.Shape(("?",) * input_shape.rank())
40+
41+
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
42+
43+
w = tape.initializer(
44+
ir.tensor(self.rng.uniform(-0.5, 0.5, weight_shape).astype(np.float32), name="W")
45+
)
46+
47+
if zero_bias:
48+
bias = np.zeros(bias_shape, dtype=np.float32)
49+
else:
50+
bias = self.rng.uniform(-0.5, 0.5, bias_shape).astype(np.float32)
51+
52+
b = tape.initializer(ir.tensor(bias, name="B"))
53+
y = tape.op(
54+
op_type,
55+
inputs=[x, w, b],
56+
attributes=attributes,
57+
output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
58+
)
59+
60+
# Build the model
61+
ir_model = ir.Model(
62+
ir.Graph(
63+
inputs=[x],
64+
outputs=[y],
65+
nodes=tape.nodes,
66+
initializers=tape.initializers,
67+
opset_imports={"": 20},
68+
name="test_model",
69+
),
70+
ir_version=10,
71+
)
72+
onnx_checker.CheckerPass(True)(ir_model)
73+
return ir_model
74+
75+
def run_test(
76+
self,
77+
base_model: ir.Model,
78+
input_shape: tuple,
79+
input_dtype=np.float32,
80+
):
81+
updated_model = self.clone_model(base_model)
82+
count = _remove_optional_bias.rules.apply_to_model(updated_model)
83+
84+
# Check rule is applied
85+
self.assertEqual(count, 1)
86+
87+
# Check number of inputs is reduced
88+
self.assertEqual(
89+
len(updated_model.graph[0].inputs), len(base_model.graph[0].inputs) - 1
90+
)
91+
92+
# Prepare inputs
93+
inputs = (self.rng.random(input_shape).astype(input_dtype),)
94+
95+
# Check inference
96+
testing.assert_numerically_equal(base_model, updated_model, inputs)
97+
98+
# Validate serialized model
99+
output_model_proto = ir.serde.serialize_model(updated_model)
100+
onnx.checker.check_model(output_model_proto, full_check=True)
101+
102+
def run_failed_condition_test(
103+
self,
104+
base_model: ir.Model,
105+
rewrite_rule: RewriteRule,
106+
expected_message: str,
107+
):
108+
onnx_checker.CheckerPass(True)(base_model)
109+
110+
updated_model = self.clone_model(base_model)
111+
tracer = MatchingTracer()
112+
count = rewrite_rule.apply_to_model(updated_model, tracer=tracer)
113+
114+
# Check that the model is unchanged
115+
self.assertEqual(count, 0)
116+
117+
# Check that the error message is the expected one
118+
tracer_match = tracer.best_matches_map[rewrite_rule][0]
119+
self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED)
120+
self.assertRegex(tracer_match.match_result.reason, expected_message)
121+
122+
123+
class RemoveOptionalBiasGemmTest(_RemoveOptionalBiasTestBase):
124+
def test_successful_remove_optional_bias_gemm(self):
125+
input_shape = (512, 256)
126+
base_model = self._get_test_model(
127+
op_type="Gemm",
128+
input_shape=ir.Shape(input_shape),
129+
weight_shape=ir.Shape((64, 256)),
130+
zero_bias=True,
131+
attributes={"transB": 1},
132+
)
133+
self.run_test(base_model, input_shape)
134+
135+
def test_fail_remove_optional_bias_gemm(self):
136+
input_shape = (512, 256)
137+
base_model = self._get_test_model(
138+
op_type="Gemm",
139+
input_shape=ir.Shape(input_shape),
140+
weight_shape=ir.Shape((64, 256)),
141+
zero_bias=False,
142+
attributes={"transB": 1},
143+
)
144+
self.run_failed_condition_test(
145+
base_model, remove_optional_bias_from_gemm_rule, "Bias is not all zeros."
146+
)
147+
148+
149+
class RemoveOptionalBiasGonvTest(_RemoveOptionalBiasTestBase):
150+
def test_successful_remove_optional_bias_conv(self):
151+
input_shape = (1, 3, 32, 32)
152+
base_model = self._get_test_model(
153+
op_type="Conv",
154+
input_shape=ir.Shape(input_shape),
155+
weight_shape=ir.Shape((16, 3, 3, 3)),
156+
zero_bias=True,
157+
attributes={"strides": (2, 2)},
158+
)
159+
self.run_test(base_model, input_shape)
160+
161+
def test_fail_remove_optional_bias_conv(self):
162+
input_shape = (1, 3, 32, 32)
163+
base_model = self._get_test_model(
164+
op_type="Conv",
165+
input_shape=ir.Shape(input_shape),
166+
weight_shape=ir.Shape((16, 3, 3, 3)),
167+
zero_bias=False,
168+
)
169+
self.run_failed_condition_test(
170+
base_model, remove_optional_bias_from_conv_rule, "Bias is not all zeros."
171+
)
172+
173+
174+
class RemoveOptionalBiasGonvTransposeTest(_RemoveOptionalBiasTestBase):
175+
def test_successful_remove_optional_bias_conv_transpose(self):
176+
input_shape = (1, 3, 32, 32)
177+
base_model = self._get_test_model(
178+
op_type="ConvTranspose",
179+
input_shape=ir.Shape(input_shape),
180+
weight_shape=ir.Shape((3, 16, 3, 3)),
181+
zero_bias=True,
182+
)
183+
self.run_test(base_model, input_shape)
184+
185+
def test_fail_remove_optional_bias_conv_transpose(self):
186+
input_shape = (1, 3, 32, 32)
187+
base_model = self._get_test_model(
188+
op_type="ConvTranspose",
189+
input_shape=ir.Shape(input_shape),
190+
weight_shape=ir.Shape((3, 16, 3, 3)),
191+
zero_bias=False,
192+
)
193+
self.run_failed_condition_test(
194+
base_model, remove_optional_bias_from_conv_transpose_rule, "Bias is not all zeros."
195+
)
196+
197+
198+
class RemoveOptionalBiasQLinearConvTest(_RemoveOptionalBiasTestBase):
199+
def _get_test_model(self, zero_bias):
200+
if zero_bias:
201+
bias = np.zeros((16,), dtype=np.int32)
202+
else:
203+
bias = self.rng.uniform(-5, 5, (16,)).astype(np.int32)
204+
205+
w = ir.tensor(self.rng.uniform(-5, 5, (16, 3, 3, 3)).astype(np.uint8), name="W")
206+
b = ir.tensor(bias, name="B")
207+
208+
model = ir.from_onnx_text(
209+
"""
210+
< ir_version: 10, opset_import: ["" : 20] >
211+
test_model (uint8[N, 3, 32, 32] X) => (uint8 [N, ?, ?, ?] Y)
212+
<uint8[16, 3, 3, 3] W, int32[16] B, float x_scale = {1.5}, uint8 x_zero_point = {123},
213+
float w_scale = {1.5}, uint8 w_zero_point = {123},
214+
float y_scale = {1.5}, uint8 y_zero_point = {123}>
215+
{
216+
Y = QLinearConv(X, x_scale, x_zero_point, W, w_scale, w_zero_point, y_scale, y_zero_point, B)
217+
}
218+
""",
219+
initializers=[w, b],
220+
)
221+
onnx_checker.CheckerPass(True)(model)
222+
return model
223+
224+
def test_successful_remove_optional_bias_qlinear_conv(self):
225+
input_shape = (1, 3, 32, 32)
226+
base_model = self._get_test_model(zero_bias=True)
227+
self.run_test(base_model, input_shape, np.uint8)
228+
229+
def test_fail_remove_optional_bias_qlinear_conv(self):
230+
base_model = self._get_test_model(zero_bias=False)
231+
self.run_failed_condition_test(
232+
base_model, remove_optional_bias_from_qlinear_conv_rule, "Bias is not all zeros."
233+
)
234+
235+
236+
if __name__ == "__main__":
237+
unittest.main()

0 commit comments

Comments
 (0)