|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import onnx_ir as ir |
| 6 | + |
| 7 | +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern |
| 8 | + |
| 9 | +""" |
| 10 | +Layer Normalization fusion optimization. |
| 11 | +
|
| 12 | +This module contains rewrite rules for fusing Layer Normalization patterns into the |
| 13 | +ONNX LayerNormalization operator. |
| 14 | +
|
| 15 | +Layer Normalization performs normalization over the last D dimensions as specified by the axis. |
| 16 | +The computation follows: Y = scale * (X - mean) / sqrt(variance + epsilon) + bias |
| 17 | +
|
| 18 | +Key points for the fusion optimization: |
| 19 | +* Following restrictions from opset 17 LayerNormalization: |
| 20 | +* Input, scale, and bias must be of same type T in {float16, bfloat16, float, double} |
| 21 | +* The normalization can be done in a different precision than the input type (bfloat16 or float), |
| 22 | +which is also the precision of the output mean/invstddev |
| 23 | +""" |
| 24 | + |
| 25 | +# input types permitted by LayerNormalization op (ONNX Opset 17) |
| 26 | +LAYER_NORM_INPUT_TYPES = frozenset( |
| 27 | + [ |
| 28 | + ir.DataType.FLOAT, |
| 29 | + ir.DataType.FLOAT16, |
| 30 | + ir.DataType.BFLOAT16, |
| 31 | + ir.DataType.DOUBLE, |
| 32 | + ] |
| 33 | +) |
| 34 | + |
| 35 | +# Compute types permitted by LayerNormalization op (ONNX Opset 17), aka stash_type. |
| 36 | +LAYER_NORM_COMPUTE_TYPES = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) |
| 37 | + |
| 38 | + |
| 39 | +class LayerNormFusion(pattern.RewriteRuleClassBase): |
| 40 | + """Fuse LayerNorm pattern into LayerNormalization op.""" |
| 41 | + |
| 42 | + def pattern(self, op, x, scale, epsilon): |
| 43 | + # Compute mean: Mean = ReduceMean(X, axes=normalized_axes) |
| 44 | + # TODO: support axes attribute too |
| 45 | + mean = op.ReduceMean(x, [-1], keepdims=1) |
| 46 | + |
| 47 | + # Compute deviation: D = Sub(X, Mean) |
| 48 | + deviation = op.Sub(x, mean) |
| 49 | + |
| 50 | + # Compute squared deviation: DD = Mul(D, D) |
| 51 | + deviation_squared = pattern.OrValue( |
| 52 | + [ |
| 53 | + op.Mul(deviation, deviation), |
| 54 | + op.Pow(deviation, 2), |
| 55 | + ] |
| 56 | + ) |
| 57 | + |
| 58 | + # Compute variance: Var = ReduceMean(DD, axes=normalized_axes) |
| 59 | + variance = op.ReduceMean(deviation_squared, [-1], keepdims=1) |
| 60 | + |
| 61 | + # Add epsilon: VarEps = Add(Var, epsilon) |
| 62 | + variance_plus_epsilon = op.Add(variance, epsilon) |
| 63 | + |
| 64 | + # Compute standard deviation: StdDev = Sqrt(VarEps) |
| 65 | + std_dev = op.Sqrt(variance_plus_epsilon) |
| 66 | + |
| 67 | + # Compute reciprocal: InvStdDev = Reciprocal(StdDev) |
| 68 | + # Normalize: Normalized = Mul(D, InvStdDev) |
| 69 | + |
| 70 | + inv_std_dev = op.Reciprocal(std_dev) |
| 71 | + normalized = pattern.OrValue( |
| 72 | + [op.Mul(deviation, inv_std_dev), op.Div(deviation, std_dev)] |
| 73 | + ) |
| 74 | + |
| 75 | + # Scale: NormalizedScaled = Mul(Normalized, Scale) |
| 76 | + normalized_scaled = op.Mul(normalized, scale) |
| 77 | + |
| 78 | + return normalized_scaled |
| 79 | + |
| 80 | + def check(self, context, x, epsilon, **_) -> pattern.MatchResult: # type: ignore[name-defined] |
| 81 | + """Check if the pattern matches conditions for use of LayerNormalization op.""" |
| 82 | + check_result = pattern.MatchResult() |
| 83 | + |
| 84 | + # Type validation: |
| 85 | + if x.dtype not in LAYER_NORM_COMPUTE_TYPES: |
| 86 | + return check_result.fail("Input is not a float type.", x) |
| 87 | + self._stash_type = x.dtype |
| 88 | + |
| 89 | + # Check that epsilon is a scalar constant |
| 90 | + epsilon_value = _ir_utils.get_singleton_value(epsilon) |
| 91 | + if epsilon_value is None: |
| 92 | + return check_result.fail("Epsilon is not a constant scalar.", epsilon) |
| 93 | + # Epsilon is guaranteed to be same type as x (float or double, in this pattern) |
| 94 | + self._epsilon = float(epsilon_value) |
| 95 | + |
| 96 | + return check_result |
| 97 | + |
| 98 | + def rewrite(self, op, x, scale, epsilon, **_): |
| 99 | + return op.LayerNormalization( |
| 100 | + x, |
| 101 | + scale, |
| 102 | + axis=-1, |
| 103 | + epsilon=self._epsilon, |
| 104 | + stash_type=self._stash_type, |
| 105 | + ) |
| 106 | + |
| 107 | + |
| 108 | +class LayerNormBiasFusion(pattern.RewriteRuleClassBase): |
| 109 | + """Fuse LayerNorm => Add into LayerNorm with bias.""" |
| 110 | + |
| 111 | + def pattern(self, op, x, scale, bias): |
| 112 | + return op.LayerNormalization(x, scale, _outputs=["normalized"]) + bias |
| 113 | + |
| 114 | + def rewrite(self, op, x, scale, bias, normalized): |
| 115 | + layernorm_node = normalized.producer() |
| 116 | + attributes = layernorm_node.attributes |
| 117 | + num_outputs = len(layernorm_node.outputs) |
| 118 | + return op.LayerNormalization(x, scale, bias, _outputs=num_outputs, **attributes) |
| 119 | + |
| 120 | + |
| 121 | +# Create rules for both with and without bias |
| 122 | +_layer_norm_rule = LayerNormFusion.rule() |
| 123 | +_layer_norm_with_bias_rule = LayerNormBiasFusion.rule() |
| 124 | + |
| 125 | +layer_normalization_rules = [_layer_norm_rule, _layer_norm_with_bias_rule] |
| 126 | +layer_normalization_ruleset = pattern.RewriteRuleSet(layer_normalization_rules) |
| 127 | + |
| 128 | +fuse_layer_normalization = _fusion_utils.apply_fusion_rules(layer_normalization_ruleset) |
0 commit comments