@@ -78,13 +78,11 @@ def pattern(
78
78
value_BSDkv ,
79
79
past_key ,
80
80
past_value ,
81
- input_ids ,
82
- past_seq_length ,
83
- total_seq_length ,
81
+ position_ids_q ,
82
+ position_ids_k ,
84
83
cos ,
85
84
sin ,
86
- some_kv_cache ,
87
- shape_B111 ,
85
+ mask ,
88
86
):
89
87
# Reshape query from (B, S, D) to (B, S, H, D/H)
90
88
query_BSHDh = op .Reshape (query_BSD , pattern .ANY_VALUE , _outputs = ["query_BSHDh" ])
@@ -101,10 +99,6 @@ def pattern(
101
99
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
102
100
value_BHkvSDh = op .Transpose (value_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
103
101
104
- position_ids = op .Range (past_seq_length , total_seq_length , 1 )
105
- position_ids_q = op .Unsqueeze (position_ids , [0 ])
106
- position_ids_k = op .Unsqueeze (position_ids , [0 ])
107
-
108
102
query_BHSDh_rope = op .RotaryEmbedding (
109
103
query_BHSDh ,
110
104
position_ids_q ,
@@ -141,15 +135,13 @@ def pattern(
141
135
value_seq_BHkvGTDh , pattern .ANY_VALUE , _outputs = ["value_seq_BHTDh" ]
142
136
)
143
137
144
- mask = causal_mask_pattern (op , input_ids , some_kv_cache , shape_B111 )
145
-
146
- key_seq_BHDhT = op .Transpose (key_seq_BHTDh , perm = [0 , 1 , 3 , 2 ])
147
138
attention_BHSDh = op .SDPA (
148
139
query_BHSDh_rope ,
149
- key_seq_BHDhT ,
140
+ key_seq_BHTDh ,
150
141
value_seq_BHTDh ,
151
142
mask ,
152
- _domain = "ai.onnxruntime.fusion" ,
143
+ key_format = "BHSd" ,
144
+ _domain = "ai.onnxruntime._fusion" ,
153
145
)
154
146
155
147
# Transpose attention back to (B, S, H, D/H)
@@ -209,8 +201,8 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
209
201
# Rotary embedding attributes
210
202
query_rotary_attributes = query_BHSDh_rope .producer ().attributes
211
203
key_rotary_attributes = key_BHkvSDh_rope .producer ().attributes
212
- query_interleaved = query_rotary_attributes .get ("interleaved" , 0 )
213
- key_interleaved = key_rotary_attributes .get ("interleaved" , 0 )
204
+ query_interleaved = query_rotary_attributes .get_int ("interleaved" , 0 )
205
+ key_interleaved = key_rotary_attributes .get_int ("interleaved" , 0 )
214
206
if query_interleaved != key_interleaved :
215
207
return pattern .MatchResult ().fail (
216
208
"Rotary embedding interleaved attribute mismatch" ,
@@ -228,42 +220,104 @@ def rewrite(
228
220
value_BSDkv ,
229
221
past_key ,
230
222
past_value ,
231
- total_seq_length ,
223
+ position_ids_q ,
224
+ position_ids_k ,
232
225
cos ,
233
226
sin ,
227
+ mask ,
234
228
** _ ,
235
229
):
236
- total_seq_length_int32 = op .Cast (total_seq_length , to = ir .DataType .INT32 )
237
- one_0D = op .Constant (value_int = 1 )
238
- one_0D_int32 = op .Cast (one_0D , to = ir .DataType .INT32 )
239
- seqlens_k_0D = op .Sub (total_seq_length_int32 , one_0D_int32 )
240
- zero_1D = op .Constant (value_int = 0 , dtype = ir .DataType .INT64 , shape = [1 ])
241
- seqlens_k = op .Unsqueeze (seqlens_k_0D , zero_1D )
242
-
243
- return op .GroupQueryAttention (
230
+ return op .GQA (
231
+ mask ,
232
+ position_ids_k ,
233
+ position_ids_q ,
244
234
query_BSD ,
245
235
key_BSDkv ,
246
236
value_BSDkv ,
247
237
past_key ,
248
238
past_value ,
249
- seqlens_k ,
250
- total_seq_length_int32 ,
239
+ None , # seqlens_k,
240
+ None , # total_seq_length_int32,
251
241
cos ,
252
242
sin ,
253
- # mask, # TODO: this is not a valid input for GQA
254
243
num_heads = self .num_heads ,
255
244
kv_num_heads = self .kv_num_heads ,
256
245
do_rotary = 1 ,
257
246
rotary_interleaved = self ._interleaved ,
258
247
# skipped optional attributes: local_window_size, scale, smooth_softmax, softcap
259
- _domain = "com.microsoft " ,
248
+ _domain = "ai.onnxruntime._fusion " ,
260
249
_outputs = 3 ,
261
250
)
262
251
263
252
264
- _rule1 = GroupQueryAttention .rule ()
253
+ class GQACausalMask (pattern .RewriteRuleClassBase ):
254
+ def __init__ (self ):
255
+ super ().__init__ ("GQACausalMask" , remove_nodes = False )
256
+
257
+ def pattern (
258
+ self ,
259
+ op ,
260
+ mask ,
261
+ input_ids ,
262
+ some_kv_cache ,
263
+ shape_B111 ,
264
+ past_seq_length ,
265
+ total_seq_length ,
266
+ ):
267
+ mask = causal_mask_pattern (op , input_ids , some_kv_cache , shape_B111 )
268
+ position_ids = op .Range (past_seq_length , total_seq_length , 1 )
269
+ position_ids_q = op .Unsqueeze (position_ids , [0 ])
270
+ position_ids_k = op .Unsqueeze (position_ids , [0 ])
271
+ return op .GQA (
272
+ mask ,
273
+ position_ids_k ,
274
+ position_ids_q ,
275
+ _allow_other_inputs = True ,
276
+ _domain = "ai.onnxruntime._fusion" ,
277
+ _outputs = ["attn_output" , "key_seq" , "value_seq" ],
278
+ )
279
+
280
+ def rewrite (
281
+ self ,
282
+ op ,
283
+ total_seq_length ,
284
+ attn_output ,
285
+ ** _ ,
286
+ ):
287
+ # Construct total_seq_length_int32 and seqlens_k
288
+ total_seq_length_int32 = op .Cast (total_seq_length , to = ir .DataType .INT32 )
289
+ one_0D = op .Constant (value_int = 1 )
290
+ one_0D_int32 = op .Cast (one_0D , to = ir .DataType .INT32 )
291
+ seqlens_k_0D = op .Sub (total_seq_length_int32 , one_0D_int32 )
292
+ zero_1D = op .Constant (value_int = 0 , dtype = ir .DataType .INT64 , shape = [1 ])
293
+ seqlens_k = op .Unsqueeze (seqlens_k_0D , zero_1D )
294
+
295
+ gqa_node = attn_output .producer ()
296
+ assert len (gqa_node .inputs ) == 12 , (
297
+ f"Expected 12 inputs for GQA node, got { len (gqa_node .inputs )} "
298
+ )
299
+ query , key , value , past_key , past_value = gqa_node .inputs [3 :8 ]
300
+ cos , sin = gqa_node .inputs [10 :12 ]
301
+ updated_inputs = [
302
+ query ,
303
+ key ,
304
+ value ,
305
+ past_key ,
306
+ past_value ,
307
+ seqlens_k ,
308
+ total_seq_length_int32 ,
309
+ cos ,
310
+ sin ,
311
+ ]
312
+ attributes = gqa_node .attributes
313
+ return op .GroupQueryAttention (
314
+ * updated_inputs , ** attributes , _domain = "com.microsoft" , _outputs = 3
315
+ )
316
+
265
317
266
- gqa_rules = pattern .RewriteRuleSet ([_rule1 ])
318
+ _basic_gqa_rule = GroupQueryAttention .rule ()
319
+ _gqa_causal_mask_rule = GQACausalMask .rule ()
267
320
321
+ gqa_rules = pattern .RewriteRuleSet ([_basic_gqa_rule , _gqa_causal_mask_rule ])
268
322
269
323
fuse_gqa = _fusion_utils .apply_fusion_rules (gqa_rules )
0 commit comments