diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index fc000dc176..fe530159c7 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -19,8 +19,8 @@ ] import onnx -import onnx_ir.passes.common as common_passes +import onnxscript.ir.passes.common as common_passes from onnxscript import ir from onnxscript.rewriter import pattern from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus @@ -41,6 +41,7 @@ _min_max_to_clip, _no_op, _redundant_scatter_nd, + _remove_zero_bias, ) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) @@ -55,6 +56,7 @@ *_redundant_scatter_nd.rules, *_fuse_pad_into_conv.rules, *_fuse_batchnorm.rules, + *_remove_zero_bias.rules, ) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index ea7af31b3e..98d597f6cb 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -33,7 +33,12 @@ fuse_skip_layer_normalization, fuse_skip_rms_normalization, ) -from onnxscript.rewriter.rules.common import _gemm_to_matmul_add +from onnxscript.rewriter.rules.common import ( + _fuse_batchnorm, + _fuse_pad_into_conv, + _gemm_to_matmul_add, + _remove_zero_bias, +) ORT_PATTERN_REWRITE_RULES = [ *softmax.rules.rules, @@ -41,6 +46,10 @@ # NOTE: group normalization merge silu should be applied after instance to group normalization # *group_normalization_merge_silu.rules.rules, *fused_matmul_rule_sets.fused_matmul_rule_sets(), + # Add Conv fusion rules for better ORT optimization + *_fuse_batchnorm.rules.rules, + *_fuse_pad_into_conv.rules.rules, + *_remove_zero_bias.rules.rules, ] diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 14ed3587f3..e5d577aa5e 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -34,6 +34,10 @@ "normalize_pad_format_conv_integer_rule", "normalize_pad_format_conv_rule", "one_reshape_matmul_reshape_rule", + "remove_zero_bias_from_conv_rule", + "remove_zero_bias_from_conv_transpose_rule", + "remove_zero_bias_from_qlinear_conv_rule", + "remove_zero_bias_from_gemm_rule", "reshape_reshape_rule", "slice_split_rule", "squeeze_reshape_1d_rule", @@ -121,3 +125,9 @@ no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule, ) +from onnxscript.rewriter.rules.common._remove_zero_bias import ( + remove_zero_bias_from_conv_rule, + remove_zero_bias_from_conv_transpose_rule, + remove_zero_bias_from_gemm_rule, + remove_zero_bias_from_qlinear_conv_rule, +) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py new file mode 100644 index 0000000000..901e643ac1 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Remove optional bias when it is all zero from Conv and related operations.""" + +from __future__ import annotations + +from typing import ClassVar + +import numpy as np + +from onnxscript import ir +from onnxscript.ir import convenience +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _RemoveZeroBiasBase(RewriteRuleClassBase): + """Base class for removing zero bias from operations.""" + + def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: + """Remove the bias input from the operation.""" + node = out.producer() + + return op.op( + self.op_type, + inputs=node.inputs[:-1], + attributes=node.attributes, + ) + + def _check_bias_is_zero(self, bias_value: ir.Value) -> MatchResult: + """Check if the bias value is present and is all zeros.""" + check_result = MatchResult() + + # Check if bias is a constant/initializer + bias_tensor = convenience.get_const_tensor(bias_value) + if bias_tensor is None: + return check_result.fail("Bias is not a constant/initializer.") + + # Check if bias is all zeros + bias_array = bias_tensor.numpy() + if not np.allclose(bias_array, 0.0, atol=1e-8): + return check_result.fail("Bias is not all zeros.") + + return check_result + + def check(self, context, x: ir.Value, w: ir.Value, b: ir.Value, **_) -> MatchResult: + """Check if the bias is present and is all zeros.""" + del context # Unused + return self._check_bias_is_zero(b) + + +class RemoveZeroBiasFromConv(_RemoveZeroBiasBase): + """Remove zero bias from Conv operations.""" + + op_type: ClassVar = "Conv" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Conv(x, w, b, _outputs=["out"]) + + +class RemoveZeroBiasFromConvTranspose(_RemoveZeroBiasBase): + """Remove zero bias from ConvTranspose operations.""" + + op_type: ClassVar = "ConvTranspose" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.ConvTranspose(x, w, b, _outputs=["out"]) + + +class RemoveZeroBiasFromQLinearConv(_RemoveZeroBiasBase): + """Remove zero bias from QLinearConv operations.""" + + op_type: ClassVar = "QLinearConv" + + def pattern( + self, + op: ir.tape.Tape, + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b: ir.Value, + ) -> ir.Value: + return op.QLinearConv( + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b, + _outputs=["out"], + ) + + +class RemoveZeroBiasFromGemm(_RemoveZeroBiasBase): + """Remove zero bias from Gemm operations.""" + + op_type: ClassVar = "Gemm" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Gemm(x, w, b, _outputs=["out"]) + + +# Create rule instances +remove_zero_bias_from_conv_rule = RemoveZeroBiasFromConv().rule() +remove_zero_bias_from_conv_transpose_rule = RemoveZeroBiasFromConvTranspose().rule() +remove_zero_bias_from_qlinear_conv_rule = RemoveZeroBiasFromQLinearConv().rule() +remove_zero_bias_from_gemm_rule = RemoveZeroBiasFromGemm().rule() + +rules = RewriteRuleSet( + [ + remove_zero_bias_from_conv_rule, + remove_zero_bias_from_conv_transpose_rule, + remove_zero_bias_from_qlinear_conv_rule, + remove_zero_bias_from_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py new file mode 100644 index 0000000000..71b279f4f4 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -0,0 +1,447 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Tests for removing zero bias from Conv and related operations.""" + +import unittest +from typing import Optional + +import onnx_ir as ir + +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common._remove_zero_bias import ( + remove_zero_bias_from_conv_rule, + remove_zero_bias_from_conv_transpose_rule, + remove_zero_bias_from_gemm_rule, + remove_zero_bias_from_qlinear_conv_rule, +) + + +def _apply_rule_and_check_optimization( + model: ir.Model, + rule, + expected_count: int, + target_op_type: str, + expected_inputs_after: int, + expected_attributes: Optional[dict] = None, +) -> None: + """Helper function to test bias removal rules.""" + # Make a copy for comparison + original_model = ir.from_proto(ir.to_proto(model)) + + # Get original attributes for comparison + original_target_node = None + for node in original_model.graph: + if node.op_type == target_op_type: + original_target_node = node + break + + # Apply the rule + count = rule.apply_to_model(model) + + # Check that the rule was applied the expected number of times + assert count == expected_count, f"Expected {expected_count} applications, got {count}" + + # Check that the target node has the expected number of inputs + target_node = None + for node in model.graph: + if node.op_type == target_op_type: + target_node = node + break + + assert target_node is not None, f"{target_op_type} node not found" + assert len(target_node.inputs) == expected_inputs_after, ( + f"Expected {expected_inputs_after} inputs after optimization, " + f"got {len(target_node.inputs)}" + ) + + # Check that attributes are preserved if the rule was applied + if expected_count > 0 and original_target_node is not None: + # All original attributes should be preserved + for attr_name, attr_value in original_target_node.attributes.items(): + assert attr_name in target_node.attributes, f"Attribute {attr_name} was lost" + original_value = attr_value.value + new_value = target_node.attributes[attr_name].value + # Convert both to same type for comparison to handle list vs tuple differences + if isinstance(original_value, (list, tuple)) and isinstance( + new_value, (list, tuple) + ): + assert list(original_value) == list(new_value), ( + f"Attribute {attr_name} value changed from {original_value} to {new_value}" + ) + else: + assert new_value == original_value, ( + f"Attribute {attr_name} value changed from {original_value} to {new_value}" + ) + + # Check specific expected attributes if provided + if expected_attributes: + for attr_name, expected_value in expected_attributes.items(): + assert attr_name in target_node.attributes, ( + f"Expected attribute {attr_name} not found" + ) + actual_attr = target_node.attributes[attr_name] + actual_value = actual_attr.value + # Convert both to same type for comparison to handle list vs tuple differences + if isinstance(actual_value, (list, tuple)) and isinstance( + expected_value, (list, tuple) + ): + assert list(actual_value) == list(expected_value), ( + f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" + ) + else: + assert actual_value == expected_value, ( + f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" + ) + + # Compare outputs to ensure correctness (only for supported input types) + if expected_count > 0: + try: + # Generate random inputs for the model using the existing testing utility + original_model_proto = ir.to_proto(original_model) + inputs = testing.generate_random_inputs(original_model_proto) + testing.assert_numerically_equal(original_model, model, inputs) + except ValueError as e: + if "Not implemented for input type" in str(e): + # Skip numerical comparison for unsupported input types + # The structural checks above are sufficient for these cases + pass + else: + raise + + +class RemoveZeroBiasTest(unittest.TestCase): + """Test class for remove zero bias rules.""" + + def test_remove_zero_bias_from_conv(self): + """Test that zero bias is removed from Conv operations.""" + # Create a simple Conv with zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 4, 4] x) => (float[1, 2, 2, 2] y) + { + weight = Constant () + bias = Constant () + y = Conv(x, weight, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_conv_rule, + expected_count=1, + target_op_type="Conv", + expected_inputs_after=2, + ) + + def test_conv_with_non_zero_bias_unchanged(self): + """Test that Conv with non-zero bias is not modified.""" + # Create a Conv with non-zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 4, 4] x) => (float[1, 2, 2, 2] y) + { + weight = Constant () + bias = Constant () + y = Conv(x, weight, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_conv_rule, + expected_count=0, + target_op_type="Conv", + expected_inputs_after=3, + ) + + def test_remove_zero_bias_from_conv_transpose(self): + """Test that zero bias is removed from ConvTranspose operations.""" + # Create a ConvTranspose with zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 2, 2] x) => (float[1, 2, 4, 4] y) + { + weight = Constant () + bias = Constant () + y = ConvTranspose(x, weight, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_conv_transpose_rule, + expected_count=1, + target_op_type="ConvTranspose", + expected_inputs_after=2, + ) + + def test_conv_transpose_with_non_zero_bias_unchanged(self): + """Test that ConvTranspose with non-zero bias is not modified.""" + # Create a ConvTranspose with non-zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 2, 2] x) => (float[1, 2, 4, 4] y) + { + weight = Constant () + bias = Constant () + y = ConvTranspose(x, weight, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_conv_transpose_rule, + expected_count=0, + target_op_type="ConvTranspose", + expected_inputs_after=3, + ) + + def test_remove_zero_bias_from_gemm(self): + """Test that zero bias is removed from Gemm operations.""" + # Create a Gemm with zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[2, 3] a) => (float[2, 4] y) + { + b = Constant () + c = Constant () + y = Gemm(a, b, c) + } + """ + ) + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_gemm_rule, + expected_count=1, + target_op_type="Gemm", + expected_inputs_after=2, + ) + + def test_gemm_with_non_zero_bias_unchanged(self): + """Test that Gemm with non-zero bias is not modified.""" + # Create a Gemm with non-zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[2, 3] a) => (float[2, 4] y) + { + b = Constant () + c = Constant () + y = Gemm(a, b, c) + } + """ + ) + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_gemm_rule, + expected_count=0, + target_op_type="Gemm", + expected_inputs_after=3, + ) + + def test_remove_zero_bias_from_qlinear_conv(self): + """Test that zero bias is removed from QLinearConv operations.""" + # Create a QLinearConv with zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (uint8[1, 2, 4, 4] x) => (uint8[1, 2, 2, 2] y) + { + x_scale = Constant () + x_zero_point = Constant () + weight = Constant () + w_scale = Constant () + w_zero_point = Constant () + y_scale = Constant () + y_zero_point = Constant () + bias = Constant () + y = QLinearConv(x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_qlinear_conv_rule, + expected_count=1, + target_op_type="QLinearConv", + expected_inputs_after=8, + ) + + def test_remove_zero_bias_from_conv_with_attributes(self): + """Test that zero bias is removed from Conv operations and attributes are preserved.""" + # Create a Conv with zero bias and various attributes using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 6, 6] x) => (float[1, 2, 2, 2] y) + { + weight = Constant () + bias = Constant () + y = Conv (x, weight, bias) + } + """ + ) + + expected_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_conv_rule, + expected_count=1, + target_op_type="Conv", + expected_inputs_after=2, + expected_attributes=expected_attributes, + ) + + def test_remove_zero_bias_from_conv_transpose_with_attributes(self): + """Test that zero bias is removed from ConvTranspose operations and attributes are preserved.""" + # Create a ConvTranspose with zero bias and various attributes using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 2, 2] x) => (float[1, 2, 6, 6] y) + { + weight = Constant () + bias = Constant () + y = ConvTranspose (x, weight, bias) + } + """ + ) + + expected_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "output_padding": [0, 0], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_conv_transpose_rule, + expected_count=1, + target_op_type="ConvTranspose", + expected_inputs_after=2, + expected_attributes=expected_attributes, + ) + + def test_remove_zero_bias_from_gemm_with_attributes(self): + """Test that zero bias is removed from Gemm operations and attributes are preserved.""" + # Create a Gemm with zero bias and various attributes using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[2, 3] a) => (float[2, 4] y) + { + b = Constant () + c = Constant () + y = Gemm (a, b, c) + } + """ + ) + + expected_attributes = { + "alpha": 2.0, + "beta": 1.0, + "transA": 0, + "transB": 1, + } + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_gemm_rule, + expected_count=1, + target_op_type="Gemm", + expected_inputs_after=2, + expected_attributes=expected_attributes, + ) + + def test_remove_zero_bias_from_qlinear_conv_with_attributes(self): + """Test that zero bias is removed from QLinearConv operations and attributes are preserved.""" + # Create a QLinearConv with zero bias and various attributes using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (uint8[1, 2, 6, 6] x) => (uint8[1, 2, 2, 2] y) + { + x_scale = Constant () + x_zero_point = Constant () + weight = Constant () + w_scale = Constant () + w_zero_point = Constant () + y_scale = Constant () + y_zero_point = Constant () + bias = Constant () + y = QLinearConv (x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) + } + """ + ) + + expected_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_qlinear_conv_rule, + expected_count=1, + target_op_type="QLinearConv", + expected_inputs_after=8, + expected_attributes=expected_attributes, + ) + + def test_qlinear_conv_with_non_zero_bias_unchanged(self): + """Test that QLinearConv with non-zero bias is not modified.""" + # Create a QLinearConv with non-zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (uint8[1, 2, 4, 4] x) => (uint8[1, 2, 2, 2] y) + { + x_scale = Constant () + x_zero_point = Constant () + weight = Constant () + w_scale = Constant () + w_zero_point = Constant () + y_scale = Constant () + y_zero_point = Constant () + bias = Constant () + y = QLinearConv(x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_qlinear_conv_rule, + expected_count=0, + target_op_type="QLinearConv", + expected_inputs_after=9, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 82d7a54026..c80068919f 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -26,7 +26,7 @@ from onnxscript import ir -def test_model(model_info: hub.ModelInfo) -> float: +def validate_model(model_info: hub.ModelInfo) -> float: model_name = model_info.model with tempfile.TemporaryDirectory() as temp_dir, contextlib.redirect_stdout(None): # For parallel testing, this must be in a separate process because hub.set_dir @@ -58,7 +58,7 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, str | None]: model_path = model_info.model_path message = f"\n----Testing: {model_name} @ {model_path}----" try: - time_passed = test_model(model_info) + time_passed = validate_model(model_info) message += green(f"\n[PASS]: {model_name} roundtrip test passed.") except Exception as e: # pylint: disable=broad-exception-caught time_passed = -1