Skip to content

Commit d087bf8

Browse files
authored
[Model] Support quantization of Qwen2VisionTransformer (vllm-project#9817)
Signed-off-by: mgoin <[email protected]>
1 parent 890ca36 commit d087bf8

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

vllm/model_executor/models/qwen2_vl.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,18 @@ def __init__(
126126
hidden_features: int = None,
127127
act_layer: Type[nn.Module] = QuickGELU,
128128
quant_config: Optional[QuantizationConfig] = None,
129+
prefix: str = "",
129130
):
130131
super().__init__()
131132
self.fc1 = ColumnParallelLinear(in_features,
132133
hidden_features,
133-
quant_config=quant_config)
134+
quant_config=quant_config,
135+
prefix=f"{prefix}.fc1")
134136
self.act = act_layer()
135137
self.fc2 = RowParallelLinear(hidden_features,
136138
in_features,
137-
quant_config=quant_config)
139+
quant_config=quant_config,
140+
prefix=f"{prefix}.fc2")
138141

139142
def forward(self, x: torch.Tensor) -> torch.Tensor:
140143
x_parallel, _ = self.fc1(x)
@@ -196,6 +199,7 @@ def __init__(
196199
num_heads: Optional[int] = None,
197200
projection_size: Optional[int] = None,
198201
quant_config: Optional[QuantizationConfig] = None,
202+
prefix: str = "",
199203
) -> None:
200204
super().__init__()
201205
# Per attention head and per partition values.
@@ -207,10 +211,12 @@ def __init__(
207211

208212
self.qkv = ColumnParallelLinear(input_size=embed_dim,
209213
output_size=3 * projection_size,
210-
quant_config=quant_config)
214+
quant_config=quant_config,
215+
prefix=f"{prefix}.qkv")
211216
self.proj = RowParallelLinear(input_size=projection_size,
212217
output_size=embed_dim,
213-
quant_config=quant_config)
218+
quant_config=quant_config,
219+
prefix=f"{prefix}.proj")
214220

215221
# Detect attention implementation.
216222
self.attn_backend: _Backend = get_vit_attn_backend()
@@ -310,6 +316,7 @@ def __init__(
310316
act_layer: Type[nn.Module] = QuickGELU,
311317
norm_layer: Type[nn.Module] = None,
312318
quant_config: Optional[QuantizationConfig] = None,
319+
prefix: str = "",
313320
) -> None:
314321
super().__init__()
315322
if norm_layer is None:
@@ -321,11 +328,13 @@ def __init__(
321328
self.attn = Qwen2VisionAttention(embed_dim=dim,
322329
num_heads=num_heads,
323330
projection_size=dim,
324-
quant_config=quant_config)
331+
quant_config=quant_config,
332+
prefix=f"{prefix}.attn")
325333
self.mlp = Qwen2VisionMLP(dim,
326334
mlp_hidden_dim,
327335
act_layer=act_layer,
328-
quant_config=quant_config)
336+
quant_config=quant_config,
337+
prefix=f"{prefix}.mlp")
329338

330339
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
331340
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
@@ -374,6 +383,7 @@ def __init__(
374383
norm_layer: Type[nn.Module] = None,
375384
spatial_merge_size: int = 2,
376385
quant_config: Optional[QuantizationConfig] = None,
386+
prefix: str = "",
377387
) -> None:
378388
super().__init__()
379389
self.hidden_size = context_dim * (spatial_merge_size**2)
@@ -384,12 +394,14 @@ def __init__(
384394
ColumnParallelLinear(self.hidden_size,
385395
self.hidden_size,
386396
bias=True,
387-
quant_config=quant_config),
397+
quant_config=quant_config,
398+
prefix=f"{prefix}.mlp.0"),
388399
nn.GELU(),
389400
RowParallelLinear(self.hidden_size,
390401
d_model,
391402
bias=True,
392-
quant_config=quant_config),
403+
quant_config=quant_config,
404+
prefix=f"{prefix}.mlp.2"),
393405
])
394406

395407
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -440,6 +452,7 @@ def __init__(
440452
vision_config: Qwen2VLVisionConfig,
441453
norm_eps: float = 1e-6,
442454
quant_config: Optional[QuantizationConfig] = None,
455+
prefix: str = "",
443456
) -> None:
444457
super().__init__()
445458

@@ -467,28 +480,29 @@ def __init__(
467480
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
468481

469482
self.blocks = nn.ModuleList([
470-
Qwen2VisionBlock(
471-
dim=embed_dim,
472-
num_heads=num_heads,
473-
mlp_ratio=mlp_ratio,
474-
norm_layer=norm_layer,
475-
quant_config=quant_config,
476-
) for _ in range(depth)
483+
Qwen2VisionBlock(dim=embed_dim,
484+
num_heads=num_heads,
485+
mlp_ratio=mlp_ratio,
486+
norm_layer=norm_layer,
487+
quant_config=quant_config,
488+
prefix=f"{prefix}.blocks.{layer_idx}")
489+
for layer_idx in range(depth)
477490
])
478491
self.merger = Qwen2VisionPatchMerger(
479492
d_model=hidden_size,
480493
context_dim=embed_dim,
481494
norm_layer=norm_layer,
482495
quant_config=quant_config,
496+
prefix=f"{prefix}.merger",
483497
)
484498

485499
@property
486500
def dtype(self) -> torch.dtype:
487-
return self.blocks[0].mlp.fc2.weight.dtype
501+
return self.patch_embed.proj.weight.dtype
488502

489503
@property
490504
def device(self) -> torch.device:
491-
return self.blocks[0].mlp.fc2.weight.device
505+
return self.patch_embed.proj.weight.device
492506

493507
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
494508
pos_ids = []
@@ -932,10 +946,8 @@ def __init__(self,
932946
self.visual = Qwen2VisionTransformer(
933947
config.vision_config,
934948
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
935-
936-
# NOTE: Qwen2-VL vision encoder does not support any
937-
# quantization method now.
938-
quant_config=None,
949+
quant_config=quant_config,
950+
prefix="visual",
939951
)
940952

941953
self.model = Qwen2Model(config,
@@ -1175,7 +1187,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
11751187
weight_loader(param, loaded_weight, shard_id)
11761188
break
11771189
else:
1178-
if "visual" in name and "qkv.weight" in name:
1190+
if "visual" in name and name.endswith("qkv.weight"):
11791191
visual_num_heads = self.config.vision_config.num_heads
11801192
visual_embed_dim = self.config.vision_config.embed_dim
11811193
head_size = visual_embed_dim // visual_num_heads
@@ -1184,7 +1196,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
11841196
visual_embed_dim)
11851197
loaded_weight = loaded_weight.transpose(0, 1)
11861198
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
1187-
elif "visual" in name and "qkv.bias" in name:
1199+
elif "visual" in name and name.endswith("qkv.bias"):
11881200
visual_num_heads = self.config.vision_config.num_heads
11891201
visual_embed_dim = self.config.vision_config.embed_dim
11901202
head_size = visual_embed_dim // visual_num_heads

0 commit comments

Comments
 (0)