Skip to content

Commit dbc789a

Browse files
author
Jeremy Teboul
committed
Support multiple image/audio embeddings per request
This enables the Chat Completions API to leverage the model's existing capability for multiple embeddings - Remove limitation that only allowed one message with image_embeds/audio_embeds - Update MultiModalItemTracker and AsyncMultiModalItemTracker to treat embeddings as lists - Add unit tests for multiple image embeddings support: * test_parse_chat_messages_multiple_image_embeds * test_parse_chat_messages_multiple_image_embeds_with_uuids * test_parse_chat_messages_multiple_image_embeds_async - Embeddings now behave consistently with regular images/audios - Validation via existing validate_num_items() against --limit-mm-per-prompt
1 parent 48a5fff commit dbc789a

File tree

3 files changed

+204
-20
lines changed

3 files changed

+204
-20
lines changed

docs/features/multimodal_inputs.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
445445

446446
For Qwen3-VL, the `image_embeds` should contain both the base image embedding and deepstack features.
447447

448-
#### Audio Embeddings
448+
#### Audio Embedding Inputs
449449

450450
You can pass pre-computed audio embeddings similar to image embeddings:
451451

@@ -892,5 +892,11 @@ For Online Serving, you can also skip sending media if you expect cache hits wit
892892
```
893893

894894
!!! note
895-
Only one message can contain `{"type": "image_embeds"}`.
895+
Multiple messages can now contain `{"type": "image_embeds"}`, enabling you to pass multiple image embeddings in a single request (similar to regular images). The number of embeddings is limited by `--limit-mm-per-prompt`.
896+
897+
**Important**: The embedding shape format differs based on the number of embeddings:
898+
899+
- **Single embedding**: 3D tensor of shape `(1, feature_size, hidden_size)`
900+
- **Multiple embeddings**: List of 2D tensors, each of shape `(feature_size, hidden_size)`
901+
896902
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.

tests/entrypoints/test_chat_utils.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Literal
77

88
import pytest
9+
import torch
910
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
1011

1112
from vllm.assets.audio import AudioAsset
@@ -915,6 +916,189 @@ async def test_parse_chat_messages_audio_embeds_async(
915916
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
916917

917918

919+
def test_parse_chat_messages_multiple_image_embeds(
920+
phi3v_model_config_image_embeds,
921+
phi3v_tokenizer,
922+
):
923+
"""Test that multiple image_embeds in a single message are now supported.
924+
925+
This test validates the fix for the limitation that previously only allowed
926+
one message with {'type': 'image_embeds'}. Now multiple image embeddings
927+
can be provided in a single request, similar to regular images.
928+
"""
929+
# Create two sample image embedding tensors
930+
image_embedding_1 = torch.randn(256, 1024)
931+
image_embedding_2 = torch.randn(128, 1024)
932+
933+
# Encode them as base64 using the convenience function
934+
base64_image_embedding_1 = tensor2base64(image_embedding_1)
935+
base64_image_embedding_2 = tensor2base64(image_embedding_2)
936+
937+
conversation, mm_data, mm_uuids = parse_chat_messages(
938+
[
939+
{
940+
"role": "user",
941+
"content": [
942+
{
943+
"type": "image_embeds",
944+
"image_embeds": base64_image_embedding_1,
945+
},
946+
{
947+
"type": "image_embeds",
948+
"image_embeds": base64_image_embedding_2,
949+
},
950+
{"type": "text", "text": "Describe these two images."},
951+
],
952+
}
953+
],
954+
phi3v_model_config_image_embeds,
955+
phi3v_tokenizer,
956+
content_format="string",
957+
)
958+
959+
# Verify conversation structure
960+
assert conversation == [
961+
{
962+
"role": "user",
963+
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
964+
}
965+
]
966+
967+
# Verify mm_data contains a list of embeddings (not a single embedding)
968+
assert mm_data is not None
969+
assert "image" in mm_data
970+
assert isinstance(mm_data["image"], list)
971+
assert len(mm_data["image"]) == 2
972+
973+
# Verify each embedding has the correct shape
974+
assert isinstance(mm_data["image"][0], torch.Tensor)
975+
assert mm_data["image"][0].shape == image_embedding_1.shape
976+
assert isinstance(mm_data["image"][1], torch.Tensor)
977+
assert mm_data["image"][1].shape == image_embedding_2.shape
978+
979+
# Verify UUIDs (None since we didn't provide any)
980+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
981+
982+
983+
def test_parse_chat_messages_multiple_image_embeds_with_uuids(
984+
phi3v_model_config_image_embeds,
985+
phi3v_tokenizer,
986+
):
987+
"""Test multiple image_embeds with UUIDs.
988+
989+
This validates that UUIDs are properly tracked for multiple embeddings.
990+
"""
991+
uuid1 = "image-uuid-1"
992+
uuid2 = "image-uuid-2"
993+
994+
conversation, mm_data, mm_uuids = parse_chat_messages(
995+
[
996+
{
997+
"role": "user",
998+
"content": [
999+
{
1000+
"type": "image_embeds",
1001+
"image_embeds": None,
1002+
"uuid": uuid1,
1003+
},
1004+
{
1005+
"type": "image_embeds",
1006+
"image_embeds": None,
1007+
"uuid": uuid2,
1008+
},
1009+
{"type": "text", "text": "Compare these images."},
1010+
],
1011+
}
1012+
],
1013+
phi3v_model_config_image_embeds,
1014+
phi3v_tokenizer,
1015+
content_format="string",
1016+
)
1017+
1018+
# Verify conversation structure
1019+
assert conversation == [
1020+
{
1021+
"role": "user",
1022+
"content": "<|image_1|>\n<|image_2|>\nCompare these images.",
1023+
}
1024+
]
1025+
1026+
# Verify mm_data contains a list with None values (UUID references)
1027+
assert mm_data is not None
1028+
assert "image" in mm_data
1029+
assert isinstance(mm_data["image"], list)
1030+
assert len(mm_data["image"]) == 2
1031+
assert mm_data["image"][0] is None
1032+
assert mm_data["image"][1] is None
1033+
1034+
# Verify UUIDs are correctly tracked
1035+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[uuid1, uuid2])
1036+
1037+
1038+
@pytest.mark.asyncio
1039+
async def test_parse_chat_messages_multiple_image_embeds_async(
1040+
phi3v_model_config_image_embeds,
1041+
phi3v_tokenizer,
1042+
):
1043+
"""Test multiple image_embeds with async parsing.
1044+
1045+
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
1046+
"""
1047+
# Create two sample image embedding tensors
1048+
image_embedding_1 = torch.randn(200, 768)
1049+
image_embedding_2 = torch.randn(150, 768)
1050+
1051+
# Encode them as base64 using the convenience function
1052+
base64_image_embedding_1 = tensor2base64(image_embedding_1)
1053+
base64_image_embedding_2 = tensor2base64(image_embedding_2)
1054+
1055+
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
1056+
[
1057+
{
1058+
"role": "user",
1059+
"content": [
1060+
{
1061+
"type": "image_embeds",
1062+
"image_embeds": base64_image_embedding_1,
1063+
},
1064+
{
1065+
"type": "image_embeds",
1066+
"image_embeds": base64_image_embedding_2,
1067+
},
1068+
{"type": "text", "text": "What do these images show?"},
1069+
],
1070+
}
1071+
],
1072+
phi3v_model_config_image_embeds,
1073+
phi3v_tokenizer,
1074+
content_format="string",
1075+
)
1076+
1077+
# Verify conversation structure
1078+
assert conversation == [
1079+
{
1080+
"role": "user",
1081+
"content": "<|image_1|>\n<|image_2|>\nWhat do these images show?",
1082+
}
1083+
]
1084+
1085+
# Await the future and verify mm_data
1086+
mm_data = await mm_future
1087+
assert mm_data is not None
1088+
assert "image" in mm_data
1089+
assert isinstance(mm_data["image"], list)
1090+
assert len(mm_data["image"]) == 2
1091+
1092+
# Verify each embedding has the correct shape
1093+
assert isinstance(mm_data["image"][0], torch.Tensor)
1094+
assert mm_data["image"][0].shape == image_embedding_1.shape
1095+
assert isinstance(mm_data["image"][1], torch.Tensor)
1096+
assert mm_data["image"][1].shape == image_embedding_2.shape
1097+
1098+
# Verify UUIDs
1099+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
1100+
1101+
9181102
@pytest.mark.asyncio
9191103
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
9201104
phi3v_model_config_image_embeds,

vllm/entrypoints/chat_utils.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -694,16 +694,10 @@ def all_mm_uuids(self) -> MultiModalUUIDDict | None:
694694
raise ValueError("Mixing raw image and embedding inputs is not allowed")
695695

696696
if "image_embeds" in uuids_by_modality:
697-
image_embeds_uuids = uuids_by_modality["image_embeds"]
698-
if len(image_embeds_uuids) > 1:
699-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
700697
mm_uuids["image"] = uuids_by_modality["image_embeds"]
701698
if "image" in uuids_by_modality:
702699
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
703700
if "audio_embeds" in uuids_by_modality:
704-
audio_embeds_uuids = uuids_by_modality["audio_embeds"]
705-
if len(audio_embeds_uuids) > 1:
706-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
707701
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
708702
if "audio" in uuids_by_modality:
709703
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
@@ -729,16 +723,16 @@ def all_mm_data(self) -> MultiModalDataDict | None:
729723

730724
if "image_embeds" in items_by_modality:
731725
image_embeds_lst = items_by_modality["image_embeds"]
732-
if len(image_embeds_lst) > 1:
733-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
734-
mm_inputs["image"] = image_embeds_lst[0]
726+
mm_inputs["image"] = (
727+
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
728+
)
735729
if "image" in items_by_modality:
736730
mm_inputs["image"] = items_by_modality["image"] # A list of images
737731
if "audio_embeds" in items_by_modality:
738732
audio_embeds_lst = items_by_modality["audio_embeds"]
739-
if len(audio_embeds_lst) > 1:
740-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
741-
mm_inputs["audio"] = audio_embeds_lst[0]
733+
mm_inputs["audio"] = (
734+
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
735+
)
742736
if "audio" in items_by_modality:
743737
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
744738
if "video" in items_by_modality:
@@ -771,16 +765,16 @@ async def all_mm_data(self) -> MultiModalDataDict | None:
771765

772766
if "image_embeds" in items_by_modality:
773767
image_embeds_lst = items_by_modality["image_embeds"]
774-
if len(image_embeds_lst) > 1:
775-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
776-
mm_inputs["image"] = image_embeds_lst[0]
768+
mm_inputs["image"] = (
769+
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
770+
)
777771
if "image" in items_by_modality:
778772
mm_inputs["image"] = items_by_modality["image"] # A list of images
779773
if "audio_embeds" in items_by_modality:
780774
audio_embeds_lst = items_by_modality["audio_embeds"]
781-
if len(audio_embeds_lst) > 1:
782-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
783-
mm_inputs["audio"] = audio_embeds_lst[0]
775+
mm_inputs["audio"] = (
776+
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
777+
)
784778
if "audio" in items_by_modality:
785779
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
786780
if "video" in items_by_modality:

0 commit comments

Comments
 (0)