From 6cfd6fd36e0448cd90b0bcde130ba859712d4427 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Wed, 10 Sep 2025 01:09:09 +0530 Subject: [PATCH 01/13] [Rewriter] Implement zero bias removal for Conv operations and related rules --- onnxscript/rewriter/__init__.py | 6 +- onnxscript/rewriter/ort_fusions/_core.py | 11 +- onnxscript/rewriter/rules/common/__init__.py | 10 + .../rules/common/_remove_zero_bias.py | 203 ++++++++++++++++++ .../rules/common/_remove_zero_bias_test.py | 87 ++++++++ tools/ir/model_zoo_test/model_zoo_test.py | 4 +- 6 files changed, 317 insertions(+), 4 deletions(-) create mode 100644 onnxscript/rewriter/rules/common/_remove_zero_bias.py create mode 100644 onnxscript/rewriter/rules/common/_remove_zero_bias_test.py diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 232750af78..fe530159c7 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -19,8 +19,8 @@ ] import onnx -import onnx_ir.passes.common as common_passes +import onnxscript.ir.passes.common as common_passes from onnxscript import ir from onnxscript.rewriter import pattern from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus @@ -35,11 +35,13 @@ _broadcast_to_matmul, _cast_constant_of_shape, _collapse_slices, + _fuse_batchnorm, _fuse_pad_into_conv, _fuse_relus_clips, _min_max_to_clip, _no_op, _redundant_scatter_nd, + _remove_zero_bias, ) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) @@ -53,6 +55,8 @@ *_basic_rules.basic_optimization_rules(), *_redundant_scatter_nd.rules, *_fuse_pad_into_conv.rules, + *_fuse_batchnorm.rules, + *_remove_zero_bias.rules, ) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 8f3c7c463a..bb27f499d0 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -33,7 +33,12 @@ fuse_skip_layer_normalization, fuse_skip_rms_normalization, ) -from onnxscript.rewriter.rules.common import _gemm_to_matmul_add +from onnxscript.rewriter.rules.common import ( + _fuse_batchnorm, + _fuse_pad_into_conv, + _gemm_to_matmul_add, + _remove_zero_bias, +) ORT_PATTERN_REWRITE_RULES = [ *softmax.rules.rules, @@ -41,6 +46,10 @@ # NOTE: group normalization merge silu should be applied after instance to group normalization # *group_normalization_merge_silu.rules.rules, *fused_matmul_rule_sets.fused_matmul_rule_sets(), + # Add Conv fusion rules for better ORT optimization + *_fuse_batchnorm.rules.rules, + *_fuse_pad_into_conv.rules.rules, + *_remove_zero_bias.rules.rules, ] diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 0b01bade72..43570420fd 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -31,6 +31,10 @@ "normalize_pad_format_conv_integer_rule", "normalize_pad_format_conv_rule", "one_reshape_matmul_reshape_rule", + "remove_zero_bias_from_conv_rule", + "remove_zero_bias_from_conv_transpose_rule", + "remove_zero_bias_from_qlinear_conv_rule", + "remove_zero_bias_from_gemm_rule", "reshape_reshape_rule", "slice_split_rule", "squeeze_reshape_1d_rule", @@ -113,3 +117,9 @@ no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule, ) +from onnxscript.rewriter.rules.common._remove_zero_bias import ( + remove_zero_bias_from_conv_rule, + remove_zero_bias_from_conv_transpose_rule, + remove_zero_bias_from_gemm_rule, + remove_zero_bias_from_qlinear_conv_rule, +) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py new file mode 100644 index 0000000000..1480fcea1f --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Remove optional bias when it is all zero from Conv and related operations.""" + +from __future__ import annotations + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _RemoveZeroBiasBase(RewriteRuleClassBase): + """Base class for removing zero bias from operations.""" + + def __init__(self, op_type: str): + super().__init__(remove_nodes=False) + self.op_type = op_type + + def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + """Remove the bias input from the operation.""" + return op.op( + self.op_type, + inputs=[x, w], # Remove bias input + ) + + def check(self, context, x: ir.Value, w: ir.Value, b: ir.Value, **_) -> MatchResult: + """Check if the bias is present and is all zeros.""" + del context # Unused + check_result = MatchResult() + + # Check if bias is a constant/initializer + if b.const_value is None: + return check_result.fail("Bias is not a constant/initializer.") + + # Check if bias is all zeros + bias_array = b.const_value.numpy() + if not np.allclose(bias_array, 0.0, atol=1e-8): + return check_result.fail("Bias is not all zeros.") + + return check_result + + +class RemoveZeroBiasFromConv(_RemoveZeroBiasBase): + """Remove zero bias from Conv operations.""" + + def __init__(self): + super().__init__("Conv") + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Conv(x, w, b, _outputs=["conv_out"]) + + def check(self, context, x: ir.Value, w: ir.Value, b: ir.Value, conv_out: ir.Value, **_) -> MatchResult: + """Check if the bias is present and is all zeros.""" + del context # Unused + check_result = MatchResult() + + # Check if bias is a constant/initializer + if b.const_value is None: + return check_result.fail("Bias is not a constant/initializer.") + + # Check if bias is all zeros + bias_array = b.const_value.numpy() + if not np.allclose(bias_array, 0.0, atol=1e-8): + return check_result.fail("Bias is not all zeros.") + + return check_result + + def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value, conv_out: ir.Value) -> ir.Value: + """Remove the bias input from the operation.""" + # Get the Conv node that produced conv_out to access its attributes + conv_node = conv_out.producer() + + # Create new Conv with preserved attributes but without bias + return op.op( + "Conv", + inputs=[x, w], # Remove bias input + attributes=conv_node.attributes, + domain=conv_node.domain, + ) + + +class RemoveZeroBiasFromConvTranspose(_RemoveZeroBiasBase): + """Remove zero bias from ConvTranspose operations.""" + + def __init__(self): + super().__init__("ConvTranspose") + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.ConvTranspose(x, w, b, _allow_other_inputs=False, _outputs=["conv_out"]) + + def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value, conv_out: ir.Value) -> ir.Value: + """Remove the bias input from the operation.""" + # Get the ConvTranspose node that produced conv_out to access its attributes + conv_node = conv_out.producer() + + # Create new ConvTranspose with preserved attributes but without bias + return op.op( + "ConvTranspose", + inputs=[x, w], # Remove bias input + attributes=conv_node.attributes, + domain=conv_node.domain, + ) + + +class RemoveZeroBiasFromQLinearConv(_RemoveZeroBiasBase): + """Remove zero bias from QLinearConv operations.""" + + def __init__(self): + super().__init__("QLinearConv") + + def pattern(self, op: ir.tape.Tape, x, x_scale, x_zero_point, w, w_scale, w_zero_point, + y_scale, y_zero_point, b: ir.Value) -> ir.Value: + return op.QLinearConv( + x, x_scale, x_zero_point, w, w_scale, w_zero_point, + y_scale, y_zero_point, b, _allow_other_inputs=False, _outputs=["conv_out"] + ) + + def check(self, context, x, x_scale, x_zero_point, w, w_scale, w_zero_point, + y_scale, y_zero_point, b: ir.Value, conv_out: ir.Value, **_) -> MatchResult: + """Check if the bias (b) is present and is all zeros.""" + del context # Unused + check_result = MatchResult() + + # Check if bias is a constant/initializer + if b.const_value is None: + return check_result.fail("Bias is not a constant/initializer.") + + # Check if bias is all zeros + bias_array = b.const_value.numpy() + if not np.allclose(bias_array, 0.0, atol=1e-8): + return check_result.fail("Bias is not all zeros.") + + return check_result + + def rewrite(self, op: ir.tape.Tape, x, x_scale, x_zero_point, w, w_scale, w_zero_point, + y_scale, y_zero_point, b: ir.Value, conv_out: ir.Value) -> ir.Value: + """Remove the bias input from the operation.""" + # Get the QLinearConv node that produced conv_out to access its attributes + conv_node = conv_out.producer() + + # Create new QLinearConv with preserved attributes but without bias + return op.op( + "QLinearConv", + inputs=[x, x_scale, x_zero_point, w, w_scale, w_zero_point, + y_scale, y_zero_point], # Remove bias input + attributes=conv_node.attributes, + domain=conv_node.domain, + ) + + +class RemoveZeroBiasFromGemm(_RemoveZeroBiasBase): + """Remove zero bias from Gemm operations.""" + + def __init__(self): + super().__init__("Gemm") + + def pattern(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value) -> ir.Value: + return op.Gemm(a, b, c, _allow_other_inputs=False, _outputs=["gemm_out"]) + + def check(self, context, a: ir.Value, b: ir.Value, c: ir.Value, gemm_out: ir.Value, **_) -> MatchResult: + """Check if the bias (c) is present and is all zeros.""" + del context # Unused + check_result = MatchResult() + + # Check if bias is a constant/initializer + if c.const_value is None: + return check_result.fail("Bias is not a constant/initializer.") + + # Check if bias is all zeros + bias_array = c.const_value.numpy() + if not np.allclose(bias_array, 0.0, atol=1e-8): + return check_result.fail("Bias is not all zeros.") + + return check_result + + def rewrite(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value, gemm_out: ir.Value) -> ir.Value: + """Remove the bias input from the operation.""" + # Get the Gemm node that produced gemm_out to access its attributes + gemm_node = gemm_out.producer() + + # Create new Gemm with preserved attributes but without bias + return op.op( + "Gemm", + inputs=[a, b], # Remove bias input + attributes=gemm_node.attributes, + domain=gemm_node.domain, + ) + + +# Create rule instances +remove_zero_bias_from_conv_rule = RemoveZeroBiasFromConv().rule() +remove_zero_bias_from_conv_transpose_rule = RemoveZeroBiasFromConvTranspose().rule() +remove_zero_bias_from_qlinear_conv_rule = RemoveZeroBiasFromQLinearConv().rule() +remove_zero_bias_from_gemm_rule = RemoveZeroBiasFromGemm().rule() + +rules = RewriteRuleSet([ + remove_zero_bias_from_conv_rule, + remove_zero_bias_from_conv_transpose_rule, + remove_zero_bias_from_qlinear_conv_rule, + remove_zero_bias_from_gemm_rule, +]) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py new file mode 100644 index 0000000000..4ae60d71b8 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Tests for removing zero bias from Conv and related operations.""" + +import onnx +import onnx.parser +import onnx_ir as ir + +from onnxscript.rewriter.rules.common._remove_zero_bias import ( + remove_zero_bias_from_conv_rule, +) + + +def test_remove_zero_bias_from_conv(): + """Test that zero bias is removed from Conv operations.""" + # Create a simple Conv with zero bias using ONNX parser + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 2, 4, 4] x) => (float[1, 2, 2, 2] y) + { + weight = Constant () + bias = Constant () + y = Conv(x, weight, bias) + } + """ + ) + + # Convert to IR model + model = ir.serde.deserialize_model(model_proto) + + # Apply the rule + count = remove_zero_bias_from_conv_rule.apply_to_model(model) + + # Check that the rule was applied + assert count == 1, f"Expected 1 application, got {count}" + + # Check that bias input was removed + conv_node = None + for node in model.graph: + if node.op_type == "Conv": + conv_node = node + break + + assert conv_node is not None, "Conv node not found" + assert len(conv_node.inputs) == 2, f"Expected 2 inputs after optimization, got {len(conv_node.inputs)}" + + +def test_conv_with_non_zero_bias_unchanged(): + """Test that Conv with non-zero bias is not modified.""" + # Create a Conv with non-zero bias using ONNX parser + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 2, 4, 4] x) => (float[1, 2, 2, 2] y) + { + weight = Constant () + bias = Constant () + y = Conv(x, weight, bias) + } + """ + ) + + # Convert to IR model + model = ir.serde.deserialize_model(model_proto) + + # Apply the rule + count = remove_zero_bias_from_conv_rule.apply_to_model(model) + + # Check that the rule was NOT applied + assert count == 0, f"Expected 0 applications, got {count}" + + # Check that bias input is still present + conv_node = None + for node in model.graph: + if node.op_type == "Conv": + conv_node = node + break + + assert conv_node is not None, "Conv node not found" + assert len(conv_node.inputs) == 3, f"Expected 3 inputs, got {len(conv_node.inputs)}" + + +if __name__ == "__main__": + test_remove_zero_bias_from_conv() + test_conv_with_non_zero_bias_unchanged() + print("All tests passed!") diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 82d7a54026..c80068919f 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -26,7 +26,7 @@ from onnxscript import ir -def test_model(model_info: hub.ModelInfo) -> float: +def validate_model(model_info: hub.ModelInfo) -> float: model_name = model_info.model with tempfile.TemporaryDirectory() as temp_dir, contextlib.redirect_stdout(None): # For parallel testing, this must be in a separate process because hub.set_dir @@ -58,7 +58,7 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, str | None]: model_path = model_info.model_path message = f"\n----Testing: {model_name} @ {model_path}----" try: - time_passed = test_model(model_info) + time_passed = validate_model(model_info) message += green(f"\n[PASS]: {model_name} roundtrip test passed.") except Exception as e: # pylint: disable=broad-exception-caught time_passed = -1 From 3742e7bad89216a67a2c23d2c79faf67af1c83ff Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Wed, 10 Sep 2025 23:27:20 +0530 Subject: [PATCH 02/13] [Rewriter] Enhance zero bias removal for Conv, ConvTranspose, Gemm, and QLinearConv operations with additional tests --- .../rules/common/_remove_zero_bias.py | 90 ++----- .../rules/common/_remove_zero_bias_test.py | 231 ++++++++++++++++++ 2 files changed, 255 insertions(+), 66 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py index 1480fcea1f..15ddf46340 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -4,9 +4,12 @@ from __future__ import annotations +from typing import ClassVar + import numpy as np from onnxscript import ir +from onnxscript.ir import convenience from onnxscript.rewriter._basics import MatchResult from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet @@ -14,9 +17,7 @@ class _RemoveZeroBiasBase(RewriteRuleClassBase): """Base class for removing zero bias from operations.""" - def __init__(self, op_type: str): - super().__init__(remove_nodes=False) - self.op_type = op_type + op_type: ClassVar def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: """Remove the bias input from the operation.""" @@ -25,48 +26,36 @@ def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir inputs=[x, w], # Remove bias input ) - def check(self, context, x: ir.Value, w: ir.Value, b: ir.Value, **_) -> MatchResult: - """Check if the bias is present and is all zeros.""" - del context # Unused + def _check_bias_is_zero(self, bias_value: ir.Value) -> MatchResult: + """Check if the bias value is present and is all zeros.""" check_result = MatchResult() # Check if bias is a constant/initializer - if b.const_value is None: + bias_tensor = convenience.get_const_tensor(bias_value) + if bias_tensor is None: return check_result.fail("Bias is not a constant/initializer.") # Check if bias is all zeros - bias_array = b.const_value.numpy() + bias_array = bias_tensor.numpy() if not np.allclose(bias_array, 0.0, atol=1e-8): return check_result.fail("Bias is not all zeros.") return check_result + def check(self, context, x: ir.Value, w: ir.Value, b: ir.Value, **_) -> MatchResult: + """Check if the bias is present and is all zeros.""" + del context # Unused + return self._check_bias_is_zero(b) + class RemoveZeroBiasFromConv(_RemoveZeroBiasBase): """Remove zero bias from Conv operations.""" - def __init__(self): - super().__init__("Conv") + op_type: ClassVar = "Conv" def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: return op.Conv(x, w, b, _outputs=["conv_out"]) - def check(self, context, x: ir.Value, w: ir.Value, b: ir.Value, conv_out: ir.Value, **_) -> MatchResult: - """Check if the bias is present and is all zeros.""" - del context # Unused - check_result = MatchResult() - - # Check if bias is a constant/initializer - if b.const_value is None: - return check_result.fail("Bias is not a constant/initializer.") - - # Check if bias is all zeros - bias_array = b.const_value.numpy() - if not np.allclose(bias_array, 0.0, atol=1e-8): - return check_result.fail("Bias is not all zeros.") - - return check_result - def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value, conv_out: ir.Value) -> ir.Value: """Remove the bias input from the operation.""" # Get the Conv node that produced conv_out to access its attributes @@ -84,11 +73,10 @@ def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value, conv_ class RemoveZeroBiasFromConvTranspose(_RemoveZeroBiasBase): """Remove zero bias from ConvTranspose operations.""" - def __init__(self): - super().__init__("ConvTranspose") + op_type: ClassVar = "ConvTranspose" def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: - return op.ConvTranspose(x, w, b, _allow_other_inputs=False, _outputs=["conv_out"]) + return op.ConvTranspose(x, w, b, _outputs=["conv_out"]) def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value, conv_out: ir.Value) -> ir.Value: """Remove the bias input from the operation.""" @@ -107,33 +95,15 @@ def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value, conv_ class RemoveZeroBiasFromQLinearConv(_RemoveZeroBiasBase): """Remove zero bias from QLinearConv operations.""" - def __init__(self): - super().__init__("QLinearConv") + op_type: ClassVar = "QLinearConv" def pattern(self, op: ir.tape.Tape, x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, y_zero_point, b: ir.Value) -> ir.Value: return op.QLinearConv( x, x_scale, x_zero_point, w, w_scale, w_zero_point, - y_scale, y_zero_point, b, _allow_other_inputs=False, _outputs=["conv_out"] + y_scale, y_zero_point, b, _outputs=["conv_out"] ) - def check(self, context, x, x_scale, x_zero_point, w, w_scale, w_zero_point, - y_scale, y_zero_point, b: ir.Value, conv_out: ir.Value, **_) -> MatchResult: - """Check if the bias (b) is present and is all zeros.""" - del context # Unused - check_result = MatchResult() - - # Check if bias is a constant/initializer - if b.const_value is None: - return check_result.fail("Bias is not a constant/initializer.") - - # Check if bias is all zeros - bias_array = b.const_value.numpy() - if not np.allclose(bias_array, 0.0, atol=1e-8): - return check_result.fail("Bias is not all zeros.") - - return check_result - def rewrite(self, op: ir.tape.Tape, x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, y_zero_point, b: ir.Value, conv_out: ir.Value) -> ir.Value: """Remove the bias input from the operation.""" @@ -153,27 +123,15 @@ def rewrite(self, op: ir.tape.Tape, x, x_scale, x_zero_point, w, w_scale, w_zero class RemoveZeroBiasFromGemm(_RemoveZeroBiasBase): """Remove zero bias from Gemm operations.""" - def __init__(self): - super().__init__("Gemm") + op_type: ClassVar = "Gemm" def pattern(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value) -> ir.Value: - return op.Gemm(a, b, c, _allow_other_inputs=False, _outputs=["gemm_out"]) + return op.Gemm(a, b, c, _outputs=["gemm_out"]) - def check(self, context, a: ir.Value, b: ir.Value, c: ir.Value, gemm_out: ir.Value, **_) -> MatchResult: - """Check if the bias (c) is present and is all zeros.""" + def check(self, context, a: ir.Value, b: ir.Value, c: ir.Value, **_) -> MatchResult: + """Check if the bias (c parameter) is present and is all zeros.""" del context # Unused - check_result = MatchResult() - - # Check if bias is a constant/initializer - if c.const_value is None: - return check_result.fail("Bias is not a constant/initializer.") - - # Check if bias is all zeros - bias_array = c.const_value.numpy() - if not np.allclose(bias_array, 0.0, atol=1e-8): - return check_result.fail("Bias is not all zeros.") - - return check_result + return self._check_bias_is_zero(c) def rewrite(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value, gemm_out: ir.Value) -> ir.Value: """Remove the bias input from the operation.""" diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py index 4ae60d71b8..7983f110e8 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -8,6 +8,9 @@ from onnxscript.rewriter.rules.common._remove_zero_bias import ( remove_zero_bias_from_conv_rule, + remove_zero_bias_from_conv_transpose_rule, + remove_zero_bias_from_gemm_rule, + remove_zero_bias_from_qlinear_conv_rule, ) @@ -81,7 +84,235 @@ def test_conv_with_non_zero_bias_unchanged(): assert len(conv_node.inputs) == 3, f"Expected 3 inputs, got {len(conv_node.inputs)}" +def test_remove_zero_bias_from_conv_transpose(): + """Test that zero bias is removed from ConvTranspose operations.""" + # Create a ConvTranspose with zero bias using ONNX parser + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 2, 2, 2] x) => (float[1, 2, 4, 4] y) + { + weight = Constant () + bias = Constant () + y = ConvTranspose(x, weight, bias) + } + """ + ) + + # Convert to IR model + model = ir.serde.deserialize_model(model_proto) + + # Apply the rule + count = remove_zero_bias_from_conv_transpose_rule.apply_to_model(model) + + # Check that the rule was applied + assert count == 1, f"Expected 1 application, got {count}" + + # Check that bias input was removed + conv_node = None + for node in model.graph: + if node.op_type == "ConvTranspose": + conv_node = node + break + + assert conv_node is not None, "ConvTranspose node not found" + assert len(conv_node.inputs) == 2, f"Expected 2 inputs after optimization, got {len(conv_node.inputs)}" + + +def test_conv_transpose_with_non_zero_bias_unchanged(): + """Test that ConvTranspose with non-zero bias is not modified.""" + # Create a ConvTranspose with non-zero bias using ONNX parser + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 2, 2, 2] x) => (float[1, 2, 4, 4] y) + { + weight = Constant () + bias = Constant () + y = ConvTranspose(x, weight, bias) + } + """ + ) + + # Convert to IR model + model = ir.serde.deserialize_model(model_proto) + + # Apply the rule + count = remove_zero_bias_from_conv_transpose_rule.apply_to_model(model) + + # Check that the rule was NOT applied + assert count == 0, f"Expected 0 applications, got {count}" + + # Check that bias input is still present + conv_node = None + for node in model.graph: + if node.op_type == "ConvTranspose": + conv_node = node + break + + assert conv_node is not None, "ConvTranspose node not found" + assert len(conv_node.inputs) == 3, f"Expected 3 inputs, got {len(conv_node.inputs)}" + + +def test_remove_zero_bias_from_gemm(): + """Test that zero bias is removed from Gemm operations.""" + # Create a Gemm with zero bias using ONNX parser + model_proto = onnx.parser.parse_model( + """ + + agraph (float[2, 3] a) => (float[2, 4] y) + { + b = Constant () + c = Constant () + y = Gemm(a, b, c) + } + """ + ) + + # Convert to IR model + model = ir.serde.deserialize_model(model_proto) + + # Apply the rule + count = remove_zero_bias_from_gemm_rule.apply_to_model(model) + + # Check that the rule was applied + assert count == 1, f"Expected 1 application, got {count}" + + # Check that bias input was removed + gemm_node = None + for node in model.graph: + if node.op_type == "Gemm": + gemm_node = node + break + + assert gemm_node is not None, "Gemm node not found" + assert len(gemm_node.inputs) == 2, f"Expected 2 inputs after optimization, got {len(gemm_node.inputs)}" + + +def test_gemm_with_non_zero_bias_unchanged(): + """Test that Gemm with non-zero bias is not modified.""" + # Create a Gemm with non-zero bias using ONNX parser + model_proto = onnx.parser.parse_model( + """ + + agraph (float[2, 3] a) => (float[2, 4] y) + { + b = Constant () + c = Constant () + y = Gemm(a, b, c) + } + """ + ) + + # Convert to IR model + model = ir.serde.deserialize_model(model_proto) + + # Apply the rule + count = remove_zero_bias_from_gemm_rule.apply_to_model(model) + + # Check that the rule was NOT applied + assert count == 0, f"Expected 0 applications, got {count}" + + # Check that bias input is still present + gemm_node = None + for node in model.graph: + if node.op_type == "Gemm": + gemm_node = node + break + + assert gemm_node is not None, "Gemm node not found" + assert len(gemm_node.inputs) == 3, f"Expected 3 inputs, got {len(gemm_node.inputs)}" + + +def test_remove_zero_bias_from_qlinear_conv(): + """Test that zero bias is removed from QLinearConv operations.""" + # Create a QLinearConv with zero bias using ONNX parser + model_proto = onnx.parser.parse_model( + """ + + agraph (uint8[1, 2, 4, 4] x) => (uint8[1, 2, 2, 2] y) + { + x_scale = Constant () + x_zero_point = Constant () + weight = Constant () + w_scale = Constant () + w_zero_point = Constant () + y_scale = Constant () + y_zero_point = Constant () + bias = Constant () + y = QLinearConv(x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) + } + """ + ) + + # Convert to IR model + model = ir.serde.deserialize_model(model_proto) + + # Apply the rule + count = remove_zero_bias_from_qlinear_conv_rule.apply_to_model(model) + + # Check that the rule was applied + assert count == 1, f"Expected 1 application, got {count}" + + # Check that bias input was removed + qconv_node = None + for node in model.graph: + if node.op_type == "QLinearConv": + qconv_node = node + break + + assert qconv_node is not None, "QLinearConv node not found" + assert len(qconv_node.inputs) == 8, f"Expected 8 inputs after optimization, got {len(qconv_node.inputs)}" + + +def test_qlinear_conv_with_non_zero_bias_unchanged(): + """Test that QLinearConv with non-zero bias is not modified.""" + # Create a QLinearConv with non-zero bias using ONNX parser + model_proto = onnx.parser.parse_model( + """ + + agraph (uint8[1, 2, 4, 4] x) => (uint8[1, 2, 2, 2] y) + { + x_scale = Constant () + x_zero_point = Constant () + weight = Constant () + w_scale = Constant () + w_zero_point = Constant () + y_scale = Constant () + y_zero_point = Constant () + bias = Constant () + y = QLinearConv(x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) + } + """ + ) + + # Convert to IR model + model = ir.serde.deserialize_model(model_proto) + + # Apply the rule + count = remove_zero_bias_from_qlinear_conv_rule.apply_to_model(model) + + # Check that the rule was NOT applied + assert count == 0, f"Expected 0 applications, got {count}" + + # Check that bias input is still present + qconv_node = None + for node in model.graph: + if node.op_type == "QLinearConv": + qconv_node = node + break + + assert qconv_node is not None, "QLinearConv node not found" + assert len(qconv_node.inputs) == 9, f"Expected 9 inputs, got {len(qconv_node.inputs)}" + + if __name__ == "__main__": test_remove_zero_bias_from_conv() test_conv_with_non_zero_bias_unchanged() + test_remove_zero_bias_from_conv_transpose() + test_conv_transpose_with_non_zero_bias_unchanged() + test_remove_zero_bias_from_gemm() + test_gemm_with_non_zero_bias_unchanged() + test_remove_zero_bias_from_qlinear_conv() + test_qlinear_conv_with_non_zero_bias_unchanged() print("All tests passed!") From 8bfa65f94b52718b72f2a9d354f253406d4c0554 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Thu, 11 Sep 2025 03:24:52 +0530 Subject: [PATCH 03/13] Refactor zero bias removal tests to use helper function and improve structure --- .../rules/common/_remove_zero_bias.py | 2 - .../rules/common/_remove_zero_bias_test.py | 499 ++++++++---------- 2 files changed, 207 insertions(+), 294 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py index 15ddf46340..1e17bbb9a5 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -17,8 +17,6 @@ class _RemoveZeroBiasBase(RewriteRuleClassBase): """Base class for removing zero bias from operations.""" - op_type: ClassVar - def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: """Remove the bias input from the operation.""" return op.op( diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py index 7983f110e8..984cb9cb8b 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -2,10 +2,11 @@ # Licensed under the MIT License. """Tests for removing zero bias from Conv and related operations.""" -import onnx -import onnx.parser +import unittest + import onnx_ir as ir +from onnxscript.rewriter import testing from onnxscript.rewriter.rules.common._remove_zero_bias import ( remove_zero_bias_from_conv_rule, remove_zero_bias_from_conv_transpose_rule, @@ -14,305 +15,219 @@ ) -def test_remove_zero_bias_from_conv(): - """Test that zero bias is removed from Conv operations.""" - # Create a simple Conv with zero bias using ONNX parser - model_proto = onnx.parser.parse_model( - """ - - agraph (float[1, 2, 4, 4] x) => (float[1, 2, 2, 2] y) - { - weight = Constant () - bias = Constant () - y = Conv(x, weight, bias) - } - """ - ) - - # Convert to IR model - model = ir.serde.deserialize_model(model_proto) - - # Apply the rule - count = remove_zero_bias_from_conv_rule.apply_to_model(model) - - # Check that the rule was applied - assert count == 1, f"Expected 1 application, got {count}" - - # Check that bias input was removed - conv_node = None - for node in model.graph: - if node.op_type == "Conv": - conv_node = node - break - - assert conv_node is not None, "Conv node not found" - assert len(conv_node.inputs) == 2, f"Expected 2 inputs after optimization, got {len(conv_node.inputs)}" - - -def test_conv_with_non_zero_bias_unchanged(): - """Test that Conv with non-zero bias is not modified.""" - # Create a Conv with non-zero bias using ONNX parser - model_proto = onnx.parser.parse_model( - """ - - agraph (float[1, 2, 4, 4] x) => (float[1, 2, 2, 2] y) - { - weight = Constant () - bias = Constant () - y = Conv(x, weight, bias) - } - """ - ) - - # Convert to IR model - model = ir.serde.deserialize_model(model_proto) - - # Apply the rule - count = remove_zero_bias_from_conv_rule.apply_to_model(model) - - # Check that the rule was NOT applied - assert count == 0, f"Expected 0 applications, got {count}" - - # Check that bias input is still present - conv_node = None - for node in model.graph: - if node.op_type == "Conv": - conv_node = node - break - - assert conv_node is not None, "Conv node not found" - assert len(conv_node.inputs) == 3, f"Expected 3 inputs, got {len(conv_node.inputs)}" - - -def test_remove_zero_bias_from_conv_transpose(): - """Test that zero bias is removed from ConvTranspose operations.""" - # Create a ConvTranspose with zero bias using ONNX parser - model_proto = onnx.parser.parse_model( - """ - - agraph (float[1, 2, 2, 2] x) => (float[1, 2, 4, 4] y) - { - weight = Constant () - bias = Constant () - y = ConvTranspose(x, weight, bias) - } - """ - ) - - # Convert to IR model - model = ir.serde.deserialize_model(model_proto) - - # Apply the rule - count = remove_zero_bias_from_conv_transpose_rule.apply_to_model(model) - - # Check that the rule was applied - assert count == 1, f"Expected 1 application, got {count}" - - # Check that bias input was removed - conv_node = None - for node in model.graph: - if node.op_type == "ConvTranspose": - conv_node = node - break - - assert conv_node is not None, "ConvTranspose node not found" - assert len(conv_node.inputs) == 2, f"Expected 2 inputs after optimization, got {len(conv_node.inputs)}" - - -def test_conv_transpose_with_non_zero_bias_unchanged(): - """Test that ConvTranspose with non-zero bias is not modified.""" - # Create a ConvTranspose with non-zero bias using ONNX parser - model_proto = onnx.parser.parse_model( - """ - - agraph (float[1, 2, 2, 2] x) => (float[1, 2, 4, 4] y) - { - weight = Constant () - bias = Constant () - y = ConvTranspose(x, weight, bias) - } - """ - ) - - # Convert to IR model - model = ir.serde.deserialize_model(model_proto) - - # Apply the rule - count = remove_zero_bias_from_conv_transpose_rule.apply_to_model(model) - - # Check that the rule was NOT applied - assert count == 0, f"Expected 0 applications, got {count}" - - # Check that bias input is still present - conv_node = None - for node in model.graph: - if node.op_type == "ConvTranspose": - conv_node = node - break - - assert conv_node is not None, "ConvTranspose node not found" - assert len(conv_node.inputs) == 3, f"Expected 3 inputs, got {len(conv_node.inputs)}" - - -def test_remove_zero_bias_from_gemm(): - """Test that zero bias is removed from Gemm operations.""" - # Create a Gemm with zero bias using ONNX parser - model_proto = onnx.parser.parse_model( - """ - - agraph (float[2, 3] a) => (float[2, 4] y) - { - b = Constant () - c = Constant () - y = Gemm(a, b, c) - } - """ - ) - - # Convert to IR model - model = ir.serde.deserialize_model(model_proto) +def _apply_rule_and_check_optimization( + model: ir.Model, + rule, + expected_count: int, + target_op_type: str, + expected_inputs_after: int, +) -> None: + """Helper function to test bias removal rules.""" + # Make a copy for comparison + original_model = ir.from_proto(ir.to_proto(model)) # Apply the rule - count = remove_zero_bias_from_gemm_rule.apply_to_model(model) + count = rule.apply_to_model(model) - # Check that the rule was applied - assert count == 1, f"Expected 1 application, got {count}" + # Check that the rule was applied the expected number of times + assert count == expected_count, f"Expected {expected_count} applications, got {count}" - # Check that bias input was removed - gemm_node = None + # Check that the target node has the expected number of inputs + target_node = None for node in model.graph: - if node.op_type == "Gemm": - gemm_node = node + if node.op_type == target_op_type: + target_node = node break - assert gemm_node is not None, "Gemm node not found" - assert len(gemm_node.inputs) == 2, f"Expected 2 inputs after optimization, got {len(gemm_node.inputs)}" - - -def test_gemm_with_non_zero_bias_unchanged(): - """Test that Gemm with non-zero bias is not modified.""" - # Create a Gemm with non-zero bias using ONNX parser - model_proto = onnx.parser.parse_model( - """ - - agraph (float[2, 3] a) => (float[2, 4] y) - { - b = Constant () - c = Constant () - y = Gemm(a, b, c) - } - """ + assert target_node is not None, f"{target_op_type} node not found" + assert len(target_node.inputs) == expected_inputs_after, ( + f"Expected {expected_inputs_after} inputs after optimization, " + f"got {len(target_node.inputs)}" ) - # Convert to IR model - model = ir.serde.deserialize_model(model_proto) - - # Apply the rule - count = remove_zero_bias_from_gemm_rule.apply_to_model(model) - - # Check that the rule was NOT applied - assert count == 0, f"Expected 0 applications, got {count}" - - # Check that bias input is still present - gemm_node = None - for node in model.graph: - if node.op_type == "Gemm": - gemm_node = node - break - - assert gemm_node is not None, "Gemm node not found" - assert len(gemm_node.inputs) == 3, f"Expected 3 inputs, got {len(gemm_node.inputs)}" - - -def test_remove_zero_bias_from_qlinear_conv(): - """Test that zero bias is removed from QLinearConv operations.""" - # Create a QLinearConv with zero bias using ONNX parser - model_proto = onnx.parser.parse_model( - """ - - agraph (uint8[1, 2, 4, 4] x) => (uint8[1, 2, 2, 2] y) - { - x_scale = Constant () - x_zero_point = Constant () - weight = Constant () - w_scale = Constant () - w_zero_point = Constant () - y_scale = Constant () - y_zero_point = Constant () - bias = Constant () - y = QLinearConv(x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) - } - """ - ) - - # Convert to IR model - model = ir.serde.deserialize_model(model_proto) - - # Apply the rule - count = remove_zero_bias_from_qlinear_conv_rule.apply_to_model(model) - - # Check that the rule was applied - assert count == 1, f"Expected 1 application, got {count}" - - # Check that bias input was removed - qconv_node = None - for node in model.graph: - if node.op_type == "QLinearConv": - qconv_node = node - break - - assert qconv_node is not None, "QLinearConv node not found" - assert len(qconv_node.inputs) == 8, f"Expected 8 inputs after optimization, got {len(qconv_node.inputs)}" - - -def test_qlinear_conv_with_non_zero_bias_unchanged(): - """Test that QLinearConv with non-zero bias is not modified.""" - # Create a QLinearConv with non-zero bias using ONNX parser - model_proto = onnx.parser.parse_model( - """ - - agraph (uint8[1, 2, 4, 4] x) => (uint8[1, 2, 2, 2] y) - { - x_scale = Constant () - x_zero_point = Constant () - weight = Constant () - w_scale = Constant () - w_zero_point = Constant () - y_scale = Constant () - y_zero_point = Constant () - bias = Constant () - y = QLinearConv(x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) - } - """ - ) - - # Convert to IR model - model = ir.serde.deserialize_model(model_proto) - - # Apply the rule - count = remove_zero_bias_from_qlinear_conv_rule.apply_to_model(model) - - # Check that the rule was NOT applied - assert count == 0, f"Expected 0 applications, got {count}" - - # Check that bias input is still present - qconv_node = None - for node in model.graph: - if node.op_type == "QLinearConv": - qconv_node = node - break - - assert qconv_node is not None, "QLinearConv node not found" - assert len(qconv_node.inputs) == 9, f"Expected 9 inputs, got {len(qconv_node.inputs)}" + # Compare outputs to ensure correctness (only for supported input types) + if expected_count > 0: + try: + # Generate random inputs for the model using the existing testing utility + original_model_proto = ir.to_proto(original_model) + inputs = testing.generate_random_inputs(original_model_proto) + testing.assert_numerically_equal(original_model, model, inputs) + except ValueError as e: + if "Not implemented for input type" in str(e): + # Skip numerical comparison for unsupported input types + # The structural checks above are sufficient for these cases + pass + else: + raise + + +class RemoveZeroBiasTest(unittest.TestCase): + """Test class for remove zero bias rules.""" + + def test_remove_zero_bias_from_conv(self): + """Test that zero bias is removed from Conv operations.""" + # Create a simple Conv with zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 4, 4] x) => (float[1, 2, 2, 2] y) + { + weight = Constant () + bias = Constant () + y = Conv(x, weight, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, remove_zero_bias_from_conv_rule, expected_count=1, target_op_type="Conv", expected_inputs_after=2 + ) + + def test_conv_with_non_zero_bias_unchanged(self): + """Test that Conv with non-zero bias is not modified.""" + # Create a Conv with non-zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 4, 4] x) => (float[1, 2, 2, 2] y) + { + weight = Constant () + bias = Constant () + y = Conv(x, weight, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, remove_zero_bias_from_conv_rule, expected_count=0, target_op_type="Conv", expected_inputs_after=3 + ) + + def test_remove_zero_bias_from_conv_transpose(self): + """Test that zero bias is removed from ConvTranspose operations.""" + # Create a ConvTranspose with zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 2, 2] x) => (float[1, 2, 4, 4] y) + { + weight = Constant () + bias = Constant () + y = ConvTranspose(x, weight, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, remove_zero_bias_from_conv_transpose_rule, expected_count=1, target_op_type="ConvTranspose", expected_inputs_after=2 + ) + + def test_conv_transpose_with_non_zero_bias_unchanged(self): + """Test that ConvTranspose with non-zero bias is not modified.""" + # Create a ConvTranspose with non-zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 2, 2] x) => (float[1, 2, 4, 4] y) + { + weight = Constant () + bias = Constant () + y = ConvTranspose(x, weight, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, remove_zero_bias_from_conv_transpose_rule, expected_count=0, target_op_type="ConvTranspose", expected_inputs_after=3 + ) + + def test_remove_zero_bias_from_gemm(self): + """Test that zero bias is removed from Gemm operations.""" + # Create a Gemm with zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[2, 3] a) => (float[2, 4] y) + { + b = Constant () + c = Constant () + y = Gemm(a, b, c) + } + """ + ) + + _apply_rule_and_check_optimization( + model, remove_zero_bias_from_gemm_rule, expected_count=1, target_op_type="Gemm", expected_inputs_after=2 + ) + + def test_gemm_with_non_zero_bias_unchanged(self): + """Test that Gemm with non-zero bias is not modified.""" + # Create a Gemm with non-zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[2, 3] a) => (float[2, 4] y) + { + b = Constant () + c = Constant () + y = Gemm(a, b, c) + } + """ + ) + + _apply_rule_and_check_optimization( + model, remove_zero_bias_from_gemm_rule, expected_count=0, target_op_type="Gemm", expected_inputs_after=3 + ) + + def test_remove_zero_bias_from_qlinear_conv(self): + """Test that zero bias is removed from QLinearConv operations.""" + # Create a QLinearConv with zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (uint8[1, 2, 4, 4] x) => (uint8[1, 2, 2, 2] y) + { + x_scale = Constant () + x_zero_point = Constant () + weight = Constant () + w_scale = Constant () + w_zero_point = Constant () + y_scale = Constant () + y_zero_point = Constant () + bias = Constant () + y = QLinearConv(x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, remove_zero_bias_from_qlinear_conv_rule, expected_count=1, target_op_type="QLinearConv", expected_inputs_after=8 + ) + + def test_qlinear_conv_with_non_zero_bias_unchanged(self): + """Test that QLinearConv with non-zero bias is not modified.""" + # Create a QLinearConv with non-zero bias using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (uint8[1, 2, 4, 4] x) => (uint8[1, 2, 2, 2] y) + { + x_scale = Constant () + x_zero_point = Constant () + weight = Constant () + w_scale = Constant () + w_zero_point = Constant () + y_scale = Constant () + y_zero_point = Constant () + bias = Constant () + y = QLinearConv(x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) + } + """ + ) + + _apply_rule_and_check_optimization( + model, remove_zero_bias_from_qlinear_conv_rule, expected_count=0, target_op_type="QLinearConv", expected_inputs_after=9 + ) if __name__ == "__main__": - test_remove_zero_bias_from_conv() - test_conv_with_non_zero_bias_unchanged() - test_remove_zero_bias_from_conv_transpose() - test_conv_transpose_with_non_zero_bias_unchanged() - test_remove_zero_bias_from_gemm() - test_gemm_with_non_zero_bias_unchanged() - test_remove_zero_bias_from_qlinear_conv() - test_qlinear_conv_with_non_zero_bias_unchanged() - print("All tests passed!") + unittest.main() From b93a56c35bf5c1bfae710ccc2d3b316722e0497d Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Thu, 11 Sep 2025 23:25:48 +0530 Subject: [PATCH 04/13] Refactor test cases for zero bias removal to improve readability and maintainability --- .../rules/common/_remove_zero_bias_test.py | 48 +++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py index 984cb9cb8b..38b28085e6 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -80,7 +80,11 @@ def test_remove_zero_bias_from_conv(self): ) _apply_rule_and_check_optimization( - model, remove_zero_bias_from_conv_rule, expected_count=1, target_op_type="Conv", expected_inputs_after=2 + model, + remove_zero_bias_from_conv_rule, + expected_count=1, + target_op_type="Conv", + expected_inputs_after=2, ) def test_conv_with_non_zero_bias_unchanged(self): @@ -99,7 +103,11 @@ def test_conv_with_non_zero_bias_unchanged(self): ) _apply_rule_and_check_optimization( - model, remove_zero_bias_from_conv_rule, expected_count=0, target_op_type="Conv", expected_inputs_after=3 + model, + remove_zero_bias_from_conv_rule, + expected_count=0, + target_op_type="Conv", + expected_inputs_after=3, ) def test_remove_zero_bias_from_conv_transpose(self): @@ -118,7 +126,11 @@ def test_remove_zero_bias_from_conv_transpose(self): ) _apply_rule_and_check_optimization( - model, remove_zero_bias_from_conv_transpose_rule, expected_count=1, target_op_type="ConvTranspose", expected_inputs_after=2 + model, + remove_zero_bias_from_conv_transpose_rule, + expected_count=1, + target_op_type="ConvTranspose", + expected_inputs_after=2, ) def test_conv_transpose_with_non_zero_bias_unchanged(self): @@ -137,7 +149,11 @@ def test_conv_transpose_with_non_zero_bias_unchanged(self): ) _apply_rule_and_check_optimization( - model, remove_zero_bias_from_conv_transpose_rule, expected_count=0, target_op_type="ConvTranspose", expected_inputs_after=3 + model, + remove_zero_bias_from_conv_transpose_rule, + expected_count=0, + target_op_type="ConvTranspose", + expected_inputs_after=3, ) def test_remove_zero_bias_from_gemm(self): @@ -156,7 +172,11 @@ def test_remove_zero_bias_from_gemm(self): ) _apply_rule_and_check_optimization( - model, remove_zero_bias_from_gemm_rule, expected_count=1, target_op_type="Gemm", expected_inputs_after=2 + model, + remove_zero_bias_from_gemm_rule, + expected_count=1, + target_op_type="Gemm", + expected_inputs_after=2, ) def test_gemm_with_non_zero_bias_unchanged(self): @@ -175,7 +195,11 @@ def test_gemm_with_non_zero_bias_unchanged(self): ) _apply_rule_and_check_optimization( - model, remove_zero_bias_from_gemm_rule, expected_count=0, target_op_type="Gemm", expected_inputs_after=3 + model, + remove_zero_bias_from_gemm_rule, + expected_count=0, + target_op_type="Gemm", + expected_inputs_after=3, ) def test_remove_zero_bias_from_qlinear_conv(self): @@ -200,7 +224,11 @@ def test_remove_zero_bias_from_qlinear_conv(self): ) _apply_rule_and_check_optimization( - model, remove_zero_bias_from_qlinear_conv_rule, expected_count=1, target_op_type="QLinearConv", expected_inputs_after=8 + model, + remove_zero_bias_from_qlinear_conv_rule, + expected_count=1, + target_op_type="QLinearConv", + expected_inputs_after=8, ) def test_qlinear_conv_with_non_zero_bias_unchanged(self): @@ -225,7 +253,11 @@ def test_qlinear_conv_with_non_zero_bias_unchanged(self): ) _apply_rule_and_check_optimization( - model, remove_zero_bias_from_qlinear_conv_rule, expected_count=0, target_op_type="QLinearConv", expected_inputs_after=9 + model, + remove_zero_bias_from_qlinear_conv_rule, + expected_count=0, + target_op_type="QLinearConv", + expected_inputs_after=9, ) From e32262539a0012dcf7dd8ce5409b24ec20f3ad07 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Thu, 11 Sep 2025 23:25:56 +0530 Subject: [PATCH 05/13] Remove duplicate import of _fuse_batchnorm in rewriter module --- onnxscript/rewriter/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 16be97adca..fe530159c7 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -36,7 +36,6 @@ _cast_constant_of_shape, _collapse_slices, _fuse_batchnorm, - _fuse_batchnorm, _fuse_pad_into_conv, _fuse_relus_clips, _min_max_to_clip, From 121360e967e7af7cb2fe3ef4149a42a79a548406 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Thu, 11 Sep 2025 23:26:08 +0530 Subject: [PATCH 06/13] Refactor zero bias removal logic to streamline input handling and enhance clarity in Conv, ConvTranspose, QLinearConv, and Gemm operations --- .../rules/common/_remove_zero_bias.py | 112 +++++++----------- 1 file changed, 43 insertions(+), 69 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py index 1e17bbb9a5..fcc7e66216 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -17,11 +17,18 @@ class _RemoveZeroBiasBase(RewriteRuleClassBase): """Base class for removing zero bias from operations.""" - def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: """Remove the bias input from the operation.""" + node = out.producer() + + original_inputs = list(node.inputs) + inputs_without_bias = original_inputs[:-1] + return op.op( self.op_type, - inputs=[x, w], # Remove bias input + inputs=inputs_without_bias, + attributes=node.attributes, + domain=node.domain, ) def _check_bias_is_zero(self, bias_value: ir.Value) -> MatchResult: @@ -52,20 +59,7 @@ class RemoveZeroBiasFromConv(_RemoveZeroBiasBase): op_type: ClassVar = "Conv" def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: - return op.Conv(x, w, b, _outputs=["conv_out"]) - - def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value, conv_out: ir.Value) -> ir.Value: - """Remove the bias input from the operation.""" - # Get the Conv node that produced conv_out to access its attributes - conv_node = conv_out.producer() - - # Create new Conv with preserved attributes but without bias - return op.op( - "Conv", - inputs=[x, w], # Remove bias input - attributes=conv_node.attributes, - domain=conv_node.domain, - ) + return op.Conv(x, w, b, _outputs=["out"]) class RemoveZeroBiasFromConvTranspose(_RemoveZeroBiasBase): @@ -74,20 +68,7 @@ class RemoveZeroBiasFromConvTranspose(_RemoveZeroBiasBase): op_type: ClassVar = "ConvTranspose" def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: - return op.ConvTranspose(x, w, b, _outputs=["conv_out"]) - - def rewrite(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value, conv_out: ir.Value) -> ir.Value: - """Remove the bias input from the operation.""" - # Get the ConvTranspose node that produced conv_out to access its attributes - conv_node = conv_out.producer() - - # Create new ConvTranspose with preserved attributes but without bias - return op.op( - "ConvTranspose", - inputs=[x, w], # Remove bias input - attributes=conv_node.attributes, - domain=conv_node.domain, - ) + return op.ConvTranspose(x, w, b, _outputs=["out"]) class RemoveZeroBiasFromQLinearConv(_RemoveZeroBiasBase): @@ -95,26 +76,30 @@ class RemoveZeroBiasFromQLinearConv(_RemoveZeroBiasBase): op_type: ClassVar = "QLinearConv" - def pattern(self, op: ir.tape.Tape, x, x_scale, x_zero_point, w, w_scale, w_zero_point, - y_scale, y_zero_point, b: ir.Value) -> ir.Value: + def pattern( + self, + op: ir.tape.Tape, + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b: ir.Value, + ) -> ir.Value: return op.QLinearConv( - x, x_scale, x_zero_point, w, w_scale, w_zero_point, - y_scale, y_zero_point, b, _outputs=["conv_out"] - ) - - def rewrite(self, op: ir.tape.Tape, x, x_scale, x_zero_point, w, w_scale, w_zero_point, - y_scale, y_zero_point, b: ir.Value, conv_out: ir.Value) -> ir.Value: - """Remove the bias input from the operation.""" - # Get the QLinearConv node that produced conv_out to access its attributes - conv_node = conv_out.producer() - - # Create new QLinearConv with preserved attributes but without bias - return op.op( - "QLinearConv", - inputs=[x, x_scale, x_zero_point, w, w_scale, w_zero_point, - y_scale, y_zero_point], # Remove bias input - attributes=conv_node.attributes, - domain=conv_node.domain, + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b, + _outputs=["out"], ) @@ -124,26 +109,13 @@ class RemoveZeroBiasFromGemm(_RemoveZeroBiasBase): op_type: ClassVar = "Gemm" def pattern(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value) -> ir.Value: - return op.Gemm(a, b, c, _outputs=["gemm_out"]) + return op.Gemm(a, b, c, _outputs=["out"]) def check(self, context, a: ir.Value, b: ir.Value, c: ir.Value, **_) -> MatchResult: """Check if the bias (c parameter) is present and is all zeros.""" del context # Unused return self._check_bias_is_zero(c) - def rewrite(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value, gemm_out: ir.Value) -> ir.Value: - """Remove the bias input from the operation.""" - # Get the Gemm node that produced gemm_out to access its attributes - gemm_node = gemm_out.producer() - - # Create new Gemm with preserved attributes but without bias - return op.op( - "Gemm", - inputs=[a, b], # Remove bias input - attributes=gemm_node.attributes, - domain=gemm_node.domain, - ) - # Create rule instances remove_zero_bias_from_conv_rule = RemoveZeroBiasFromConv().rule() @@ -151,9 +123,11 @@ def rewrite(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value, gemm_ remove_zero_bias_from_qlinear_conv_rule = RemoveZeroBiasFromQLinearConv().rule() remove_zero_bias_from_gemm_rule = RemoveZeroBiasFromGemm().rule() -rules = RewriteRuleSet([ - remove_zero_bias_from_conv_rule, - remove_zero_bias_from_conv_transpose_rule, - remove_zero_bias_from_qlinear_conv_rule, - remove_zero_bias_from_gemm_rule, -]) +rules = RewriteRuleSet( + [ + remove_zero_bias_from_conv_rule, + remove_zero_bias_from_conv_transpose_rule, + remove_zero_bias_from_qlinear_conv_rule, + remove_zero_bias_from_gemm_rule, + ] +) From ee7dafa627e78f4ae2ac29bdded8e5a5d23f8c49 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Sun, 14 Sep 2025 16:17:05 +0530 Subject: [PATCH 07/13] Refactor Gemm operation pattern and check method to align with zero bias removal logic --- onnxscript/rewriter/rules/common/_remove_zero_bias.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py index fcc7e66216..ccf5838776 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -108,13 +108,8 @@ class RemoveZeroBiasFromGemm(_RemoveZeroBiasBase): op_type: ClassVar = "Gemm" - def pattern(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value) -> ir.Value: - return op.Gemm(a, b, c, _outputs=["out"]) - - def check(self, context, a: ir.Value, b: ir.Value, c: ir.Value, **_) -> MatchResult: - """Check if the bias (c parameter) is present and is all zeros.""" - del context # Unused - return self._check_bias_is_zero(c) + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Gemm(x, w, b, _outputs=["out"]) # Create rule instances From 8ef6c41d308bdea02ec99700e54d546740791cec Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Sun, 14 Sep 2025 16:40:09 +0530 Subject: [PATCH 08/13] Enhance zero bias removal logic to filter bias parameters and preserve attributes in Conv, ConvTranspose, Gemm, and QLinearConv operations --- .../rules/common/_remove_zero_bias.py | 9 +- .../rules/common/_remove_zero_bias_test.py | 166 ++++++++++++++++++ 2 files changed, 172 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py index ccf5838776..5915b48e80 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -21,12 +21,15 @@ def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: """Remove the bias input from the operation.""" node = out.producer() - original_inputs = list(node.inputs) - inputs_without_bias = original_inputs[:-1] + # Filter out the bias parameter and keep all other inputs + inputs = [] + for param_name, param_value in _.items(): + if param_name != "b": # 'b' is the bias parameter + inputs.append(param_value) return op.op( self.op_type, - inputs=inputs_without_bias, + inputs=inputs, attributes=node.attributes, domain=node.domain, ) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py index 38b28085e6..f67dbeb000 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -3,6 +3,7 @@ """Tests for removing zero bias from Conv and related operations.""" import unittest +from typing import Optional import onnx_ir as ir @@ -21,11 +22,19 @@ def _apply_rule_and_check_optimization( expected_count: int, target_op_type: str, expected_inputs_after: int, + expected_attributes: Optional[dict] = None, ) -> None: """Helper function to test bias removal rules.""" # Make a copy for comparison original_model = ir.from_proto(ir.to_proto(model)) + # Get original attributes for comparison + original_target_node = None + for node in original_model.graph: + if node.op_type == target_op_type: + original_target_node = node + break + # Apply the rule count = rule.apply_to_model(model) @@ -45,6 +54,29 @@ def _apply_rule_and_check_optimization( f"got {len(target_node.inputs)}" ) + # Check that attributes are preserved if the rule was applied + if expected_count > 0 and original_target_node is not None: + # All original attributes should be preserved + for attr_name, attr_value in original_target_node.attributes.items(): + assert attr_name in target_node.attributes, f"Attribute {attr_name} was lost" + original_value = attr_value.value + new_value = target_node.attributes[attr_name].value + assert new_value == original_value, ( + f"Attribute {attr_name} value changed from {original_value} to {new_value}" + ) + + # Check specific expected attributes if provided + if expected_attributes: + for attr_name, expected_value in expected_attributes.items(): + assert attr_name in target_node.attributes, ( + f"Expected attribute {attr_name} not found" + ) + actual_attr = target_node.attributes[attr_name] + actual_value = actual_attr.value + assert actual_value == expected_value, ( + f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" + ) + # Compare outputs to ensure correctness (only for supported input types) if expected_count > 0: try: @@ -231,6 +263,140 @@ def test_remove_zero_bias_from_qlinear_conv(self): expected_inputs_after=8, ) + def test_remove_zero_bias_from_conv_with_attributes(self): + """Test that zero bias is removed from Conv operations and attributes are preserved.""" + # Create a Conv with zero bias and various attributes using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 6, 6] x) => (float[1, 2, 2, 2] y) + { + weight = Constant () + bias = Constant () + y = Conv (x, weight, bias) + } + """ + ) + + expected_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_conv_rule, + expected_count=1, + target_op_type="Conv", + expected_inputs_after=2, + expected_attributes=expected_attributes, + ) + + def test_remove_zero_bias_from_conv_transpose_with_attributes(self): + """Test that zero bias is removed from ConvTranspose operations and attributes are preserved.""" + # Create a ConvTranspose with zero bias and various attributes using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[1, 2, 2, 2] x) => (float[1, 2, 6, 6] y) + { + weight = Constant () + bias = Constant () + y = ConvTranspose (x, weight, bias) + } + """ + ) + + expected_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "output_padding": [0, 0], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_conv_transpose_rule, + expected_count=1, + target_op_type="ConvTranspose", + expected_inputs_after=2, + expected_attributes=expected_attributes, + ) + + def test_remove_zero_bias_from_gemm_with_attributes(self): + """Test that zero bias is removed from Gemm operations and attributes are preserved.""" + # Create a Gemm with zero bias and various attributes using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (float[2, 3] a) => (float[2, 4] y) + { + b = Constant () + c = Constant () + y = Gemm (a, b, c) + } + """ + ) + + expected_attributes = { + "alpha": 2.0, + "beta": 1.0, + "transA": 0, + "transB": 1, + } + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_gemm_rule, + expected_count=1, + target_op_type="Gemm", + expected_inputs_after=2, + expected_attributes=expected_attributes, + ) + + def test_remove_zero_bias_from_qlinear_conv_with_attributes(self): + """Test that zero bias is removed from QLinearConv operations and attributes are preserved.""" + # Create a QLinearConv with zero bias and various attributes using ONNX text format + model = ir.from_onnx_text( + """ + + agraph (uint8[1, 2, 6, 6] x) => (uint8[1, 2, 2, 2] y) + { + x_scale = Constant () + x_zero_point = Constant () + weight = Constant () + w_scale = Constant () + w_zero_point = Constant () + y_scale = Constant () + y_zero_point = Constant () + bias = Constant () + y = QLinearConv (x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias) + } + """ + ) + + expected_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + + _apply_rule_and_check_optimization( + model, + remove_zero_bias_from_qlinear_conv_rule, + expected_count=1, + target_op_type="QLinearConv", + expected_inputs_after=8, + expected_attributes=expected_attributes, + ) + def test_qlinear_conv_with_non_zero_bias_unchanged(self): """Test that QLinearConv with non-zero bias is not modified.""" # Create a QLinearConv with non-zero bias using ONNX text format From 2b9dda49e8b6fc9c82142f3a1b6c54cff0c625fc Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Sun, 14 Sep 2025 21:47:11 +0530 Subject: [PATCH 09/13] Refactor bias removal logic to directly use operation inputs, improving efficiency in _RemoveZeroBiasBase class --- onnxscript/rewriter/rules/common/_remove_zero_bias.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py index 5915b48e80..26d3ecab52 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -21,15 +21,9 @@ def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: """Remove the bias input from the operation.""" node = out.producer() - # Filter out the bias parameter and keep all other inputs - inputs = [] - for param_name, param_value in _.items(): - if param_name != "b": # 'b' is the bias parameter - inputs.append(param_value) - return op.op( self.op_type, - inputs=inputs, + inputs=node.inputs[:-1], attributes=node.attributes, domain=node.domain, ) From 153b4e718e70913fffa9e7fc00953e4bd8b12103 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Mon, 15 Sep 2025 22:59:56 +0530 Subject: [PATCH 10/13] Remove redundant domain attribute from operation inputs in _RemoveZeroBiasBase class --- onnxscript/rewriter/rules/common/_remove_zero_bias.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias.py b/onnxscript/rewriter/rules/common/_remove_zero_bias.py index 26d3ecab52..901e643ac1 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias.py @@ -25,7 +25,6 @@ def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: self.op_type, inputs=node.inputs[:-1], attributes=node.attributes, - domain=node.domain, ) def _check_bias_is_zero(self, bias_value: ir.Value) -> MatchResult: From 86de85f324d9eca074fea9d9704173ddb7d1a296 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Sun, 21 Sep 2025 14:41:22 +0530 Subject: [PATCH 11/13] Refactor IR value creation in tests to use `ir.Value` for consistency and improved clarity --- .../bfloat16_utils/bfloat16_converter_test.py | 17 ++++++----- .../rules/common/_basic_rules_test.py | 18 ++++++++---- .../rules/common/_fuse_pad_into_conv_test.py | 8 +++--- .../rules/common/_matmul_add_to_gemm_test.py | 14 ++++++---- .../rules/common/_remove_zero_bias_test.py | 28 +++++++++++++++---- 5 files changed, 56 insertions(+), 29 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py index a64d6e6023..c401f94a15 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -6,20 +6,23 @@ import onnx import onnx.checker import onnx.shape_inference +import onnx_ir as ir import onnxruntime -from onnxscript import ir from onnxscript.rewriter.onnxruntime.bfloat16_utils import bfloat16_converter class Bfloat16ConversionTest(unittest.TestCase): def setUp(self) -> None: - self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4])) - self.v0.dtype = ir.DataType.BFLOAT16 - self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4])) - self.v1.dtype = ir.DataType.BFLOAT16 - self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4])) - self.v2.dtype = ir.DataType.BFLOAT16 + self.v0 = ir.Value( + name="v0", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16) + ) + self.v1 = ir.Value( + name="v1", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16) + ) + self.v2 = ir.Value( + name="v2", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16) + ) self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) self.add_node.outputs[0].dtype = ir.DataType.BFLOAT16 diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 7d4e9d9b33..44db0a55ff 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -8,11 +8,11 @@ import numpy as np import onnx import onnx.reference +import onnx_ir as ir import parameterized import onnxscript import onnxscript.onnx_types as ot -from onnxscript import ir from onnxscript.onnx_opset import opset18 from onnxscript.rewriter import MatchingTracer, testing from onnxscript.rewriter import pattern as orp @@ -421,14 +421,18 @@ def _convert_shape(shape, name): if isinstance(shape, np.ndarray): shape = tape.initializer(ir.Tensor(shape, name=name)) elif isinstance(shape, (list, tuple)): - shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape)) + shape = ir.Value( + name=name, type=ir.TensorType(ir.DataType.INT64), shape=ir.Shape(shape) + ) tape.graph_like.inputs.append(shape) else: raise TypeError(f"Unsupported type {type(shape)} for shape.") return shape - x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) - y = ir.val("Y", ir.DataType.FLOAT) + x = ir.Value( + name="X", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape(input_shape) + ) + y = ir.Value(name="Y", type=ir.TensorType(ir.DataType.FLOAT)) tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) # Build the graph. @@ -554,8 +558,10 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg): class Flatten2ReshapeTest(unittest.TestCase): @staticmethod def create_model(input_shape, axis=1): - x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) - y = ir.val("Y", ir.DataType.FLOAT) + x = ir.Value( + name="X", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape(input_shape) + ) + y = ir.Value(name="Y", type=ir.TensorType(ir.DataType.FLOAT)) tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) # Build the graph. diff --git a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py index ded57fe023..22b66b3c9a 100644 --- a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py @@ -61,13 +61,13 @@ def build_model( # Register operations in the tape idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT - x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype)) + x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(idtype)) y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes) y = tape.op( op_type, inputs=[y, self.get_conv_weights(weight_shape, tape)], attributes=conv_attributes, - output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)), + output=ir.Value(name="Y", shape=output_shape, type=ir.TensorType(x.dtype)), ) if op_type == "ConvInteger": y.dtype = ir.DataType.INT32 @@ -290,12 +290,12 @@ def build_model( raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.") # Register operations in the tape - x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) y = tape.op( "Conv", inputs=[x, *conv_inputs], attributes=conv_attributes, - output=ir.val("Y", shape=output_shape, type=x.type), + output=ir.Value(name="Y", shape=output_shape, type=x.type), ) # Build the model diff --git a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py index 4c643801fc..edfb0f57ac 100644 --- a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -5,10 +5,10 @@ import numpy as np import onnx +import onnx_ir as ir from onnx_ir.passes.common import onnx_checker, shape_inference from parameterized import parameterized -from onnxscript import ir from onnxscript.rewriter import MatchingTracer, MatchStatus, testing from onnxscript.rewriter.rules.common import _matmul_add_to_gemm @@ -46,10 +46,10 @@ def get_test_model( bias_shape = weight_shape[0] if transB else weight_shape[-1] output_shape = ir.Shape(("?",) * input_shape.rank()) - x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) if weight_as_inputs: - w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) + w = ir.Value(name="W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) inputs.append(w) else: w = ir.tensor( @@ -58,8 +58,8 @@ def get_test_model( w = tape.initializer(w) if bias_as_inputs: - b = ir.val( - "B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT) + b = ir.Value( + name="B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT) ) inputs.append(b) else: @@ -77,7 +77,9 @@ def get_test_model( y = tape.op( "Add", inputs=[y, b], - output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), + output=ir.Value( + name="Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT) + ), ) # Build the model diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py index f67dbeb000..71b279f4f4 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -61,9 +61,17 @@ def _apply_rule_and_check_optimization( assert attr_name in target_node.attributes, f"Attribute {attr_name} was lost" original_value = attr_value.value new_value = target_node.attributes[attr_name].value - assert new_value == original_value, ( - f"Attribute {attr_name} value changed from {original_value} to {new_value}" - ) + # Convert both to same type for comparison to handle list vs tuple differences + if isinstance(original_value, (list, tuple)) and isinstance( + new_value, (list, tuple) + ): + assert list(original_value) == list(new_value), ( + f"Attribute {attr_name} value changed from {original_value} to {new_value}" + ) + else: + assert new_value == original_value, ( + f"Attribute {attr_name} value changed from {original_value} to {new_value}" + ) # Check specific expected attributes if provided if expected_attributes: @@ -73,9 +81,17 @@ def _apply_rule_and_check_optimization( ) actual_attr = target_node.attributes[attr_name] actual_value = actual_attr.value - assert actual_value == expected_value, ( - f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" - ) + # Convert both to same type for comparison to handle list vs tuple differences + if isinstance(actual_value, (list, tuple)) and isinstance( + expected_value, (list, tuple) + ): + assert list(actual_value) == list(expected_value), ( + f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" + ) + else: + assert actual_value == expected_value, ( + f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" + ) # Compare outputs to ensure correctness (only for supported input types) if expected_count > 0: From d62eafb08863e18bb9e775f87280cae4ca706d23 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Sun, 21 Sep 2025 23:43:51 +0530 Subject: [PATCH 12/13] Revert "Refactor IR value creation in tests to use `ir.Value` for consistency and improved clarity" This reverts commit 86de85f324d9eca074fea9d9704173ddb7d1a296. --- .../bfloat16_utils/bfloat16_converter_test.py | 17 +++++------ .../rules/common/_basic_rules_test.py | 18 ++++-------- .../rules/common/_fuse_pad_into_conv_test.py | 8 +++--- .../rules/common/_matmul_add_to_gemm_test.py | 14 ++++------ .../rules/common/_remove_zero_bias_test.py | 28 ++++--------------- 5 files changed, 29 insertions(+), 56 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py index c401f94a15..a64d6e6023 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -6,23 +6,20 @@ import onnx import onnx.checker import onnx.shape_inference -import onnx_ir as ir import onnxruntime +from onnxscript import ir from onnxscript.rewriter.onnxruntime.bfloat16_utils import bfloat16_converter class Bfloat16ConversionTest(unittest.TestCase): def setUp(self) -> None: - self.v0 = ir.Value( - name="v0", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16) - ) - self.v1 = ir.Value( - name="v1", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16) - ) - self.v2 = ir.Value( - name="v2", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16) - ) + self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4])) + self.v0.dtype = ir.DataType.BFLOAT16 + self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4])) + self.v1.dtype = ir.DataType.BFLOAT16 + self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4])) + self.v2.dtype = ir.DataType.BFLOAT16 self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) self.add_node.outputs[0].dtype = ir.DataType.BFLOAT16 diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 44db0a55ff..7d4e9d9b33 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -8,11 +8,11 @@ import numpy as np import onnx import onnx.reference -import onnx_ir as ir import parameterized import onnxscript import onnxscript.onnx_types as ot +from onnxscript import ir from onnxscript.onnx_opset import opset18 from onnxscript.rewriter import MatchingTracer, testing from onnxscript.rewriter import pattern as orp @@ -421,18 +421,14 @@ def _convert_shape(shape, name): if isinstance(shape, np.ndarray): shape = tape.initializer(ir.Tensor(shape, name=name)) elif isinstance(shape, (list, tuple)): - shape = ir.Value( - name=name, type=ir.TensorType(ir.DataType.INT64), shape=ir.Shape(shape) - ) + shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape)) tape.graph_like.inputs.append(shape) else: raise TypeError(f"Unsupported type {type(shape)} for shape.") return shape - x = ir.Value( - name="X", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape(input_shape) - ) - y = ir.Value(name="Y", type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) + y = ir.val("Y", ir.DataType.FLOAT) tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) # Build the graph. @@ -558,10 +554,8 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg): class Flatten2ReshapeTest(unittest.TestCase): @staticmethod def create_model(input_shape, axis=1): - x = ir.Value( - name="X", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape(input_shape) - ) - y = ir.Value(name="Y", type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) + y = ir.val("Y", ir.DataType.FLOAT) tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) # Build the graph. diff --git a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py index 22b66b3c9a..ded57fe023 100644 --- a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py @@ -61,13 +61,13 @@ def build_model( # Register operations in the tape idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT - x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(idtype)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype)) y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes) y = tape.op( op_type, inputs=[y, self.get_conv_weights(weight_shape, tape)], attributes=conv_attributes, - output=ir.Value(name="Y", shape=output_shape, type=ir.TensorType(x.dtype)), + output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)), ) if op_type == "ConvInteger": y.dtype = ir.DataType.INT32 @@ -290,12 +290,12 @@ def build_model( raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.") # Register operations in the tape - x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) y = tape.op( "Conv", inputs=[x, *conv_inputs], attributes=conv_attributes, - output=ir.Value(name="Y", shape=output_shape, type=x.type), + output=ir.val("Y", shape=output_shape, type=x.type), ) # Build the model diff --git a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py index edfb0f57ac..4c643801fc 100644 --- a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -5,10 +5,10 @@ import numpy as np import onnx -import onnx_ir as ir from onnx_ir.passes.common import onnx_checker, shape_inference from parameterized import parameterized +from onnxscript import ir from onnxscript.rewriter import MatchingTracer, MatchStatus, testing from onnxscript.rewriter.rules.common import _matmul_add_to_gemm @@ -46,10 +46,10 @@ def get_test_model( bias_shape = weight_shape[0] if transB else weight_shape[-1] output_shape = ir.Shape(("?",) * input_shape.rank()) - x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) if weight_as_inputs: - w = ir.Value(name="W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) + w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) inputs.append(w) else: w = ir.tensor( @@ -58,8 +58,8 @@ def get_test_model( w = tape.initializer(w) if bias_as_inputs: - b = ir.Value( - name="B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT) + b = ir.val( + "B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT) ) inputs.append(b) else: @@ -77,9 +77,7 @@ def get_test_model( y = tape.op( "Add", inputs=[y, b], - output=ir.Value( - name="Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT) - ), + output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), ) # Build the model diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py index 71b279f4f4..f67dbeb000 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -61,17 +61,9 @@ def _apply_rule_and_check_optimization( assert attr_name in target_node.attributes, f"Attribute {attr_name} was lost" original_value = attr_value.value new_value = target_node.attributes[attr_name].value - # Convert both to same type for comparison to handle list vs tuple differences - if isinstance(original_value, (list, tuple)) and isinstance( - new_value, (list, tuple) - ): - assert list(original_value) == list(new_value), ( - f"Attribute {attr_name} value changed from {original_value} to {new_value}" - ) - else: - assert new_value == original_value, ( - f"Attribute {attr_name} value changed from {original_value} to {new_value}" - ) + assert new_value == original_value, ( + f"Attribute {attr_name} value changed from {original_value} to {new_value}" + ) # Check specific expected attributes if provided if expected_attributes: @@ -81,17 +73,9 @@ def _apply_rule_and_check_optimization( ) actual_attr = target_node.attributes[attr_name] actual_value = actual_attr.value - # Convert both to same type for comparison to handle list vs tuple differences - if isinstance(actual_value, (list, tuple)) and isinstance( - expected_value, (list, tuple) - ): - assert list(actual_value) == list(expected_value), ( - f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" - ) - else: - assert actual_value == expected_value, ( - f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" - ) + assert actual_value == expected_value, ( + f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" + ) # Compare outputs to ensure correctness (only for supported input types) if expected_count > 0: From d4f73dd2b28b42bd73fc53e39c75e86b0cceb415 Mon Sep 17 00:00:00 2001 From: Vineet Kumar Date: Sun, 21 Sep 2025 23:44:39 +0530 Subject: [PATCH 13/13] Enhance attribute comparison in optimization tests to handle list vs tuple differences --- .../rules/common/_remove_zero_bias_test.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py index f67dbeb000..71b279f4f4 100644 --- a/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py +++ b/onnxscript/rewriter/rules/common/_remove_zero_bias_test.py @@ -61,9 +61,17 @@ def _apply_rule_and_check_optimization( assert attr_name in target_node.attributes, f"Attribute {attr_name} was lost" original_value = attr_value.value new_value = target_node.attributes[attr_name].value - assert new_value == original_value, ( - f"Attribute {attr_name} value changed from {original_value} to {new_value}" - ) + # Convert both to same type for comparison to handle list vs tuple differences + if isinstance(original_value, (list, tuple)) and isinstance( + new_value, (list, tuple) + ): + assert list(original_value) == list(new_value), ( + f"Attribute {attr_name} value changed from {original_value} to {new_value}" + ) + else: + assert new_value == original_value, ( + f"Attribute {attr_name} value changed from {original_value} to {new_value}" + ) # Check specific expected attributes if provided if expected_attributes: @@ -73,9 +81,17 @@ def _apply_rule_and_check_optimization( ) actual_attr = target_node.attributes[attr_name] actual_value = actual_attr.value - assert actual_value == expected_value, ( - f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" - ) + # Convert both to same type for comparison to handle list vs tuple differences + if isinstance(actual_value, (list, tuple)) and isinstance( + expected_value, (list, tuple) + ): + assert list(actual_value) == list(expected_value), ( + f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" + ) + else: + assert actual_value == expected_value, ( + f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}" + ) # Compare outputs to ensure correctness (only for supported input types) if expected_count > 0: