24
24
25
25
class Gating (nn .Module ):
26
26
27
- def __init__ (self , hidden_size = 2048 , expansion_factor = 4 ):
27
+ def __init__ (self , hidden_size = 2048 , expansion_factor = 4 , dtype = None , device = None ):
28
28
super ().__init__ ()
29
29
30
30
mid_dim = hidden_size * expansion_factor
31
31
32
32
def mlp_block (in_dim , out_dim ):
33
33
return nn .Sequential (
34
- nn .Linear (in_dim , out_dim , bias = True ),
34
+ nn .Linear (in_dim , out_dim , bias = True , dtype = dtype , device = device ),
35
35
nn .GELU (),
36
36
nn .Identity (),
37
- nn .Linear (out_dim , in_dim , bias = True ),
37
+ nn .Linear (out_dim , in_dim , bias = True , dtype = dtype , device = device ),
38
38
nn .Identity (),
39
- nn .LayerNorm (in_dim ),
39
+ nn .LayerNorm (in_dim , dtype = dtype , device = device ),
40
40
)
41
41
42
42
self .block1 = mlp_block (hidden_size , mid_dim )
@@ -45,8 +45,8 @@ def mlp_block(in_dim, out_dim):
45
45
self .block4 = mlp_block (hidden_size , mid_dim )
46
46
47
47
self .gate = nn .Sequential (
48
- nn .LayerNorm (hidden_size ),
49
- nn .Linear (hidden_size , 2 , bias = True ) # 2 experts
48
+ nn .LayerNorm (hidden_size , dtype = dtype , device = device ),
49
+ nn .Linear (hidden_size , 2 , bias = True , dtype = dtype , device = device ) # 2 experts
50
50
)
51
51
52
52
def forward (self , x ):
@@ -62,21 +62,37 @@ def forward(self, x):
62
62
63
63
class CrossAttentionPooling (nn .Module ):
64
64
65
- def __init__ (self , dim , num_heads = 16 ):
65
+ def __init__ (self , dim , num_heads = 16 , dtype = None , device = None ):
66
66
super ().__init__ ()
67
- self .query_token = nn .Parameter (torch .randn (1 , dim )) # [1, D]
68
-
69
- self .attn1 = nn .MultiheadAttention (embed_dim = dim , num_heads = num_heads , batch_first = True )
70
- self .norm1 = nn .LayerNorm (dim )
71
-
72
- self .attn2 = nn .MultiheadAttention (embed_dim = dim , num_heads = num_heads , batch_first = True )
73
- self .norm2 = nn .LayerNorm (dim )
74
-
75
- self .attn3 = nn .MultiheadAttention (embed_dim = dim , num_heads = num_heads , batch_first = True )
76
- self .norm3 = nn .LayerNorm (dim )
77
-
78
- self .attn4 = nn .MultiheadAttention (embed_dim = dim , num_heads = num_heads , batch_first = True )
79
- self .norm4 = nn .LayerNorm (dim )
67
+ self .query_token = nn .Parameter (torch .randn (1 , dim , dtype = dtype , device = device )) # [1, D]
68
+
69
+ self .attn1 = nn .MultiheadAttention (embed_dim = dim ,
70
+ num_heads = num_heads ,
71
+ batch_first = True ,
72
+ dtype = dtype ,
73
+ device = device )
74
+ self .norm1 = nn .LayerNorm (dim , dtype = dtype , device = device )
75
+
76
+ self .attn2 = nn .MultiheadAttention (embed_dim = dim ,
77
+ num_heads = num_heads ,
78
+ batch_first = True ,
79
+ dtype = dtype ,
80
+ device = device )
81
+ self .norm2 = nn .LayerNorm (dim , dtype = dtype , device = device )
82
+
83
+ self .attn3 = nn .MultiheadAttention (embed_dim = dim ,
84
+ num_heads = num_heads ,
85
+ batch_first = True ,
86
+ dtype = dtype ,
87
+ device = device )
88
+ self .norm3 = nn .LayerNorm (dim , dtype = dtype , device = device )
89
+
90
+ self .attn4 = nn .MultiheadAttention (embed_dim = dim ,
91
+ num_heads = num_heads ,
92
+ batch_first = True ,
93
+ dtype = dtype ,
94
+ device = device )
95
+ self .norm4 = nn .LayerNorm (dim , dtype = dtype , device = device )
80
96
81
97
def forward (self , batched_tokens : list [torch .Tensor ]):
82
98
"""
@@ -493,8 +509,10 @@ def __init__(self,
493
509
nn .Linear (llm_hidden_size * 2 , llm_hidden_size * 2 , bias = True , dtype = dtype , device = device ), nn .GELU (),
494
510
nn .Identity (), nn .Linear (llm_hidden_size * 2 , llm_hidden_size , bias = True , dtype = dtype , device = device ))
495
511
496
- self .pooling_before_gating = CrossAttentionPooling (dim = vit_hidden_size )
497
- self .gating = Gating (hidden_size = vit_hidden_size )
512
+ self .pooling_before_gating = CrossAttentionPooling (dim = vit_hidden_size , dtype = dtype , device = device )
513
+ self .gating = Gating (hidden_size = vit_hidden_size , dtype = dtype , device = device )
514
+
515
+ self .model_metas = None
498
516
499
517
def compile_model (self ):
500
518
torch_version = version .parse (torch .__version__ )
@@ -688,46 +706,44 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens
688
706
vit_embeds = torch .cat (selected_embeds , dim = 0 )
689
707
690
708
# compress visual tokens in sentence
691
- lang_embeds , input_ids , image_mask , seq_lens = self .compress_visual_tokens_in_sentence (
709
+ new_lang_embeds , new_input_ids , new_image_mask , new_seq_lengths = self .compress_visual_tokens_in_sentence (
692
710
input_embeds = lang_embeds ,
693
711
input_ids = input_ids ,
694
712
img_context_token_id = img_context_token_id ,
695
713
gate_result = gate_result ,
696
714
)
697
715
698
- return vit_embeds , lang_embeds , input_ids , image_mask , seq_lens
716
+ return vit_embeds , new_lang_embeds , new_input_ids , new_image_mask , new_seq_lengths
699
717
700
- def update_context (self , new_input_ids : torch .Tensor , new_seqlens : List [int ]) -> StepContext :
718
+ def update_forward_inputs (self , input_ids : torch .Tensor , new_seqlens : List [int ]) -> StepContext :
701
719
"""Update the current context with new input_ids."""
702
720
from lmdeploy .pytorch .model_inputs import ModelInputs
703
721
704
722
crt_ctx = self .ctx_mgr .current_context ()
705
- if crt_ctx is None :
706
- raise RuntimeError ('Cannot update a non-existent context.' )
723
+ assert crt_ctx is not None , 'Current context cannot be None.'
707
724
708
- device = new_input_ids . device
709
- new_seqlens = torch . tensor ( new_seqlens , device = device , dtype = torch . long )
725
+ # fill model metas
726
+ self . model_metas = [ dict ( new_seqlen = seqlen ) for seqlen in new_seqlens ]
710
727
711
728
# create new model inputs
712
- new_model_inputs = ModelInputs (input_ids = new_input_ids ,
729
+ device = input_ids .device
730
+ total_msgs = len (new_seqlens )
731
+ new_seqlens = torch .tensor (new_seqlens , device = device , dtype = torch .long )
732
+ new_model_inputs = ModelInputs (input_ids = input_ids ,
713
733
seq_length = new_seqlens ,
714
- history_lengths = torch .tensor ([ 0 ] , device = device , dtype = torch .long ),
734
+ history_lengths = torch .zeros ( total_msgs , device = device , dtype = torch .long ),
715
735
block_offsets = crt_ctx .block_offsets ,
716
736
is_decoding = False ,
717
- num_ignored_history = torch .tensor ([ 0 ] , device = device , dtype = torch .long ),
737
+ num_ignored_history = torch .zeros ( total_msgs , device = device , dtype = torch .long ),
718
738
max_q_seqlen = new_seqlens .max ().item (),
719
739
max_kv_seqlen = new_seqlens .max ().item (),
720
740
sum_kv_seqlen = new_seqlens .sum ().item (),
721
- model_metas = [None ])
741
+ model_metas = [None for _ in range ( total_msgs ) ])
722
742
723
- # build and set new context
724
- # NOTE: we keep original block_offsets, vision_inputs and kv_caches
743
+ # build new context, to get new position_ids and attn_metadata
725
744
new_ctx = self .ctx_mgr .build_context (new_model_inputs , crt_ctx .model_config )
726
- new_ctx .vision_inputs = crt_ctx .vision_inputs
727
- new_ctx .kv_caches = crt_ctx .kv_caches
728
- self .ctx_mgr .set_context (new_ctx )
729
745
730
- return new_ctx
746
+ return new_ctx . position_ids , new_ctx . attn_metadata
731
747
732
748
def forward (
733
749
self ,
@@ -751,15 +767,11 @@ def forward(
751
767
lang_embeds = self .language_model .get_input_embeddings ()(input_ids )
752
768
else :
753
769
# extract feature and compress visual tokens
754
- vit_embeds , lang_embeds , new_input_ids , new_image_mask , new_seqlens = self .extract_and_compress (
770
+ vit_embeds , lang_embeds , input_ids , image_mask , new_seqlens = self .extract_and_compress (
755
771
pixel_values , input_ids , image_token_id )
756
- input_ids = new_input_ids
757
- image_mask = new_image_mask
758
772
759
- # update context and relevant attributes
760
- ctx = self .update_context (new_input_ids , new_seqlens )
761
- position_ids = ctx .position_ids
762
- attn_metadata = ctx .attn_metadata
773
+ # update forward inputs
774
+ position_ids , attn_metadata = self .update_forward_inputs (input_ids , new_seqlens )
763
775
764
776
lang_embeds .masked_scatter_ (image_mask [..., None ], vit_embeds )
765
777
@@ -801,6 +813,20 @@ def prepare_inputs_for_generation(
801
813
vision_embeddings = context .input_embeddings
802
814
vision_embedding_indexing = None
803
815
816
+ if context .is_decoding and context .model_metas is not None and context .model_metas [0 ] is not None :
817
+ # model meta from the previous step, therefore +1 for the current decoding step
818
+ new_kv_seqlens = [(meta ['new_seqlen' ] + 1 ) for meta in context .model_metas ]
819
+
820
+ # update model metas for the next step
821
+ self .model_metas = [dict (new_seqlen = seqlen ) for seqlen in new_kv_seqlens ]
822
+
823
+ # update position ids, attn_metadata
824
+ new_kv_seqlens = torch .tensor (new_kv_seqlens , device = input_ids .device , dtype = torch .long )
825
+ position_ids = new_kv_seqlens
826
+ attn_metadata .kv_seqlens = new_kv_seqlens
827
+ attn_metadata .cu_seqlens_k = torch .nn .functional .pad (torch .cumsum (new_kv_seqlens , dim = 0 , dtype = torch .int32 ),
828
+ (1 , 0 ))
829
+
804
830
# vision inputs
805
831
pixel_values = None
806
832
image_mask = None
@@ -890,6 +916,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
890
916
891
917
self .language_model .load_weights (new_weights .items ())
892
918
919
+ def post_update_model_metas (self , model_metas ):
920
+ """Post update model meta."""
921
+ new_model_metas = self .model_metas if self .model_metas is not None else model_metas
922
+ return new_model_metas
923
+
893
924
def get_input_processor (self ) -> BaseModelInputProcessor :
894
925
"""Get input processor."""
895
926
return self .input_processor
0 commit comments