-
Notifications
You must be signed in to change notification settings - Fork 86
First version of attention fusion #1986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
8de7231
First version
gramalingam a20b903
Add rotary embedding
gramalingam b8f7a08
Remove SDPA
gramalingam 315c94e
Add comment
gramalingam 2219fd3
Remove MHA
gramalingam f77f0e7
Merge branch 'main' into rama/fuse-attn
gramalingam 5ec9d1e
Add rewrite for cos-sin computation
gramalingam 90f0b7b
Merge branch 'rama/fuse-attn' of https://github.com/microsoft/onnx-sc…
gramalingam 1fdc19b
Run lint
gramalingam eb916b8
Add cos sin test
gramalingam d874dbc
Extend rewriter to support node reuse
gramalingam a745039
Minor fixes
gramalingam 17c06c3
Fix concat bug in rotary embedding
gramalingam c7c7c79
Minor cleanup
gramalingam 834815b
Merge branch 'main' into rama/fuse-attn
gramalingam 9a4a58e
Use callable to test callable
gramalingam 766791d
Fix lint issues
gramalingam c7384af
Attention fusion
gramalingam d0254d1
Add support for cached state in rewrite
gramalingam b91166b
Cleanup MHA pattern
gramalingam 205805c
Complete MHA pattern
gramalingam e907f3e
Add MHA fusion test
gramalingam 82f1919
Add validation condition
gramalingam fa3b94d
Run lint
gramalingam 9310b67
Merge with main
gramalingam e0f29e2
Fix merge conflict
gramalingam 41aa177
Fix merge conflict
gramalingam 2688d6e
Merge conflict fix
gramalingam c080f4a
Merge with main
gramalingam 2e6de3d
Merge branch 'main' into rama/fuse-attn-2
gramalingam 889f7d2
Address lint issues
gramalingam a9947dc
Add smollm models to mypy exclusion
gramalingam c5c6588
Rename unused variable in test onnx model
gramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache | ||
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha | ||
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | ||
from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa | ||
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization | ||
|
||
|
||
def fuse_xformers(model): | ||
fuse_rms_normalization(model) | ||
fuse_normalization(model) | ||
fuse_rotary_embedding(model) | ||
fuse_cos_sin_cache(model) | ||
fuse_sdpa(model) | ||
fuse_mha(model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
from typing import Iterable | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.rewriter import pattern | ||
|
||
""" | ||
The MultiHeadAttention pattern: | ||
|
||
B: Batch size | ||
S: Sequence length | ||
D: input embedding dimension | ||
H: number of heads | ||
d_h: head size (usually, D = H * d_h) | ||
|
||
thus, weights are usually of shape (D, D) and (D, D) and (D, D) | ||
|
||
for each of Q, K, and V, we have the following pattern: | ||
MatMul (Input, W), producing output of shape (B, S, D) | ||
Reshape to produce a matrix of shape (B, S, H, d_h) | ||
Transpose middle two axes to produce a matrix of shape (B, H, S, d_h) | ||
|
||
This is followed by a RotaryEmbedding pattern for Q and K | ||
|
||
The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence) | ||
|
||
The dot-product attention is then computed using SDPA | ||
|
||
Finally, the output is transposed and reshaped back to (B, S, D) shape | ||
""" | ||
|
||
|
||
def _project_transpose_head(op, input, weight, reshape_var: str): | ||
"""Applied to each of Q, K, and V.""" | ||
projected = op.MatMul(input, weight) | ||
# Reshape from (B, S, D) to (B, S, H, D/H) | ||
reshaped = op.Reshape( | ||
projected, | ||
_allow_other_inputs=True, | ||
_allow_other_attributes=True, | ||
_outputs=[reshape_var], | ||
) | ||
# Transpose from (B, S, H, D/H) to (B, H, S, D/H) | ||
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) | ||
return transposed | ||
|
||
|
||
def _multi_head_attention_pattern( | ||
op, | ||
input, | ||
query_weight, | ||
key_weight, | ||
value_weight, | ||
mask, | ||
cos, | ||
sin, | ||
past_key, | ||
past_value, | ||
position_ids, | ||
): | ||
query = _project_transpose_head(op, input, query_weight, "query_mm_reshaped") | ||
query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft") | ||
key = _project_transpose_head(op, input, key_weight, "key_mm_reshaped") | ||
key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft") | ||
key_rope = op.Concat(past_key, key_rope, axis=-2) | ||
# Transpose last two axes of key_rope to compute dot-product via matmul. | ||
key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True, _outputs=["key_reshaped"]) | ||
key_reshaped_transposed = op.Transpose(key_reshaped, perm=[0, 2, 1]) | ||
key_transposed = op.Reshape( | ||
key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"] | ||
) | ||
value = _project_transpose_head(op, input, value_weight, "value_mm_reshaped") | ||
value = op.Concat(past_value, value, axis=-2) | ||
attention = op.SDPA( | ||
query_rope, key_transposed, value, mask, _domain="ai.onnxruntime.fusion" | ||
) | ||
# Transpose back to (B, S, H, D/H) | ||
attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) | ||
# Reshape back to (B, S, D) | ||
attention_reshaped = op.Reshape( | ||
attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"] | ||
) | ||
return attention_reshaped, key_rope, value | ||
|
||
|
||
def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Iterable[str]) -> bool: | ||
if val.shape is None: | ||
return False | ||
if val.shape.rank() != len(shape): | ||
return False | ||
for actual, expected in zip(val.shape, shape): | ||
if expected not in bindings: | ||
bindings[expected] = actual | ||
elif actual != bindings[expected]: | ||
return False | ||
return True | ||
|
||
|
||
def _mha_validation( | ||
op, | ||
query_mm_reshaped, | ||
key_mm_reshaped, | ||
value_mm_reshaped, | ||
key_reshaped, | ||
key_transposed, | ||
attention_reshaped, | ||
**_, | ||
): | ||
bindings: dict[str, int] = {} | ||
check = ( | ||
_check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) | ||
and _check_shape(bindings, key_mm_reshaped, ["B", "KVS", "H", "d_h"]) | ||
and _check_shape(bindings, value_mm_reshaped, ["B", "KVS", "H", "d_h"]) | ||
and _check_shape(bindings, key_reshaped, ["B*H", "TS", "d_h"]) | ||
and _check_shape(bindings, key_transposed, ["B", "H", "d_h", "TS"]) | ||
and _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"]) | ||
) | ||
if not check: | ||
return False | ||
if bindings["B"] * bindings["H"] != bindings["B*H"]: | ||
return False | ||
if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]: | ||
return False | ||
return True | ||
|
||
|
||
def _multi_head_attention( | ||
op, | ||
input, | ||
query_weight, | ||
key_weight, | ||
value_weight, | ||
mask, | ||
cos, | ||
sin, | ||
past_key, | ||
past_value, | ||
position_ids, | ||
query_mm_reshaped, | ||
**_, | ||
): | ||
num_heads = query_mm_reshaped.shape[2] | ||
query = op.MatMul(input, query_weight) | ||
query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft") | ||
key = op.MatMul(input, key_weight) | ||
key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft") | ||
value = op.MatMul(input, value_weight) | ||
tiling_factor = op.Constant(value_ints=[1, num_heads, 1, 1]) | ||
expanded_mask = op.Tile(mask, tiling_factor) | ||
return op.MultiHeadAttention( | ||
query_rope, | ||
key_rope, | ||
value, | ||
None, # bias | ||
None, # key padding mask | ||
expanded_mask, # attention mask/bias | ||
past_key, | ||
past_value, | ||
num_heads=num_heads, | ||
_domain="com.microsoft", | ||
_outputs=3, | ||
) | ||
|
||
|
||
_rule1 = pattern.RewriteRule( | ||
_multi_head_attention_pattern, _multi_head_attention, _mha_validation | ||
) | ||
|
||
|
||
mha_rules = pattern.RewriteRuleSet([_rule1]) | ||
|
||
|
||
def fuse_mha(model: ir.Model) -> int: | ||
count = mha_rules.apply_to_model(model) | ||
print(f"MHA count: {count}") | ||
return count |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import onnxscript.optimizer | ||
import onnxscript.rewriter.onnxruntime.xformers as xformers | ||
from onnxscript.rewriter.onnxruntime.xformers._smollm_2 import TestData | ||
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run | ||
|
||
|
||
class TestMultiHeadAttention(unittest.TestCase): | ||
def test_smollm(self): | ||
# Generate model | ||
smollm_test = TestData() | ||
model = smollm_test.get_onnx_model() | ||
onnxscript.optimizer.optimize(model) | ||
xformers.fuse_rms_normalization(model) | ||
xformers.fuse_normalization(model) | ||
xformers.fuse_rotary_embedding(model) | ||
xformers.fuse_cos_sin_cache(model) | ||
|
||
# Run model | ||
inputs = smollm_test.get_ort_inputs() | ||
original_outputs = ort_run("original", model, inputs) | ||
|
||
# Fuse SDPA and MHA | ||
sdpa_count = xformers.fuse_sdpa(model) | ||
self.assertGreater(sdpa_count, 0) | ||
mha_count = xformers.fuse_mha(model) | ||
self.assertGreater(mha_count, 0) | ||
|
||
# Run model again | ||
new_outputs = ort_run("optimized", model, inputs) | ||
assert_allclose(new_outputs, original_outputs) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.