@@ -126,15 +126,18 @@ def __init__(
126126 hidden_features : int = None ,
127127 act_layer : Type [nn .Module ] = QuickGELU ,
128128 quant_config : Optional [QuantizationConfig ] = None ,
129+ prefix : str = "" ,
129130 ):
130131 super ().__init__ ()
131132 self .fc1 = ColumnParallelLinear (in_features ,
132133 hidden_features ,
133- quant_config = quant_config )
134+ quant_config = quant_config ,
135+ prefix = f"{ prefix } .fc1" )
134136 self .act = act_layer ()
135137 self .fc2 = RowParallelLinear (hidden_features ,
136138 in_features ,
137- quant_config = quant_config )
139+ quant_config = quant_config ,
140+ prefix = f"{ prefix } .fc2" )
138141
139142 def forward (self , x : torch .Tensor ) -> torch .Tensor :
140143 x_parallel , _ = self .fc1 (x )
@@ -196,6 +199,7 @@ def __init__(
196199 num_heads : Optional [int ] = None ,
197200 projection_size : Optional [int ] = None ,
198201 quant_config : Optional [QuantizationConfig ] = None ,
202+ prefix : str = "" ,
199203 ) -> None :
200204 super ().__init__ ()
201205 # Per attention head and per partition values.
@@ -207,10 +211,12 @@ def __init__(
207211
208212 self .qkv = ColumnParallelLinear (input_size = embed_dim ,
209213 output_size = 3 * projection_size ,
210- quant_config = quant_config )
214+ quant_config = quant_config ,
215+ prefix = f"{ prefix } .qkv" )
211216 self .proj = RowParallelLinear (input_size = projection_size ,
212217 output_size = embed_dim ,
213- quant_config = quant_config )
218+ quant_config = quant_config ,
219+ prefix = f"{ prefix } .proj" )
214220
215221 # Detect attention implementation.
216222 self .attn_backend : _Backend = get_vit_attn_backend ()
@@ -310,6 +316,7 @@ def __init__(
310316 act_layer : Type [nn .Module ] = QuickGELU ,
311317 norm_layer : Type [nn .Module ] = None ,
312318 quant_config : Optional [QuantizationConfig ] = None ,
319+ prefix : str = "" ,
313320 ) -> None :
314321 super ().__init__ ()
315322 if norm_layer is None :
@@ -321,11 +328,13 @@ def __init__(
321328 self .attn = Qwen2VisionAttention (embed_dim = dim ,
322329 num_heads = num_heads ,
323330 projection_size = dim ,
324- quant_config = quant_config )
331+ quant_config = quant_config ,
332+ prefix = f"{ prefix } .attn" )
325333 self .mlp = Qwen2VisionMLP (dim ,
326334 mlp_hidden_dim ,
327335 act_layer = act_layer ,
328- quant_config = quant_config )
336+ quant_config = quant_config ,
337+ prefix = f"{ prefix } .mlp" )
329338
330339 def forward (self , x : torch .Tensor , cu_seqlens : torch .Tensor ,
331340 rotary_pos_emb : torch .Tensor ) -> torch .Tensor :
@@ -374,6 +383,7 @@ def __init__(
374383 norm_layer : Type [nn .Module ] = None ,
375384 spatial_merge_size : int = 2 ,
376385 quant_config : Optional [QuantizationConfig ] = None ,
386+ prefix : str = "" ,
377387 ) -> None :
378388 super ().__init__ ()
379389 self .hidden_size = context_dim * (spatial_merge_size ** 2 )
@@ -384,12 +394,14 @@ def __init__(
384394 ColumnParallelLinear (self .hidden_size ,
385395 self .hidden_size ,
386396 bias = True ,
387- quant_config = quant_config ),
397+ quant_config = quant_config ,
398+ prefix = f"{ prefix } .mlp.0" ),
388399 nn .GELU (),
389400 RowParallelLinear (self .hidden_size ,
390401 d_model ,
391402 bias = True ,
392- quant_config = quant_config ),
403+ quant_config = quant_config ,
404+ prefix = f"{ prefix } .mlp.2" ),
393405 ])
394406
395407 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -440,6 +452,7 @@ def __init__(
440452 vision_config : Qwen2VLVisionConfig ,
441453 norm_eps : float = 1e-6 ,
442454 quant_config : Optional [QuantizationConfig ] = None ,
455+ prefix : str = "" ,
443456 ) -> None :
444457 super ().__init__ ()
445458
@@ -467,28 +480,29 @@ def __init__(
467480 self .rotary_pos_emb = Qwen2VisionRotaryEmbedding (head_dim // 2 )
468481
469482 self .blocks = nn .ModuleList ([
470- Qwen2VisionBlock (
471- dim = embed_dim ,
472- num_heads = num_heads ,
473- mlp_ratio = mlp_ratio ,
474- norm_layer = norm_layer ,
475- quant_config = quant_config ,
476- ) for _ in range (depth )
483+ Qwen2VisionBlock (dim = embed_dim ,
484+ num_heads = num_heads ,
485+ mlp_ratio = mlp_ratio ,
486+ norm_layer = norm_layer ,
487+ quant_config = quant_config ,
488+ prefix = f" { prefix } .blocks. { layer_idx } " )
489+ for layer_idx in range (depth )
477490 ])
478491 self .merger = Qwen2VisionPatchMerger (
479492 d_model = hidden_size ,
480493 context_dim = embed_dim ,
481494 norm_layer = norm_layer ,
482495 quant_config = quant_config ,
496+ prefix = f"{ prefix } .merger" ,
483497 )
484498
485499 @property
486500 def dtype (self ) -> torch .dtype :
487- return self .blocks [ 0 ]. mlp . fc2 .weight .dtype
501+ return self .patch_embed . proj .weight .dtype
488502
489503 @property
490504 def device (self ) -> torch .device :
491- return self .blocks [ 0 ]. mlp . fc2 .weight .device
505+ return self .patch_embed . proj .weight .device
492506
493507 def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
494508 pos_ids = []
@@ -932,10 +946,8 @@ def __init__(self,
932946 self .visual = Qwen2VisionTransformer (
933947 config .vision_config ,
934948 norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
935-
936- # NOTE: Qwen2-VL vision encoder does not support any
937- # quantization method now.
938- quant_config = None ,
949+ quant_config = quant_config ,
950+ prefix = "visual" ,
939951 )
940952
941953 self .model = Qwen2Model (config ,
@@ -1175,7 +1187,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
11751187 weight_loader (param , loaded_weight , shard_id )
11761188 break
11771189 else :
1178- if "visual" in name and "qkv.weight" in name :
1190+ if "visual" in name and name . endswith ( "qkv.weight" ) :
11791191 visual_num_heads = self .config .vision_config .num_heads
11801192 visual_embed_dim = self .config .vision_config .embed_dim
11811193 head_size = visual_embed_dim // visual_num_heads
@@ -1184,7 +1196,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
11841196 visual_embed_dim )
11851197 loaded_weight = loaded_weight .transpose (0 , 1 )
11861198 loaded_weight = loaded_weight .reshape (- 1 , visual_embed_dim )
1187- elif "visual" in name and "qkv.bias" in name :
1199+ elif "visual" in name and name . endswith ( "qkv.bias" ) :
11881200 visual_num_heads = self .config .vision_config .num_heads
11891201 visual_embed_dim = self .config .vision_config .embed_dim
11901202 head_size = visual_embed_dim // visual_num_heads
0 commit comments