Skip to content

Commit 700bb1a

Browse files
authored
[ort_fusuion] Support fp16 in rms_norm fusion (#2491)
In RMSNorm, there are compute_type and target_type, which we run the computation on compute_type and then convert it back to target_type after RMSNorm. Typical example can be found in RMSNorm class in LLMs, like in GPT-OSS: https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L54 Therefore, we need to take op.Cast into pattern consideration.
1 parent 7407431 commit 700bb1a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
4040
reciprocal_rms = op.Reciprocal(rms)
4141
normalized = op.Mul(x, reciprocal_rms)
4242
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
43+
# To support float16, we need to ensure the scale is casted or not.
44+
scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale])
4345
return op.Mul(scale, normalized)
4446

4547
def check(

0 commit comments

Comments
 (0)