@@ -44,7 +44,10 @@ def _unmasked_pre_div_sdpa_script(query, key, value):
44
44
scaled_key = op .Div (key_transposed , divisor )
45
45
attn_score = op .MatMul (scaled_query , scaled_key )
46
46
attn_weight = op .Softmax (attn_score , axis = - 1 )
47
- attn_output = op .MatMul (attn_weight , value )
47
+ is_nan = op .IsNaN (attn_weight )
48
+ zero = op .Constant (value_float = 0.0 )
49
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
50
+ attn_output = op .MatMul (adj_attn_weight , value )
48
51
return attn_output
49
52
50
53
@@ -56,7 +59,10 @@ def _unmasked_pre_mul_sdpa_script(query, key, value):
56
59
scaled_key = op .Mul (key_transposed , multiplier )
57
60
attn_score = op .MatMul (scaled_query , scaled_key )
58
61
attn_weight = op .Softmax (attn_score , axis = - 1 )
59
- attn_output = op .MatMul (attn_weight , value )
62
+ is_nan = op .IsNaN (attn_weight )
63
+ zero = op .Constant (value_float = 0.0 )
64
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
65
+ attn_output = op .MatMul (adj_attn_weight , value )
60
66
return attn_output
61
67
62
68
@@ -67,7 +73,10 @@ def _unmasked_post_div_sdpa_script(query, key, value):
67
73
attn_score = op .MatMul (query , key_transposed )
68
74
scaled_attn_score = op .Div (attn_score , divisor )
69
75
attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
70
- attn_output = op .MatMul (attn_weight , value )
76
+ is_nan = op .IsNaN (attn_weight )
77
+ zero = op .Constant (value_float = 0.0 )
78
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
79
+ attn_output = op .MatMul (adj_attn_weight , value )
71
80
return attn_output
72
81
73
82
@@ -78,7 +87,10 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
78
87
attn_score = op .MatMul (query , key_transposed )
79
88
scaled_attn_score = op .Mul (attn_score , multiplier )
80
89
attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
81
- attn_output = op .MatMul (attn_weight , value )
90
+ is_nan = op .IsNaN (attn_weight )
91
+ zero = op .Constant (value_float = 0.0 )
92
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
93
+ attn_output = op .MatMul (adj_attn_weight , value )
82
94
return attn_output
83
95
84
96
@@ -90,7 +102,10 @@ def _custom_scale_pre_div_sdpa_script(query, key, value):
90
102
scaled_key = op .Div (key_transposed , divisor )
91
103
attn_score = op .MatMul (scaled_query , scaled_key )
92
104
attn_weight = op .Softmax (attn_score , axis = - 1 )
93
- attn_output = op .MatMul (attn_weight , value )
105
+ is_nan = op .IsNaN (attn_weight )
106
+ zero = op .Constant (value_float = 0.0 )
107
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
108
+ attn_output = op .MatMul (adj_attn_weight , value )
94
109
return attn_output
95
110
96
111
@@ -102,7 +117,10 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
102
117
scaled_key = op .Mul (key_transposed , multiplier )
103
118
attn_score = op .MatMul (scaled_query , scaled_key )
104
119
attn_weight = op .Softmax (attn_score , axis = - 1 )
105
- attn_output = op .MatMul (attn_weight , value )
120
+ is_nan = op .IsNaN (attn_weight )
121
+ zero = op .Constant (value_float = 0.0 )
122
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
123
+ attn_output = op .MatMul (adj_attn_weight , value )
106
124
return attn_output
107
125
108
126
@@ -115,7 +133,10 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
115
133
scaled_key = op .Mul (key_transposed , multiplier_k )
116
134
attn_score = op .MatMul (scaled_query , scaled_key )
117
135
attn_weight = op .Softmax (attn_score , axis = - 1 )
118
- attn_output = op .MatMul (attn_weight , value )
136
+ is_nan = op .IsNaN (attn_weight )
137
+ zero = op .Constant (value_float = 0.0 )
138
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
139
+ attn_output = op .MatMul (adj_attn_weight , value )
119
140
return attn_output
120
141
121
142
@@ -126,7 +147,10 @@ def _custom_scale_post_div_sdpa_script(query, key, value):
126
147
attn_score = op .MatMul (query , key_transposed )
127
148
scaled_attn_score = op .Div (attn_score , divisor )
128
149
attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
129
- attn_output = op .MatMul (attn_weight , value )
150
+ is_nan = op .IsNaN (attn_weight )
151
+ zero = op .Constant (value_float = 0.0 )
152
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
153
+ attn_output = op .MatMul (adj_attn_weight , value )
130
154
return attn_output
131
155
132
156
@@ -137,7 +161,10 @@ def _custom_scale_post_mul_sdpa_script(query, key, value):
137
161
attn_score = op .MatMul (query , key_transposed )
138
162
scaled_attn_score = op .Mul (attn_score , multiplier )
139
163
attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
140
- attn_output = op .MatMul (attn_weight , value )
164
+ is_nan = op .IsNaN (attn_weight )
165
+ zero = op .Constant (value_float = 0.0 )
166
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
167
+ attn_output = op .MatMul (adj_attn_weight , value )
141
168
return attn_output
142
169
143
170
@@ -150,7 +177,10 @@ def _masked_pre_div_sdpa_script(query, key, value, mask):
150
177
attn_score = op .MatMul (scaled_query , scaled_key )
151
178
masked_attn_score = op .Add (attn_score , mask )
152
179
attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
153
- attn_output = op .MatMul (attn_weight , value )
180
+ is_nan = op .IsNaN (attn_weight )
181
+ zero = op .Constant (value_float = 0.0 )
182
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
183
+ attn_output = op .MatMul (adj_attn_weight , value )
154
184
return attn_output
155
185
156
186
@@ -163,7 +193,10 @@ def _masked_pre_mul_sdpa_script(query, key, value, mask):
163
193
attn_score = op .MatMul (scaled_query , scaled_key )
164
194
masked_attn_score = op .Add (attn_score , mask )
165
195
attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
166
- attn_output = op .MatMul (attn_weight , value )
196
+ is_nan = op .IsNaN (attn_weight )
197
+ zero = op .Constant (value_float = 0.0 )
198
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
199
+ attn_output = op .MatMul (adj_attn_weight , value )
167
200
return attn_output
168
201
169
202
@@ -175,7 +208,10 @@ def _masked_post_div_sdpa_script(query, key, value, mask):
175
208
scaled_attn_score = op .Div (attn_score , divisor )
176
209
masked_attn_score = op .Add (scaled_attn_score , mask )
177
210
attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
178
- attn_output = op .MatMul (attn_weight , value )
211
+ is_nan = op .IsNaN (attn_weight )
212
+ zero = op .Constant (value_float = 0.0 )
213
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
214
+ attn_output = op .MatMul (adj_attn_weight , value )
179
215
return attn_output
180
216
181
217
@@ -187,7 +223,10 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
187
223
scaled_attn_score = op .Mul (attn_score , multiplier )
188
224
masked_attn_score = op .Add (scaled_attn_score , mask )
189
225
attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
190
- attn_output = op .MatMul (attn_weight , value )
226
+ is_nan = op .IsNaN (attn_weight )
227
+ zero = op .Constant (value_float = 0.0 )
228
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
229
+ attn_output = op .MatMul (adj_attn_weight , value )
191
230
return attn_output
192
231
193
232
@@ -200,7 +239,10 @@ def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask):
200
239
attn_score = op .MatMul (scaled_query , scaled_key )
201
240
masked_attn_score = op .Add (attn_score , mask )
202
241
attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
203
- attn_output = op .MatMul (attn_weight , value )
242
+ is_nan = op .IsNaN (attn_weight )
243
+ zero = op .Constant (value_float = 0.0 )
244
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
245
+ attn_output = op .MatMul (adj_attn_weight , value )
204
246
return attn_output
205
247
206
248
@@ -213,7 +255,10 @@ def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask):
213
255
attn_score = op .MatMul (scaled_query , scaled_key )
214
256
masked_attn_score = op .Add (attn_score , mask )
215
257
attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
216
- attn_output = op .MatMul (attn_weight , value )
258
+ is_nan = op .IsNaN (attn_weight )
259
+ zero = op .Constant (value_float = 0.0 )
260
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
261
+ attn_output = op .MatMul (adj_attn_weight , value )
217
262
return attn_output
218
263
219
264
@@ -225,7 +270,10 @@ def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask):
225
270
scaled_attn_score = op .Div (attn_score , divisor )
226
271
masked_attn_score = op .Add (scaled_attn_score , mask )
227
272
attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
228
- attn_output = op .MatMul (attn_weight , value )
273
+ is_nan = op .IsNaN (attn_weight )
274
+ zero = op .Constant (value_float = 0.0 )
275
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
276
+ attn_output = op .MatMul (adj_attn_weight , value )
229
277
return attn_output
230
278
231
279
@@ -237,7 +285,10 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
237
285
scaled_attn_score = op .Mul (attn_score , multiplier )
238
286
masked_attn_score = op .Add (scaled_attn_score , mask )
239
287
attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
240
- attn_output = op .MatMul (attn_weight , value )
288
+ is_nan = op .IsNaN (attn_weight )
289
+ zero = op .Constant (value_float = 0.0 )
290
+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
291
+ attn_output = op .MatMul (adj_attn_weight , value )
241
292
return attn_output
242
293
243
294
0 commit comments