Skip to content

Commit de968e7

Browse files
committed
fix acc, explicit dtype, optimize
1 parent 9740225 commit de968e7

File tree

3 files changed

+93
-52
lines changed

3 files changed

+93
-52
lines changed

lmdeploy/pytorch/backends/graph_runner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ def update_model_metas(
8282

8383
return None
8484

85+
def post_update_model_metas(self, model_metas):
86+
"""Post update model meta."""
87+
if hasattr(self.model, 'post_update_model_metas'):
88+
return self.model.post_update_model_metas(model_metas)
89+
90+
return None
91+
8592
def get_input_processor(self):
8693
"""Get input processor."""
8794
if hasattr(self.model, 'get_input_processor'):

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,8 @@ def model_forward(
238238
context=context,
239239
)
240240
output = model(**input_dict)
241-
seq_length = ctx_mgr.current_context().q_seqlens
242-
243-
return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length)
241+
model_metas = model.post_update_model_metas(model_metas)
242+
return dict(hidden_states=output, model_metas=model_metas)
244243

245244

246245
@record_function('stopping_criteria')
@@ -504,9 +503,13 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int):
504503
if not is_long_context:
505504
ret = await __forward(inputs)
506505
if not return_logits and not inputs.is_decoding:
507-
# fetch seq_length from the returned context, since models may change it (e.g. InternVL-Flash)
508-
seq_length = ret['seq_length']
509-
assert seq_length is not None, 'seq_length cannot be None'
506+
seq_length = inputs.seq_length
507+
508+
# for InternVL-3.5-Flash, update seq_length if model_metas contain 'new_seqlen'
509+
model_metas = ret.get('model_metas', None)
510+
if model_metas is not None and 'new_seqlen' in model_metas[0]:
511+
seq_length = torch.tensor([meta['new_seqlen'] for meta in model_metas], device='cuda')
512+
510513
last_token_loc = seq_length.cumsum(0) - 1
511514

512515
ret['hidden_states'] = ret['hidden_states'][:, last_token_loc]

lmdeploy/pytorch/models/internvl.py

Lines changed: 77 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@
2424

2525
class Gating(nn.Module):
2626

27-
def __init__(self, hidden_size=2048, expansion_factor=4):
27+
def __init__(self, hidden_size=2048, expansion_factor=4, dtype=None, device=None):
2828
super().__init__()
2929

3030
mid_dim = hidden_size * expansion_factor
3131

3232
def mlp_block(in_dim, out_dim):
3333
return nn.Sequential(
34-
nn.Linear(in_dim, out_dim, bias=True),
34+
nn.Linear(in_dim, out_dim, bias=True, dtype=dtype, device=device),
3535
nn.GELU(),
3636
nn.Identity(),
37-
nn.Linear(out_dim, in_dim, bias=True),
37+
nn.Linear(out_dim, in_dim, bias=True, dtype=dtype, device=device),
3838
nn.Identity(),
39-
nn.LayerNorm(in_dim),
39+
nn.LayerNorm(in_dim, dtype=dtype, device=device),
4040
)
4141

4242
self.block1 = mlp_block(hidden_size, mid_dim)
@@ -45,8 +45,8 @@ def mlp_block(in_dim, out_dim):
4545
self.block4 = mlp_block(hidden_size, mid_dim)
4646

4747
self.gate = nn.Sequential(
48-
nn.LayerNorm(hidden_size),
49-
nn.Linear(hidden_size, 2, bias=True) # 2 experts
48+
nn.LayerNorm(hidden_size, dtype=dtype, device=device),
49+
nn.Linear(hidden_size, 2, bias=True, dtype=dtype, device=device) # 2 experts
5050
)
5151

5252
def forward(self, x):
@@ -62,21 +62,37 @@ def forward(self, x):
6262

6363
class CrossAttentionPooling(nn.Module):
6464

65-
def __init__(self, dim, num_heads=16):
65+
def __init__(self, dim, num_heads=16, dtype=None, device=None):
6666
super().__init__()
67-
self.query_token = nn.Parameter(torch.randn(1, dim)) # [1, D]
68-
69-
self.attn1 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
70-
self.norm1 = nn.LayerNorm(dim)
71-
72-
self.attn2 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
73-
self.norm2 = nn.LayerNorm(dim)
74-
75-
self.attn3 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
76-
self.norm3 = nn.LayerNorm(dim)
77-
78-
self.attn4 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
79-
self.norm4 = nn.LayerNorm(dim)
67+
self.query_token = nn.Parameter(torch.randn(1, dim, dtype=dtype, device=device)) # [1, D]
68+
69+
self.attn1 = nn.MultiheadAttention(embed_dim=dim,
70+
num_heads=num_heads,
71+
batch_first=True,
72+
dtype=dtype,
73+
device=device)
74+
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
75+
76+
self.attn2 = nn.MultiheadAttention(embed_dim=dim,
77+
num_heads=num_heads,
78+
batch_first=True,
79+
dtype=dtype,
80+
device=device)
81+
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
82+
83+
self.attn3 = nn.MultiheadAttention(embed_dim=dim,
84+
num_heads=num_heads,
85+
batch_first=True,
86+
dtype=dtype,
87+
device=device)
88+
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
89+
90+
self.attn4 = nn.MultiheadAttention(embed_dim=dim,
91+
num_heads=num_heads,
92+
batch_first=True,
93+
dtype=dtype,
94+
device=device)
95+
self.norm4 = nn.LayerNorm(dim, dtype=dtype, device=device)
8096

8197
def forward(self, batched_tokens: list[torch.Tensor]):
8298
"""
@@ -493,8 +509,10 @@ def __init__(self,
493509
nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, bias=True, dtype=dtype, device=device), nn.GELU(),
494510
nn.Identity(), nn.Linear(llm_hidden_size * 2, llm_hidden_size, bias=True, dtype=dtype, device=device))
495511

496-
self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size)
497-
self.gating = Gating(hidden_size=vit_hidden_size)
512+
self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size, dtype=dtype, device=device)
513+
self.gating = Gating(hidden_size=vit_hidden_size, dtype=dtype, device=device)
514+
515+
self.model_metas = None
498516

499517
def compile_model(self):
500518
torch_version = version.parse(torch.__version__)
@@ -688,46 +706,44 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens
688706
vit_embeds = torch.cat(selected_embeds, dim=0)
689707

690708
# compress visual tokens in sentence
691-
lang_embeds, input_ids, image_mask, seq_lens = self.compress_visual_tokens_in_sentence(
709+
new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths = self.compress_visual_tokens_in_sentence(
692710
input_embeds=lang_embeds,
693711
input_ids=input_ids,
694712
img_context_token_id=img_context_token_id,
695713
gate_result=gate_result,
696714
)
697715

698-
return vit_embeds, lang_embeds, input_ids, image_mask, seq_lens
716+
return vit_embeds, new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths
699717

700-
def update_context(self, new_input_ids: torch.Tensor, new_seqlens: List[int]) -> StepContext:
718+
def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int]) -> StepContext:
701719
"""Update the current context with new input_ids."""
702720
from lmdeploy.pytorch.model_inputs import ModelInputs
703721

704722
crt_ctx = self.ctx_mgr.current_context()
705-
if crt_ctx is None:
706-
raise RuntimeError('Cannot update a non-existent context.')
723+
assert crt_ctx is not None, 'Current context cannot be None.'
707724

708-
device = new_input_ids.device
709-
new_seqlens = torch.tensor(new_seqlens, device=device, dtype=torch.long)
725+
# fill model metas
726+
self.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_seqlens]
710727

711728
# create new model inputs
712-
new_model_inputs = ModelInputs(input_ids=new_input_ids,
729+
device = input_ids.device
730+
total_msgs = len(new_seqlens)
731+
new_seqlens = torch.tensor(new_seqlens, device=device, dtype=torch.long)
732+
new_model_inputs = ModelInputs(input_ids=input_ids,
713733
seq_length=new_seqlens,
714-
history_lengths=torch.tensor([0], device=device, dtype=torch.long),
734+
history_lengths=torch.zeros(total_msgs, device=device, dtype=torch.long),
715735
block_offsets=crt_ctx.block_offsets,
716736
is_decoding=False,
717-
num_ignored_history=torch.tensor([0], device=device, dtype=torch.long),
737+
num_ignored_history=torch.zeros(total_msgs, device=device, dtype=torch.long),
718738
max_q_seqlen=new_seqlens.max().item(),
719739
max_kv_seqlen=new_seqlens.max().item(),
720740
sum_kv_seqlen=new_seqlens.sum().item(),
721-
model_metas=[None])
741+
model_metas=[None for _ in range(total_msgs)])
722742

723-
# build and set new context
724-
# NOTE: we keep original block_offsets, vision_inputs and kv_caches
743+
# build new context, to get new position_ids and attn_metadata
725744
new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config)
726-
new_ctx.vision_inputs = crt_ctx.vision_inputs
727-
new_ctx.kv_caches = crt_ctx.kv_caches
728-
self.ctx_mgr.set_context(new_ctx)
729745

730-
return new_ctx
746+
return new_ctx.position_ids, new_ctx.attn_metadata
731747

732748
def forward(
733749
self,
@@ -751,15 +767,11 @@ def forward(
751767
lang_embeds = self.language_model.get_input_embeddings()(input_ids)
752768
else:
753769
# extract feature and compress visual tokens
754-
vit_embeds, lang_embeds, new_input_ids, new_image_mask, new_seqlens = self.extract_and_compress(
770+
vit_embeds, lang_embeds, input_ids, image_mask, new_seqlens = self.extract_and_compress(
755771
pixel_values, input_ids, image_token_id)
756-
input_ids = new_input_ids
757-
image_mask = new_image_mask
758772

759-
# update context and relevant attributes
760-
ctx = self.update_context(new_input_ids, new_seqlens)
761-
position_ids = ctx.position_ids
762-
attn_metadata = ctx.attn_metadata
773+
# update forward inputs
774+
position_ids, attn_metadata = self.update_forward_inputs(input_ids, new_seqlens)
763775

764776
lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)
765777

@@ -801,6 +813,20 @@ def prepare_inputs_for_generation(
801813
vision_embeddings = context.input_embeddings
802814
vision_embedding_indexing = None
803815

816+
if context.is_decoding and context.model_metas is not None and context.model_metas[0] is not None:
817+
# model meta from the previous step, therefore +1 for the current decoding step
818+
new_kv_seqlens = [(meta['new_seqlen'] + 1) for meta in context.model_metas]
819+
820+
# update model metas for the next step
821+
self.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_kv_seqlens]
822+
823+
# update position ids, attn_metadata
824+
new_kv_seqlens = torch.tensor(new_kv_seqlens, device=input_ids.device, dtype=torch.long)
825+
position_ids = new_kv_seqlens
826+
attn_metadata.kv_seqlens = new_kv_seqlens
827+
attn_metadata.cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(new_kv_seqlens, dim=0, dtype=torch.int32),
828+
(1, 0))
829+
804830
# vision inputs
805831
pixel_values = None
806832
image_mask = None
@@ -890,6 +916,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
890916

891917
self.language_model.load_weights(new_weights.items())
892918

919+
def post_update_model_metas(self, model_metas):
920+
"""Post update model meta."""
921+
new_model_metas = self.model_metas if self.model_metas is not None else model_metas
922+
return new_model_metas
923+
893924
def get_input_processor(self) -> BaseModelInputProcessor:
894925
"""Get input processor."""
895926
return self.input_processor

0 commit comments

Comments
 (0)