Skip to content

Commit c7d5786

Browse files
gramalingamCopilotgithub-advanced-security[bot]
authored
Test SDPA fusion via MHA (#2366)
Implements SDPA (introduced by our fusions) via MHA (in a subset of cases), so that the fused model can be run and tested using ORT. Not yet addressed: use of KV cache, 3D vs 4D Q/K/V formats. (Will address them as I cleanup the MHA fusion rules next). Also fix some copy-paste errors in the SDPA test-cases (and make the test-case naming scheme more uniform, helps with pytest test-selection filter -k). --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent 99323bf commit c7d5786

File tree

2 files changed

+105
-24
lines changed

2 files changed

+105
-24
lines changed

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from onnxscript import script
1717
from onnxscript.onnx_opset import opset18 as op
1818
from onnxscript.onnx_types import FLOAT
19+
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
1920
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
21+
from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha
2022

2123
B = 2 # batch size
2224
N = 4 # number of heads
@@ -190,7 +192,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
190192

191193

192194
@script()
193-
def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
195+
def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask):
194196
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
195197
divisor = op.Constant(value_float=SQRT_CUSTOM_DIV_SCALE_FACTOR)
196198
scaled_query = op.Div(query, divisor)
@@ -203,7 +205,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
203205

204206

205207
@script()
206-
def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
208+
def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask):
207209
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
208210
multiplier = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR)
209211
scaled_query = op.Mul(query, multiplier)
@@ -216,7 +218,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
216218

217219

218220
@script()
219-
def _custom_scale_post_div_sdpa_script(query, key, value, mask):
221+
def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask):
220222
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
221223
divisor = op.Constant(value_float=CUSTOM_DIV_SCALE_FACTOR)
222224
attn_score = op.MatMul(query, key_transposed)
@@ -228,7 +230,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask):
228230

229231

230232
@script()
231-
def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
233+
def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
232234
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
233235
multiplier = op.Constant(value_float=CUSTOM_MUL_SCALE_FACTOR)
234236
attn_score = op.MatMul(query, key_transposed)
@@ -240,15 +242,19 @@ def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
240242

241243

242244
class SDPATestCase:
243-
def __init__(self, script_func):
245+
def __init__(self, script_func, *, with_mask):
244246
self.script_func = script_func
247+
self.with_mask = with_mask
245248

246249
def get_onnx_model(self):
247250
if not hasattr(self, "_onnx_model"):
248251
qkv_type = FLOAT[B, N, S, H]
249252
mask_type = FLOAT[B, N, S, S]
253+
input_types = [qkv_type, qkv_type, qkv_type]
254+
if self.with_mask:
255+
input_types.append(mask_type)
250256
model_proto = self.script_func.to_model_proto(
251-
input_types=[qkv_type, qkv_type, qkv_type, mask_type], output_types=[qkv_type]
257+
input_types=input_types, output_types=[qkv_type]
252258
)
253259
self._onnx_model = ir.serde.deserialize_model(model_proto)
254260
return self._onnx_model
@@ -259,8 +265,9 @@ def get_ort_inputs(self):
259265
"query": numpy.random.rand(B, N, S, H).astype(numpy.float32),
260266
"key": numpy.random.rand(B, N, S, H).astype(numpy.float32),
261267
"value": numpy.random.rand(B, N, S, H).astype(numpy.float32),
262-
"mask": numpy.random.rand(B, N, S, S).astype(numpy.float32),
263268
}
269+
if self.with_mask:
270+
inputs["mask"] = numpy.random.rand(B, N, S, S).astype(numpy.float32)
264271
self._ort_inputs = inputs
265272
return self._ort_inputs
266273

@@ -296,35 +303,35 @@ def get_ort_inputs(self):
296303
class TestSDPAFusion(unittest.TestCase):
297304
@parameterized.parameterized.expand(
298305
[
299-
("unmasked_pre_div", _unmasked_pre_div_sdpa_script),
300-
("unmasked_pre_mul", _unmasked_pre_mul_sdpa_script),
301-
("unmasked_post_div", _unmasked_post_div_sdpa_script),
302-
("unmasked_post_mul", _unmasked_post_mul_sdpa_script),
303-
("pre_div", _masked_pre_div_sdpa_script),
304-
("pre_mul", _masked_pre_mul_sdpa_script),
305-
("post_div", _masked_post_div_sdpa_script),
306-
("post_mul", _masked_post_mul_sdpa_script),
306+
("pre_div", _unmasked_pre_div_sdpa_script),
307+
("pre_mul", _unmasked_pre_mul_sdpa_script),
308+
("post_div", _unmasked_post_div_sdpa_script),
309+
("post_mul", _unmasked_post_mul_sdpa_script),
310+
("masked_pre_div", _masked_pre_div_sdpa_script),
311+
("masked_pre_mul", _masked_pre_mul_sdpa_script),
312+
("masked_post_div", _masked_post_div_sdpa_script),
313+
("masked_post_mul", _masked_post_mul_sdpa_script),
307314
("custom_scale_post_mul", _custom_scale_post_mul_sdpa_script),
308315
("custom_scale_post_div", _custom_scale_post_div_sdpa_script),
309316
("custom_scale_pre_mul", _custom_scale_pre_mul_sdpa_script),
310317
("custom_scale_pre_div", _custom_scale_pre_div_sdpa_script),
311-
("custom_scale_post_mul_masked", _custom_scale_post_mul_sdpa_script),
312-
("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script),
313-
("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script),
314-
("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script),
318+
("masked_custom_scale_post_mul", _masked_custom_scale_post_mul_sdpa_script),
319+
("masked_custom_scale_post_div", _masked_custom_scale_post_div_sdpa_script),
320+
("masked_custom_scale_pre_mul", _masked_custom_scale_pre_mul_sdpa_script),
321+
("masked_custom_scale_pre_div", _masked_custom_scale_pre_div_sdpa_script),
315322
(
316323
"_custom_multi_scale_pre_mul_sdpa_script",
317324
_custom_multi_scale_pre_mul_sdpa_script,
318325
),
319326
]
320327
)
321328
def test_sdpa_fusion(self, name, script_func):
322-
test_case = SDPATestCase(script_func)
329+
test_case = SDPATestCase(script_func, with_mask="masked" in name)
323330
model = test_case.get_onnx_model()
324331
onnxscript.optimizer.optimize(model)
325332

326-
# inputs = test_case.get_ort_inputs()
327-
# original_outputs = ort_run("original", model, inputs)
333+
inputs = test_case.get_ort_inputs()
334+
original_outputs = ort_run("original", model, inputs)
328335

329336
count = fuse_sdpa(model, debug=True)
330337
self.assertGreater(count, 0)
@@ -347,8 +354,12 @@ def test_sdpa_fusion(self, name, script_func):
347354
# of scale_factor (is =default_scaling_factor)
348355
self.assertIsNone(sdpa_node.attributes.get("scale"))
349356

350-
# new_outputs = ort_run("optimized", model, inputs)
351-
# assert_allclose(new_outputs, original_outputs)
357+
replace_sdpa_by_mha(model, debug=True)
358+
359+
self.assertNotIn("SDPA", [n.op_type for n in model.graph])
360+
361+
new_outputs = ort_run("optimized", model, inputs)
362+
assert_allclose(new_outputs, original_outputs)
352363

353364
def test_invalid_sdpa_fusion_value_batch_dim(self):
354365
test_case = InvalidSDPATestCase(_masked_pre_mul_sdpa_script)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from typing import Union
6+
7+
import onnxscript.ir as ir
8+
from onnxscript.rewriter import _fusion_utils, pattern
9+
10+
Dim = Union[int, ir.SymbolicDim]
11+
12+
13+
class SDPAImplementation(pattern.RewriteRuleClassBase):
14+
def pattern(self, op, query, key_transposed, value):
15+
return op.SDPA(
16+
query,
17+
key_transposed,
18+
value,
19+
_allow_other_inputs=True, # Mask is optional
20+
_outputs=["sdpa_output"],
21+
_domain="ai.onnxruntime.fusion",
22+
)
23+
24+
def check(self, context, query, key_transposed, value, sdpa_output):
25+
bindings: dict[str, Dim] = {}
26+
_fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"])
27+
_fusion_utils.check_shape(bindings, key_transposed, ["B", "H", "Dh", "Skv"])
28+
_fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"])
29+
30+
self._num_heads = bindings["H"]
31+
if not isinstance(self._num_heads, int):
32+
return False
33+
self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed
34+
return isinstance(self._num_heads, int)
35+
36+
def rewrite(self, op, query, key_transposed, value, sdpa_output):
37+
sdpa_node = sdpa_output.producer()
38+
scale = sdpa_node.attributes.get("scale", None)
39+
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
40+
to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1])
41+
query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape)
42+
key_3d = op.Reshape(op.Transpose(key_transposed, perm=[0, 3, 1, 2]), to_3d_shape)
43+
value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape)
44+
45+
inputs = [query_3d, key_3d, value_3d]
46+
if len(sdpa_node.inputs) > 3:
47+
mask = sdpa_node.inputs[3]
48+
49+
if self._use_mask_broadcast:
50+
one = op.Constant(value_ints=[1])
51+
query_length = op.Shape(query, start=2, end=3)
52+
shape_11S1 = op.Concat(one, one, query_length, one, axis=0)
53+
mask = op.Expand(mask, shape_11S1)
54+
55+
inputs.extend([None, None, mask])
56+
57+
output = op.MultiHeadAttention(
58+
*inputs,
59+
num_heads=self._num_heads,
60+
scale=scale,
61+
_domain="com.microsoft",
62+
)
63+
output_4d = op.Reshape(output, to_4d_shape)
64+
output = op.Transpose(output_4d, perm=[0, 2, 1, 3])
65+
return output
66+
67+
68+
_rules = pattern.RewriteRuleSet([SDPAImplementation.rule()])
69+
70+
replace_sdpa_by_mha = _fusion_utils.apply_fusion_rules(_rules)

0 commit comments

Comments
 (0)