Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_min_max_to_clip,
_no_op,
_redundant_scatter_nd,
_remove_optional_bias,
)

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
Expand All @@ -55,6 +56,7 @@
*_redundant_scatter_nd.rules,
*_fuse_pad_into_conv.rules,
*_fuse_batchnorm.rules,
*_remove_optional_bias.rules,
)


Expand Down
10 changes: 10 additions & 0 deletions onnxscript/rewriter/rules/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
"normalize_pad_format_conv_integer_rule",
"normalize_pad_format_conv_rule",
"one_reshape_matmul_reshape_rule",
"remove_optional_bias_from_conv_rule",
"remove_optional_bias_from_conv_transpose_rule",
"remove_optional_bias_from_gemm_rule",
"remove_optional_bias_from_qlinear_conv_rule",
"reshape_reshape_rule",
"slice_split_rule",
"squeeze_reshape_1d_rule",
Expand Down Expand Up @@ -121,3 +125,9 @@
no_op_dynamic_scatter_nd_rule,
no_op_static_scatter_nd_rule,
)
from onnxscript.rewriter.rules.common._remove_optional_bias import (
remove_optional_bias_from_conv_rule,
remove_optional_bias_from_conv_transpose_rule,
remove_optional_bias_from_gemm_rule,
remove_optional_bias_from_qlinear_conv_rule,
)
123 changes: 123 additions & 0 deletions onnxscript/rewriter/rules/common/_remove_optional_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Remove optional bias when it is all zero from Conv, ConvTranspose, Gemm and QLinearConv operations."""

from __future__ import annotations

from typing import ClassVar

import numpy as np

from onnxscript import ir
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet


class _RemoveOptionalBias(RewriteRuleClassBase):
def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value:
node = out.producer()

return op.op(
self.op_type,
inputs=node.inputs[:-1],
attributes=node.attributes,
)

def check(self, context, b: ir.Value, **_) -> MatchResult:
"""Condition to check if we need to replace the pattern.

The pattern is applied only when the bias is all zeros. The bias should be
a constant value (i.e., provided by Constant nodes or initializers).

Returns:
MatchResult:
Success if we need to replace the pattern, Failure otherwise.
"""
del context # Unused
check_result = MatchResult()

# Check if bias is a constant/initializer
bias_tensor = ir.convenience.get_const_tensor(b)
if bias_tensor is None:
return check_result.fail("Bias is not a constant/initializer.")

# Check if bias is all zeros
bias_array = bias_tensor.numpy()
if not np.equal(bias_array, 0.0).all():
return check_result.fail("Bias is not all zeros.")

return check_result


class RemoveOptionalBiasFromConv(_RemoveOptionalBias):
"""Remove zero bias from Conv operation."""

op_type: ClassVar[str] = "Conv"

def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value:
return op.Conv(x, w, b, _outputs=["out"])


class RemoveOptionalBiasFromConvTranspose(_RemoveOptionalBias):
"""Remove zero bias from ConvTranspose operation."""

op_type: ClassVar[str] = "ConvTranspose"

def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value:
return op.ConvTranspose(x, w, b, _outputs=["out"])


class RemoveOptionalBiasFromQLinearConv(_RemoveOptionalBias):
"""Remove zero bias from QLinearConv operation."""

op_type: ClassVar[str] = "QLinearConv"

def pattern(
self,
op: ir.tape.Tape,
x,
x_scale,
x_zero_point,
w,
w_scale,
w_zero_point,
y_scale,
y_zero_point,
b: ir.Value,
) -> ir.Value:
return op.QLinearConv(
x,
x_scale,
x_zero_point,
w,
w_scale,
w_zero_point,
y_scale,
y_zero_point,
b,
_outputs=["out"],
)


class RemoveOptionalBiasFromGemm(_RemoveOptionalBias):
"""Remove zero bias from Gemm operation."""

op_type: ClassVar[str] = "Gemm"

def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value:
return op.Gemm(x, w, b, _outputs=["out"])


remove_optional_bias_from_conv_rule = RemoveOptionalBiasFromConv().rule()
remove_optional_bias_from_conv_transpose_rule = RemoveOptionalBiasFromConvTranspose().rule()
remove_optional_bias_from_qlinear_conv_rule = RemoveOptionalBiasFromQLinearConv().rule()
remove_optional_bias_from_gemm_rule = RemoveOptionalBiasFromGemm().rule()

rules = RewriteRuleSet(
[
remove_optional_bias_from_conv_rule,
remove_optional_bias_from_conv_transpose_rule,
remove_optional_bias_from_qlinear_conv_rule,
remove_optional_bias_from_gemm_rule,
]
)
237 changes: 237 additions & 0 deletions onnxscript/rewriter/rules/common/_remove_optional_bias_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import unittest

import numpy as np
import onnx
import onnx_ir as ir
from onnx_ir.passes.common import onnx_checker

from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing
from onnxscript.rewriter.rules.common import _remove_optional_bias
from onnxscript.rewriter.rules.common._remove_optional_bias import (
remove_optional_bias_from_conv_rule,
remove_optional_bias_from_conv_transpose_rule,
remove_optional_bias_from_gemm_rule,
remove_optional_bias_from_qlinear_conv_rule,
)


class _RemoveOptionalBiasTestBase(unittest.TestCase):
@property
def rng(self):
return np.random.default_rng(20251016)

def clone_model(self, model: ir.Model) -> ir.Model:
return ir.from_proto(ir.to_proto(model))

def _get_test_model(
self,
op_type: str,
input_shape: ir.Shape,
weight_shape: ir.Shape,
zero_bias: bool,
attributes=None,
):
tape = ir.tape.Tape()
bias_shape = weight_shape[1] if op_type == "ConvTranspose" else weight_shape[0]
output_shape = ir.Shape(("?",) * input_shape.rank())

x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))

w = tape.initializer(
ir.tensor(self.rng.uniform(-0.5, 0.5, weight_shape).astype(np.float32), name="W")
)

if zero_bias:
bias = np.zeros(bias_shape, dtype=np.float32)
else:
bias = self.rng.uniform(-0.5, 0.5, bias_shape).astype(np.float32)

b = tape.initializer(ir.tensor(bias, name="B"))
y = tape.op(
op_type,
inputs=[x, w, b],
attributes=attributes,
output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
)

# Build the model
ir_model = ir.Model(
ir.Graph(
inputs=[x],
outputs=[y],
nodes=tape.nodes,
initializers=tape.initializers,
opset_imports={"": 20},
name="test_model",
),
ir_version=10,
)
onnx_checker.CheckerPass(True)(ir_model)
return ir_model

def run_test(
self,
base_model: ir.Model,
input_shape: tuple,
input_dtype=np.float32,
):
updated_model = self.clone_model(base_model)
count = _remove_optional_bias.rules.apply_to_model(updated_model)

# Check rule is applied
self.assertEqual(count, 1)

# Check number of inputs is reduced
self.assertEqual(
len(updated_model.graph[0].inputs), len(base_model.graph[0].inputs) - 1
)

# Prepare inputs
inputs = (self.rng.random(input_shape).astype(input_dtype),)

# Check inference
testing.assert_numerically_equal(base_model, updated_model, inputs)

# Validate serialized model
output_model_proto = ir.serde.serialize_model(updated_model)
onnx.checker.check_model(output_model_proto, full_check=True)

def run_failed_condition_test(
self,
base_model: ir.Model,
rewrite_rule: RewriteRule,
expected_message: str,
):
onnx_checker.CheckerPass(True)(base_model)

updated_model = self.clone_model(base_model)
tracer = MatchingTracer()
count = rewrite_rule.apply_to_model(updated_model, tracer=tracer)

# Check that the model is unchanged
self.assertEqual(count, 0)

# Check that the error message is the expected one
tracer_match = tracer.best_matches_map[rewrite_rule][0]
self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED)
self.assertRegex(tracer_match.match_result.reason, expected_message)


class RemoveOptionalBiasGemmTest(_RemoveOptionalBiasTestBase):
def test_successful_remove_optional_bias_gemm(self):
input_shape = (512, 256)
base_model = self._get_test_model(
op_type="Gemm",
input_shape=ir.Shape(input_shape),
weight_shape=ir.Shape((64, 256)),
zero_bias=True,
attributes={"transB": 1},
)
self.run_test(base_model, input_shape)

def test_fail_remove_optional_bias_gemm(self):
input_shape = (512, 256)
base_model = self._get_test_model(
op_type="Gemm",
input_shape=ir.Shape(input_shape),
weight_shape=ir.Shape((64, 256)),
zero_bias=False,
attributes={"transB": 1},
)
self.run_failed_condition_test(
base_model, remove_optional_bias_from_gemm_rule, "Bias is not all zeros."
)


class RemoveOptionalBiasGonvTest(_RemoveOptionalBiasTestBase):
def test_successful_remove_optional_bias_conv(self):
input_shape = (1, 3, 32, 32)
base_model = self._get_test_model(
op_type="Conv",
input_shape=ir.Shape(input_shape),
weight_shape=ir.Shape((16, 3, 3, 3)),
zero_bias=True,
attributes={"strides": (2, 2)},
)
self.run_test(base_model, input_shape)

def test_fail_remove_optional_bias_conv(self):
input_shape = (1, 3, 32, 32)
base_model = self._get_test_model(
op_type="Conv",
input_shape=ir.Shape(input_shape),
weight_shape=ir.Shape((16, 3, 3, 3)),
zero_bias=False,
)
self.run_failed_condition_test(
base_model, remove_optional_bias_from_conv_rule, "Bias is not all zeros."
)


class RemoveOptionalBiasGonvTransposeTest(_RemoveOptionalBiasTestBase):
def test_successful_remove_optional_bias_conv_transpose(self):
input_shape = (1, 3, 32, 32)
base_model = self._get_test_model(
op_type="ConvTranspose",
input_shape=ir.Shape(input_shape),
weight_shape=ir.Shape((3, 16, 3, 3)),
zero_bias=True,
)
self.run_test(base_model, input_shape)

def test_fail_remove_optional_bias_conv_transpose(self):
input_shape = (1, 3, 32, 32)
base_model = self._get_test_model(
op_type="ConvTranspose",
input_shape=ir.Shape(input_shape),
weight_shape=ir.Shape((3, 16, 3, 3)),
zero_bias=False,
)
self.run_failed_condition_test(
base_model, remove_optional_bias_from_conv_transpose_rule, "Bias is not all zeros."
)


class RemoveOptionalBiasQLinearConvTest(_RemoveOptionalBiasTestBase):
def _get_test_model(self, zero_bias):

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method requires 2 positional arguments, whereas overridden
_RemoveOptionalBiasTestBase._get_test_model
requires at least 5.
if zero_bias:
bias = np.zeros((16,), dtype=np.int32)
else:
bias = self.rng.uniform(-5, 5, (16,)).astype(np.int32)

w = ir.tensor(self.rng.uniform(-5, 5, (16, 3, 3, 3)).astype(np.uint8), name="W")
b = ir.tensor(bias, name="B")

model = ir.from_onnx_text(
"""
< ir_version: 10, opset_import: ["" : 20] >
test_model (uint8[N, 3, 32, 32] X) => (uint8 [N, ?, ?, ?] Y)
<uint8[16, 3, 3, 3] W, int32[16] B, float x_scale = {1.5}, uint8 x_zero_point = {123},
float w_scale = {1.5}, uint8 w_zero_point = {123},
float y_scale = {1.5}, uint8 y_zero_point = {123}>
{
Y = QLinearConv(X, x_scale, x_zero_point, W, w_scale, w_zero_point, y_scale, y_zero_point, B)
}
""",
initializers=[w, b],
)
onnx_checker.CheckerPass(True)(model)
return model

def test_successful_remove_optional_bias_qlinear_conv(self):
input_shape = (1, 3, 32, 32)
base_model = self._get_test_model(zero_bias=True)
self.run_test(base_model, input_shape, np.uint8)

def test_fail_remove_optional_bias_qlinear_conv(self):
base_model = self._get_test_model(zero_bias=False)
self.run_failed_condition_test(
base_model, remove_optional_bias_from_qlinear_conv_rule, "Bias is not all zeros."
)


if __name__ == "__main__":
unittest.main()
Loading