Skip to content

Commit 8089bc7

Browse files
authored
Add RMS Normalization rule variant (#2638)
Add RMS Normalization rule variant to support different order of multiplying by scale. Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 75b3d42 commit 8089bc7

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131

3232

3333
class RmsNormFusion(pattern.RewriteRuleClassBase):
34+
def __init__(self, name: str, _mul_order: bool):
35+
super().__init__(name)
36+
self._mul_order = _mul_order
37+
3438
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
3539
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
3640
x_square = op.Pow(x, 2.0)
@@ -42,7 +46,11 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
4246
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
4347
# To support float16, we need to ensure the scale is casted or not.
4448
scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale])
45-
return op.Mul(scale, normalized)
49+
# Workaround: can't use OrValue for final (returned) value
50+
if self._mul_order:
51+
return op.Mul(normalized, scale)
52+
else:
53+
return op.Mul(scale, normalized)
4654

4755
def check(
4856
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
@@ -77,8 +85,10 @@ def rewrite(self, op, x, scale, epsilon, **_):
7785
)
7886

7987

80-
_rule = RmsNormFusion.rule()
81-
rms_normalization_rules = [_rule]
88+
_rule1 = RmsNormFusion.rule("RmsNormFusion1", _mul_order=False)
89+
_rule2 = RmsNormFusion.rule("RmsNormFusion2", _mul_order=True)
90+
91+
rms_normalization_rules = [_rule1, _rule2]
8292
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)
8393

8494

0 commit comments

Comments
 (0)