Skip to content

Commit 73d6134

Browse files
authored
Introduce layer-norm fusion (#2492)
Introduce layer-norm fusion rules, along with a couple of test cases. This is just the first version. TO DO: * We need improved infrastructure for ONNX fusions to handle opset dependence. LayerNorm exists in ONNX since opset 17. For now the fusion rule exists, but it is not automatically called yet (but users can invoke it themselves). * If users want to use opsets < 17, this could be done as an ORT fusion using ORT contrib op LayerNorm. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent fde4802 commit 73d6134

File tree

3 files changed

+278
-7
lines changed

3 files changed

+278
-7
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import onnx_ir as ir
6+
7+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
8+
9+
"""
10+
Layer Normalization fusion optimization.
11+
12+
This module contains rewrite rules for fusing Layer Normalization patterns into the
13+
ONNX LayerNormalization operator.
14+
15+
Layer Normalization performs normalization over the last D dimensions as specified by the axis.
16+
The computation follows: Y = scale * (X - mean) / sqrt(variance + epsilon) + bias
17+
18+
Key points for the fusion optimization:
19+
* Following restrictions from opset 17 LayerNormalization:
20+
* Input, scale, and bias must be of same type T in {float16, bfloat16, float, double}
21+
* The normalization can be done in a different precision than the input type (bfloat16 or float),
22+
which is also the precision of the output mean/invstddev
23+
"""
24+
25+
# input types permitted by LayerNormalization op (ONNX Opset 17)
26+
LAYER_NORM_INPUT_TYPES = frozenset(
27+
[
28+
ir.DataType.FLOAT,
29+
ir.DataType.FLOAT16,
30+
ir.DataType.BFLOAT16,
31+
ir.DataType.DOUBLE,
32+
]
33+
)
34+
35+
# Compute types permitted by LayerNormalization op (ONNX Opset 17), aka stash_type.
36+
LAYER_NORM_COMPUTE_TYPES = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE])
37+
38+
39+
class LayerNormFusion(pattern.RewriteRuleClassBase):
40+
"""Fuse LayerNorm pattern into LayerNormalization op."""
41+
42+
def pattern(self, op, x, scale, epsilon):
43+
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
44+
# TODO: support axes attribute too
45+
mean = op.ReduceMean(x, [-1], keepdims=1)
46+
47+
# Compute deviation: D = Sub(X, Mean)
48+
deviation = op.Sub(x, mean)
49+
50+
# Compute squared deviation: DD = Mul(D, D)
51+
deviation_squared = pattern.OrValue(
52+
[
53+
op.Mul(deviation, deviation),
54+
op.Pow(deviation, 2),
55+
]
56+
)
57+
58+
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
59+
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
60+
61+
# Add epsilon: VarEps = Add(Var, epsilon)
62+
variance_plus_epsilon = op.Add(variance, epsilon)
63+
64+
# Compute standard deviation: StdDev = Sqrt(VarEps)
65+
std_dev = op.Sqrt(variance_plus_epsilon)
66+
67+
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
68+
# Normalize: Normalized = Mul(D, InvStdDev)
69+
70+
inv_std_dev = op.Reciprocal(std_dev)
71+
normalized = pattern.OrValue(
72+
[op.Mul(deviation, inv_std_dev), op.Div(deviation, std_dev)]
73+
)
74+
75+
# Scale: NormalizedScaled = Mul(Normalized, Scale)
76+
normalized_scaled = op.Mul(normalized, scale)
77+
78+
return normalized_scaled
79+
80+
def check(self, context, x, epsilon, **_) -> pattern.MatchResult: # type: ignore[name-defined]
81+
"""Check if the pattern matches conditions for use of LayerNormalization op."""
82+
check_result = pattern.MatchResult()
83+
84+
# Type validation:
85+
if x.dtype not in LAYER_NORM_COMPUTE_TYPES:
86+
return check_result.fail("Input is not a float type.", x)
87+
self._stash_type = x.dtype
88+
89+
# Check that epsilon is a scalar constant
90+
epsilon_value = _ir_utils.get_singleton_value(epsilon)
91+
if epsilon_value is None:
92+
return check_result.fail("Epsilon is not a constant scalar.", epsilon)
93+
# Epsilon is guaranteed to be same type as x (float or double, in this pattern)
94+
self._epsilon = float(epsilon_value)
95+
96+
return check_result
97+
98+
def rewrite(self, op, x, scale, epsilon, **_):
99+
return op.LayerNormalization(
100+
x,
101+
scale,
102+
axis=-1,
103+
epsilon=self._epsilon,
104+
stash_type=self._stash_type,
105+
)
106+
107+
108+
class LayerNormBiasFusion(pattern.RewriteRuleClassBase):
109+
"""Fuse LayerNorm => Add into LayerNorm with bias."""
110+
111+
def pattern(self, op, x, scale, bias):
112+
return op.LayerNormalization(x, scale, _outputs=["normalized"]) + bias
113+
114+
def rewrite(self, op, x, scale, bias, normalized):
115+
layernorm_node = normalized.producer()
116+
attributes = layernorm_node.attributes
117+
num_outputs = len(layernorm_node.outputs)
118+
return op.LayerNormalization(x, scale, bias, _outputs=num_outputs, **attributes)
119+
120+
121+
# Create rules for both with and without bias
122+
_layer_norm_rule = LayerNormFusion.rule()
123+
_layer_norm_with_bias_rule = LayerNormBiasFusion.rule()
124+
125+
layer_normalization_rules = [_layer_norm_rule, _layer_norm_with_bias_rule]
126+
layer_normalization_ruleset = pattern.RewriteRuleSet(layer_normalization_rules)
127+
128+
fuse_layer_normalization = _fusion_utils.apply_fusion_rules(layer_normalization_ruleset)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import unittest
5+
6+
import onnx_ir as ir
7+
8+
import onnxscript
9+
import onnxscript.optimizer
10+
import onnxscript.rewriter.testing
11+
from onnxscript import FLOAT, OnnxFunction, script
12+
from onnxscript import opset18 as op
13+
from onnxscript.rewriter.onnx_fusions._layer_norm import fuse_layer_normalization
14+
15+
16+
@script()
17+
def _test_layer_norm_without_bias(x: FLOAT[2, 4, 8], scale: FLOAT[8]) -> FLOAT[2, 4, 8]:
18+
"""LayerNorm pattern without bias."""
19+
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
20+
mean = op.ReduceMean(x, [-1], keepdims=1)
21+
22+
# Compute deviation: D = Sub(X, Mean)
23+
deviation = op.Sub(x, mean)
24+
25+
# Compute squared deviation: DD = Mul(D, D)
26+
deviation_squared = op.Mul(deviation, deviation)
27+
28+
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
29+
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
30+
31+
# Add epsilon: VarEps = Add(Var, epsilon)
32+
epsilon = op.Constant(value_float=1e-5)
33+
variance_plus_epsilon = op.Add(variance, epsilon)
34+
35+
# Compute standard deviation: StdDev = Sqrt(VarEps)
36+
std_dev = op.Sqrt(variance_plus_epsilon)
37+
38+
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
39+
inv_std_dev = op.Reciprocal(std_dev)
40+
41+
# Normalize: Normalized = Mul(D, InvStdDev)
42+
normalized = op.Mul(deviation, inv_std_dev)
43+
44+
# Scale: NormalizedScaled = Mul(Normalized, Scale)
45+
normalized_scaled = op.Mul(normalized, scale)
46+
47+
return normalized_scaled
48+
49+
50+
@script()
51+
def _test_layer_norm_with_bias(
52+
x: FLOAT[2, 4, 8], scale: FLOAT[8], bias: FLOAT[8]
53+
) -> FLOAT[2, 4, 8]:
54+
"""LayerNorm pattern with bias."""
55+
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
56+
mean = op.ReduceMean(x, [-1], keepdims=1)
57+
58+
# Compute deviation: D = Sub(X, Mean)
59+
deviation = op.Sub(x, mean)
60+
61+
# Compute squared deviation: DD = Mul(D, D)
62+
deviation_squared = op.Mul(deviation, deviation)
63+
64+
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
65+
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
66+
67+
# Add epsilon: VarEps = Add(Var, epsilon)
68+
epsilon = op.Constant(value_float=1e-5)
69+
variance_plus_epsilon = op.Add(variance, epsilon)
70+
71+
# Compute standard deviation: StdDev = Sqrt(VarEps)
72+
std_dev = op.Sqrt(variance_plus_epsilon)
73+
74+
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
75+
inv_std_dev = op.Reciprocal(std_dev)
76+
77+
# Normalize: Normalized = Mul(D, InvStdDev)
78+
normalized = op.Mul(deviation, inv_std_dev)
79+
80+
# Scale: NormalizedScaled = Mul(Normalized, Scale)
81+
normalized_scaled = op.Mul(normalized, scale)
82+
83+
# Add bias: Y = Add(NormalizedScaled, B)
84+
result = op.Add(normalized_scaled, bias)
85+
86+
return result
87+
88+
89+
class LayerNormFusionTest(unittest.TestCase):
90+
def _check(self, test_script: OnnxFunction):
91+
"""Helper method to run a fusion test scenario."""
92+
model_proto = test_script.to_model_proto()
93+
# Create test inputs
94+
input_data = onnxscript.rewriter.testing.generate_random_inputs(model_proto)
95+
96+
model = ir.serde.deserialize_model(model_proto)
97+
fuse_layer_normalization(model)
98+
99+
onnxscript.optimizer.remove_unused_nodes(model)
100+
101+
# Check that a LayerNormalization node was created
102+
self.assertEqual(["LayerNormalization"], [n.op_type for n in model.graph])
103+
104+
fused_model_proto = ir.serde.serialize_model(model)
105+
106+
onnxscript.rewriter.testing.assert_numerically_equal(
107+
model_proto, fused_model_proto, input_data
108+
)
109+
110+
def test_layer_norm_fusion_without_bias(self):
111+
"""Test LayerNorm fusion without bias."""
112+
self._check(_test_layer_norm_without_bias)
113+
114+
def test_layer_norm_fusion_with_bias(self):
115+
"""Test LayerNorm fusion with bias."""
116+
self._check(_test_layer_norm_with_bias)
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main()

onnxscript/rewriter/testing.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,28 @@
1111
from onnxscript import ir
1212

1313

14+
def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]:
15+
feeds: dict[str, Any] = {}
16+
for input in model.graph.input:
17+
input_type = input.type.tensor_type
18+
shape = tuple(input_type.shape.dim)
19+
if not all(hasattr(d, "dim_value") for d in shape):
20+
raise ValueError(f"Input {input.name} has dynamic shape dimensions.")
21+
shape = tuple(d.dim_value for d in shape)
22+
if input_type.elem_type == onnx.TensorProto.FLOAT:
23+
if shape:
24+
feeds[input.name] = np.random.randn(*shape).astype(np.float32)
25+
else:
26+
feeds[input.name] = np.random.randn(1).astype(np.float32)
27+
else:
28+
raise ValueError(f"Not implemented for input type {input_type.elem_type}")
29+
return feeds
30+
31+
1432
def assert_numerically_equal(
1533
original_model_proto: onnx.ModelProto | ir.Model,
1634
rewritten_model_proto: onnx.ModelProto | ir.Model,
17-
args: tuple[Any, ...],
35+
args: tuple[Any, ...] | dict[str, Any],
1836
ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
1937
rtol: float = 1,
2038
atol: float = 1e-3,
@@ -35,9 +53,17 @@ def assert_numerically_equal(
3553
if isinstance(rewritten_model_proto, ir.Model):
3654
rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto)
3755

38-
original_proto_ort_inputs = {
39-
k.name: v for k, v in zip(original_model_proto.graph.input, args)
40-
}
56+
if isinstance(args, dict):
57+
original_proto_ort_inputs = args
58+
the_rewritten_proto_ort_inputs = args
59+
else:
60+
original_proto_ort_inputs = {
61+
k.name: v for k, v in zip(original_model_proto.graph.input, args)
62+
}
63+
the_rewritten_proto_ort_inputs = {
64+
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
65+
}
66+
4167
original_proto_ort_inference_session = _ort_session_initializer(
4268
original_model_proto.SerializeToString(), ort_optimization_level
4369
)
@@ -47,9 +73,6 @@ def assert_numerically_equal(
4773
None, original_proto_ort_inputs, run_options=run_options
4874
)
4975

50-
the_rewritten_proto_ort_inputs = {
51-
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
52-
}
5376
the_rewritten_proto_ort_inference_session = _ort_session_initializer(
5477
rewritten_model_proto.SerializeToString(), ort_optimization_level
5578
)

0 commit comments

Comments
 (0)