Skip to content

Commit 3af04e9

Browse files
authored
Add Erf-based Gelu fusion rule (#2495)
Add Erf-based Gelu fusion rule --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent fe152d4 commit 3af04e9

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

onnxscript/rewriter/ort_fusions/gelu.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from onnxscript.rewriter import _fusion_utils, pattern
88

9-
_sqrt_two_over_pi = math.sqrt(2.0 / math.pi)
9+
_SQRT_TWO_OVER_PI = math.sqrt(2.0 / math.pi)
10+
_SQRT_TWO = math.sqrt(2.0)
1011

1112

1213
class GeluTanhFusion(pattern.RewriteRuleClassBase):
@@ -16,7 +17,7 @@ def pattern(self, op, x):
1617
t2 = op.Mul(0.044715, t1)
1718
t3 = op.Add(x, t2)
1819

19-
t4 = op.Mul(_sqrt_two_over_pi, t3)
20+
t4 = op.Mul(_SQRT_TWO_OVER_PI, t3)
2021
t5 = op.Tanh(t4)
2122
t6 = op.Add(t5, 1)
2223
t7 = op.Mul(0.5, t6)
@@ -27,9 +28,23 @@ def rewrite(self, op, x):
2728
return op.FastGelu(x, _domain="com.microsoft")
2829

2930

30-
_rule = GeluTanhFusion.rule()
31+
class GeluErfFusion(pattern.RewriteRuleClassBase):
32+
def pattern(self, op, x):
33+
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
34+
t1 = op.Div(x, _SQRT_TWO)
35+
t2 = op.Erf(t1)
36+
t3 = op.Add(t2, 1.0)
37+
t4 = op.Mul(x, t3)
38+
result = op.Mul(t4, 0.5)
39+
return result
40+
41+
def rewrite(self, op, x):
42+
return op.Gelu(x, _domain="com.microsoft")
43+
3144

32-
gelu_rules = pattern.RewriteRuleSet([_rule])
45+
_tanh_rule = GeluTanhFusion.rule()
46+
_erf_rule = GeluErfFusion.rule()
3347

48+
gelu_rules = pattern.RewriteRuleSet([_tanh_rule, _erf_rule])
3449

3550
fuse_gelu = _fusion_utils.apply_fusion_rules(gelu_rules)

onnxscript/rewriter/ort_fusions/gelu_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,39 @@ def gelu_model(x):
5252
optimized_output = test_utils.ort_run("Optimized", model, input)
5353
test_utils.assert_allclose(original_output, optimized_output)
5454

55+
def test_gelu_erf_fusion(self):
56+
_sqrt_two = math.sqrt(2.0)
57+
58+
@script()
59+
def gelu_erf_model(x):
60+
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
61+
t1 = op.Div(x, _sqrt_two)
62+
t2 = op.Erf(t1)
63+
t3 = op.Add(t2, 1.0)
64+
t4 = op.Mul(x, t3)
65+
result = op.Mul(t4, 0.5)
66+
return result
67+
68+
model_proto = gelu_erf_model.to_model_proto(
69+
input_types=[FLOAT[10]], output_types=[FLOAT[10]]
70+
)
71+
model = ir.serde.deserialize_model(model_proto)
72+
73+
# Eliminate redundant CastLike ops:
74+
optimize(model)
75+
76+
input = {"x": np.random.randn(10).astype(np.float32)}
77+
original_output = test_utils.ort_run("Original", model, input)
78+
79+
fuse_gelu(model)
80+
remove_unused_nodes(model)
81+
82+
self.assertEqual(len(model.graph), 1)
83+
self.assertEqual(model.graph.node(0).op_type, "Gelu")
84+
85+
optimized_output = test_utils.ort_run("Optimized", model, input)
86+
test_utils.assert_allclose(original_output, optimized_output)
87+
5588

5689
if __name__ == "__main__":
5790
unittest.main()

0 commit comments

Comments
 (0)