Skip to content

Commit b9489f5

Browse files
authored
[Model][Perf] Use cos and sin cache in QwenVL (vllm-project#28798)
Signed-off-by: gcanlin <[email protected]>
1 parent 285eaa4 commit b9489f5

File tree

6 files changed

+218
-217
lines changed

6 files changed

+218
-217
lines changed

vllm/model_executor/layers/rotary_embedding/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
8383
):
8484
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
8585

86+
def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]:
87+
cos_sin = self.cos_sin_cache[:seqlen]
88+
cos, sin = cos_sin.chunk(2, dim=-1)
89+
return cos, sin
90+
8691

8792
class RotaryEmbedding(RotaryEmbeddingBase):
8893
def __init__(

vllm/model_executor/models/glm4_1v.py

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
RowParallelLinear,
6666
)
6767
from vllm.model_executor.layers.quantization import QuantizationConfig
68+
from vllm.model_executor.layers.rotary_embedding import get_rope
6869
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
6970
from vllm.model_executor.models.module_mapping import MultiModelKeys
7071
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -341,7 +342,8 @@ def forward(
341342
self,
342343
x: torch.Tensor,
343344
cu_seqlens: torch.Tensor,
344-
rotary_pos_emb: torch.Tensor,
345+
rotary_pos_emb_cos: torch.Tensor,
346+
rotary_pos_emb_sin: torch.Tensor,
345347
max_seqlen: int | None = None, # Only used for Flash Attention
346348
seqlens: list[int] | None = None, # Only used for xFormers
347349
) -> torch.Tensor:
@@ -353,10 +355,12 @@ def forward(
353355
batch_size = q.shape[1]
354356

355357
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
356-
if rotary_pos_emb is not None:
358+
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
357359
# [2 * b, s, heads, head_dim]
358360
qk_concat = torch.cat([q, k], dim=0)
359-
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
361+
qk_rotated = apply_rotary_pos_emb_vision(
362+
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
363+
)
360364
q, k = torch.chunk(qk_rotated, 2, dim=0)
361365

362366
if self.is_flash_attn_backend:
@@ -454,14 +458,16 @@ def forward(
454458
self,
455459
x: torch.Tensor,
456460
cu_seqlens: torch.Tensor,
457-
rotary_pos_emb: torch.Tensor,
461+
rotary_pos_emb_cos: torch.Tensor,
462+
rotary_pos_emb_sin: torch.Tensor,
458463
max_seqlen: int | None = None, # Only used for Flash Attention
459464
seqlens: list[int] | None = None, # Only used for xFormers
460465
) -> torch.Tensor:
461466
x_attn = self.attn(
462467
self.norm1(x),
463468
cu_seqlens=cu_seqlens,
464-
rotary_pos_emb=rotary_pos_emb,
469+
rotary_pos_emb_cos=rotary_pos_emb_cos,
470+
rotary_pos_emb_sin=rotary_pos_emb_sin,
465471
max_seqlen=max_seqlen,
466472
seqlens=seqlens,
467473
)
@@ -660,44 +666,6 @@ def forward(
660666
return embeddings
661667

662668

663-
class Glm4vVisionRotaryEmbedding(nn.Module):
664-
def __init__(self, dim: int, theta: float = 10000.0) -> None:
665-
super().__init__()
666-
self.dim = dim
667-
self.theta = theta
668-
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
669-
self.register_buffer("inv_freq", inv_freq, persistent=False)
670-
self._seq_len_cached = 0
671-
self._freqs_cached = None
672-
673-
def update_freqs_cache(self, seqlen: int) -> None:
674-
if seqlen > self._seq_len_cached:
675-
seqlen *= 2
676-
self._seq_len_cached = seqlen
677-
self.inv_freq = 1.0 / (
678-
self.theta
679-
** (
680-
torch.arange(
681-
0,
682-
self.dim,
683-
2,
684-
dtype=torch.float,
685-
device=self.inv_freq.device,
686-
)
687-
/ self.dim
688-
)
689-
)
690-
seq = torch.arange(
691-
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
692-
)
693-
freqs = torch.outer(seq, self.inv_freq)
694-
self._freqs_cached = freqs
695-
696-
def forward(self, seqlen: int) -> torch.Tensor:
697-
self.update_freqs_cache(seqlen)
698-
return self._freqs_cached[:seqlen]
699-
700-
701669
class Glm4vVisionTransformer(nn.Module):
702670
def __init__(
703671
self,
@@ -731,7 +699,13 @@ def __init__(
731699

732700
norm_layer = partial(RMSNorm, eps=norm_eps)
733701
head_dim = self.hidden_size // self.num_heads
734-
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
702+
self.rotary_pos_emb = get_rope(
703+
head_size=head_dim,
704+
rotary_dim=head_dim // 2,
705+
max_position=8192,
706+
base=10000.0,
707+
is_neox_style=True,
708+
)
735709
self.blocks = nn.ModuleList(
736710
[
737711
Glm4vVisionBlock(
@@ -789,7 +763,9 @@ def dtype(self) -> torch.dtype:
789763
def device(self) -> torch.device:
790764
return self.patch_embed.proj.weight.device
791765

792-
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
766+
def rot_pos_emb(
767+
self, grid_thw: torch.Tensor
768+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
793769
pos_ids = []
794770
for t, h, w in grid_thw:
795771
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
@@ -817,9 +793,18 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
817793
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
818794
pos_ids = torch.cat(pos_ids, dim=0)
819795
max_grid_size = grid_thw[:, 1:].max()
820-
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
821-
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
822-
return rotary_pos_emb, pos_ids
796+
797+
# Use pre-computed cos_sin_cache from RotaryEmbedding
798+
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
799+
800+
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
801+
cos_w = cos[pos_ids[:, 1]]
802+
sin_h = sin[pos_ids[:, 0]]
803+
sin_w = sin[pos_ids[:, 1]]
804+
805+
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
806+
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
807+
return cos_combined, sin_combined, pos_ids
823808

824809
def compute_attn_mask_seqlen(
825810
self,
@@ -848,7 +833,9 @@ def forward(
848833
x = self.post_conv_layernorm(x)
849834

850835
# compute position embedding
851-
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
836+
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
837+
grid_thw
838+
)
852839
# compute cu_seqlens
853840
cu_seqlens = torch.repeat_interleave(
854841
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@@ -867,7 +854,8 @@ def forward(
867854
x = blk(
868855
x,
869856
cu_seqlens=cu_seqlens,
870-
rotary_pos_emb=rotary_pos_emb,
857+
rotary_pos_emb_cos=rotary_pos_emb_cos,
858+
rotary_pos_emb_sin=rotary_pos_emb_sin,
871859
max_seqlen=max_seqlen,
872860
seqlens=seqlens,
873861
)

0 commit comments

Comments
 (0)