13
13
SAGE_ATTN_AVAILABLE ,
14
14
SPARGE_ATTN_AVAILABLE ,
15
15
)
16
+ from diffsynth_engine .utils .platform import DTYPE_FP8
16
17
17
18
FA3_MAX_HEADDIM = 256
18
19
@@ -125,12 +126,13 @@ def attention(
125
126
None ,
126
127
"auto" ,
127
128
"eager" ,
128
- "flash_attn_2" ,
129
- "flash_attn_3" ,
129
+ "fa2" ,
130
+ "fa3" ,
131
+ "fa3_fp8" ,
130
132
"xformers" ,
131
133
"sdpa" ,
132
- "sage_attn " ,
133
- "sparge_attn " ,
134
+ "sage " ,
135
+ "sparge " ,
134
136
]
135
137
flash_attn3_compatible = q .shape [- 1 ] <= FA3_MAX_HEADDIM
136
138
if attn_impl is None or attn_impl == "auto" :
@@ -139,9 +141,13 @@ def attention(
139
141
return flash_attn3 (q , k , v , softmax_scale = scale )
140
142
else :
141
143
if not flash_attn3_compatible :
142
- logger .warning (f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation" )
144
+ logger .warning (
145
+ f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation"
146
+ )
143
147
else :
144
- logger .debug ("flash_attn_3 does not support attention mask, will use fallback attention implementation" )
148
+ logger .debug (
149
+ "flash_attn_3 does not support attention mask, will use fallback attention implementation"
150
+ )
145
151
if XFORMERS_AVAILABLE :
146
152
return xformers_attn (q , k , v , attn_mask = attn_mask , scale = scale )
147
153
if SDPA_AVAILABLE :
@@ -152,23 +158,31 @@ def attention(
152
158
else :
153
159
if attn_impl == "eager" :
154
160
return eager_attn (q , k , v , attn_mask = attn_mask , scale = scale )
155
- if attn_impl == "flash_attn_3 " :
161
+ if attn_impl == "fa3" or attn_impl == "fa3_fp8 " :
156
162
if not flash_attn3_compatible :
157
163
raise RuntimeError (
158
164
f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } "
159
165
)
160
166
if attn_mask is not None :
161
167
raise RuntimeError ("flash_attn_3 does not support attention mask" )
162
- return flash_attn3 (q , k , v , softmax_scale = scale )
163
- if attn_impl == "flash_attn_2" :
168
+ if attn_impl == "fa3" :
169
+ return flash_attn3 (q , k , v , softmax_scale = scale )
170
+ else :
171
+ origin_dtype = q .dtype
172
+ q = q .to (dtype = DTYPE_FP8 )
173
+ k = k .to (dtype = DTYPE_FP8 )
174
+ v = v .to (dtype = DTYPE_FP8 )
175
+ out = flash_attn3 (q , k , v , softmax_scale = scale )
176
+ return out .to (dtype = origin_dtype )
177
+ if attn_impl == "fa2" :
164
178
return flash_attn2 (q , k , v , softmax_scale = scale )
165
179
if attn_impl == "xformers" :
166
180
return xformers_attn (q , k , v , attn_mask = attn_mask , scale = scale )
167
181
if attn_impl == "sdpa" :
168
182
return sdpa_attn (q , k , v , attn_mask = attn_mask , scale = scale )
169
- if attn_impl == "sage_attn " :
183
+ if attn_impl == "sage " :
170
184
return sage_attn (q , k , v , attn_mask = attn_mask , scale = scale )
171
- if attn_impl == "sparge_attn " :
185
+ if attn_impl == "sparge " :
172
186
return sparge_attn (
173
187
q ,
174
188
k ,
@@ -247,12 +261,14 @@ def long_context_attention(
247
261
assert attn_impl in [
248
262
None ,
249
263
"auto" ,
250
- "flash_attn_2" ,
251
- "flash_attn_3" ,
264
+ "fa2" ,
265
+ "fa3" ,
266
+ "fa3_fp8" ,
252
267
"sdpa" ,
253
- "sage_attn " ,
254
- "sparge_attn " ,
268
+ "sage " ,
269
+ "sparge " ,
255
270
]
271
+ assert attn_mask is None , "long context attention does not support attention mask"
256
272
flash_attn3_compatible = q .shape [- 1 ] <= FA3_MAX_HEADDIM
257
273
if attn_impl is None or attn_impl == "auto" :
258
274
if FLASH_ATTN_3_AVAILABLE :
@@ -268,20 +284,27 @@ def long_context_attention(
268
284
return LongContextAttention (attn_type = AttnType .FA )(q , k , v , softmax_scale = scale )
269
285
raise ValueError ("No available long context attention implementation" )
270
286
else :
271
- if attn_impl == "flash_attn_3" :
272
- if flash_attn3_compatible :
273
- return LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
274
- else :
287
+ if attn_impl == "fa3" or attn_impl == "fa3_fp8" :
288
+ if not flash_attn3_compatible :
275
289
raise RuntimeError (
276
290
f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } "
277
291
)
278
- if attn_impl == "flash_attn_2" :
292
+ if attn_impl == "fa3" :
293
+ return LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
294
+
295
+ origin_dtype = q .dtype
296
+ q = q .to (dtype = DTYPE_FP8 )
297
+ k = k .to (dtype = DTYPE_FP8 )
298
+ v = v .to (dtype = DTYPE_FP8 )
299
+ out = LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
300
+ return out .to (dtype = origin_dtype )
301
+ if attn_impl == "fa2" :
279
302
return LongContextAttention (attn_type = AttnType .FA )(q , k , v , softmax_scale = scale )
280
303
if attn_impl == "sdpa" :
281
304
return LongContextAttention (attn_type = AttnType .TORCH )(q , k , v , softmax_scale = scale )
282
- if attn_impl == "sage_attn " :
283
- return LongContextAttention (attn_type = AttnType .SAGE_FP8 )(q , k , v , softmax_scale = scale )
284
- if attn_impl == "sparge_attn " :
305
+ if attn_impl == "sage " :
306
+ return LongContextAttention (attn_type = AttnType .SAGE_AUTO )(q , k , v , softmax_scale = scale )
307
+ if attn_impl == "sparge " :
285
308
attn_processor = SparseAttentionMeansim ()
286
309
# default args from spas_sage2_attn_meansim_cuda
287
310
attn_processor .smooth_k = torch .tensor (kwargs .get ("sparge_smooth_k" , True ))
0 commit comments