3131
3232
3333class RmsNormFusion (pattern .RewriteRuleClassBase ):
34+ def __init__ (self , name : str , _mul_order : bool ):
35+ super ().__init__ (name )
36+ self ._mul_order = _mul_order
37+
3438 def pattern (self , op , x , scale , epsilon , compute_dtype , target_dtype ):
3539 x = pattern .OrValue ([op .Cast (x , to = compute_dtype ), x ])
3640 x_square = op .Pow (x , 2.0 )
@@ -42,7 +46,11 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
4246 normalized = pattern .OrValue ([op .Cast (normalized , to = target_dtype ), normalized ])
4347 # To support float16, we need to ensure the scale is casted or not.
4448 scale = pattern .OrValue ([op .Cast (scale , to = compute_dtype ), scale ])
45- return op .Mul (scale , normalized )
49+ # Workaround: can't use OrValue for final (returned) value
50+ if self ._mul_order :
51+ return op .Mul (normalized , scale )
52+ else :
53+ return op .Mul (scale , normalized )
4654
4755 def check (
4856 self , op , x , scale , epsilon , compute_dtype , target_dtype , ** _
@@ -77,8 +85,10 @@ def rewrite(self, op, x, scale, epsilon, **_):
7785 )
7886
7987
80- _rule = RmsNormFusion .rule ()
81- rms_normalization_rules = [_rule ]
88+ _rule1 = RmsNormFusion .rule ("RmsNormFusion1" , _mul_order = False )
89+ _rule2 = RmsNormFusion .rule ("RmsNormFusion2" , _mul_order = True )
90+
91+ rms_normalization_rules = [_rule1 , _rule2 ]
8292rms_normalization_ruleset = pattern .RewriteRuleSet (rms_normalization_rules )
8393
8494
0 commit comments