6565 RowParallelLinear ,
6666)
6767from vllm .model_executor .layers .quantization import QuantizationConfig
68+ from vllm .model_executor .layers .rotary_embedding import get_rope
6869from vllm .model_executor .model_loader .weight_utils import default_weight_loader
6970from vllm .model_executor .models .module_mapping import MultiModelKeys
7071from vllm .multimodal import MULTIMODAL_REGISTRY
@@ -341,7 +342,8 @@ def forward(
341342 self ,
342343 x : torch .Tensor ,
343344 cu_seqlens : torch .Tensor ,
344- rotary_pos_emb : torch .Tensor ,
345+ rotary_pos_emb_cos : torch .Tensor ,
346+ rotary_pos_emb_sin : torch .Tensor ,
345347 max_seqlen : int | None = None , # Only used for Flash Attention
346348 seqlens : list [int ] | None = None , # Only used for xFormers
347349 ) -> torch .Tensor :
@@ -353,10 +355,12 @@ def forward(
353355 batch_size = q .shape [1 ]
354356
355357 q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous () for x in (q , k , v ))
356- if rotary_pos_emb is not None :
358+ if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None :
357359 # [2 * b, s, heads, head_dim]
358360 qk_concat = torch .cat ([q , k ], dim = 0 )
359- qk_rotated = apply_rotary_pos_emb_vision (qk_concat , rotary_pos_emb )
361+ qk_rotated = apply_rotary_pos_emb_vision (
362+ qk_concat , rotary_pos_emb_cos , rotary_pos_emb_sin
363+ )
360364 q , k = torch .chunk (qk_rotated , 2 , dim = 0 )
361365
362366 if self .is_flash_attn_backend :
@@ -454,14 +458,16 @@ def forward(
454458 self ,
455459 x : torch .Tensor ,
456460 cu_seqlens : torch .Tensor ,
457- rotary_pos_emb : torch .Tensor ,
461+ rotary_pos_emb_cos : torch .Tensor ,
462+ rotary_pos_emb_sin : torch .Tensor ,
458463 max_seqlen : int | None = None , # Only used for Flash Attention
459464 seqlens : list [int ] | None = None , # Only used for xFormers
460465 ) -> torch .Tensor :
461466 x_attn = self .attn (
462467 self .norm1 (x ),
463468 cu_seqlens = cu_seqlens ,
464- rotary_pos_emb = rotary_pos_emb ,
469+ rotary_pos_emb_cos = rotary_pos_emb_cos ,
470+ rotary_pos_emb_sin = rotary_pos_emb_sin ,
465471 max_seqlen = max_seqlen ,
466472 seqlens = seqlens ,
467473 )
@@ -660,44 +666,6 @@ def forward(
660666 return embeddings
661667
662668
663- class Glm4vVisionRotaryEmbedding (nn .Module ):
664- def __init__ (self , dim : int , theta : float = 10000.0 ) -> None :
665- super ().__init__ ()
666- self .dim = dim
667- self .theta = theta
668- inv_freq = 1.0 / (theta ** (torch .arange (0 , dim , 2 , dtype = torch .float ) / dim ))
669- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
670- self ._seq_len_cached = 0
671- self ._freqs_cached = None
672-
673- def update_freqs_cache (self , seqlen : int ) -> None :
674- if seqlen > self ._seq_len_cached :
675- seqlen *= 2
676- self ._seq_len_cached = seqlen
677- self .inv_freq = 1.0 / (
678- self .theta
679- ** (
680- torch .arange (
681- 0 ,
682- self .dim ,
683- 2 ,
684- dtype = torch .float ,
685- device = self .inv_freq .device ,
686- )
687- / self .dim
688- )
689- )
690- seq = torch .arange (
691- seqlen , device = self .inv_freq .device , dtype = self .inv_freq .dtype
692- )
693- freqs = torch .outer (seq , self .inv_freq )
694- self ._freqs_cached = freqs
695-
696- def forward (self , seqlen : int ) -> torch .Tensor :
697- self .update_freqs_cache (seqlen )
698- return self ._freqs_cached [:seqlen ]
699-
700-
701669class Glm4vVisionTransformer (nn .Module ):
702670 def __init__ (
703671 self ,
@@ -731,7 +699,13 @@ def __init__(
731699
732700 norm_layer = partial (RMSNorm , eps = norm_eps )
733701 head_dim = self .hidden_size // self .num_heads
734- self .rotary_pos_emb = Glm4vVisionRotaryEmbedding (head_dim // 2 )
702+ self .rotary_pos_emb = get_rope (
703+ head_size = head_dim ,
704+ rotary_dim = head_dim // 2 ,
705+ max_position = 8192 ,
706+ base = 10000.0 ,
707+ is_neox_style = True ,
708+ )
735709 self .blocks = nn .ModuleList (
736710 [
737711 Glm4vVisionBlock (
@@ -789,7 +763,9 @@ def dtype(self) -> torch.dtype:
789763 def device (self ) -> torch .device :
790764 return self .patch_embed .proj .weight .device
791765
792- def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
766+ def rot_pos_emb (
767+ self , grid_thw : torch .Tensor
768+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
793769 pos_ids = []
794770 for t , h , w in grid_thw :
795771 hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
@@ -817,9 +793,18 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
817793 pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
818794 pos_ids = torch .cat (pos_ids , dim = 0 )
819795 max_grid_size = grid_thw [:, 1 :].max ()
820- rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
821- rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
822- return rotary_pos_emb , pos_ids
796+
797+ # Use pre-computed cos_sin_cache from RotaryEmbedding
798+ cos , sin = self .rotary_pos_emb .get_cos_sin (max_grid_size )
799+
800+ cos_h = cos [pos_ids [:, 0 ]] # (num_tokens, rotary_dim // 2)
801+ cos_w = cos [pos_ids [:, 1 ]]
802+ sin_h = sin [pos_ids [:, 0 ]]
803+ sin_w = sin [pos_ids [:, 1 ]]
804+
805+ cos_combined = torch .cat ([cos_h , cos_w ], dim = - 1 )
806+ sin_combined = torch .cat ([sin_h , sin_w ], dim = - 1 )
807+ return cos_combined , sin_combined , pos_ids
823808
824809 def compute_attn_mask_seqlen (
825810 self ,
@@ -848,7 +833,9 @@ def forward(
848833 x = self .post_conv_layernorm (x )
849834
850835 # compute position embedding
851- rotary_pos_emb , image_type_ids = self .rot_pos_emb (grid_thw )
836+ rotary_pos_emb_cos , rotary_pos_emb_sin , image_type_ids = self .rot_pos_emb (
837+ grid_thw
838+ )
852839 # compute cu_seqlens
853840 cu_seqlens = torch .repeat_interleave (
854841 grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]
@@ -867,7 +854,8 @@ def forward(
867854 x = blk (
868855 x ,
869856 cu_seqlens = cu_seqlens ,
870- rotary_pos_emb = rotary_pos_emb ,
857+ rotary_pos_emb_cos = rotary_pos_emb_cos ,
858+ rotary_pos_emb_sin = rotary_pos_emb_sin ,
871859 max_seqlen = max_seqlen ,
872860 seqlens = seqlens ,
873861 )
0 commit comments