Skip to content

Commit ee06b3c

Browse files
author
Jeremy Teboul
committed
Support multiple image/audio embeddings per request
This enables the Chat Completions API to leverage the model's existing capability for multiple embeddings - Remove limitation that only allowed one message with image_embeds/audio_embeds - Update MultiModalItemTracker and AsyncMultiModalItemTracker to treat embeddings as lists - Add unit tests for multiple image embeddings support: * test_parse_chat_messages_multiple_image_embeds * test_parse_chat_messages_multiple_image_embeds_with_uuids * test_parse_chat_messages_multiple_image_embeds_async - Embeddings now behave consistently with regular images/audios - Validation via existing validate_num_items() against --limit-mm-per-prompt
1 parent 6dcb07f commit ee06b3c

File tree

3 files changed

+220
-20
lines changed

3 files changed

+220
-20
lines changed

docs/features/multimodal_inputs.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
445445

446446
For Qwen3-VL, the `image_embeds` should contain both the base image embedding and deepstack features.
447447

448-
#### Audio Embeddings
448+
#### Audio Embedding Inputs
449449

450450
You can pass pre-computed audio embeddings similar to image embeddings:
451451

@@ -892,5 +892,11 @@ For Online Serving, you can also skip sending media if you expect cache hits wit
892892
```
893893

894894
!!! note
895-
Only one message can contain `{"type": "image_embeds"}`.
895+
Multiple messages can now contain `{"type": "image_embeds"}`, enabling you to pass multiple image embeddings in a single request (similar to regular images). The number of embeddings is limited by `--limit-mm-per-prompt`.
896+
897+
**Important**: The embedding shape format differs based on the number of embeddings:
898+
899+
- **Single embedding**: 3D tensor of shape `(1, feature_size, hidden_size)`
900+
- **Multiple embeddings**: List of 2D tensors, each of shape `(feature_size, hidden_size)`
901+
896902
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.

tests/entrypoints/test_chat_utils.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import base64
5+
import io
46
import warnings
57
from collections.abc import Mapping
68
from typing import Literal
79

810
import pytest
11+
import torch
912
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
1013

1114
from vllm.assets.audio import AudioAsset
@@ -915,6 +918,203 @@ async def test_parse_chat_messages_audio_embeds_async(
915918
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
916919

917920

921+
def test_parse_chat_messages_multiple_image_embeds(
922+
phi3v_model_config_image_embeds,
923+
phi3v_tokenizer,
924+
):
925+
"""Test that multiple image_embeds in a single message are now supported.
926+
927+
This test validates the fix for the limitation that previously only allowed
928+
one message with {'type': 'image_embeds'}. Now multiple image embeddings
929+
can be provided in a single request, similar to regular images.
930+
"""
931+
# Create two sample image embedding tensors
932+
image_embedding_1 = torch.randn(256, 1024)
933+
image_embedding_2 = torch.randn(128, 1024)
934+
935+
# Encode them as base64
936+
def encode_embedding(embedding):
937+
buffer = io.BytesIO()
938+
torch.save(embedding, buffer)
939+
buffer.seek(0)
940+
binary_data = buffer.read()
941+
return base64.b64encode(binary_data).decode("utf-8")
942+
943+
base64_image_embedding_1 = encode_embedding(image_embedding_1)
944+
base64_image_embedding_2 = encode_embedding(image_embedding_2)
945+
946+
conversation, mm_data, mm_uuids = parse_chat_messages(
947+
[
948+
{
949+
"role": "user",
950+
"content": [
951+
{
952+
"type": "image_embeds",
953+
"image_embeds": base64_image_embedding_1,
954+
},
955+
{
956+
"type": "image_embeds",
957+
"image_embeds": base64_image_embedding_2,
958+
},
959+
{"type": "text", "text": "Describe these two images."},
960+
],
961+
}
962+
],
963+
phi3v_model_config_image_embeds,
964+
phi3v_tokenizer,
965+
content_format="string",
966+
)
967+
968+
# Verify conversation structure
969+
assert conversation == [
970+
{
971+
"role": "user",
972+
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
973+
}
974+
]
975+
976+
# Verify mm_data contains a list of embeddings (not a single embedding)
977+
assert mm_data is not None
978+
assert "image" in mm_data
979+
assert isinstance(mm_data["image"], list)
980+
assert len(mm_data["image"]) == 2
981+
982+
# Verify each embedding has the correct shape
983+
assert isinstance(mm_data["image"][0], torch.Tensor)
984+
assert mm_data["image"][0].shape == image_embedding_1.shape
985+
assert isinstance(mm_data["image"][1], torch.Tensor)
986+
assert mm_data["image"][1].shape == image_embedding_2.shape
987+
988+
# Verify UUIDs (None since we didn't provide any)
989+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
990+
991+
992+
def test_parse_chat_messages_multiple_image_embeds_with_uuids(
993+
phi3v_model_config_image_embeds,
994+
phi3v_tokenizer,
995+
):
996+
"""Test multiple image_embeds with UUIDs.
997+
998+
This validates that UUIDs are properly tracked for multiple embeddings.
999+
"""
1000+
uuid1 = "image-uuid-1"
1001+
uuid2 = "image-uuid-2"
1002+
1003+
conversation, mm_data, mm_uuids = parse_chat_messages(
1004+
[
1005+
{
1006+
"role": "user",
1007+
"content": [
1008+
{
1009+
"type": "image_embeds",
1010+
"image_embeds": None,
1011+
"uuid": uuid1,
1012+
},
1013+
{
1014+
"type": "image_embeds",
1015+
"image_embeds": None,
1016+
"uuid": uuid2,
1017+
},
1018+
{"type": "text", "text": "Compare these images."},
1019+
],
1020+
}
1021+
],
1022+
phi3v_model_config_image_embeds,
1023+
phi3v_tokenizer,
1024+
content_format="string",
1025+
)
1026+
1027+
# Verify conversation structure
1028+
assert conversation == [
1029+
{
1030+
"role": "user",
1031+
"content": "<|image_1|>\n<|image_2|>\nCompare these images.",
1032+
}
1033+
]
1034+
1035+
# Verify mm_data contains a list with None values (UUID references)
1036+
assert mm_data is not None
1037+
assert "image" in mm_data
1038+
assert isinstance(mm_data["image"], list)
1039+
assert len(mm_data["image"]) == 2
1040+
assert mm_data["image"][0] is None
1041+
assert mm_data["image"][1] is None
1042+
1043+
# Verify UUIDs are correctly tracked
1044+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[uuid1, uuid2])
1045+
1046+
1047+
@pytest.mark.asyncio
1048+
async def test_parse_chat_messages_multiple_image_embeds_async(
1049+
phi3v_model_config_image_embeds,
1050+
phi3v_tokenizer,
1051+
):
1052+
"""Test multiple image_embeds with async parsing.
1053+
1054+
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
1055+
"""
1056+
# Create two sample image embedding tensors
1057+
image_embedding_1 = torch.randn(200, 768)
1058+
image_embedding_2 = torch.randn(150, 768)
1059+
1060+
# Encode them as base64
1061+
def encode_embedding(embedding):
1062+
buffer = io.BytesIO()
1063+
torch.save(embedding, buffer)
1064+
buffer.seek(0)
1065+
binary_data = buffer.read()
1066+
return base64.b64encode(binary_data).decode("utf-8")
1067+
1068+
base64_image_embedding_1 = encode_embedding(image_embedding_1)
1069+
base64_image_embedding_2 = encode_embedding(image_embedding_2)
1070+
1071+
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
1072+
[
1073+
{
1074+
"role": "user",
1075+
"content": [
1076+
{
1077+
"type": "image_embeds",
1078+
"image_embeds": base64_image_embedding_1,
1079+
},
1080+
{
1081+
"type": "image_embeds",
1082+
"image_embeds": base64_image_embedding_2,
1083+
},
1084+
{"type": "text", "text": "What do these images show?"},
1085+
],
1086+
}
1087+
],
1088+
phi3v_model_config_image_embeds,
1089+
phi3v_tokenizer,
1090+
content_format="string",
1091+
)
1092+
1093+
# Verify conversation structure
1094+
assert conversation == [
1095+
{
1096+
"role": "user",
1097+
"content": "<|image_1|>\n<|image_2|>\nWhat do these images show?",
1098+
}
1099+
]
1100+
1101+
# Await the future and verify mm_data
1102+
mm_data = await mm_future
1103+
assert mm_data is not None
1104+
assert "image" in mm_data
1105+
assert isinstance(mm_data["image"], list)
1106+
assert len(mm_data["image"]) == 2
1107+
1108+
# Verify each embedding has the correct shape
1109+
assert isinstance(mm_data["image"][0], torch.Tensor)
1110+
assert mm_data["image"][0].shape == image_embedding_1.shape
1111+
assert isinstance(mm_data["image"][1], torch.Tensor)
1112+
assert mm_data["image"][1].shape == image_embedding_2.shape
1113+
1114+
# Verify UUIDs
1115+
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
1116+
1117+
9181118
@pytest.mark.asyncio
9191119
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
9201120
phi3v_model_config_image_embeds,

vllm/entrypoints/chat_utils.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -694,16 +694,10 @@ def all_mm_uuids(self) -> MultiModalUUIDDict | None:
694694
raise ValueError("Mixing raw image and embedding inputs is not allowed")
695695

696696
if "image_embeds" in uuids_by_modality:
697-
image_embeds_uuids = uuids_by_modality["image_embeds"]
698-
if len(image_embeds_uuids) > 1:
699-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
700697
mm_uuids["image"] = uuids_by_modality["image_embeds"]
701698
if "image" in uuids_by_modality:
702699
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
703700
if "audio_embeds" in uuids_by_modality:
704-
audio_embeds_uuids = uuids_by_modality["audio_embeds"]
705-
if len(audio_embeds_uuids) > 1:
706-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
707701
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
708702
if "audio" in uuids_by_modality:
709703
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
@@ -729,16 +723,16 @@ def all_mm_data(self) -> MultiModalDataDict | None:
729723

730724
if "image_embeds" in items_by_modality:
731725
image_embeds_lst = items_by_modality["image_embeds"]
732-
if len(image_embeds_lst) > 1:
733-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
734-
mm_inputs["image"] = image_embeds_lst[0]
726+
mm_inputs["image"] = (
727+
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
728+
)
735729
if "image" in items_by_modality:
736730
mm_inputs["image"] = items_by_modality["image"] # A list of images
737731
if "audio_embeds" in items_by_modality:
738732
audio_embeds_lst = items_by_modality["audio_embeds"]
739-
if len(audio_embeds_lst) > 1:
740-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
741-
mm_inputs["audio"] = audio_embeds_lst[0]
733+
mm_inputs["audio"] = (
734+
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
735+
)
742736
if "audio" in items_by_modality:
743737
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
744738
if "video" in items_by_modality:
@@ -771,16 +765,16 @@ async def all_mm_data(self) -> MultiModalDataDict | None:
771765

772766
if "image_embeds" in items_by_modality:
773767
image_embeds_lst = items_by_modality["image_embeds"]
774-
if len(image_embeds_lst) > 1:
775-
raise ValueError("Only one message can have {'type': 'image_embeds'}")
776-
mm_inputs["image"] = image_embeds_lst[0]
768+
mm_inputs["image"] = (
769+
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
770+
)
777771
if "image" in items_by_modality:
778772
mm_inputs["image"] = items_by_modality["image"] # A list of images
779773
if "audio_embeds" in items_by_modality:
780774
audio_embeds_lst = items_by_modality["audio_embeds"]
781-
if len(audio_embeds_lst) > 1:
782-
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
783-
mm_inputs["audio"] = audio_embeds_lst[0]
775+
mm_inputs["audio"] = (
776+
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
777+
)
784778
if "audio" in items_by_modality:
785779
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
786780
if "video" in items_by_modality:

0 commit comments

Comments
 (0)