Skip to content

Commit 32f2196

Browse files
authored
Rename fusion files (#2476)
Rename fusion files to follow a uniform style and to locate them easily: * fuse_packed_qkv_gqa => gqa_packed_qkv * fuse_mha_bias => mha_bias * Rename corresponding test file also --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent da23d76 commit 32f2196

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu
1818
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
1919
from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu
20-
from onnxscript.rewriter.ort_fusions.fuse_mha_bias import fuse_mha_bias
21-
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
2220
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
2321
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
22+
from onnxscript.rewriter.ort_fusions.gqa_packed_qkv import fuse_qkv_gqa
2423
from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2
24+
from onnxscript.rewriter.ort_fusions.mha_bias import fuse_mha_bias
2525
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
2626
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
2727
fuse_partial_rotary_embedding,
File renamed without changes.

onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py renamed to onnxscript/rewriter/ort_fusions/gqa_packed_qkv_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from onnxscript import FLOAT, INT32, script
1515
from onnxscript import opset18 as op
1616
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose
17-
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
17+
from onnxscript.rewriter.ort_fusions.gqa_packed_qkv import fuse_qkv_gqa
1818

1919
msft_op = onnxscript.values.Opset("com.microsoft", 1)
2020

onnxscript/rewriter/ort_fusions/fuse_mha_bias.py renamed to onnxscript/rewriter/ort_fusions/mha_bias.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def rewrite(
163163
)
164164

165165

166-
fuse_mha_bias_rules = pattern.RewriteRuleSet([FuseBiasMHA.rule()])
166+
mha_bias_rules = pattern.RewriteRuleSet([FuseBiasMHA.rule()])
167167

168168

169-
fuse_mha_bias = _fusion_utils.apply_fusion_rules(fuse_mha_bias_rules)
169+
fuse_mha_bias = _fusion_utils.apply_fusion_rules(mha_bias_rules)

0 commit comments

Comments
 (0)