Skip to content

Commit 28338b9

Browse files
authored
supports flash attn 3 fp8 (#174)
1 parent 39cc5c5 commit 28338b9

File tree

8 files changed

+89
-33
lines changed

8 files changed

+89
-33
lines changed

diffsynth_engine/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
FluxStateDicts,
1212
WanStateDicts,
1313
QwenImageStateDicts,
14+
AttnImpl,
1415
ControlNetParams,
1516
ControlType,
1617
)
@@ -54,6 +55,7 @@
5455
"FluxStateDicts",
5556
"WanStateDicts",
5657
"QwenImageStateDicts",
58+
"AttnImpl",
5759
"ControlNetParams",
5860
"ControlType",
5961
"SDImagePipeline",

diffsynth_engine/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
WanStateDicts,
1818
WanS2VStateDicts,
1919
QwenImageStateDicts,
20+
AttnImpl,
2021
)
2122
from .controlnet import ControlType, ControlNetParams
2223

@@ -39,6 +40,7 @@
3940
"WanStateDicts",
4041
"WanS2VStateDicts",
4142
"QwenImageStateDicts",
43+
"AttnImpl",
4244
"ControlType",
4345
"ControlNetParams",
4446
]

diffsynth_engine/configs/pipeline.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import torch
3+
from enum import Enum
34
from dataclasses import dataclass, field
45
from typing import List, Dict, Tuple, Optional
56

@@ -19,9 +20,21 @@ class BaseConfig:
1920
offload_to_disk: bool = False
2021

2122

23+
class AttnImpl(Enum):
24+
AUTO = "auto"
25+
EAGER = "eager" # Native Attention
26+
FA2 = "fa2" # Flash Attention 2
27+
FA3 = "fa3" # Flash Attention 3
28+
FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
29+
XFORMERS = "xformers" # XFormers
30+
SDPA = "sdpa" # Scaled Dot Product Attention
31+
SAGE = "sage" # Sage Attention
32+
SPARGE = "sparge" # Sparge Attention
33+
34+
2235
@dataclass
2336
class AttentionConfig:
24-
dit_attn_impl: str = "auto"
37+
dit_attn_impl: AttnImpl = AttnImpl.AUTO
2538
# Sparge Attention
2639
sparge_smooth_k: bool = True
2740
sparge_cdfthreshd: float = 0.6

diffsynth_engine/models/basic/attention.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
SAGE_ATTN_AVAILABLE,
1414
SPARGE_ATTN_AVAILABLE,
1515
)
16+
from diffsynth_engine.utils.platform import DTYPE_FP8
1617

1718
FA3_MAX_HEADDIM = 256
1819

@@ -125,12 +126,13 @@ def attention(
125126
None,
126127
"auto",
127128
"eager",
128-
"flash_attn_2",
129-
"flash_attn_3",
129+
"fa2",
130+
"fa3",
131+
"fa3_fp8",
130132
"xformers",
131133
"sdpa",
132-
"sage_attn",
133-
"sparge_attn",
134+
"sage",
135+
"sparge",
134136
]
135137
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
136138
if attn_impl is None or attn_impl == "auto":
@@ -139,9 +141,13 @@ def attention(
139141
return flash_attn3(q, k, v, softmax_scale=scale)
140142
else:
141143
if not flash_attn3_compatible:
142-
logger.warning(f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation")
144+
logger.warning(
145+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
146+
)
143147
else:
144-
logger.debug("flash_attn_3 does not support attention mask, will use fallback attention implementation")
148+
logger.debug(
149+
"flash_attn_3 does not support attention mask, will use fallback attention implementation"
150+
)
145151
if XFORMERS_AVAILABLE:
146152
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
147153
if SDPA_AVAILABLE:
@@ -152,23 +158,31 @@ def attention(
152158
else:
153159
if attn_impl == "eager":
154160
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
155-
if attn_impl == "flash_attn_3":
161+
if attn_impl == "fa3" or attn_impl == "fa3_fp8":
156162
if not flash_attn3_compatible:
157163
raise RuntimeError(
158164
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
159165
)
160166
if attn_mask is not None:
161167
raise RuntimeError("flash_attn_3 does not support attention mask")
162-
return flash_attn3(q, k, v, softmax_scale=scale)
163-
if attn_impl == "flash_attn_2":
168+
if attn_impl == "fa3":
169+
return flash_attn3(q, k, v, softmax_scale=scale)
170+
else:
171+
origin_dtype = q.dtype
172+
q = q.to(dtype=DTYPE_FP8)
173+
k = k.to(dtype=DTYPE_FP8)
174+
v = v.to(dtype=DTYPE_FP8)
175+
out = flash_attn3(q, k, v, softmax_scale=scale)
176+
return out.to(dtype=origin_dtype)
177+
if attn_impl == "fa2":
164178
return flash_attn2(q, k, v, softmax_scale=scale)
165179
if attn_impl == "xformers":
166180
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
167181
if attn_impl == "sdpa":
168182
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
169-
if attn_impl == "sage_attn":
183+
if attn_impl == "sage":
170184
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
171-
if attn_impl == "sparge_attn":
185+
if attn_impl == "sparge":
172186
return sparge_attn(
173187
q,
174188
k,
@@ -247,12 +261,14 @@ def long_context_attention(
247261
assert attn_impl in [
248262
None,
249263
"auto",
250-
"flash_attn_2",
251-
"flash_attn_3",
264+
"fa2",
265+
"fa3",
266+
"fa3_fp8",
252267
"sdpa",
253-
"sage_attn",
254-
"sparge_attn",
268+
"sage",
269+
"sparge",
255270
]
271+
assert attn_mask is None, "long context attention does not support attention mask"
256272
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
257273
if attn_impl is None or attn_impl == "auto":
258274
if FLASH_ATTN_3_AVAILABLE:
@@ -268,20 +284,27 @@ def long_context_attention(
268284
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
269285
raise ValueError("No available long context attention implementation")
270286
else:
271-
if attn_impl == "flash_attn_3":
272-
if flash_attn3_compatible:
273-
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
274-
else:
287+
if attn_impl == "fa3" or attn_impl == "fa3_fp8":
288+
if not flash_attn3_compatible:
275289
raise RuntimeError(
276290
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
277291
)
278-
if attn_impl == "flash_attn_2":
292+
if attn_impl == "fa3":
293+
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
294+
295+
origin_dtype = q.dtype
296+
q = q.to(dtype=DTYPE_FP8)
297+
k = k.to(dtype=DTYPE_FP8)
298+
v = v.to(dtype=DTYPE_FP8)
299+
out = LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
300+
return out.to(dtype=origin_dtype)
301+
if attn_impl == "fa2":
279302
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
280303
if attn_impl == "sdpa":
281304
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
282-
if attn_impl == "sage_attn":
283-
return LongContextAttention(attn_type=AttnType.SAGE_FP8)(q, k, v, softmax_scale=scale)
284-
if attn_impl == "sparge_attn":
305+
if attn_impl == "sage":
306+
return LongContextAttention(attn_type=AttnType.SAGE_AUTO)(q, k, v, softmax_scale=scale)
307+
if attn_impl == "sparge":
285308
attn_processor = SparseAttentionMeansim()
286309
# default args from spas_sage2_attn_meansim_cuda
287310
attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))

diffsynth_engine/pipelines/flux_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
516516

517517
with LoRAContext():
518518
attn_kwargs = {
519-
"attn_impl": config.dit_attn_impl,
519+
"attn_impl": config.dit_attn_impl.value,
520520
"sparge_smooth_k": config.sparge_smooth_k,
521521
"sparge_cdfthreshd": config.sparge_cdfthreshd,
522522
"sparge_simthreshd1": config.sparge_simthreshd1,

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
201201

202202
with LoRAContext():
203203
attn_kwargs = {
204-
"attn_impl": config.dit_attn_impl,
204+
"attn_impl": config.dit_attn_impl.value,
205205
"sparge_smooth_k": config.sparge_smooth_k,
206206
"sparge_cdfthreshd": config.sparge_cdfthreshd,
207207
"sparge_simthreshd1": config.sparge_simthreshd1,

diffsynth_engine/pipelines/wan_s2v.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,15 @@ def encode_ref_and_motion(
239239

240240
return ref_latents, motion_latents, motion_frames
241241

242-
def encode_pose(self, pose_video: List[Image.Image], pose_video_fps: int, num_clips: int, num_frames_per_clip: int, height: int, width: int):
242+
def encode_pose(
243+
self,
244+
pose_video: List[Image.Image],
245+
pose_video_fps: int,
246+
num_clips: int,
247+
num_frames_per_clip: int,
248+
height: int,
249+
width: int,
250+
):
243251
self.load_models_to_device(["vae"])
244252
max_num_pose_frames = num_frames_per_clip * num_clips
245253
pose_video = read_n_frames(pose_video, pose_video_fps, max_num_pose_frames, target_fps=self.config.fps)
@@ -466,7 +474,9 @@ def __call__(
466474
dtype=self.dtype,
467475
).to(self.device)
468476
if pose_video is not None:
469-
pose_latents_all_clips = self.encode_pose(pose_video, pose_video_fps, num_clips, num_frames_per_clip, height, width)
477+
pose_latents_all_clips = self.encode_pose(
478+
pose_video, pose_video_fps, num_clips, num_frames_per_clip, height, width
479+
)
470480

471481
output_frames_all_clips = []
472482
for clip_idx in range(num_clips):
@@ -602,7 +612,9 @@ def from_pretrained(cls, model_path_or_config: WanSpeech2VideoPipelineConfig) ->
602612
return cls.from_state_dict(state_dicts, config)
603613

604614
@classmethod
605-
def from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
615+
def from_state_dict(
616+
cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig
617+
) -> "WanSpeech2VideoPipeline":
606618
if config.parallelism > 1:
607619
pipe = ParallelWrapper(
608620
cfg_degree=config.cfg_degree,
@@ -617,7 +629,9 @@ def from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoP
617629
return pipe
618630

619631
@classmethod
620-
def _from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
632+
def _from_state_dict(
633+
cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig
634+
) -> "WanSpeech2VideoPipeline":
621635
# default params from model config
622636
vae_type = "wan2.1-vae"
623637
dit_type = "wan2.2-s2v-14b"
@@ -632,14 +646,16 @@ def _from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2Video
632646
init_device = "cpu" if config.offload_mode is not None else config.device
633647
tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
634648
text_encoder = WanTextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype)
635-
vae = WanVideoVAE.from_state_dict(state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype)
649+
vae = WanVideoVAE.from_state_dict(
650+
state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype
651+
)
636652
audio_encoder = Wav2Vec2Model.from_state_dict(
637653
state_dicts.audio_encoder, config=Wav2Vec2Config(), device=init_device, dtype=config.audio_encoder_dtype
638654
)
639655

640656
with LoRAContext():
641657
attn_kwargs = {
642-
"attn_impl": config.dit_attn_impl,
658+
"attn_impl": config.dit_attn_impl.value,
643659
"sparge_smooth_k": config.sparge_smooth_k,
644660
"sparge_cdfthreshd": config.sparge_cdfthreshd,
645661
"sparge_simthreshd1": config.sparge_simthreshd1,

diffsynth_engine/pipelines/wan_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig)
557557

558558
with LoRAContext():
559559
attn_kwargs = {
560-
"attn_impl": config.dit_attn_impl,
560+
"attn_impl": config.dit_attn_impl.value,
561561
"sparge_smooth_k": config.sparge_smooth_k,
562562
"sparge_cdfthreshd": config.sparge_cdfthreshd,
563563
"sparge_simthreshd1": config.sparge_simthreshd1,

0 commit comments

Comments
 (0)