Skip to content

Commit 524b402

Browse files
committed
update processor
1 parent 28b62ea commit 524b402

File tree

4 files changed

+49
-80
lines changed

4 files changed

+49
-80
lines changed

lmdeploy/pytorch/engine/engine_instance.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List
33

44
from lmdeploy.messages import EngineOutput, GenerationConfig
5-
from lmdeploy.pytorch.multimodal import MultiModalData
5+
from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
66
from lmdeploy.utils import get_logger
77

88
from ..messages import SamplingParam
@@ -125,14 +125,13 @@ def _try_add_session(self, session_id: int):
125125
"""
126126
return try_add_session(self.req_sender, session_id)
127127

128-
async def async_stream_infer(
129-
self,
130-
session_id: int,
131-
input_ids: List[int],
132-
gen_config: GenerationConfig = None,
133-
input_multimodals: List[MultiModalData] = None,
134-
adapter_name: str = None,
135-
**kwargs):
128+
async def async_stream_infer(self,
129+
session_id: int,
130+
input_ids: List[int],
131+
gen_config: GenerationConfig = None,
132+
input_multimodals: MultiModalInputs = None,
133+
adapter_name: str = None,
134+
**kwargs):
136135
"""Send stream inference request.
137136
138137
Args:
@@ -184,7 +183,7 @@ async def async_stream_infer(
184183
async def async_infer(self,
185184
session_id: int,
186185
input_ids: List[int] = None,
187-
input_multimodals: List[MultiModalData] = None,
186+
input_multimodals: MultiModalInputs = None,
188187
gen_config: GenerationConfig = None,
189188
**kwargs):
190189
"""Send inference request.
@@ -216,7 +215,7 @@ async def async_infer(self,
216215
def stream_infer(self,
217216
session_id: int,
218217
input_ids: List[int],
219-
input_multimodals: List[MultiModalData] = None,
218+
input_multimodals: MultiModalInputs = None,
220219
gen_config: GenerationConfig = None,
221220
adapter_name: str = None,
222221
**kwargs):
@@ -286,7 +285,7 @@ def __call_async():
286285
def infer(self,
287286
session_id: int,
288287
input_ids: List[int] = None,
289-
input_multimodals: List[MultiModalData] = None,
288+
input_multimodals: MultiModalInputs = None,
290289
gen_config: GenerationConfig = None,
291290
**kwargs):
292291
"""Send inference request.
@@ -318,7 +317,7 @@ async def async_batched_infer(
318317
self,
319318
session_ids: List[int],
320319
token_ids: List[List[int]] = None,
321-
input_multimodals: List[List[MultiModalData]] = None,
320+
input_multimodals: List[MultiModalInputs] = None,
322321
gen_config: GenerationConfig = None,
323322
adapter_names: List[str] = None,
324323
keep_cache: bool = False,
@@ -407,7 +406,7 @@ def batched_infer(
407406
self,
408407
session_ids: List[int],
409408
token_ids: List[List[int]] = None,
410-
input_multimodals: List[List[MultiModalData]] = None,
409+
input_multimodals: List[MultiModalInputs] = None,
411410
gen_config: GenerationConfig = None,
412411
adapter_names: List[str] = None,
413412
keep_cache: bool = False,
@@ -439,7 +438,7 @@ def cancel(self, session_id: int):
439438

440439
def decode(self,
441440
input_ids,
442-
input_multimodals: List[List[MultiModalData]] = None,
441+
input_multimodals: List[MultiModalInputs] = None,
443442
steps: List[int] = None,
444443
sequence_start: bool = True,
445444
sequence_end: bool = True,
@@ -449,7 +448,7 @@ def decode(self,
449448
Args:
450449
input_ids (numpy.ndarray): the batch of input token ids
451450
steps (List[int]): the offset of the k/v cache
452-
input_multimodals (List[List[MultiModalData]]):
451+
input_multimodals (List[MultiModalInputs]):
453452
multimodals inputs.
454453
sequence_start (bool): indicator for starting a sequence
455454
sequence_end (bool): indicator for ending a sequence

lmdeploy/pytorch/models/mllama.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,35 +1500,22 @@ def __init__(self, config: LlamaConfig, dtype: torch.dtype) -> None:
15001500
def preprocess_input(self, input_ids, input_multimodals: MultiModalInputs,
15011501
**kwargs):
15021502
"""prepare multimodal input."""
1503-
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
1504-
1505-
multimodals_dict = dict()
1506-
multimodals_dict['image'] = []
1507-
1508-
input_multimodals = sorted(input_multimodals, key=lambda mm: mm.loc)
1509-
1510-
for input_mm in input_multimodals:
1511-
image = input_mm.data
1512-
start = input_mm.loc
1513-
size = image.size
1514-
if any([s < 3 for s in size]):
1515-
image = image.resize([s * 3 for s in size])
1516-
image_inputs = self.processor.image_processor(images=image,
1517-
return_tensors='pt')
1518-
pixel_values = image_inputs['pixel_values'].to(self.dtype)
1519-
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
1520-
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
1521-
mm_tensor = MultiModalTensor(
1522-
data=pixel_values,
1523-
start=start,
1524-
end=start + 1,
1525-
encoder_len=self.encoder_len,
1526-
meta=dict(aspect_ratio_ids=aspect_ratio_ids,
1527-
aspect_ratio_mask=aspect_ratio_mask))
1528-
multimodals_dict['image'].append(mm_tensor)
1503+
if input_multimodals is None:
1504+
return input_ids, input_multimodals
1505+
1506+
input_imgs = input_multimodals.get('image', None)
1507+
if input_imgs is None:
1508+
return input_ids, input_multimodals
1509+
1510+
input_imgs = sorted(input_imgs, key=lambda mm: mm.start)
1511+
1512+
for img in input_imgs:
1513+
img.data = img.data.to(self.dtype)
1514+
img.end = img.start + 1
1515+
img.encoder_len = self.encoder_len
15291516

15301517
result = PreprocessInputResult(
15311518
input_ids=input_ids,
1532-
input_multimodals=multimodals_dict,
1519+
input_multimodals=dict(image=input_imgs),
15331520
)
15341521
return result

lmdeploy/pytorch/models/qwen2_vl.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,6 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
672672
return self.merger(hidden_states)
673673

674674

675-
OPTIONAL_KEYS = ['resized_height', 'resized_width', 'min_pixels', 'max_pixels']
676-
677-
678675
class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixin,
679676
CudaGraphMixin):
680677
"""ModelForCausalLM."""
@@ -1004,52 +1001,34 @@ def __init__(self, config: PretrainedConfig) -> None:
10041001

10051002
def preprocess_input(self,
10061003
input_ids: List[int],
1007-
input_mms: MultiModalInputs = None,
1004+
input_multimodals: MultiModalInputs = None,
10081005
**kwargs) -> PreprocessInputResult:
10091006
"""prepare multimodal input."""
1010-
from qwen_vl_utils import process_vision_info
1011-
1012-
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
1013-
global OPTIONAL_KEYS
1007+
if input_multimodals is None:
1008+
return input_ids, input_multimodals
10141009

1015-
multimodals_dict = dict()
1016-
multimodals_dict['image'] = []
1010+
input_imgs = input_multimodals.get('image', None)
1011+
if input_imgs is None:
1012+
return input_ids, input_multimodals
10171013

1018-
input_mms = sorted(input_mms, key=lambda mm: mm.loc)
1014+
input_imgs = sorted(input_imgs, key=lambda mm: mm.start)
10191015

10201016
cum_pad = 0
10211017
image_token_id = self.config.image_token_id
10221018

1023-
# image
1024-
for in_mm in input_mms:
1025-
image = in_mm.data
1026-
param = in_mm.meta
1027-
param = dict() if param is None else param
1028-
item = dict(type='image', image=image)
1029-
item.update({k: param[k] for k in OPTIONAL_KEYS if k in param})
1030-
messages = [dict(content=[item])]
1031-
image_inputs, _ = process_vision_info(messages)
1032-
image_inputs = self.processor.image_processor(images=image_inputs,
1033-
videos=None,
1034-
return_tensors='pt')
1035-
pixel_values = image_inputs['pixel_values']
1036-
image_grid_thw = image_inputs['image_grid_thw']
1019+
for img in input_imgs:
1020+
pixel_values = img.data
10371021
pad_size = pixel_values.size(0) // 4
1038-
loc = in_mm.loc
1039-
start = loc + cum_pad
1022+
start = img.start + cum_pad
10401023
end = start + pad_size
10411024
cum_pad += pad_size
1042-
input_ids = input_ids[:start] + [image_token_id
1043-
] * pad_size + input_ids[start:]
1044-
1045-
mm_tensor = MultiModalTensor(data=pixel_values,
1046-
start=start,
1047-
end=end,
1048-
meta=dict(grid_thw=image_grid_thw))
1049-
multimodals_dict['image'].append(mm_tensor)
1025+
input_ids = (input_ids[:start] + [image_token_id] * pad_size +
1026+
input_ids[start:])
1027+
img.start = start
1028+
img.end = end
10501029

10511030
result = PreprocessInputResult(
10521031
input_ids=input_ids,
1053-
input_multimodals=multimodals_dict,
1032+
input_multimodals=dict(image=input_imgs),
10541033
)
10551034
return result

lmdeploy/pytorch/multimodal/data_type.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@ class MultiModalData:
1818
class MultiModalTensor:
1919
data: NestedTensor
2020
start: int
21-
end: int
21+
end: int = None
2222
encoder_len: int = None
2323
meta: Dict[str, Any] = None
2424

25+
def __post_init__(self):
26+
if self.end is None:
27+
self.end = self.start
28+
2529
def to_device(self, device: str, non_blocking: bool = False):
2630
"""to device."""
2731
if isinstance(self.data, Tensor):

0 commit comments

Comments
 (0)