Skip to content

Commit a1db753

Browse files
authored
Add NaN handling in softmax pattern in SDPA fusion (#2593)
Add NaN handling in softmax pattern in SDPA fusion Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 7227655 commit a1db753

File tree

2 files changed

+71
-17
lines changed

2 files changed

+71
-17
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def pattern(
8888
)
8989

9090
attn_weight = op.Softmax(attn_score, axis=-1)
91+
is_nan = op.IsNaN(attn_weight)
92+
adj_attn_weight = op.Where(is_nan, 0.0, attn_weight)
93+
attn_weight = pattern.OrValue([adj_attn_weight, attn_weight])
9194
attn_output = op.MatMul(attn_weight, value)
9295
return attn_output
9396

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def _unmasked_pre_div_sdpa_script(query, key, value):
4444
scaled_key = op.Div(key_transposed, divisor)
4545
attn_score = op.MatMul(scaled_query, scaled_key)
4646
attn_weight = op.Softmax(attn_score, axis=-1)
47-
attn_output = op.MatMul(attn_weight, value)
47+
is_nan = op.IsNaN(attn_weight)
48+
zero = op.Constant(value_float=0.0)
49+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
50+
attn_output = op.MatMul(adj_attn_weight, value)
4851
return attn_output
4952

5053

@@ -56,7 +59,10 @@ def _unmasked_pre_mul_sdpa_script(query, key, value):
5659
scaled_key = op.Mul(key_transposed, multiplier)
5760
attn_score = op.MatMul(scaled_query, scaled_key)
5861
attn_weight = op.Softmax(attn_score, axis=-1)
59-
attn_output = op.MatMul(attn_weight, value)
62+
is_nan = op.IsNaN(attn_weight)
63+
zero = op.Constant(value_float=0.0)
64+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
65+
attn_output = op.MatMul(adj_attn_weight, value)
6066
return attn_output
6167

6268

@@ -67,7 +73,10 @@ def _unmasked_post_div_sdpa_script(query, key, value):
6773
attn_score = op.MatMul(query, key_transposed)
6874
scaled_attn_score = op.Div(attn_score, divisor)
6975
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
70-
attn_output = op.MatMul(attn_weight, value)
76+
is_nan = op.IsNaN(attn_weight)
77+
zero = op.Constant(value_float=0.0)
78+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
79+
attn_output = op.MatMul(adj_attn_weight, value)
7180
return attn_output
7281

7382

@@ -78,7 +87,10 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7887
attn_score = op.MatMul(query, key_transposed)
7988
scaled_attn_score = op.Mul(attn_score, multiplier)
8089
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
81-
attn_output = op.MatMul(attn_weight, value)
90+
is_nan = op.IsNaN(attn_weight)
91+
zero = op.Constant(value_float=0.0)
92+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
93+
attn_output = op.MatMul(adj_attn_weight, value)
8294
return attn_output
8395

8496

@@ -90,7 +102,10 @@ def _custom_scale_pre_div_sdpa_script(query, key, value):
90102
scaled_key = op.Div(key_transposed, divisor)
91103
attn_score = op.MatMul(scaled_query, scaled_key)
92104
attn_weight = op.Softmax(attn_score, axis=-1)
93-
attn_output = op.MatMul(attn_weight, value)
105+
is_nan = op.IsNaN(attn_weight)
106+
zero = op.Constant(value_float=0.0)
107+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
108+
attn_output = op.MatMul(adj_attn_weight, value)
94109
return attn_output
95110

96111

@@ -102,7 +117,10 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
102117
scaled_key = op.Mul(key_transposed, multiplier)
103118
attn_score = op.MatMul(scaled_query, scaled_key)
104119
attn_weight = op.Softmax(attn_score, axis=-1)
105-
attn_output = op.MatMul(attn_weight, value)
120+
is_nan = op.IsNaN(attn_weight)
121+
zero = op.Constant(value_float=0.0)
122+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
123+
attn_output = op.MatMul(adj_attn_weight, value)
106124
return attn_output
107125

108126

@@ -115,7 +133,10 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
115133
scaled_key = op.Mul(key_transposed, multiplier_k)
116134
attn_score = op.MatMul(scaled_query, scaled_key)
117135
attn_weight = op.Softmax(attn_score, axis=-1)
118-
attn_output = op.MatMul(attn_weight, value)
136+
is_nan = op.IsNaN(attn_weight)
137+
zero = op.Constant(value_float=0.0)
138+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
139+
attn_output = op.MatMul(adj_attn_weight, value)
119140
return attn_output
120141

121142

@@ -126,7 +147,10 @@ def _custom_scale_post_div_sdpa_script(query, key, value):
126147
attn_score = op.MatMul(query, key_transposed)
127148
scaled_attn_score = op.Div(attn_score, divisor)
128149
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
129-
attn_output = op.MatMul(attn_weight, value)
150+
is_nan = op.IsNaN(attn_weight)
151+
zero = op.Constant(value_float=0.0)
152+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
153+
attn_output = op.MatMul(adj_attn_weight, value)
130154
return attn_output
131155

132156

@@ -137,7 +161,10 @@ def _custom_scale_post_mul_sdpa_script(query, key, value):
137161
attn_score = op.MatMul(query, key_transposed)
138162
scaled_attn_score = op.Mul(attn_score, multiplier)
139163
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
140-
attn_output = op.MatMul(attn_weight, value)
164+
is_nan = op.IsNaN(attn_weight)
165+
zero = op.Constant(value_float=0.0)
166+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
167+
attn_output = op.MatMul(adj_attn_weight, value)
141168
return attn_output
142169

143170

@@ -150,7 +177,10 @@ def _masked_pre_div_sdpa_script(query, key, value, mask):
150177
attn_score = op.MatMul(scaled_query, scaled_key)
151178
masked_attn_score = op.Add(attn_score, mask)
152179
attn_weight = op.Softmax(masked_attn_score, axis=-1)
153-
attn_output = op.MatMul(attn_weight, value)
180+
is_nan = op.IsNaN(attn_weight)
181+
zero = op.Constant(value_float=0.0)
182+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
183+
attn_output = op.MatMul(adj_attn_weight, value)
154184
return attn_output
155185

156186

@@ -163,7 +193,10 @@ def _masked_pre_mul_sdpa_script(query, key, value, mask):
163193
attn_score = op.MatMul(scaled_query, scaled_key)
164194
masked_attn_score = op.Add(attn_score, mask)
165195
attn_weight = op.Softmax(masked_attn_score, axis=-1)
166-
attn_output = op.MatMul(attn_weight, value)
196+
is_nan = op.IsNaN(attn_weight)
197+
zero = op.Constant(value_float=0.0)
198+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
199+
attn_output = op.MatMul(adj_attn_weight, value)
167200
return attn_output
168201

169202

@@ -175,7 +208,10 @@ def _masked_post_div_sdpa_script(query, key, value, mask):
175208
scaled_attn_score = op.Div(attn_score, divisor)
176209
masked_attn_score = op.Add(scaled_attn_score, mask)
177210
attn_weight = op.Softmax(masked_attn_score, axis=-1)
178-
attn_output = op.MatMul(attn_weight, value)
211+
is_nan = op.IsNaN(attn_weight)
212+
zero = op.Constant(value_float=0.0)
213+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
214+
attn_output = op.MatMul(adj_attn_weight, value)
179215
return attn_output
180216

181217

@@ -187,7 +223,10 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
187223
scaled_attn_score = op.Mul(attn_score, multiplier)
188224
masked_attn_score = op.Add(scaled_attn_score, mask)
189225
attn_weight = op.Softmax(masked_attn_score, axis=-1)
190-
attn_output = op.MatMul(attn_weight, value)
226+
is_nan = op.IsNaN(attn_weight)
227+
zero = op.Constant(value_float=0.0)
228+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
229+
attn_output = op.MatMul(adj_attn_weight, value)
191230
return attn_output
192231

193232

@@ -200,7 +239,10 @@ def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask):
200239
attn_score = op.MatMul(scaled_query, scaled_key)
201240
masked_attn_score = op.Add(attn_score, mask)
202241
attn_weight = op.Softmax(masked_attn_score, axis=-1)
203-
attn_output = op.MatMul(attn_weight, value)
242+
is_nan = op.IsNaN(attn_weight)
243+
zero = op.Constant(value_float=0.0)
244+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
245+
attn_output = op.MatMul(adj_attn_weight, value)
204246
return attn_output
205247

206248

@@ -213,7 +255,10 @@ def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask):
213255
attn_score = op.MatMul(scaled_query, scaled_key)
214256
masked_attn_score = op.Add(attn_score, mask)
215257
attn_weight = op.Softmax(masked_attn_score, axis=-1)
216-
attn_output = op.MatMul(attn_weight, value)
258+
is_nan = op.IsNaN(attn_weight)
259+
zero = op.Constant(value_float=0.0)
260+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
261+
attn_output = op.MatMul(adj_attn_weight, value)
217262
return attn_output
218263

219264

@@ -225,7 +270,10 @@ def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask):
225270
scaled_attn_score = op.Div(attn_score, divisor)
226271
masked_attn_score = op.Add(scaled_attn_score, mask)
227272
attn_weight = op.Softmax(masked_attn_score, axis=-1)
228-
attn_output = op.MatMul(attn_weight, value)
273+
is_nan = op.IsNaN(attn_weight)
274+
zero = op.Constant(value_float=0.0)
275+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
276+
attn_output = op.MatMul(adj_attn_weight, value)
229277
return attn_output
230278

231279

@@ -237,7 +285,10 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
237285
scaled_attn_score = op.Mul(attn_score, multiplier)
238286
masked_attn_score = op.Add(scaled_attn_score, mask)
239287
attn_weight = op.Softmax(masked_attn_score, axis=-1)
240-
attn_output = op.MatMul(attn_weight, value)
288+
is_nan = op.IsNaN(attn_weight)
289+
zero = op.Constant(value_float=0.0)
290+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
291+
attn_output = op.MatMul(adj_attn_weight, value)
241292
return attn_output
242293

243294

0 commit comments

Comments
 (0)