16
16
from onnxscript import script
17
17
from onnxscript .onnx_opset import opset18 as op
18
18
from onnxscript .onnx_types import FLOAT
19
+ from onnxscript .rewriter .ort_fusions ._test_utils import assert_allclose , ort_run
19
20
from onnxscript .rewriter .ort_fusions .sdpa import fuse_sdpa
21
+ from onnxscript .rewriter .ort_fusions .sdpa_via_mha import replace_sdpa_by_mha
20
22
21
23
B = 2 # batch size
22
24
N = 4 # number of heads
@@ -190,7 +192,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
190
192
191
193
192
194
@script ()
193
- def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
195
+ def _masked_custom_scale_pre_div_sdpa_script (query , key , value , mask ):
194
196
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
195
197
divisor = op .Constant (value_float = SQRT_CUSTOM_DIV_SCALE_FACTOR )
196
198
scaled_query = op .Div (query , divisor )
@@ -203,7 +205,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
203
205
204
206
205
207
@script ()
206
- def _custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
208
+ def _masked_custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
207
209
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
208
210
multiplier = op .Constant (value_float = SQRT_CUSTOM_MUL_SCALE_FACTOR )
209
211
scaled_query = op .Mul (query , multiplier )
@@ -216,7 +218,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
216
218
217
219
218
220
@script ()
219
- def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
221
+ def _masked_custom_scale_post_div_sdpa_script (query , key , value , mask ):
220
222
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
221
223
divisor = op .Constant (value_float = CUSTOM_DIV_SCALE_FACTOR )
222
224
attn_score = op .MatMul (query , key_transposed )
@@ -228,7 +230,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask):
228
230
229
231
230
232
@script ()
231
- def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
233
+ def _masked_custom_scale_post_mul_sdpa_script (query , key , value , mask ):
232
234
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
233
235
multiplier = op .Constant (value_float = CUSTOM_MUL_SCALE_FACTOR )
234
236
attn_score = op .MatMul (query , key_transposed )
@@ -240,15 +242,19 @@ def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
240
242
241
243
242
244
class SDPATestCase :
243
- def __init__ (self , script_func ):
245
+ def __init__ (self , script_func , * , with_mask ):
244
246
self .script_func = script_func
247
+ self .with_mask = with_mask
245
248
246
249
def get_onnx_model (self ):
247
250
if not hasattr (self , "_onnx_model" ):
248
251
qkv_type = FLOAT [B , N , S , H ]
249
252
mask_type = FLOAT [B , N , S , S ]
253
+ input_types = [qkv_type , qkv_type , qkv_type ]
254
+ if self .with_mask :
255
+ input_types .append (mask_type )
250
256
model_proto = self .script_func .to_model_proto (
251
- input_types = [ qkv_type , qkv_type , qkv_type , mask_type ] , output_types = [qkv_type ]
257
+ input_types = input_types , output_types = [qkv_type ]
252
258
)
253
259
self ._onnx_model = ir .serde .deserialize_model (model_proto )
254
260
return self ._onnx_model
@@ -259,8 +265,9 @@ def get_ort_inputs(self):
259
265
"query" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
260
266
"key" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
261
267
"value" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
262
- "mask" : numpy .random .rand (B , N , S , S ).astype (numpy .float32 ),
263
268
}
269
+ if self .with_mask :
270
+ inputs ["mask" ] = numpy .random .rand (B , N , S , S ).astype (numpy .float32 )
264
271
self ._ort_inputs = inputs
265
272
return self ._ort_inputs
266
273
@@ -296,35 +303,35 @@ def get_ort_inputs(self):
296
303
class TestSDPAFusion (unittest .TestCase ):
297
304
@parameterized .parameterized .expand (
298
305
[
299
- ("unmasked_pre_div " , _unmasked_pre_div_sdpa_script ),
300
- ("unmasked_pre_mul " , _unmasked_pre_mul_sdpa_script ),
301
- ("unmasked_post_div " , _unmasked_post_div_sdpa_script ),
302
- ("unmasked_post_mul " , _unmasked_post_mul_sdpa_script ),
303
- ("pre_div " , _masked_pre_div_sdpa_script ),
304
- ("pre_mul " , _masked_pre_mul_sdpa_script ),
305
- ("post_div " , _masked_post_div_sdpa_script ),
306
- ("post_mul " , _masked_post_mul_sdpa_script ),
306
+ ("pre_div " , _unmasked_pre_div_sdpa_script ),
307
+ ("pre_mul " , _unmasked_pre_mul_sdpa_script ),
308
+ ("post_div " , _unmasked_post_div_sdpa_script ),
309
+ ("post_mul " , _unmasked_post_mul_sdpa_script ),
310
+ ("masked_pre_div " , _masked_pre_div_sdpa_script ),
311
+ ("masked_pre_mul " , _masked_pre_mul_sdpa_script ),
312
+ ("masked_post_div " , _masked_post_div_sdpa_script ),
313
+ ("masked_post_mul " , _masked_post_mul_sdpa_script ),
307
314
("custom_scale_post_mul" , _custom_scale_post_mul_sdpa_script ),
308
315
("custom_scale_post_div" , _custom_scale_post_div_sdpa_script ),
309
316
("custom_scale_pre_mul" , _custom_scale_pre_mul_sdpa_script ),
310
317
("custom_scale_pre_div" , _custom_scale_pre_div_sdpa_script ),
311
- ("custom_scale_post_mul_masked " , _custom_scale_post_mul_sdpa_script ),
312
- ("custom_scale_post_div_masked " , _custom_scale_post_div_sdpa_script ),
313
- ("custom_scale_pre_mul_masked " , _custom_scale_pre_mul_sdpa_script ),
314
- ("custom_scale_pre_div_masked " , _custom_scale_pre_div_sdpa_script ),
318
+ ("masked_custom_scale_post_mul " , _masked_custom_scale_post_mul_sdpa_script ),
319
+ ("masked_custom_scale_post_div " , _masked_custom_scale_post_div_sdpa_script ),
320
+ ("masked_custom_scale_pre_mul " , _masked_custom_scale_pre_mul_sdpa_script ),
321
+ ("masked_custom_scale_pre_div " , _masked_custom_scale_pre_div_sdpa_script ),
315
322
(
316
323
"_custom_multi_scale_pre_mul_sdpa_script" ,
317
324
_custom_multi_scale_pre_mul_sdpa_script ,
318
325
),
319
326
]
320
327
)
321
328
def test_sdpa_fusion (self , name , script_func ):
322
- test_case = SDPATestCase (script_func )
329
+ test_case = SDPATestCase (script_func , with_mask = "masked" in name )
323
330
model = test_case .get_onnx_model ()
324
331
onnxscript .optimizer .optimize (model )
325
332
326
- # inputs = test_case.get_ort_inputs()
327
- # original_outputs = ort_run("original", model, inputs)
333
+ inputs = test_case .get_ort_inputs ()
334
+ original_outputs = ort_run ("original" , model , inputs )
328
335
329
336
count = fuse_sdpa (model , debug = True )
330
337
self .assertGreater (count , 0 )
@@ -347,8 +354,12 @@ def test_sdpa_fusion(self, name, script_func):
347
354
# of scale_factor (is =default_scaling_factor)
348
355
self .assertIsNone (sdpa_node .attributes .get ("scale" ))
349
356
350
- # new_outputs = ort_run("optimized", model, inputs)
351
- # assert_allclose(new_outputs, original_outputs)
357
+ replace_sdpa_by_mha (model , debug = True )
358
+
359
+ self .assertNotIn ("SDPA" , [n .op_type for n in model .graph ])
360
+
361
+ new_outputs = ort_run ("optimized" , model , inputs )
362
+ assert_allclose (new_outputs , original_outputs )
352
363
353
364
def test_invalid_sdpa_fusion_value_batch_dim (self ):
354
365
test_case = InvalidSDPATestCase (_masked_pre_mul_sdpa_script )
0 commit comments