Skip to content

Commit 385d2e6

Browse files
authored
[Rewriter(matmul_add_to_gemm)]: check shapes (#2528)
As we need to check the rank of input shapes, we need to ensure that input shapes are not None before checking their rank. Used `_ir_utils.has_rank` to handle that.
1 parent 93e428d commit 385d2e6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

onnxscript/rewriter/matmul_add_to_gemm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import abc
1111
from typing import ClassVar
1212

13+
from onnxscript.rewriter import _ir_utils
1314
from onnxscript.rewriter._basics import MatchResult
1415
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
1516

@@ -30,7 +31,7 @@ def check(self, context, input_a, input_b, **_):
3031
del context # Not used
3132
check_result = MatchResult()
3233
# Rank of input_a and input_b must be 2
33-
if len(input_a.shape) != 2 or len(input_b.shape) != 2:
34+
if not (_ir_utils.has_rank(input_a, 2) and _ir_utils.has_rank(input_b, 2)):
3435
return check_result.fail("Rank of input_a and input_b must be 2")
3536
return check_result
3637

0 commit comments

Comments
 (0)