Skip to content

Commit 96e82eb

Browse files
lvhan028grimoire
andauthored
Refactor VLM modules (#2810)
* refactor VL modules for internvl and qwen2-vl (#2764) * qwen2-vl * internvl * qwen2 * Refactor VL modules for glm4v, deepseek-vl, llava-hf, cogvlm (#2772) * qwen2-vl * internvl * qwen2 * get image_tokens_per_patch for internvl2 * deepseek-vl * cogvlm * glm4v * update internvl * internvl_llava * llava * glm4v * upate internvl * cogvlm * deepseek * llava_hf * rollback llava, internvl-llava * Refactor VL modules for qwen-vl, llava and llava_next (#2773) * qwen2-vl * internvl * qwen2 * get image_tokens_per_patch for internvl2 * deepseek-vl * cogvlm * glm4v * update internvl * internvl_llava * llava * glm4v * upate internvl * cogvlm * deepseek * llava_hf * rollback llava, internvl-llava * refactor qwen * update internvl * update llava_hf * update qwen2-vl * llava_next * update llava_next * update llava * update llava * update llava * Refactor VL modules for qwen2-vl (#2777) * qwen2-vl * internvl * qwen2 * get image_tokens_per_patch for internvl2 * deepseek-vl * cogvlm * glm4v * update internvl * internvl_llava * llava * glm4v * upate internvl * cogvlm * deepseek * llava_hf * rollback llava, internvl-llava * refactor qwen * update internvl * update llava_hf * update qwen2-vl * llava_next * update llava_next * update llava * update llava * update llava * qwen2 * Fix side-effect to internvl (#2778) * qwen2-vl * internvl * qwen2 * get image_tokens_per_patch for internvl2 * deepseek-vl * cogvlm * glm4v * update internvl * internvl_llava * llava * glm4v * upate internvl * cogvlm * deepseek * llava_hf * rollback llava, internvl-llava * refactor qwen * update internvl * update llava_hf * update qwen2-vl * llava_next * update llava_next * update llava * update llava * update llava * qwen2 * fix internvl * Refactor VL modules for phi3-vision (#2779) * qwen2-vl * internvl * qwen2 * get image_tokens_per_patch for internvl2 * deepseek-vl * cogvlm * glm4v * update internvl * internvl_llava * llava * glm4v * upate internvl * cogvlm * deepseek * llava_hf * rollback llava, internvl-llava * refactor qwen * update internvl * update llava_hf * update qwen2-vl * llava_next * update llava_next * update llava * update llava * update llava * qwen2 * fix internvl * phi3-vision * Refactor VL modules for mllama and yi-vl (#2781) * qwen2-vl * internvl * qwen2 * get image_tokens_per_patch for internvl2 * deepseek-vl * cogvlm * glm4v * update internvl * internvl_llava * llava * glm4v * upate internvl * cogvlm * deepseek * llava_hf * rollback llava, internvl-llava * refactor qwen * update internvl * update llava_hf * update qwen2-vl * llava_next * update llava_next * update llava * update llava * update llava * qwen2 * fix internvl * phi3-vision * refactor yi-vl * refactor mllama * Refactor VLM module for minicpm and molmo (#2794) * Refactor VLM modules for xcomposer series (#2796) * Refactor VLM modules for internvl-llava (#2797) * Refactor VLM modules v2 (#2806) * internvl2 v2 * cogvlm * deepseek-vl * glm-4v * llava-hf * llava-next * llava * internvl-llava * mllama * phi3-vision * qwen * qwen2 * yi-vl * xcomposer * minicpm * molmo * update * update * Remove vl template (#2809) * Resolve conflicts (#2811) * feature: support qwen2.5 fuction_call (#2737) * feat: support qwen2.5 tools_call * fix: npe bug * fix: 模版不一致 * fix: adopting review suggestions * fix: adopting review suggestions * fix: adopting review suggestions * fix: adopting review suggestions * feat: Support multi tools calling * feat: Support multi tools calling * fix: Add '\n' between each tool * fix: Add ensure_ascii=False * bugfix: rfind * bugfix: tools_call -> tool_calls * bugfix: add toolName in tool_response * fix: some '\n' error * fix: remove toolname * fix: replace '\n' to self.separator * feat: add doc with multiple tool calling * fix:update doc * feat: add qwen2.5 prompt template test * feat: add qwen2.5 no tool call prompt test --------- Co-authored-by: gaozixiang <[email protected]> * Update supported models & Ascend doc (#2765) * update ascend supported model list * fix markdown * fix markdown * fix lint * Update get_started.md * Update get_started.md * [CI] Split vl testcases into turbomind and pytorch backend (#2751) * updaet * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * [Feature] support minicpm-v_2_6 for pytorch engine. (#2767) * support minicpmv_2_6. * update supported_models. * update supported_models. * Support qwen2-vl AWQ quantization (#2787) * Support qwen2-vl AWQ quantization * Update config.yaml --------- Co-authored-by: zhulinJulia24 <[email protected]> * [dlinfer] Fix qwenvl rope error for dlinfer backend (#2795) * Optimize update_step_ctx on Ascend (#2804) * opt update_ctx for ascend * fix lint --------- Co-authored-by: 逝夜长歌 <[email protected]> Co-authored-by: gaozixiang <[email protected]> Co-authored-by: jinminxi104 <[email protected]> Co-authored-by: zhulinJulia24 <[email protected]> Co-authored-by: zhoushenglong <[email protected]> Co-authored-by: AllentDan <[email protected]> Co-authored-by: Wei Tao <[email protected]> * PytorchEngine refactor multimodal (#2742) * WIP * support mrope * support long context * support causal=false * fix mask * flash attn bound * optimize * Moskau, Moskau, wirf die Gläser an die Wand * YMCA * optimize mllama * update processor * support cogvlm * all work and no play make jack a dull boy * upgrade triton * support qwen2vl * support internvl * phi3-v WIP * glm4v WIP * support chatglm and cogvlm * use image tokens * support llava * support internvl-mono * phi3v, mllama * add llavanext * use img token ids * support multiimage chatglm cogvlm * fix ut * minor-fix * minor-fix (#2813) * fix * fix mono * fix docs * read norm_type * super().collect_images->self.collect_images * add note in supported models * define the parameters clearly * better streaming * fix molmo * Fix vision model batch inference (#2868) * remove forward from vl models that are not supported by tm * support max_batch_size * fix * warn glm4v does not support multi images * unconst * fix deepseek-vl * fix internvl * fix llava * fix minicpm 2.6 * fix callback * fix minicpm v2.5 * fix minicpm v2.6 * update llava_next.py * remove hardcode from xcomposer2.py * rollback supported_models * change to staticmethod * fix vlm quantization * update doc * update --------- Co-authored-by: q yao <[email protected]>
1 parent 422b9f2 commit 96e82eb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+7597
-3251
lines changed

docs/en/multi_modal/llava.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@ LMDeploy supports the following llava series of models, which are detailed in th
66
| :----------------------------------: | :--: | :------------------------: |
77
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
88
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
9-
| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch |
10-
| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch |
9+
| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
10+
| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
11+
| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
12+
| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
1113

1214
The next chapter demonstrates how to deploy an Llava model using LMDeploy, with [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) as an example.
1315

16+
```{note}
17+
PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf
18+
```
19+
1420
## Installation
1521

1622
Please install LMDeploy by following the [installation guide](../get_started/installation.md).

docs/en/multi_modal/qwen2_vl.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ LMDeploy supports the following Qwen-VL series of models, which are detailed in
44

55
| Model | Size | Supported Inference Engine |
66
| :----------: | :----: | :------------------------: |
7-
| Qwen-VL-Chat | - | TurboMind, Pytorch |
7+
| Qwen-VL-Chat | - | TurboMind |
88
| Qwen2-VL | 2B, 7B | PyTorch |
99

1010
The next chapter demonstrates how to deploy an Qwen-VL model using LMDeploy, with [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) as an example.

docs/zh_cn/multi_modal/llava.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@ LMDeploy 支持以下 LLaVA 系列模型,具体如下表所示:
66
| :----------------------------------: | :--: | :----------------: |
77
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
88
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
9-
| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch |
10-
| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch |
9+
| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
10+
| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
11+
| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
12+
| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
1113

1214
接下来的章节将演示如何使用 LMDeploy 部署 LLaVA 模型,并以 [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) 为例。
1315

16+
```{note}
17+
自 0.6.4 之后,PyTorch 引擎移除了对 llava 原始模型的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到
18+
```
19+
1420
## 安装
1521

1622
请按照[安装指南](../get_started/installation.md)安装 LMDeploy。

docs/zh_cn/multi_modal/qwen2_vl.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ LMDeploy 支持 Qwen-VL 系列模型,具体如下:
44

55
| Model | Size | Supported Inference Engine |
66
| :----------: | :----: | :------------------------: |
7-
| Qwen-VL-Chat | - | TurboMind, Pytorch |
7+
| Qwen-VL-Chat | - | TurboMind |
88
| Qwen2-VL | 2B, 7B | PyTorch |
99

1010
本文将以[Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)为例,演示使用 LMDeploy 部署 Qwen2-VL 系列模型的方法

lmdeploy/lite/apis/calibrate.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -239,20 +239,23 @@ def calibrate(model: str,
239239

240240
model_type, _ = get_task(model)
241241
make_compatible_internvl_config(model)
242-
if model_type == 'llm':
243-
# Load tokenizer and configuration
244-
tokenizer = AutoTokenizer.from_pretrained(model,
245-
trust_remote_code=True)
246-
247-
model = load_hf_from_pretrained(model,
248-
torch_dtype=torch.float16,
249-
trust_remote_code=True)
250-
vl_model = None
251-
elif model_type == 'vlm':
252-
from lmdeploy.vl.model.builder import vl_model_with_tokenizer
253-
vl_model, model, tokenizer = vl_model_with_tokenizer(model_path=model)
254-
255-
model.config.use_cache = False
242+
243+
# Load tokenizer and configuration
244+
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
245+
246+
model = load_hf_from_pretrained(model,
247+
torch_dtype=torch.float16,
248+
trust_remote_code=True)
249+
vl_model = None
250+
if model_type == 'vlm':
251+
vl_model = model
252+
if hasattr(model, 'language_model'):
253+
model = model.language_model
254+
if hasattr(model, 'llm'):
255+
model = model.llm
256+
model.config.use_cache = False
257+
model = model.half().eval()
258+
256259
model_type = type(model).__name__
257260
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
258261
raise RuntimeError(

lmdeploy/pytorch/backends/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
alibi: bool = None,
3535
sliding_window: int = None,
3636
logit_softcapping: float = None,
37+
causal: bool = True,
3738
**kwargs,
3839
) -> None:
3940
if scale is None:
@@ -53,6 +54,7 @@ def __init__(
5354
self.alibi = alibi
5455
self.sliding_window = sliding_window
5556
self.logit_softcapping = logit_softcapping
57+
self.causal = causal
5658

5759
@abstractmethod
5860
def forward(
@@ -82,6 +84,7 @@ def build(
8284
alibi: bool = False,
8385
sliding_window: int = None,
8486
logical_softcapping: float = None,
87+
causal: bool = True,
8588
**kwargs,
8689
) -> AttentionImpl[T]:
8790
"""build."""

lmdeploy/pytorch/backends/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
class OpType(Enum):
1414
"""Layer type enumerate."""
15-
Attention = auto()
15+
PagedAttention = auto()
16+
FlashAttention = auto()
1617
Linear = auto()
1718
RotaryEmbedding = auto()
1819
ApplyRotaryEmb = auto()

lmdeploy/pytorch/backends/cuda/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
alibi: bool = False,
4242
sliding_window: int = None,
4343
logit_softcapping: float = None,
44+
causal: bool = True,
4445
**kwargs,
4546
):
4647
super().__init__(
@@ -52,8 +53,10 @@ def __init__(
5253
alibi=alibi,
5354
sliding_window=sliding_window,
5455
logit_softcapping=logit_softcapping,
56+
causal=causal,
5557
**kwargs,
5658
)
59+
assert not (alibi and not causal)
5760

5861
from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd,
5962
fill_kv_cache,
@@ -172,6 +175,7 @@ def forward(
172175
window_size=self.sliding_window,
173176
sm_scale=self.scale,
174177
logit_softcapping=self.logit_softcapping,
178+
causal=self.causal,
175179
)
176180
else:
177181
self.alibi_paged_attention_fwd(
@@ -207,6 +211,7 @@ def build(
207211
alibi: bool = False,
208212
sliding_window: int = None,
209213
logical_softcapping: float = None,
214+
causal: bool = True,
210215
**kwargs,
211216
) -> TritonAttentionImpl:
212217
"""build."""
@@ -218,4 +223,5 @@ def build(
218223
alibi=alibi,
219224
sliding_window=sliding_window,
220225
logical_softcapping=logical_softcapping,
226+
causal=causal,
221227
**kwargs)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from torch import Tensor
3+
4+
from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl
5+
6+
7+
class TritonFlashAttentionImpl(FlashAttentionImpl):
8+
"""triton flash attention implementation."""
9+
10+
def __init__(
11+
self,
12+
num_heads: int,
13+
head_dim: int,
14+
scale: float = None,
15+
num_kv_heads: int = None,
16+
v_head_dim: int = None,
17+
causal: bool = True,
18+
sliding_window: int = None,
19+
logical_softcapping: float = None,
20+
):
21+
if scale is None:
22+
scale = 1.0 / (head_dim**0.5)
23+
24+
if num_kv_heads is None:
25+
num_kv_heads = num_heads
26+
27+
if v_head_dim is None:
28+
v_head_dim = head_dim
29+
30+
self.num_heads = num_heads
31+
self.head_dim = head_dim
32+
self.scale = scale
33+
self.num_kv_heads = num_kv_heads
34+
self.v_head_dim = v_head_dim
35+
self.causal = causal
36+
self.sliding_window = sliding_window
37+
self.logical_softcapping = logical_softcapping
38+
39+
from lmdeploy.pytorch.kernels.cuda import flash_attention_fwd
40+
self.flash_attention_fwd = flash_attention_fwd
41+
42+
def forward(self,
43+
query: Tensor,
44+
key: Tensor,
45+
value: Tensor,
46+
q_start_loc: Tensor,
47+
q_seqlens: Tensor,
48+
kv_start_loc: Tensor,
49+
kv_seqlens: Tensor,
50+
max_q_seqlen: int = None):
51+
"""forward."""
52+
53+
q_shape = query.shape
54+
o_shape = q_shape[:-1] + (self.v_head_dim, )
55+
out = query.new_empty(o_shape)
56+
self.flash_attention_fwd(
57+
query,
58+
key,
59+
value,
60+
out,
61+
q_start_loc=q_start_loc,
62+
q_seqlens=q_seqlens,
63+
kv_start_loc=kv_start_loc,
64+
kv_seqlens=kv_seqlens,
65+
max_seqlen=max_q_seqlen,
66+
window_size=self.sliding_window,
67+
sm_scale=self.scale,
68+
logit_softcapping=self.logical_softcapping,
69+
causal=self.causal,
70+
kv_layout='shd',
71+
)
72+
73+
return out
74+
75+
76+
class TritonFlashAttentionBuilder(FlashAttentionBuilder):
77+
"""triton attention builder."""
78+
79+
@staticmethod
80+
def build(
81+
num_heads: int,
82+
head_dim: int,
83+
scale: float = None,
84+
num_kv_heads: int = None,
85+
v_head_dim: int = None,
86+
causal: bool = True,
87+
sliding_window: int = None,
88+
logical_softcapping: float = None,
89+
**kwargs,
90+
) -> FlashAttentionImpl:
91+
"""build."""
92+
return TritonFlashAttentionImpl(
93+
num_heads=num_heads,
94+
head_dim=head_dim,
95+
scale=scale,
96+
num_kv_heads=num_kv_heads,
97+
v_head_dim=v_head_dim,
98+
causal=causal,
99+
sliding_window=sliding_window,
100+
logical_softcapping=logical_softcapping,
101+
)

lmdeploy/pytorch/backends/cuda/op_backend.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ def get_name() -> str:
2323
@classmethod
2424
def get_layer_impl_builder(cls, layer_type: OpType):
2525
"""get cuda layer builder."""
26-
if layer_type == OpType.Attention:
26+
if layer_type == OpType.PagedAttention:
2727
from .attention import TritonAttentionBuilder
2828
return TritonAttentionBuilder
29+
elif layer_type == OpType.FlashAttention:
30+
from .flash_attention import TritonFlashAttentionBuilder
31+
return TritonFlashAttentionBuilder
2932
elif layer_type == OpType.ApplyRotaryEmb:
3033
from .apply_rotary_emb import TritonApplyRotaryEmbBuilder
3134
return TritonApplyRotaryEmbBuilder
@@ -125,30 +128,30 @@ def update_step_context(cls, step_context):
125128
quant_policy=step_context.kv_quant_policy,
126129
)
127130

128-
cross_attn_metadata = None
129-
fill_seqlens = None
130-
if step_context.cross_attention_states is not None:
131-
fill_seqlens = torch.zeros_like(q_seqlens)
132-
for idx, state in enumerate(step_context.cross_attention_states):
133-
if state is not None:
134-
fill_seqlens[idx] = state.shape[-2]
131+
cross_seqlens = step_context.cross_seqlens
135132
cross_kv_seqlens = step_context.cross_kv_seqlens
136-
cross_kv_start_loc = None
137-
cross_kv_flatten_size = None
138-
if not step_context.is_decoding and cross_kv_seqlens is not None:
139-
cross_kv_start_loc = cross_kv_seqlens.cumsum(0) - cross_kv_seqlens
140-
cross_kv_flatten_size = cross_kv_seqlens.sum().item()
141-
cross_attn_metadata = attn_meta_cls(
142-
step_context.is_decoding,
143-
step_context.block_offsets,
144-
q_start_loc=q_start_loc,
145-
q_seqlens=q_seqlens,
146-
kv_start_loc=cross_kv_start_loc,
147-
kv_seqlens=cross_kv_seqlens,
148-
kv_flatten_size=cross_kv_flatten_size,
149-
fill_seqlens=fill_seqlens,
150-
quant_policy=step_context.kv_quant_policy,
151-
)
133+
cross_attn_metadata = None
134+
if cross_seqlens is not None:
135+
fill_seqlens = cross_seqlens
136+
if fill_seqlens.sum().item() == 0:
137+
fill_seqlens = None
138+
cross_kv_start_loc = None
139+
cross_kv_flatten_size = None
140+
if not step_context.is_decoding and cross_kv_seqlens is not None:
141+
cross_kv_start_loc = cross_kv_seqlens.cumsum(
142+
0) - cross_kv_seqlens
143+
cross_kv_flatten_size = cross_kv_seqlens.sum().item()
144+
cross_attn_metadata = attn_meta_cls(
145+
step_context.is_decoding,
146+
step_context.block_offsets,
147+
q_start_loc=q_start_loc,
148+
q_seqlens=q_seqlens,
149+
kv_start_loc=cross_kv_start_loc,
150+
kv_seqlens=cross_kv_seqlens,
151+
kv_flatten_size=cross_kv_flatten_size,
152+
fill_seqlens=fill_seqlens,
153+
quant_policy=step_context.kv_quant_policy,
154+
)
152155

153156
step_context.attn_metadata = attn_metadata
154157
step_context.cross_attn_metadata = cross_attn_metadata

0 commit comments

Comments
 (0)