We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7407431 commit 700bb1aCopy full SHA for 700bb1a
onnxscript/rewriter/ort_fusions/rms_normalization.py
@@ -40,6 +40,8 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
40
reciprocal_rms = op.Reciprocal(rms)
41
normalized = op.Mul(x, reciprocal_rms)
42
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
43
+ # To support float16, we need to ensure the scale is casted or not.
44
+ scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale])
45
return op.Mul(scale, normalized)
46
47
def check(
0 commit comments