6
6
7
7
from onnxscript .rewriter import _fusion_utils , pattern
8
8
9
- _sqrt_two_over_pi = math .sqrt (2.0 / math .pi )
9
+ _SQRT_TWO_OVER_PI = math .sqrt (2.0 / math .pi )
10
+ _SQRT_TWO = math .sqrt (2.0 )
10
11
11
12
12
13
class GeluTanhFusion (pattern .RewriteRuleClassBase ):
@@ -16,7 +17,7 @@ def pattern(self, op, x):
16
17
t2 = op .Mul (0.044715 , t1 )
17
18
t3 = op .Add (x , t2 )
18
19
19
- t4 = op .Mul (_sqrt_two_over_pi , t3 )
20
+ t4 = op .Mul (_SQRT_TWO_OVER_PI , t3 )
20
21
t5 = op .Tanh (t4 )
21
22
t6 = op .Add (t5 , 1 )
22
23
t7 = op .Mul (0.5 , t6 )
@@ -27,9 +28,23 @@ def rewrite(self, op, x):
27
28
return op .FastGelu (x , _domain = "com.microsoft" )
28
29
29
30
30
- _rule = GeluTanhFusion .rule ()
31
+ class GeluErfFusion (pattern .RewriteRuleClassBase ):
32
+ def pattern (self , op , x ):
33
+ # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
34
+ t1 = op .Div (x , _SQRT_TWO )
35
+ t2 = op .Erf (t1 )
36
+ t3 = op .Add (t2 , 1.0 )
37
+ t4 = op .Mul (x , t3 )
38
+ result = op .Mul (t4 , 0.5 )
39
+ return result
40
+
41
+ def rewrite (self , op , x ):
42
+ return op .Gelu (x , _domain = "com.microsoft" )
43
+
31
44
32
- gelu_rules = pattern .RewriteRuleSet ([_rule ])
45
+ _tanh_rule = GeluTanhFusion .rule ()
46
+ _erf_rule = GeluErfFusion .rule ()
33
47
48
+ gelu_rules = pattern .RewriteRuleSet ([_tanh_rule , _erf_rule ])
34
49
35
50
fuse_gelu = _fusion_utils .apply_fusion_rules (gelu_rules )
0 commit comments