Skip to content

Commit c4e242c

Browse files
author
Jeremy Teboul
committed
feat: Support multiple image/audio embeddings per request in Chat Completions API
- Remove limitation that only allowed one message with image_embeds/audio_embeds - Update MultiModalItemTracker and AsyncMultiModalItemTracker to treat embeddings as lists - Add unit tests for multiple image embeddings support: * test_parse_chat_messages_multiple_image_embeds * test_parse_chat_messages_multiple_image_embeds_with_uuids * test_parse_chat_messages_multiple_image_embeds_async - Embeddings now behave consistently with regular images/audios - Validation via existing validate_num_items() against --limit-mm-per-prompt - Backward compatible with single embeddings This enables the Chat Completions API to leverage the model's existing capability for multiple embeddings, previously only accessible through the direct LLM inference API.
1 parent d7284a2 commit c4e242c

File tree

2 files changed

+204
-18
lines changed

2 files changed

+204
-18
lines changed

tests/entrypoints/test_chat_utils.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import base64
5+
import io
46
import warnings
57
from collections.abc import Mapping
68
from typing import Literal
79

810
import pytest
11+
import torch
912
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
1013

1114
from vllm.assets.audio import AudioAsset
@@ -987,6 +990,203 @@ async def test_parse_chat_messages_audio_embeds_async(
987990
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
988991

989992

993+
def test_parse_chat_messages_multiple_image_embeds(
994+
phi3v_model_config_image_embeds,
995+
phi3v_tokenizer,
996+
):
997+
"""Test that multiple image_embeds in a single message are now supported.
998+
999+
This test validates the fix for the limitation that previously only allowed
1000+
one message with {'type': 'image_embeds'}. Now multiple image embeddings
1001+
can be provided in a single request, similar to regular images.
1002+
"""
1003+
# Create two sample image embedding tensors
1004+
image_embedding_1 = torch.randn(256, 1024)
1005+
image_embedding_2 = torch.randn(128, 1024)
1006+
1007+
# Encode them as base64
1008+
def encode_embedding(embedding):
1009+
buffer = io.BytesIO()
1010+
torch.save(embedding, buffer)
1011+
buffer.seek(0)
1012+
binary_data = buffer.read()
1013+
return base64.b64encode(binary_data).decode("utf-8")
1014+
1015+
base64_image_embedding_1 = encode_embedding(image_embedding_1)
1016+
base64_image_embedding_2 = encode_embedding(image_embedding_2)
1017+
1018+
conversation, mm_data, mm_uuids = parse_chat_messages(
1019+
[
1020+
{
1021+
"role": "user",
1022+
"content": [
1023+
{
1024+
"type": "image_embeds",
1025+
"image_embeds": base64_image_embedding_1,
1026+
},
1027+
{
1028+
"type": "image_embeds",
1029+
"image_embeds": base64_image_embedding_2,
1030+
},
1031+
{"type": "text", "text": "Describe these two images."},
1032+
],
1033+
}
1034+
],
1035+
phi3v_model_config_image_embeds,
1036+
phi3v_tokenizer,
1037+
content_format="string",
1038+
)
1039+
1040+
# Verify conversation structure
1041+
assert conversation == [
1042+
{
1043+
"role": "user",
1044+
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
1045+
}
1046+
]
1047+
1048+
# Verify mm_data contains a list of embeddings (not a single embedding)
1049+
assert mm_data is not None
1050+
assert "image" in mm_data
1051+
assert isinstance(mm_data["image"], list)
1052+
assert len(mm_data["image"]) == 2
1053+
1054+
# Verify each embedding has the correct shape
1055+
assert isinstance(mm_data["image"][0], torch.Tensor)
1056+
assert mm_data["image"][0].shape == image_embedding_1.shape
1057+
assert isinstance(mm_data["image"][1], torch.Tensor)
1058+
assert mm_data["image"][1].shape == image_embedding_2.shape
1059+
1060+
# Verify UUIDs (None since we didn't provide any)
1061+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
1062+
1063+
1064+
def test_parse_chat_messages_multiple_image_embeds_with_uuids(
1065+
phi3v_model_config_image_embeds,
1066+
phi3v_tokenizer,
1067+
):
1068+
"""Test multiple image_embeds with UUIDs.
1069+
1070+
This validates that UUIDs are properly tracked for multiple embeddings.
1071+
"""
1072+
uuid1 = "image-uuid-1"
1073+
uuid2 = "image-uuid-2"
1074+
1075+
conversation, mm_data, mm_uuids = parse_chat_messages(
1076+
[
1077+
{
1078+
"role": "user",
1079+
"content": [
1080+
{
1081+
"type": "image_embeds",
1082+
"image_embeds": None,
1083+
"uuid": uuid1,
1084+
},
1085+
{
1086+
"type": "image_embeds",
1087+
"image_embeds": None,
1088+
"uuid": uuid2,
1089+
},
1090+
{"type": "text", "text": "Compare these images."},
1091+
],
1092+
}
1093+
],
1094+
phi3v_model_config_image_embeds,
1095+
phi3v_tokenizer,
1096+
content_format="string",
1097+
)
1098+
1099+
# Verify conversation structure
1100+
assert conversation == [
1101+
{
1102+
"role": "user",
1103+
"content": "<|image_1|>\n<|image_2|>\nCompare these images.",
1104+
}
1105+
]
1106+
1107+
# Verify mm_data contains a list with None values (UUID references)
1108+
assert mm_data is not None
1109+
assert "image" in mm_data
1110+
assert isinstance(mm_data["image"], list)
1111+
assert len(mm_data["image"]) == 2
1112+
assert mm_data["image"][0] is None
1113+
assert mm_data["image"][1] is None
1114+
1115+
# Verify UUIDs are correctly tracked
1116+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[uuid1, uuid2])
1117+
1118+
1119+
@pytest.mark.asyncio
1120+
async def test_parse_chat_messages_multiple_image_embeds_async(
1121+
phi3v_model_config_image_embeds,
1122+
phi3v_tokenizer,
1123+
):
1124+
"""Test multiple image_embeds with async parsing.
1125+
1126+
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
1127+
"""
1128+
# Create two sample image embedding tensors
1129+
image_embedding_1 = torch.randn(200, 768)
1130+
image_embedding_2 = torch.randn(150, 768)
1131+
1132+
# Encode them as base64
1133+
def encode_embedding(embedding):
1134+
buffer = io.BytesIO()
1135+
torch.save(embedding, buffer)
1136+
buffer.seek(0)
1137+
binary_data = buffer.read()
1138+
return base64.b64encode(binary_data).decode("utf-8")
1139+
1140+
base64_image_embedding_1 = encode_embedding(image_embedding_1)
1141+
base64_image_embedding_2 = encode_embedding(image_embedding_2)
1142+
1143+
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
1144+
[
1145+
{
1146+
"role": "user",
1147+
"content": [
1148+
{
1149+
"type": "image_embeds",
1150+
"image_embeds": base64_image_embedding_1,
1151+
},
1152+
{
1153+
"type": "image_embeds",
1154+
"image_embeds": base64_image_embedding_2,
1155+
},
1156+
{"type": "text", "text": "What do these images show?"},
1157+
],
1158+
}
1159+
],
1160+
phi3v_model_config_image_embeds,
1161+
phi3v_tokenizer,
1162+
content_format="string",
1163+
)
1164+
1165+
# Verify conversation structure
1166+
assert conversation == [
1167+
{
1168+
"role": "user",
1169+
"content": "<|image_1|>\n<|image_2|>\nWhat do these images show?",
1170+
}
1171+
]
1172+
1173+
# Await the future and verify mm_data
1174+
mm_data = await mm_future
1175+
assert mm_data is not None
1176+
assert "image" in mm_data
1177+
assert isinstance(mm_data["image"], list)
1178+
assert len(mm_data["image"]) == 2
1179+
1180+
# Verify each embedding has the correct shape
1181+
assert isinstance(mm_data["image"][0], torch.Tensor)
1182+
assert mm_data["image"][0].shape == image_embedding_1.shape
1183+
assert isinstance(mm_data["image"][1], torch.Tensor)
1184+
assert mm_data["image"][1].shape == image_embedding_2.shape
1185+
1186+
# Verify UUIDs
1187+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
1188+
1189+
9901190
@pytest.mark.asyncio
9911191
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
9921192
phi3v_model_config_image_embeds,

vllm/entrypoints/chat_utils.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -695,16 +695,10 @@ def all_mm_uuids(self) -> MultiModalUUIDDict | None:
695695
raise ValueError("Mixing raw image and embedding inputs is not allowed")
696696

697697
if "image_embeds" in uuids_by_modality:
698-
image_embeds_uuids = uuids_by_modality["image_embeds"]
699-
if len(image_embeds_uuids) > 1:
700-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
701698
mm_uuids["image"] = uuids_by_modality["image_embeds"]
702699
if "image" in uuids_by_modality:
703700
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
704701
if "audio_embeds" in uuids_by_modality:
705-
audio_embeds_uuids = uuids_by_modality["audio_embeds"]
706-
if len(audio_embeds_uuids) > 1:
707-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
708702
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
709703
if "audio" in uuids_by_modality:
710704
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
@@ -730,16 +724,12 @@ def all_mm_data(self) -> MultiModalDataDict | None:
730724

731725
if "image_embeds" in items_by_modality:
732726
image_embeds_lst = items_by_modality["image_embeds"]
733-
if len(image_embeds_lst) > 1:
734-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
735-
mm_inputs["image"] = image_embeds_lst[0]
727+
mm_inputs["image"] = image_embeds_lst
736728
if "image" in items_by_modality:
737729
mm_inputs["image"] = items_by_modality["image"] # A list of images
738730
if "audio_embeds" in items_by_modality:
739731
audio_embeds_lst = items_by_modality["audio_embeds"]
740-
if len(audio_embeds_lst) > 1:
741-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
742-
mm_inputs["audio"] = audio_embeds_lst[0]
732+
mm_inputs["audio"] = audio_embeds_lst
743733
if "audio" in items_by_modality:
744734
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
745735
if "video" in items_by_modality:
@@ -772,16 +762,12 @@ async def all_mm_data(self) -> MultiModalDataDict | None:
772762

773763
if "image_embeds" in items_by_modality:
774764
image_embeds_lst = items_by_modality["image_embeds"]
775-
if len(image_embeds_lst) > 1:
776-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
777-
mm_inputs["image"] = image_embeds_lst[0]
765+
mm_inputs["image"] = image_embeds_lst
778766
if "image" in items_by_modality:
779767
mm_inputs["image"] = items_by_modality["image"] # A list of images
780768
if "audio_embeds" in items_by_modality:
781769
audio_embeds_lst = items_by_modality["audio_embeds"]
782-
if len(audio_embeds_lst) > 1:
783-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
784-
mm_inputs["audio"] = audio_embeds_lst[0]
770+
mm_inputs["audio"] = audio_embeds_lst
785771
if "audio" in items_by_modality:
786772
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
787773
if "video" in items_by_modality:

0 commit comments

Comments
 (0)