Skip to content

Commit 28b62ea

Browse files
committed
optimize mllama
1 parent bd677f5 commit 28b62ea

File tree

5 files changed

+38
-21
lines changed

5 files changed

+38
-21
lines changed

lmdeploy/pytorch/backends/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
alibi: bool = None,
3535
sliding_window: int = None,
3636
logit_softcapping: float = None,
37+
causal: bool = True,
3738
**kwargs,
3839
) -> None:
3940
if scale is None:
@@ -53,6 +54,7 @@ def __init__(
5354
self.alibi = alibi
5455
self.sliding_window = sliding_window
5556
self.logit_softcapping = logit_softcapping
57+
self.causal = causal
5658

5759
@abstractmethod
5860
def forward(
@@ -82,6 +84,7 @@ def build(
8284
alibi: bool = False,
8385
sliding_window: int = None,
8486
logical_softcapping: float = None,
87+
causal: bool = True,
8588
**kwargs,
8689
) -> AttentionImpl[T]:
8790
"""build."""

lmdeploy/pytorch/backends/cuda/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
alibi: bool = False,
4242
sliding_window: int = None,
4343
logit_softcapping: float = None,
44+
causal: bool = True,
4445
**kwargs,
4546
):
4647
super().__init__(
@@ -52,8 +53,10 @@ def __init__(
5253
alibi=alibi,
5354
sliding_window=sliding_window,
5455
logit_softcapping=logit_softcapping,
56+
causal=causal,
5557
**kwargs,
5658
)
59+
assert not (alibi and not causal)
5760

5861
from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd,
5962
fill_kv_cache,
@@ -169,6 +172,7 @@ def forward(
169172
window_size=self.sliding_window,
170173
sm_scale=self.scale,
171174
logit_softcapping=self.logit_softcapping,
175+
causal=self.causal,
172176
)
173177
else:
174178
self.alibi_paged_attention_fwd(
@@ -204,6 +208,7 @@ def build(
204208
alibi: bool = False,
205209
sliding_window: int = None,
206210
logical_softcapping: float = None,
211+
causal: bool = True,
207212
**kwargs,
208213
) -> TritonAttentionImpl:
209214
"""build."""
@@ -215,4 +220,5 @@ def build(
215220
alibi=alibi,
216221
sliding_window=sliding_window,
217222
logical_softcapping=logical_softcapping,
223+
causal=causal,
218224
**kwargs)

lmdeploy/pytorch/backends/dlinfer/attention.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def __init__(
3030
alibi: bool = None,
3131
sliding_window: int = None,
3232
logit_softcapping: float = None,
33+
causal: bool = True,
3334
**kwargs,
3435
):
36+
assert causal
3537
super().__init__(
3638
num_heads,
3739
head_size,
@@ -41,6 +43,7 @@ def __init__(
4143
alibi,
4244
sliding_window,
4345
logit_softcapping,
46+
causal=causal,
4447
**kwargs,
4548
)
4649

@@ -121,6 +124,7 @@ def build(
121124
alibi_scale: float = None,
122125
sliding_window: int = None,
123126
logical_softcapping: float = None,
127+
causal: bool = True,
124128
**kwargs,
125129
) -> DlinferAttentionImpl:
126130
"""build."""
@@ -132,4 +136,5 @@ def build(
132136
alibi_scale=alibi_scale,
133137
sliding_window=sliding_window,
134138
logical_softcapping=logical_softcapping,
139+
causal=causal,
135140
**kwargs)

lmdeploy/pytorch/models/mllama.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def __init__(self,
195195
self.head_dim,
196196
num_kv_heads=self.num_key_value_heads,
197197
v_head_size=self.head_dim,
198+
causal=False,
198199
)
199200

200201
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
@@ -1128,10 +1129,12 @@ def forward(
11281129

11291130
# Collect intermediate layer outputs from encoder output
11301131
all_intermediate_hidden_states = output[1]
1132+
all_intermediate_hidden_states = [
1133+
all_intermediate_hidden_states[i]
1134+
for i in self.intermediate_layers_indices
1135+
]
11311136
intermediate_hidden_states = torch.stack(
11321137
all_intermediate_hidden_states, dim=-1)
1133-
intermediate_hidden_states = intermediate_hidden_states[
1134-
..., self.intermediate_layers_indices]
11351138

11361139
# Remove padding from intermediate hidden states
11371140
intermediate_hidden_states = intermediate_hidden_states.reshape(
@@ -1196,8 +1199,8 @@ def __init__(self,
11961199
# preprocessor
11971200
self.input_processor = MLlamaInputProcessor(self.config, dtype)
11981201

1199-
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
1200-
attn_metadata: Any, input_ids: torch.LongTensor):
1202+
def flat_encoder_result(self, attn_metadata: Any,
1203+
input_ids: torch.LongTensor):
12011204
# since every state share the same shape
12021205
full_text_row_masked_out_mask = torch.ones(
12031206
(attn_metadata.q_seqlens.sum(), 1), dtype=torch.bool)
@@ -1208,9 +1211,9 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor,
12081211
full_text_row_masked_out_mask[start_pos:img_id] = False
12091212
start_pos += q_seq_len
12101213
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
1211-
cross_attention_states.device)
1214+
input_ids.device)
12121215

1213-
return cross_attention_states, full_text_row_masked_out_mask
1216+
return full_text_row_masked_out_mask
12141217

12151218
def forward(
12161219
self,
@@ -1227,6 +1230,19 @@ def forward(
12271230
):
12281231
"""model forward, return logits."""
12291232

1233+
if cross_attn_metadata is None:
1234+
full_text_row_masked_out_mask = None
1235+
# FIXME basically, we want to inference
1236+
# text requests and image requests separately
1237+
elif pixel_values is None and (cross_attn_metadata.kv_seqlens is None):
1238+
full_text_row_masked_out_mask = None
1239+
elif cross_attn_metadata.is_decoding:
1240+
full_text_row_masked_out_mask = input_ids.new_ones(
1241+
input_ids.size(-1), 1)
1242+
else:
1243+
full_text_row_masked_out_mask = self.flat_encoder_result(
1244+
cross_attn_metadata, input_ids) # noqa
1245+
12301246
cross_attention_states = None
12311247
if pixel_values is not None:
12321248
cross_attention_states = self.vision_model(
@@ -1240,21 +1256,6 @@ def forward(
12401256
cross_attention_states = cross_attention_states.view(
12411257
bsz, -1, image_token_dim)
12421258

1243-
if cross_attn_metadata is None:
1244-
full_text_row_masked_out_mask = None
1245-
# FIXME basically, we want to inference
1246-
# text requests and image requests separately
1247-
elif cross_attention_states is None and (cross_attn_metadata.kv_seqlens
1248-
is None):
1249-
full_text_row_masked_out_mask = None
1250-
elif cross_attn_metadata.is_decoding:
1251-
full_text_row_masked_out_mask = input_ids.new_ones(
1252-
input_ids.size(-1), 1)
1253-
else:
1254-
(cross_attention_states,
1255-
full_text_row_masked_out_mask) = self.flat_encoder_result(
1256-
cross_attention_states, cross_attn_metadata,
1257-
input_ids) # noqa
12581259
hidden_states = self.language_model(
12591260
input_ids=input_ids,
12601261
position_ids=position_ids,

lmdeploy/pytorch/nn/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
sliding_window: int = None,
3333
logit_softcapping: float = None,
3434
replicate_kv: bool = False,
35+
causal: bool = True,
3536
**kwargs,
3637
):
3738
super().__init__()
@@ -55,6 +56,7 @@ def __init__(
5556
alibi=alibi,
5657
sliding_window=sliding_window,
5758
logit_softcapping=logit_softcapping,
59+
causal=causal,
5860
**kwargs,
5961
)
6062

0 commit comments

Comments
 (0)