diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 68dd07820189..5a67261d827a 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -249,7 +249,7 @@ def __init__( # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.FLASH_ATTN_VLLM_V1, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." @@ -307,6 +307,24 @@ def forward( q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1: + from vllm.vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, k, v, @@ -638,7 +656,7 @@ def compute_attn_mask_seqlen( cu_seqlens: torch.Tensor, ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend in [_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1]: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0ff0836b0897..4db924b09975 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -329,6 +329,24 @@ def forward( q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1: + from vllm.vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, k, v, @@ -618,7 +636,7 @@ def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend == [_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1]: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 901d83ec5b9e..edba12276627 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -83,16 +83,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: if current_platform.is_cuda(): device_available = current_platform.has_device_capability(80) if device_available and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): - selected_backend = _Backend.FLASH_ATTN - else: - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision " - "module, so we use xformers backend instead. You can " - "run `pip install flash-attn` to use flash-attention " - "backend.") - selected_backend = _Backend.XFORMERS + selected_backend = _Backend.FLASH_ATTN else: # For Volta and Turing GPUs, use xformers instead. selected_backend = _Backend.XFORMERS